-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathutils.py
166 lines (133 loc) · 5.63 KB
/
utils.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
import logging
from logging import LogRecord
from io import StringIO
import sys
from pathlib import Path
from mlspeclib import MLObject
import base64
if Path("src").exists():
sys.path.append(str(Path("src")))
sys.path.append(str(Path.cwd()))
sys.path.append(str(Path.cwd().parent))
class KnownException(Exception):
pass
def report_found_params(expected_params: list, offered_params: dict) -> None:
rootLogger = setupLogger().get_root_logger()
for param in expected_params:
if param not in offered_params or offered_params[param] is None:
raise KnownException(f"No parameter set for {param}.")
else:
rootLogger.debug(f"Found value for {param}.")
def raise_schema_mismatch(
expected_type: str, expected_version: str, actual_type: str, actual_version: str
):
raise KnownException(
f"""Actual data does not match the expected schema and version:
Expected Type: {expected_type}
Actual Type: {actual_type}
Expected Version: {expected_version}
Actual Version: {actual_version}")"""
)
# TODO: Think about moving logger to a library of some kind so that it can be reused with this signature across derivaed containers
class setupLogger:
_rootLogger = None
_buffer = None
def __init__(self, debug=False):
logLevel = logging.WARN
if debug:
logLevel = logging.DEBUG
self._rootLogger = logging.getLogger()
self._rootLogger.setLevel(logLevel)
formatter = logging.Formatter("::%(levelname)s - %(message)s")
if not self._rootLogger.hasHandlers():
self._buffer = StringIO()
bufferHandler = logging.StreamHandler(self._buffer)
bufferHandler.setLevel(logLevel)
bufferHandler.setFormatter(formatter)
bufferHandler.set_name("buffer.logger")
self._rootLogger.addHandler(bufferHandler)
stdout_handler = logging.StreamHandler(sys.stdout)
stdout_handler.setLevel(logLevel)
stdout_handler.setFormatter(formatter)
stdout_handler.set_name("stdout.logger")
self._rootLogger.addHandler(stdout_handler)
set_output_handler = logging.StreamHandler(sys.stdout)
set_output_handler.setLevel(logging.NOTSET)
set_output_handler.setFormatter(logging.Formatter("%(message)s"))
set_output_handler.addFilter(self.filter_for_outputs)
set_output_handler.set_name("setoutput.logger")
self._rootLogger.addHandler(set_output_handler)
else:
for i, handler in enumerate(self._rootLogger.handlers):
if handler.name == "buffer.logger":
self._buffer = self._rootLogger.handlers[i].stream
break
if self._buffer is None:
raise SystemError(
"Somehow, we've lost the 'buffer' logger, meaning nothing will be printed. Exiting now."
)
def get_loggers(self):
return (self._rootLogger, self._buffer)
def get_root_logger(self):
return self._rootLogger
def get_buffer(self):
return self._buffer
def print_and_log(self, variable_name, variable_value):
# echo "::set-output name=time::$time"
output_message = f"::set-output name={variable_name}::{variable_value}"
print(output_message)
print(f"{variable_name} - Length: {len(variable_value)}")
self._rootLogger.debug(output_message)
return output_message
@staticmethod
def filter_for_outputs(record: LogRecord):
if str(record.msg).startswith("::set-output"):
return True
return False
def verify_result_contract(
result_object: MLObject,
expected_schema_type,
expected_schema_version,
step_name: str,
):
""" Creates an MLObject based on an input string, and validates it against the workflow object
and step_name provided.
Will fail if the .validate() fails on the object or the schema mismatches what is seen in the
workflow.
"""
rootLogger = setupLogger().get_root_logger()
(contract_object, errors) = MLObject.create_object_from_string(
result_object.dict_without_internal_variables()
)
if errors is not None and len(errors) > 0:
error_string = (
f"Error verifying result object for '{step_name}.output': {errors}"
)
rootLogger.debug(error_string)
raise ValueError(error_string)
if (contract_object.schema_type != expected_schema_type) or (
contract_object.schema_version != expected_schema_version
):
error_string = f"""Actual data does not match the expected schema and version:
Expected Type: {expected_schema_type}
Actual Type: {contract_object.schema_type}
Expected Version: {expected_schema_version}
Actual Version: {contract_object.schema_version}")"""
rootLogger.debug(error_string)
raise ValueError(error_string)
rootLogger.debug(
f"Successfully loaded and validated contract object: {contract_object.schema_type} on step {step_name}.output"
)
return True
def encode_raw_object_for_db(mlobject):
# Converts object -> dict -> yaml -> base64
dict_conversion = mlobject.dict_without_internal_variables()
yaml_conversion = convert_dict_to_yaml(dict_conversion)
encode_to_utf8_bytes = yaml_conversion.encode("utf-8")
base64_encode = base64.urlsafe_b64encode(encode_to_utf8_bytes)
final_encode_to_utf8 = str(base64_encode, "utf-8")
return final_encode_to_utf8
def decode_raw_object_from_db(s: str):
# Converts base64 -> yaml
base64_decode = base64.urlsafe_b64decode(s)
return base64_decode