forked from wmpauli/knowledge_distillation
-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathmodel_registration.py
84 lines (65 loc) · 2.35 KB
/
model_registration.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
import argparse
import os
import json
from azureml.core import Workspace, Run, Experiment
from azureml.core.authentication import ServicePrincipalAuthentication
parser = argparse.ArgumentParser()
parser.add_argument('--input_dir', dest="input_dir", default = "output")
parser.add_argument('--output_dir', dest="output_dir", default = "output")
args = parser.parse_args()
print("all args: ", args)
with open('config.json', 'r') as f:
config = json.load(f)
try:
svc_pr = ServicePrincipalAuthentication(
tenant_id=config['tenant_id'],
service_principal_id=config['service_principal_id'],
service_principal_password=config['service_principal_password'])
except KeyError as e:
print("Getting Service Principal Authentication from Azure Devops")
svr_pr = None
pass
ws = Workspace.from_config(auth=svc_pr)
input_dir = os.path.dirname(args.input_dir)
with open(os.path.join(input_dir, 'data_metrics')) as f:
metrics = json.load(f)
import numpy as np
best_loss = np.inf
best_run_id = None
print(metrics)
for run in metrics.keys():
try:
loss = metrics[run]['val_loss'][-1]
if loss < best_loss:
best_loss = loss
best_run_id = run
except Exception as e:
print("WARNING: Could get val_los for run_id", run)
pass
print("best run", best_run_id, best_loss)
from azureml.core import Run
# start an Azure ML run
# run = Run.get_context()
# run_details = run.get_details()
experiment_name = "kd_teach_the_student"
# with open('config.json', 'r') as f:
# config = json.load(f)
# svc_pr = ServicePrincipalAuthentication(
# tenant_id=config['tenant_id'],
# service_principal_id=config['service_principal_id'],
# service_principal_password=config['service_principal_password'])
# ws = Workspace.from_config(auth=svc_pr)
exp = Experiment(ws, name=experiment_name)
# best_run_id = 'kd_teach_the_student_1559281481492_0'
best_run = Run(exp, best_run_id)
# register the model
if best_run_id:
tags = {}
tags['run_id'] = best_run_id
tags['val_loss'] = metrics[best_run_id]['val_loss'][-1]
model = best_run.register_model(model_name=experiment_name,
model_path='outputs',
tags=tags)
else:
print("Couldn't not find a model to register. Probably because no run completed")
raise BaseException