Skip to content

Commit 958dbb2

Browse files
authored
Merge pull request #192 from neulab/add_confidence
Add confidence interval calc to CLI Former-commit-id: e23c9f9
2 parents 8df8ebf + ad3866c commit 958dbb2

File tree

4 files changed

+13
-4
lines changed

4 files changed

+13
-4
lines changed

.pre-commit-config.yaml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
repos:
22
- repo: https://github.com/python/black.git
3-
rev: 22.1.0
3+
rev: 22.3.0
44
hooks:
55
- id: black
66
files: '\.py$'

explainaboard/explainaboard_main.py

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -63,6 +63,14 @@ def main():
6363
help="multiple metrics should be separated by space",
6464
)
6565

66+
parser.add_argument(
67+
'--conf_value',
68+
type=float,
69+
required=False,
70+
default=0.05,
71+
help="the p-value with which to calculate the confidence interval",
72+
)
73+
6674
args = parser.parse_args()
6775

6876
dataset = args.dataset
@@ -110,6 +118,7 @@ def main():
110118
"task_name": task,
111119
"reload_stat": reload_stat,
112120
"user_defined_features_configs": loaders[0].user_defined_features_configs,
121+
"conf_value": args.conf_value,
113122
}
114123
if metric_names is not None:
115124
metadata["metric_names"] = metric_names

explainaboard/info.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -87,7 +87,7 @@ class SysOutputInfo:
8787
metric_names: Optional[list[str]] = None
8888
reload_stat: bool = True
8989
is_print_case: bool = True
90-
is_print_confidence_interval: bool = False
90+
conf_value: float = 0.05
9191
# language : str = "English"
9292

9393
# set later

explainaboard/processors/processor.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -404,7 +404,7 @@ def get_bucket_performance(
404404
bucket_stats = metric_stat.filter(sample_ids)
405405
metric_result = metric_func.evaluate_from_stats(
406406
bucket_stats,
407-
conf_value=0.05 if sys_info.is_print_confidence_interval else None,
407+
conf_value=sys_info.conf_value,
408408
)
409409

410410
conf_low, conf_high = (
@@ -455,7 +455,7 @@ def get_overall_performance(
455455
):
456456
metric_result = metric_func.evaluate_from_stats(
457457
metric_stat,
458-
conf_value=0.05 if sys_info.is_print_confidence_interval else None,
458+
conf_value=sys_info.conf_value,
459459
)
460460

461461
conf_low, conf_high = (

0 commit comments

Comments
 (0)