Skip to content

Commit

Permalink
Correct values for accuracy delta percentile (#1793)
Browse files Browse the repository at this point in the history
Co-authored-by: Miro <[email protected]>
  • Loading branch information
pgmpablo157321 and mrmhodak authored Jul 23, 2024
1 parent 03657ab commit 019f9ac
Showing 1 changed file with 26 additions and 6 deletions.
32 changes: 26 additions & 6 deletions tools/submission/submission_checker.py
Original file line number Diff line number Diff line change
Expand Up @@ -109,6 +109,12 @@
"llama2-70b-99" : ("TOKENS_PER_SAMPLE", 294.45*1.1),
"llama2-70b-99.9" : ("TOKENS_PER_SAMPLE", 294.45*1.1)
},
"accuracy-delta-perc": {
"stable-diffusion-xl": {
"CLIP_SCORE": 1,
"FID_SCORE": 2
}
},
"performance-sample-count": {
"resnet": 1024,
"retinanet": 64,
Expand Down Expand Up @@ -270,6 +276,12 @@
"llama2-70b-99.9" : ("TOKENS_PER_SAMPLE", 294.45*1.1),
"mixtral-8x7b" : ("TOKENS_PER_SAMPLE", 145.9 * 1.1)
},
"accuracy-delta-perc": {
"stable-diffusion-xl": {
"CLIP_SCORE": 1,
"FID_SCORE": 2
}
},
"performance-sample-count": {
"resnet": 1024,
"retinanet": 64,
Expand Down Expand Up @@ -672,6 +684,7 @@ def __init__(
self.seeds = self.base["seeds"]
self.test05_seeds = self.base["test05_seeds"]
self.accuracy_target = self.base["accuracy-target"]
self.accuracy_delta_perc = self.base["accuracy-delta-perc"]
self.accuracy_upper_limit = self.base.get("accuracy-upper-limit", {})
self.performance_sample_count = self.base["performance-sample-count"]
self.latency_constraint = self.base.get("latency-constraint", {})
Expand Down Expand Up @@ -767,6 +780,18 @@ def get_min_query_count(self, model, scenario):
if model not in self.min_queries:
raise ValueError("model not known: " + model)
return self.min_queries[model].get(scenario)

def get_delta_perc(self, model, metric):
if model in self.accuracy_delta_perc:
if metric in self.accuracy_delta_perc[model]:
return self.accuracy_delta_perc[model][metric]

more_accurate = model.find("99.9")
if more_accurate == -1:
required_delta_perc = 1
else:
required_delta_perc = 0.1
return required_delta_perc

def has_new_logging_format(self):
return True
Expand Down Expand Up @@ -2470,12 +2495,6 @@ def check_compliance_acc_dir(test_dir, model, config):
acc_types.append(acc_type)
patterns.append(ACC_PATTERN[acc_type[0]])
acc_seen = [False for _ in acc_type]

more_accurate = model.find("99.9")
if more_accurate == -1:
required_delta_perc = 1
else:
required_delta_perc = 0.1
acc_baseline = {
acc_type: 0 for acc_type in acc_types
}
Expand Down Expand Up @@ -2507,6 +2526,7 @@ def check_compliance_acc_dir(test_dir, model, config):
is_valid = False
break
else:
required_delta_perc = config.get_delta_perc(model, acc_types)
delta_perc = abs(1 - acc_baseline[acc_type] / acc_compliance[acc_type]) * 100
if delta_perc <= required_delta_perc:
is_valid = True
Expand Down

0 comments on commit 019f9ac

Please sign in to comment.