diff --git a/causalpy/experiments/diff_in_diff.py b/causalpy/experiments/diff_in_diff.py index 04b62370..7a453a5d 100644 --- a/causalpy/experiments/diff_in_diff.py +++ b/causalpy/experiments/diff_in_diff.py @@ -30,7 +30,12 @@ ) from causalpy.plot_utils import plot_xY from causalpy.pymc_models import PyMCModel -from causalpy.utils import _is_variable_dummy_coded, convert_to_string, round_num +from causalpy.utils import ( + _is_variable_dummy_coded, + convert_to_string, + get_interaction_terms, + round_num, +) from .base import BaseExperiment @@ -52,6 +57,8 @@ class DifferenceInDifferences(BaseExperiment): Name of the data column for the time variable :param group_variable_name: Name of the data column for the group variable + :param post_treatment_variable_name: + Name of the data column indicating post-treatment period (default: "post_treatment") :param model: A PyMC model for difference in differences @@ -84,6 +91,7 @@ def __init__( formula: str, time_variable_name: str, group_variable_name: str, + post_treatment_variable_name: str = "post_treatment", model=None, **kwargs, ) -> None: @@ -95,6 +103,7 @@ def __init__( self.formula = formula self.time_variable_name = time_variable_name self.group_variable_name = group_variable_name + self.post_treatment_variable_name = post_treatment_variable_name self.input_validation() y, X = dmatrices(formula, self.data) @@ -128,6 +137,12 @@ def __init__( } self.model.fit(X=self.X, y=self.y, coords=COORDS) elif isinstance(self.model, RegressorMixin): + # For scikit-learn models, automatically set fit_intercept=False + # This ensures the intercept is included in the coefficients array rather than being a separate intercept_ attribute + # without this, the intercept is not included in the coefficients array hence would be displayed as 0 in the model summary + # TODO: later, this should be handled in ScikitLearnAdaptor itself + if hasattr(self.model, "fit_intercept"): + self.model.fit_intercept = False self.model.fit(X=self.X, y=self.y) else: raise ValueError("Model type not recognized") @@ -173,7 +188,7 @@ def __init__( # just the treated group .query(f"{self.group_variable_name} == 1") # just the treatment period(s) - .query("post_treatment == True") + .query(f"{self.post_treatment_variable_name} == True") # drop the outcome variable .drop(self.outcome_variable_name, axis=1) # We may have multiple units per time point, we only want one time point @@ -189,7 +204,10 @@ def __init__( # INTERVENTION: set the interaction term between the group and the # post_treatment variable to zero. This is the counterfactual. for i, label in enumerate(self.labels): - if "post_treatment" in label and self.group_variable_name in label: + if ( + self.post_treatment_variable_name in label + and self.group_variable_name in label + ): new_x.iloc[:, i] = 0 self.y_pred_counterfactual = self.model.predict(np.asarray(new_x)) @@ -198,31 +216,44 @@ def __init__( # This is the coefficient on the interaction term coeff_names = self.model.idata.posterior.coords["coeffs"].data for i, label in enumerate(coeff_names): - if "post_treatment" in label and self.group_variable_name in label: + if ( + self.post_treatment_variable_name in label + and self.group_variable_name in label + ): self.causal_impact = self.model.idata.posterior["beta"].isel( {"coeffs": i} ) elif isinstance(self.model, RegressorMixin): # This is the coefficient on the interaction term - # TODO: CHECK FOR CORRECTNESS - self.causal_impact = ( - self.y_pred_treatment[1] - self.y_pred_counterfactual[0] - ).item() + # Store the coefficient into dictionary {intercept:value} + coef_map = dict(zip(self.labels, self.model.get_coeffs())) + # Create and find the interaction term based on the values user provided + interaction_term = ( + f"{self.group_variable_name}:{self.post_treatment_variable_name}" + ) + matched_key = next((k for k in coef_map if interaction_term in k), None) + att = coef_map.get(matched_key) + self.causal_impact = att else: raise ValueError("Model type not recognized") return def input_validation(self): + # Validate formula structure and interaction interaction terms + self._validate_formula_interaction_terms() + """Validate the input data and model formula for correctness""" - if "post_treatment" not in self.formula: + # Check if post_treatment_variable_name is in formula + if self.post_treatment_variable_name not in self.formula: raise FormulaException( - "A predictor called `post_treatment` should be in the formula" + f"Missing required variable '{self.post_treatment_variable_name}' in formula" ) - if "post_treatment" not in self.data.columns: + # Check if post_treatment_variable_name is in data columns + if self.post_treatment_variable_name not in self.data.columns: raise DataException( - "Require a boolean column labelling observations which are `treated`" + f"Missing required column '{self.post_treatment_variable_name}' in dataset" ) if "unit" not in self.data.columns: @@ -236,6 +267,36 @@ def input_validation(self): coded. Consisting of 0's and 1's only.""" ) + def _validate_formula_interaction_terms(self): + """ + Validate that the formula contains at most one interaction term and no three-way or higher-order interactions. + Raises FormulaException if more than one interaction term is found or if any interaction term has more than 2 variables. + """ + # Define interaction indicators + INTERACTION_INDICATORS = ["*", ":"] + + # Get interaction terms + interaction_terms = get_interaction_terms(self.formula) + + # Check for interaction terms with more than 2 variables (more than one '*' or ':') + for term in interaction_terms: + total_indicators = sum( + term.count(indicator) for indicator in INTERACTION_INDICATORS + ) + if ( + total_indicators >= 2 + ): # 3 or more variables (e.g., a*b*c or a:b:c has 2 symbols) + raise FormulaException( + f"Formula contains interaction term with more than 2 variables: {term}. " + "Three-way or higher-order interactions are not supported as they complicate interpretation of the causal effect." + ) + + if len(interaction_terms) > 1: + raise FormulaException( + f"Formula contains {len(interaction_terms)} interaction terms: {interaction_terms}. " + "Multiple interaction terms are not currently supported as they complicate interpretation of the causal effect." + ) + def summary(self, round_to=None) -> None: """Print summary of main results and model coefficients. diff --git a/causalpy/tests/test_input_validation.py b/causalpy/tests/test_input_validation.py index 43fd9208..69ca3753 100644 --- a/causalpy/tests/test_input_validation.py +++ b/causalpy/tests/test_input_validation.py @@ -30,18 +30,29 @@ def test_did_validation_post_treatment_formula(): - """Test that we get a FormulaException if do not include post_treatment in the - formula""" + """Test that we get a FormulaException for invalid formulas and missing post_treatment variables""" df = pd.DataFrame( { "group": [0, 0, 1, 1], "t": [0, 1, 0, 1], "unit": [0, 0, 1, 1], "post_treatment": [0, 1, 0, 1], + "male": [0, 1, 0, 1], # Additional variable for testing "y": [1, 2, 3, 4], } ) + df_with_custom = pd.DataFrame( + { + "group": [0, 0, 1, 1], + "t": [0, 1, 0, 1], + "unit": [0, 0, 1, 1], + "custom_post": [0, 1, 0, 1], # Custom column name + "y": [1, 2, 3, 4], + } + ) + + # Test 1: Missing post_treatment variable in formula with pytest.raises(FormulaException): _ = cp.DifferenceInDifferences( df, @@ -51,6 +62,7 @@ def test_did_validation_post_treatment_formula(): model=cp.pymc_models.LinearRegression(sample_kwargs=sample_kwargs), ) + # Test 2: Missing post_treatment variable in formula (duplicate test) with pytest.raises(FormulaException): _ = cp.DifferenceInDifferences( df, @@ -60,6 +72,88 @@ def test_did_validation_post_treatment_formula(): model=cp.pymc_models.LinearRegression(sample_kwargs=sample_kwargs), ) + # Test 3: Custom post_treatment_variable_name but formula uses different name + with pytest.raises(FormulaException): + _ = cp.DifferenceInDifferences( + df_with_custom, + formula="y ~ 1 + group*post_treatment", # Formula uses 'post_treatment' + time_variable_name="t", + group_variable_name="group", + post_treatment_variable_name="custom_post", # But user specifies 'custom_post' + model=cp.pymc_models.LinearRegression(sample_kwargs=sample_kwargs), + ) + + # Test 4: Default post_treatment_variable_name but formula uses different name + with pytest.raises(FormulaException): + _ = cp.DifferenceInDifferences( + df, + formula="y ~ 1 + group*custom_post", # Formula uses 'custom_post' + time_variable_name="t", + group_variable_name="group", + # post_treatment_variable_name defaults to "post_treatment" + model=cp.pymc_models.LinearRegression(sample_kwargs=sample_kwargs), + ) + + # Test 5: Repeated interaction terms (should be invalid) + with pytest.raises(FormulaException): + _ = cp.DifferenceInDifferences( + df, + formula="y ~ 1 + group + group*post_treatment + group*post_treatment", + time_variable_name="t", + group_variable_name="group", + model=cp.pymc_models.LinearRegression(sample_kwargs=sample_kwargs), + ) + + # Test 6: Three-way interactions using * (should be invalid) + with pytest.raises(FormulaException): + _ = cp.DifferenceInDifferences( + df, + formula="y ~ 1 + group + group*post_treatment*male", + time_variable_name="t", + group_variable_name="group", + model=cp.pymc_models.LinearRegression(sample_kwargs=sample_kwargs), + ) + + # Test 7: Three-way interactions using : (should be invalid) + with pytest.raises(FormulaException): + _ = cp.DifferenceInDifferences( + df, + formula="y ~ 1 + group + group:post_treatment:male", + time_variable_name="t", + group_variable_name="group", + model=cp.pymc_models.LinearRegression(sample_kwargs=sample_kwargs), + ) + + # Test 8: Multiple different interaction terms using * (should be invalid) + with pytest.raises(FormulaException): + _ = cp.DifferenceInDifferences( + df, + formula="y ~ 1 + group + group*post_treatment + group*male", + time_variable_name="t", + group_variable_name="group", + model=cp.pymc_models.LinearRegression(sample_kwargs=sample_kwargs), + ) + + # Test 9: Multiple different interaction terms using : (should be invalid) + with pytest.raises(FormulaException): + _ = cp.DifferenceInDifferences( + df, + formula="y ~ 1 + group + group:post_treatment + group:male", + time_variable_name="t", + group_variable_name="group", + model=cp.pymc_models.LinearRegression(sample_kwargs=sample_kwargs), + ) + + # Test 10: Mixed issues - multiple terms + three-way interaction (should be invalid) + with pytest.raises(FormulaException): + _ = cp.DifferenceInDifferences( + df, + formula="y ~ 1 + group + group*post_treatment + group:post_treatment:male", + time_variable_name="t", + group_variable_name="group", + model=cp.pymc_models.LinearRegression(sample_kwargs=sample_kwargs), + ) + def test_did_validation_post_treatment_data(): """Test that we get a DataException if do not include post_treatment in the data""" @@ -91,6 +185,27 @@ def test_did_validation_post_treatment_data(): model=cp.pymc_models.LinearRegression(sample_kwargs=sample_kwargs), ) + # Test 2: Custom post_treatment_variable_name but column doesn't exist in data + df_with_post = pd.DataFrame( + { + "group": [0, 0, 1, 1], + "t": [0, 1, 0, 1], + "unit": [0, 0, 1, 1], + "post_treatment": [0, 1, 0, 1], # Data has 'post_treatment' + "y": [1, 2, 3, 4], + } + ) + + with pytest.raises(DataException): + _ = cp.DifferenceInDifferences( + df_with_post, + formula="y ~ 1 + group*custom_post", # Formula uses 'custom_post' + time_variable_name="t", + group_variable_name="group", + post_treatment_variable_name="custom_post", # User specifies 'custom_post' + model=cp.pymc_models.LinearRegression(sample_kwargs=sample_kwargs), + ) + def test_did_validation_unit_data(): """Test that we get a DataException if do not include unit in the data""" diff --git a/causalpy/tests/test_utils.py b/causalpy/tests/test_utils.py index 8dc95590..a2dea702 100644 --- a/causalpy/tests/test_utils.py +++ b/causalpy/tests/test_utils.py @@ -17,7 +17,12 @@ import pandas as pd -from causalpy.utils import _is_variable_dummy_coded, _series_has_2_levels, round_num +from causalpy.utils import ( + _is_variable_dummy_coded, + _series_has_2_levels, + get_interaction_terms, + round_num, +) def test_dummy_coding(): @@ -57,3 +62,43 @@ def test_round_num(): assert round_num(123.456, 5) == "123.46" assert round_num(123.456, 6) == "123.456" assert round_num(123.456, 7) == "123.456" + + +def test_get_interaction_terms(): + """Test if the function to extract interaction terms from formulas works correctly""" + # No interaction terms + assert get_interaction_terms("y ~ x1 + x2 + x3") == [] + assert get_interaction_terms("y ~ 1 + x1 + x2") == [] + + # Single interaction term with '*' + assert get_interaction_terms("y ~ x1 + x2*x3") == ["x2*x3"] + assert get_interaction_terms("y ~ 1 + group*post_treatment") == [ + "group*post_treatment" + ] + + # Single interaction term with ':' + assert get_interaction_terms("y ~ x1 + x2:x3") == ["x2:x3"] + assert get_interaction_terms("y ~ 1 + group:post_treatment") == [ + "group:post_treatment" + ] + + # Multiple interaction terms + assert get_interaction_terms("y ~ x1*x2 + x3*x4") == ["x1*x2", "x3*x4"] + assert get_interaction_terms("y ~ a:b + c*d") == ["a:b", "c*d"] + + # Three-way interaction + assert get_interaction_terms("y ~ x1*x2*x3") == ["x1*x2*x3"] + assert get_interaction_terms("y ~ a:b:c") == ["a:b:c"] + + # Formula with spaces (should be handled correctly) + assert get_interaction_terms("y ~ x1 + x2 * x3") == ["x2*x3"] + assert get_interaction_terms("y ~ 1 + group * post_treatment") == [ + "group*post_treatment" + ] + + # Mixed main effects and interactions + assert get_interaction_terms("y ~ 1 + x1 + x2 + x1*x2") == ["x1*x2"] + assert get_interaction_terms("y ~ x1 + x2*x3 + x4") == ["x2*x3"] + + # Formula with subtraction (edge case) + assert get_interaction_terms("y ~ x1*x2 - x3") == ["x1*x2"] diff --git a/causalpy/utils.py b/causalpy/utils.py index c64eb109..5b7c601b 100644 --- a/causalpy/utils.py +++ b/causalpy/utils.py @@ -15,6 +15,7 @@ Utility functions """ +import re from typing import Union import numpy as np @@ -84,3 +85,49 @@ def convert_to_string(x: Union[float, xr.DataArray], round_to: int = 2) -> str: raise ValueError( "Type not supported. Please provide a float or an xarray object." ) + + +def get_interaction_terms(formula: str) -> list[str]: + """ + Extract interaction terms from a statistical model formula. + + Parameters + ---------- + formula : str + A statistical model formula string (e.g., "y ~ x1 + x2*x3") + + Returns + ------- + list[str] + A list of interaction terms (those containing '*' or ':') + + Examples + -------- + >>> get_interaction_terms("y ~ 1 + x1 + x2*x3") + ['x2*x3'] + >>> get_interaction_terms("y ~ x1:x2 + x3") + ['x1:x2'] + >>> get_interaction_terms("y ~ x1 + x2 + x3") + [] + """ + # Define interaction indicators + INTERACTION_INDICATORS = ["*", ":"] + + # Remove whitespace + formula_clean = formula.replace(" ", "") + + # Extract right-hand side of the formula + rhs = formula_clean.split("~")[1] + + # Split terms by '+' or '-' while keeping them intact + terms = re.split(r"(?=[+-])", rhs) + + # Clean up terms and get interaction terms (those with '*' or ':') + interaction_terms = [] + for term in terms: + # Remove leading + or - for processing + clean_term = term.lstrip("+-") + if any(indicator in clean_term for indicator in INTERACTION_INDICATORS): + interaction_terms.append(clean_term) + + return interaction_terms diff --git a/docs/source/_static/interrogate_badge.svg b/docs/source/_static/interrogate_badge.svg index 392a876b..8734d55d 100644 --- a/docs/source/_static/interrogate_badge.svg +++ b/docs/source/_static/interrogate_badge.svg @@ -1,5 +1,5 @@ - interrogate: 96.2% + interrogate: 95.8% @@ -12,8 +12,8 @@ interrogate interrogate - 96.2% - 96.2% + 95.8% + 95.8%