From 019f9ac3e6e25bc97fef2d5115bc28a51dded4dc Mon Sep 17 00:00:00 2001 From: Pablo Gonzalez Date: Tue, 23 Jul 2024 16:13:22 -0500 Subject: [PATCH] Correct values for accuracy delta percentile (#1793) Co-authored-by: Miro --- tools/submission/submission_checker.py | 32 +++++++++++++++++++++----- 1 file changed, 26 insertions(+), 6 deletions(-) diff --git a/tools/submission/submission_checker.py b/tools/submission/submission_checker.py index 3f0b0da5e..b422cbb59 100755 --- a/tools/submission/submission_checker.py +++ b/tools/submission/submission_checker.py @@ -109,6 +109,12 @@ "llama2-70b-99" : ("TOKENS_PER_SAMPLE", 294.45*1.1), "llama2-70b-99.9" : ("TOKENS_PER_SAMPLE", 294.45*1.1) }, + "accuracy-delta-perc": { + "stable-diffusion-xl": { + "CLIP_SCORE": 1, + "FID_SCORE": 2 + } + }, "performance-sample-count": { "resnet": 1024, "retinanet": 64, @@ -270,6 +276,12 @@ "llama2-70b-99.9" : ("TOKENS_PER_SAMPLE", 294.45*1.1), "mixtral-8x7b" : ("TOKENS_PER_SAMPLE", 145.9 * 1.1) }, + "accuracy-delta-perc": { + "stable-diffusion-xl": { + "CLIP_SCORE": 1, + "FID_SCORE": 2 + } + }, "performance-sample-count": { "resnet": 1024, "retinanet": 64, @@ -672,6 +684,7 @@ def __init__( self.seeds = self.base["seeds"] self.test05_seeds = self.base["test05_seeds"] self.accuracy_target = self.base["accuracy-target"] + self.accuracy_delta_perc = self.base["accuracy-delta-perc"] self.accuracy_upper_limit = self.base.get("accuracy-upper-limit", {}) self.performance_sample_count = self.base["performance-sample-count"] self.latency_constraint = self.base.get("latency-constraint", {}) @@ -767,6 +780,18 @@ def get_min_query_count(self, model, scenario): if model not in self.min_queries: raise ValueError("model not known: " + model) return self.min_queries[model].get(scenario) + + def get_delta_perc(self, model, metric): + if model in self.accuracy_delta_perc: + if metric in self.accuracy_delta_perc[model]: + return self.accuracy_delta_perc[model][metric] + + more_accurate = model.find("99.9") + if more_accurate == -1: + required_delta_perc = 1 + else: + required_delta_perc = 0.1 + return required_delta_perc def has_new_logging_format(self): return True @@ -2470,12 +2495,6 @@ def check_compliance_acc_dir(test_dir, model, config): acc_types.append(acc_type) patterns.append(ACC_PATTERN[acc_type[0]]) acc_seen = [False for _ in acc_type] - - more_accurate = model.find("99.9") - if more_accurate == -1: - required_delta_perc = 1 - else: - required_delta_perc = 0.1 acc_baseline = { acc_type: 0 for acc_type in acc_types } @@ -2507,6 +2526,7 @@ def check_compliance_acc_dir(test_dir, model, config): is_valid = False break else: + required_delta_perc = config.get_delta_perc(model, acc_types) delta_perc = abs(1 - acc_baseline[acc_type] / acc_compliance[acc_type]) * 100 if delta_perc <= required_delta_perc: is_valid = True