|
5 | 5 | from causal_testing.specification.causal_specification import Scenario |
6 | 6 | from causal_testing.specification.variable import Input, Output, Meta |
7 | 7 | from scipy.stats import uniform, rv_discrete |
| 8 | +from enum import Enum |
| 9 | +import random |
8 | 10 | from tests.test_helpers import create_temp_dir_if_non_existent, remove_temp_dir_if_existent |
9 | 11 |
|
10 | 12 |
|
11 | 13 | class TestObservationalDataCollector(unittest.TestCase): |
12 | 14 | def setUp(self) -> None: |
| 15 | + class Color(Enum): |
| 16 | + RED = "RED" |
| 17 | + GREEN = "GREEN" |
| 18 | + BLUE = "BLUE" |
| 19 | + |
13 | 20 | temp_dir_path = create_temp_dir_if_non_existent() |
14 | 21 | self.dag_dot_path = os.path.join(temp_dir_path, "dag.dot") |
15 | 22 | self.observational_df_path = os.path.join(temp_dir_path, "observational_data.csv") |
16 | 23 | # Y = 3*X1 + X2*X3 + 10 |
17 | | - self.observational_df = pd.DataFrame({"X1": [1, 2, 3, 4], "X2": [5, 6, 7, 8], "X3": [10, 20, 30, 40]}) |
18 | | - self.observational_df["Y"] = self.observational_df.apply( |
| 24 | + self.observational_df = pd.DataFrame( |
| 25 | + {"X1": [1, 2, 3, 4], "X2": [5, 6, 7, 8], "X3": [10, 20, 30, 40], "Y2": ["RED", "GREEN", "BLUE", "BLUE"]} |
| 26 | + ) |
| 27 | + self.observational_df["Y1"] = self.observational_df.apply( |
19 | 28 | lambda row: (3 * row.X1) + (row.X2 * row.X3) + 10, axis=1 |
20 | 29 | ) |
21 | 30 | self.observational_df.to_csv(self.observational_df_path) |
| 31 | + self.observational_df["Y2"] = [Color[x] for x in self.observational_df["Y2"]] |
22 | 32 | self.X1 = Input("X1", int, uniform(1, 4)) |
23 | 33 | self.X2 = Input("X2", int, rv_discrete(values=([7], [1]))) |
24 | 34 | self.X3 = Input("X3", int, uniform(10, 40)) |
25 | 35 | self.X4 = Input("X4", int, rv_discrete(values=([10], [1]))) |
26 | | - self.Y = Output("Y", int) |
| 36 | + self.Y1 = Output("Y1", int) |
| 37 | + self.Y2 = Output("Y2", Color) |
27 | 38 |
|
28 | 39 | def test_not_all_variables_in_data(self): |
29 | 40 | scenario = Scenario({self.X1, self.X2, self.X3, self.X4}) |
30 | 41 | observational_data_collector = ObservationalDataCollector(scenario, self.observational_df_path) |
31 | 42 | self.assertRaises(IndexError, observational_data_collector.collect_data) |
32 | 43 |
|
33 | 44 | def test_all_variables_in_data(self): |
34 | | - scenario = Scenario({self.X1, self.X2, self.X3, self.Y}) |
| 45 | + scenario = Scenario({self.X1, self.X2, self.X3, self.Y1, self.Y2}) |
35 | 46 | observational_data_collector = ObservationalDataCollector(scenario, self.observational_df_path) |
36 | 47 | df = observational_data_collector.collect_data(index_col=0) |
37 | | - assert df.equals(self.observational_df), f"{df}\nwas not equal to\n{self.observational_df}" |
| 48 | + assert df.equals(self.observational_df), f"\n{df}\nwas not equal to\n{self.observational_df}" |
38 | 49 |
|
39 | 50 | def test_data_constraints(self): |
40 | | - scenario = Scenario({self.X1, self.X2, self.X3, self.Y}, {self.X1.z3 > 2}) |
| 51 | + scenario = Scenario({self.X1, self.X2, self.X3, self.Y1, self.Y2}, {self.X1.z3 > 2}) |
41 | 52 | observational_data_collector = ObservationalDataCollector(scenario, self.observational_df_path) |
42 | 53 | df = observational_data_collector.collect_data(index_col=0) |
43 | 54 | expected = self.observational_df.loc[[2, 3]] |
44 | | - assert df.equals(expected), f"{df}\nwas not equal to\n{expected}" |
| 55 | + assert df.equals(expected), f"\n{df}\nwas not equal to\n{expected}" |
45 | 56 |
|
46 | 57 | def test_meta_population(self): |
47 | 58 | def populate_m(data): |
|
0 commit comments