-
Notifications
You must be signed in to change notification settings - Fork 0
/
evaluate.py
154 lines (133 loc) · 6.73 KB
/
evaluate.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
import pickle
import constants
from tasks.arithmetics.binary_average_sum.task import AverageSumTask
from tasks.operators.mta.task import MTATask
def run_eval(sess, model, inputs_placeholder, outputs_placeholder, max_seq_len_placeholder, data_generator, args,
target_point, labels, outputs, inputs, batches, store_heat_maps=False, generalization_num=None):
task_loss = 0
task_error = 0
num_batches = len(batches)
for seq_len, inputs, labels in batches:
task_loss_, outputs = sess.run([model.loss, model.outputs],
feed_dict={
inputs_placeholder: inputs,
outputs_placeholder: labels,
max_seq_len_placeholder: seq_len
})
task_loss += task_loss_
task_error += data_generator.error_per_seq(labels, outputs, args.batch_size)
if store_heat_maps:
if generalization_num is None:
tmp = pickle.load(open(constants.HEAD_LOG_FILE, "rb"))
tmp[target_point].append({
'labels': labels[0],
'outputs': outputs[0],
'inputs': inputs[0]
})
pickle.dump(tmp, open(constants.HEAD_LOG_FILE, "wb"))
else:
tmp = pickle.load(open(constants.GENERALIZATION_HEAD_LOG_FILE, "rb"))
if tmp.get(generalization_num) is None:
tmp[generalization_num] = []
tmp[generalization_num].append({
'labels': labels[0],
'outputs': outputs[0],
'inputs': inputs[0]
})
pickle.dump(tmp, open(constants.GENERALIZATION_HEAD_LOG_FILE, "wb"))
task_loss /= float(num_batches)
task_error /= float(num_batches)
return task_loss, task_error
def eval_performance(sess, data_generator, args, model, target_point, labels, outputs, inputs, inputs_placeholder,
outputs_placeholder, max_seq_len_placeholder, curriculum_point, store_heat_maps=False,
skip_multi_task=False):
generator_args = dict(
num_batches=int(int(args.eval_batch_size / 2) / args.batch_size),
batch_size=args.batch_size,
bits_per_vector=args.num_bits_per_vector,
curriculum_point=None,
max_seq_len=args.max_seq_len,
curriculum='none',
pad_to_max_seq_len=args.pad_to_max_seq_len
)
if args.task == AverageSumTask.name:
generator_args['numbers_quantity'] = args.num_experts
elif args.task == MTATask.name:
generator_args['cli_mode'] = True
generator_args['numbers_quantity'] = args.num_experts
generator_args['two_tuple_weight_precision'] = args.two_tuple_weight_precision
generator_args['two_tuple_alpha_precision'] = args.two_tuple_alpha_precision
generator_args['two_tuple_largest_scale_size'] = args.two_tuple_largest_scale_size
generator_args['mta_encoding'] = args.mta_encoding
batches = data_generator.generate_batches(**generator_args)
# target task
target_task_loss, target_task_error = run_eval(sess, model,
inputs_placeholder,
outputs_placeholder,
max_seq_len_placeholder,
data_generator,
args,
target_point,
labels,
outputs,
inputs,
batches,
store_heat_maps=store_heat_maps)
# multi-task
multi_task_loss = None
multi_task_error = None
if not skip_multi_task:
batches = data_generator.generate_batches(
int(args.eval_batch_size / args.batch_size),
args.batch_size,
bits_per_vector=args.num_bits_per_vector,
curriculum_point=None,
max_seq_len=args.max_seq_len,
curriculum='deterministic_uniform',
pad_to_max_seq_len=args.pad_to_max_seq_len
)
multi_task_loss, multi_task_error = run_eval(sess, model, inputs_placeholder, outputs_placeholder,
max_seq_len_placeholder, data_generator, args, target_point,
labels,
outputs, inputs, batches)
# curriculum point
print(f'Current curriculum point: {curriculum_point}')
if curriculum_point is not None:
batches = data_generator.generate_batches(
int(int(args.eval_batch_size / 4) / args.batch_size),
args.batch_size,
bits_per_vector=args.num_bits_per_vector,
curriculum_point=curriculum_point,
max_seq_len=args.max_seq_len,
curriculum='naive',
pad_to_max_seq_len=args.pad_to_max_seq_len
)
curriculum_point_loss, curriculum_point_error = run_eval(sess, model, inputs_placeholder, outputs_placeholder,
max_seq_len_placeholder, data_generator, args,
target_point, labels, outputs, inputs, batches)
else:
curriculum_point_error = curriculum_point_loss = None
return target_task_error, target_task_loss, multi_task_error, multi_task_loss, curriculum_point_error, curriculum_point_loss
def eval_generalization(sess, model, inputs_placeholder, outputs_placeholder, max_seq_len_placeholder, data_generator,
args, target_point, labels, outputs, inputs):
res = []
seq_lens = []
if args.task == 'copy':
seq_lens = [40, 60, 80, 100, 120]
elif args.task == 'associative_recall':
seq_lens = [7, 8, 9, 10, 11, 12]
for i in seq_lens:
batches = data_generator.generate_batches(
6,
args.batch_size,
bits_per_vector=args.num_bits_per_vector,
curriculum_point=i,
max_seq_len=args.max_seq_len,
curriculum='naive',
pad_to_max_seq_len=False
)
loss, error = run_eval(sess, model, inputs_placeholder, outputs_placeholder, max_seq_len_placeholder,
data_generator, args, target_point, labels, outputs, inputs, batches,
store_heat_maps=args.verbose, generalization_num=i)
res.append(error)
return res