-
Notifications
You must be signed in to change notification settings - Fork 57
Expand file tree
/
Copy pathtree.py
More file actions
76 lines (61 loc) · 2 KB
/
tree.py
File metadata and controls
76 lines (61 loc) · 2 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
import os
import pickle
from statistics import stdev
import numpy as np
import hgm_utils
from hgm_utils import eval_agent, sample_child
def get_num_total_evals():
return hgm_utils.nodes[0].get_sum(lambda node: node.num_evals)
class Node:
def __init__(
self,
commit_id,
utility_measures=None,
parent_id=None,
id=None,
):
self.commit_id = commit_id
self.children = []
if utility_measures:
self.utility_measures = utility_measures
else:
self.utility_measures = []
self.parent_id = parent_id
if id is None: #
self.id = len(hgm_utils.nodes)
else:
self.id = id
hgm_utils.nodes[self.id] = self
def get_sub_tree(self, fn=lambda self: self):
if len(self.children) == 0:
return [fn(self)]
else:
nodes_list = [fn(self)]
for child in self.children:
nodes_list.extend(child.get_sub_tree(fn))
return nodes_list
def get_pseudo_decendant_evals(self, num_pseudo):
return self.utility_measures if self.num_evals < num_pseudo else [self.mean_utility] * num_pseudo
def get_decendant_evals(self, num_pseudo=10):
decendant_evals = self.get_pseudo_decendant_evals(num_pseudo)
for decendant in self.get_sub_tree()[1:]:
decendant_evals += decendant.utility_measures
return decendant_evals
@property
def num_evals(self):
return len(self.utility_measures)
@property
def mean_utility(self):
if self.num_evals == 0:
return np.inf
return np.sum(self.utility_measures) / self.num_evals
def add_child(self, child):
self.children.append(child)
def save_as_dict(self):
return {
"commit_id": self.commit_id,
"id": self.id,
"parent_id": self.parent_id,
"mean_utility": self.mean_utility,
"num_evals": self.num_evals,
}