-
Notifications
You must be signed in to change notification settings - Fork 4
/
Copy pathsagemaker_query_drift.py
84 lines (75 loc) · 3.04 KB
/
sagemaker_query_drift.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
import boto3
import logging
import os
import re
import json
from urllib.parse import urlparse
logger = logging.getLogger()
logger.setLevel(logging.INFO)
sm_client = boto3.client("sagemaker")
s3_client = boto3.client("s3")
def get_processing_job(processing_job_name):
response = sm_client.describe_processing_job(ProcessingJobName=processing_job_name)
status = response["ProcessingJobStatus"]
exit_message = response["ExitMessage"]
s3_result_uri = response["ProcessingOutputConfig"]["Outputs"][0]["S3Output"]["S3Uri"]
url_parsed = urlparse(s3_result_uri)
result_bucket, result_path = url_parsed.netloc, url_parsed.path
return status, exit_message, result_bucket, result_path
def get_s3_results_json(result_bucket, result_path, filename):
s3_object = s3_client.get_object(
Bucket=result_bucket,
Key=os.path.join(result_path.lstrip("/"), filename),
)
return json.loads(s3_object["Body"].read())
def get_baseline_drift(feature):
if "violations" in feature:
for violation in feature["violations"]:
if violation["constraint_check_type"] == "baseline_drift_check":
desc = violation["description"]
print(desc)
matches = re.search("distance: (.+) exceeds threshold: (.+)", desc)
if matches:
match = matches.group(1)
threshold = matches.group(2)
yield {
"feature": violation["feature_name"],
"drift": float(match),
"threshold": float(threshold),
}
# Retrieve transform job name from event and return transform job status.
def lambda_handler(event, context):
if "ProcessingJobName" in event:
job_name = event["ProcessingJobName"]
else:
raise KeyError("ProcessingJobName key not found in event: {}.".format(json.dumps(event)))
try:
# Parse the result uri
status, exit_message, result_bucket, result_path = get_processing_job(job_name)
logger.info("Processing job: {} has status:{}.".format(job_name, status))
drift = None
if status == "Completed":
try:
# Attempt to load the violations
violations = get_s3_results_json(
result_bucket, result_path, "constraint_violations.json"
)
status = "CompletedWithViolations"
logger.info("Has violations")
drift = list(get_baseline_drift(violations))
except Exception as e:
print(e)
logger.info("No violations")
return {
"statusCode": 200,
"results": {
"ProcessingJobName": job_name,
"ProcessingJobStatus": status,
"ExitMessage": exit_message,
"BaselineDrift": drift,
},
}
except Exception as e:
message = "Failed to read processing status!"
print(e)
return {"statusCode": 500, "error": message}