44import pandas as pd
55import z3
66from scipy import stats
7+ import itertools
78
89from causal_testing .specification .scenario import Scenario
910from causal_testing .specification .variable import Variable
1011from causal_testing .testing .causal_test_case import CausalTestCase
1112from causal_testing .testing .causal_test_outcome import CausalTestOutcome
1213
14+ from enum import Enum
15+
1316logger = logging .getLogger (__name__ )
1417
1518
@@ -24,23 +27,25 @@ def __init__(
2427 self ,
2528 scenario : Scenario ,
2629 intervention_constraints : set [z3 .ExprRef ],
27- treatment_variables : set [ Variable ] ,
30+ treatment_variable : Variable ,
2831 expected_causal_effect : dict [Variable :CausalTestOutcome ],
2932 effect_modifiers : set [Variable ] = None ,
3033 estimate_type : str = "ate" ,
34+ effect : str = "total" ,
3135 ):
32- assert treatment_variables . issubset ( scenario .variables .values () ), (
36+ assert treatment_variable in scenario .variables .values (), (
3337 "Treatment variables must be a subset of variables."
34- + f" Instead got:\n treatment_variables= { treatment_variables } \n variables={ scenario .variables } "
38+ + f" Instead got:\n treatment_variable= { treatment_variable } \n variables={ scenario .variables } "
3539 )
3640
3741 assert len (expected_causal_effect ) == 1 , "We currently only support tests with one causal outcome"
3842
3943 self .scenario = scenario
4044 self .intervention_constraints = intervention_constraints
41- self .treatment_variables = treatment_variables
45+ self .treatment_variable = treatment_variable
4246 self .expected_causal_effect = expected_causal_effect
4347 self .estimate_type = estimate_type
48+ self .effect = effect
4449
4550 if effect_modifiers is not None :
4651 self .effect_modifiers = effect_modifiers
@@ -100,7 +105,12 @@ def _generate_concrete_tests(
100105 for c in self .intervention_constraints :
101106 optimizer .assert_and_track (c , str (c ))
102107
103- optimizer .add_soft ([self .scenario .variables [v ].z3 == row [v ] for v in run_columns ])
108+ for v in run_columns :
109+ optimizer .add_soft (
110+ self .scenario .variables [v ].z3
111+ == self .scenario .variables [v ].z3_val (self .scenario .variables [v ].z3 , row [v ])
112+ )
113+
104114 if optimizer .check () == z3 .unsat :
105115 logger .warning (
106116 "Satisfiability of test case was unsat.\n " "Constraints \n %s \n Unsat core %s" ,
@@ -110,14 +120,15 @@ def _generate_concrete_tests(
110120 model = optimizer .model ()
111121
112122 concrete_test = CausalTestCase (
113- control_input_configuration = {v : v .cast (model [v .z3 ]) for v in self .treatment_variables },
123+ control_input_configuration = {v : v .cast (model [v .z3 ]) for v in [ self .treatment_variable ] },
114124 treatment_input_configuration = {
115- v : v .cast (model [self .scenario .treatment_variables [v .name ].z3 ]) for v in self .treatment_variables
125+ v : v .cast (model [self .scenario .treatment_variables [v .name ].z3 ]) for v in [ self .treatment_variable ]
116126 },
117127 expected_causal_effect = list (self .expected_causal_effect .values ())[0 ],
118128 outcome_variables = list (self .expected_causal_effect .keys ()),
119129 estimate_type = self .estimate_type ,
120130 effect_modifier_configuration = {v : v .cast (model [v .z3 ]) for v in self .effect_modifiers },
131+ effect = self .effect ,
121132 )
122133
123134 for v in self .scenario .inputs ():
@@ -128,19 +139,20 @@ def _generate_concrete_tests(
128139 + f"{ constraints } \n Using value { v .cast (model [v .z3 ])} instead in test\n { concrete_test } "
129140 )
130141
131- concrete_tests .append (concrete_test )
132- # Control run
133- control_run = {
134- v .name : v .cast (model [v .z3 ]) for v in self .scenario .variables .values () if v .name in run_columns
135- }
136- control_run ["bin" ] = index
137- runs .append (control_run )
138- # Treatment run
139- if rct :
140- treatment_run = control_run .copy ()
141- treatment_run .update ({k .name : v for k , v in concrete_test .treatment_input_configuration .items ()})
142- treatment_run ["bin" ] = index
143- runs .append (treatment_run )
142+ if not any ([vars (t ) == vars (concrete_test ) for t in concrete_tests ]):
143+ concrete_tests .append (concrete_test )
144+ # Control run
145+ control_run = {
146+ v .name : v .cast (model [v .z3 ]) for v in self .scenario .variables .values () if v .name in run_columns
147+ }
148+ control_run ["bin" ] = index
149+ runs .append (control_run )
150+ # Treatment run
151+ if rct :
152+ treatment_run = control_run .copy ()
153+ treatment_run .update ({k .name : v for k , v in concrete_test .treatment_input_configuration .items ()})
154+ treatment_run ["bin" ] = index
155+ runs .append (treatment_run )
144156
145157 return concrete_tests , pd .DataFrame (runs , columns = run_columns + ["bin" ])
146158
@@ -176,9 +188,12 @@ def generate_concrete_tests(
176188 runs = pd .DataFrame ()
177189 ks_stats = []
178190
191+ pre_break = False
179192 for i in range (hard_max ):
180193 concrete_tests_ , runs_ = self ._generate_concrete_tests (sample_size , rct , seed + i )
181- concrete_tests += concrete_tests_
194+ for t_ in concrete_tests_ :
195+ if not any ([vars (t_ ) == vars (t ) for t in concrete_tests ]):
196+ concrete_tests .append (t_ )
182197 runs = pd .concat ([runs , runs_ ])
183198 assert concrete_tests_ not in concrete_tests , "Duplicate entries unlikely unless something went wrong"
184199
@@ -205,14 +220,32 @@ def generate_concrete_tests(
205220 for var in effect_modifier_configs .columns
206221 }
207222 )
208- if target_ks_score and all ((stat <= target_ks_score for stat in ks_stats .values ())):
223+ control_values = [test .control_input_configuration [self .treatment_variable ] for test in concrete_tests ]
224+ treatment_values = [test .treatment_input_configuration [self .treatment_variable ] for test in concrete_tests ]
225+
226+ if self .treatment_variable .datatype is bool and set ([(True , False ), (False , True )]).issubset (
227+ set (zip (control_values , treatment_values ))
228+ ):
229+ pre_break = True
230+ break
231+ if issubclass (self .treatment_variable .datatype , Enum ) and set (
232+ {
233+ (x , y )
234+ for x , y in itertools .product (self .treatment_variable .datatype , self .treatment_variable .datatype )
235+ if x != y
236+ }
237+ ).issubset (zip (control_values , treatment_values )):
238+ pre_break = True
239+ break
240+ elif target_ks_score and all ((stat <= target_ks_score for stat in ks_stats .values ())):
241+ pre_break = True
209242 break
210243
211- if target_ks_score is not None and not all (( stat <= target_ks_score for stat in ks_stats . values ())) :
244+ if target_ks_score is not None and not pre_break :
212245 logger .error (
213- "Hard max of %s reached but could not achieve target ks_score of %s. Got %s." ,
214- hard_max ,
246+ "Hard max reached but could not achieve target ks_score of %s. Got %s. Generated %s distinct tests" ,
215247 target_ks_score ,
216248 ks_stats ,
249+ len (concrete_tests ),
217250 )
218251 return concrete_tests , runs
0 commit comments