diff --git a/README.md b/README.md index d478be52..156e372d 100644 --- a/README.md +++ b/README.md @@ -235,6 +235,8 @@ results.r_squared ### Parallel Trends +**Simple slope-based test:** + ```python from diff_diff.utils import check_parallel_trends @@ -248,7 +250,51 @@ trends = check_parallel_trends( print(f"Treated trend: {trends['treated_trend']:.4f}") print(f"Control trend: {trends['control_trend']:.4f}") print(f"Difference p-value: {trends['p_value']:.4f}") -print(f"Parallel trends plausible: {trends['parallel_trends_plausible']}") +``` + +**Robust distributional test (Wasserstein distance):** + +```python +from diff_diff.utils import check_parallel_trends_robust + +results = check_parallel_trends_robust( + data, + outcome='outcome', + time='period', + treatment_group='treated', + unit='firm_id', # Unit identifier for panel data + pre_periods=[2018, 2019], # Pre-treatment periods + n_permutations=1000 # Permutations for p-value +) + +print(f"Wasserstein distance: {results['wasserstein_distance']:.4f}") +print(f"Wasserstein p-value: {results['wasserstein_p_value']:.4f}") +print(f"KS test p-value: {results['ks_p_value']:.4f}") +print(f"Parallel trends plausible: {results['parallel_trends_plausible']}") +``` + +The Wasserstein (Earth Mover's) distance compares the full distribution of outcome changes, not just means. This is more robust to: +- Non-normal distributions +- Heterogeneous effects across units +- Outliers + +**Equivalence testing (TOST):** + +```python +from diff_diff.utils import equivalence_test_trends + +results = equivalence_test_trends( + data, + outcome='outcome', + time='period', + treatment_group='treated', + unit='firm_id', + equivalence_margin=0.5 # Define "practically equivalent" +) + +print(f"Mean difference: {results['mean_difference']:.4f}") +print(f"TOST p-value: {results['tost_p_value']:.4f}") +print(f"Trends equivalent: {results['equivalent']}") ``` ## API Reference diff --git a/diff_diff/utils.py b/diff_diff/utils.py index 07152875..e11327c2 100644 --- a/diff_diff/utils.py +++ b/diff_diff/utils.py @@ -248,3 +248,334 @@ def compute_trend(group_data): "p_value": p_value, "parallel_trends_plausible": p_value > 0.05 if not np.isnan(p_value) else None, } + + +def check_parallel_trends_robust( + data: pd.DataFrame, + outcome: str, + time: str, + treatment_group: str, + unit: str = None, + pre_periods: list = None, + n_permutations: int = 1000, + seed: int = None +) -> dict: + """ + Perform robust parallel trends testing using distributional comparisons. + + Uses the Wasserstein (Earth Mover's) distance to compare the full + distribution of outcome changes between treated and control groups, + with permutation-based inference. + + Parameters + ---------- + data : pd.DataFrame + Panel data with repeated observations over time. + outcome : str + Name of outcome variable column. + time : str + Name of time period column. + treatment_group : str + Name of treatment group indicator column (0/1). + unit : str, optional + Name of unit identifier column. If provided, computes unit-level + changes. Otherwise uses observation-level data. + pre_periods : list, optional + List of pre-treatment time periods. If None, uses first half of periods. + n_permutations : int, default=1000 + Number of permutations for computing p-value. + seed : int, optional + Random seed for reproducibility. + + Returns + ------- + dict + Dictionary containing: + - wasserstein_distance: Wasserstein distance between group distributions + - wasserstein_p_value: Permutation-based p-value + - ks_statistic: Kolmogorov-Smirnov test statistic + - ks_p_value: KS test p-value + - mean_difference: Difference in mean changes + - variance_ratio: Ratio of variances in changes + - treated_changes: Array of outcome changes for treated + - control_changes: Array of outcome changes for control + - parallel_trends_plausible: Boolean assessment + + Examples + -------- + >>> results = check_parallel_trends_robust( + ... data, outcome='sales', time='year', + ... treatment_group='treated', unit='firm_id' + ... ) + >>> print(f"Wasserstein distance: {results['wasserstein_distance']:.4f}") + >>> print(f"P-value: {results['wasserstein_p_value']:.4f}") + + Notes + ----- + The Wasserstein distance (Earth Mover's Distance) measures the minimum + "cost" of transforming one distribution into another. Unlike simple + mean comparisons, it captures differences in the entire distribution + shape, making it more robust to non-normal data and heterogeneous effects. + + A small Wasserstein distance and high p-value suggest the distributions + of pre-treatment changes are similar, supporting the parallel trends + assumption. + """ + if seed is not None: + np.random.seed(seed) + + # Identify pre-treatment periods + if pre_periods is None: + all_periods = sorted(data[time].unique()) + mid_point = len(all_periods) // 2 + pre_periods = all_periods[:mid_point] + + pre_data = data[data[time].isin(pre_periods)].copy() + + # Compute outcome changes + treated_changes, control_changes = _compute_outcome_changes( + pre_data, outcome, time, treatment_group, unit + ) + + if len(treated_changes) < 2 or len(control_changes) < 2: + return { + "wasserstein_distance": np.nan, + "wasserstein_p_value": np.nan, + "ks_statistic": np.nan, + "ks_p_value": np.nan, + "mean_difference": np.nan, + "variance_ratio": np.nan, + "treated_changes": treated_changes, + "control_changes": control_changes, + "parallel_trends_plausible": None, + "error": "Insufficient data for comparison", + } + + # Compute Wasserstein distance + wasserstein_dist = stats.wasserstein_distance(treated_changes, control_changes) + + # Permutation test for Wasserstein distance + all_changes = np.concatenate([treated_changes, control_changes]) + n_treated = len(treated_changes) + n_total = len(all_changes) + + permuted_distances = np.zeros(n_permutations) + for i in range(n_permutations): + perm_idx = np.random.permutation(n_total) + perm_treated = all_changes[perm_idx[:n_treated]] + perm_control = all_changes[perm_idx[n_treated:]] + permuted_distances[i] = stats.wasserstein_distance(perm_treated, perm_control) + + # P-value: proportion of permuted distances >= observed + wasserstein_p = np.mean(permuted_distances >= wasserstein_dist) + + # Kolmogorov-Smirnov test + ks_stat, ks_p = stats.ks_2samp(treated_changes, control_changes) + + # Additional summary statistics + mean_diff = np.mean(treated_changes) - np.mean(control_changes) + var_treated = np.var(treated_changes, ddof=1) + var_control = np.var(control_changes, ddof=1) + var_ratio = var_treated / var_control if var_control > 0 else np.nan + + # Normalized Wasserstein (relative to pooled std) + pooled_std = np.std(all_changes, ddof=1) + wasserstein_normalized = wasserstein_dist / pooled_std if pooled_std > 0 else np.nan + + # Assessment: parallel trends plausible if p-value > 0.05 + # and normalized Wasserstein is small (< 0.2 as rule of thumb) + plausible = bool( + wasserstein_p > 0.05 and + (wasserstein_normalized < 0.2 if not np.isnan(wasserstein_normalized) else True) + ) + + return { + "wasserstein_distance": wasserstein_dist, + "wasserstein_normalized": wasserstein_normalized, + "wasserstein_p_value": wasserstein_p, + "ks_statistic": ks_stat, + "ks_p_value": ks_p, + "mean_difference": mean_diff, + "variance_ratio": var_ratio, + "n_treated": len(treated_changes), + "n_control": len(control_changes), + "treated_changes": treated_changes, + "control_changes": control_changes, + "parallel_trends_plausible": plausible, + } + + +def _compute_outcome_changes( + data: pd.DataFrame, + outcome: str, + time: str, + treatment_group: str, + unit: str = None +) -> tuple: + """ + Compute period-to-period outcome changes for treated and control groups. + + Parameters + ---------- + data : pd.DataFrame + Panel data. + outcome : str + Outcome variable column. + time : str + Time period column. + treatment_group : str + Treatment group indicator column. + unit : str, optional + Unit identifier column. + + Returns + ------- + tuple + (treated_changes, control_changes) as numpy arrays. + """ + if unit is not None: + # Unit-level changes: compute change for each unit across periods + data_sorted = data.sort_values([unit, time]) + data_sorted["_outcome_change"] = data_sorted.groupby(unit)[outcome].diff() + + # Remove NaN from first period of each unit + changes_data = data_sorted.dropna(subset=["_outcome_change"]) + + treated_changes = changes_data[ + changes_data[treatment_group] == 1 + ]["_outcome_change"].values + + control_changes = changes_data[ + changes_data[treatment_group] == 0 + ]["_outcome_change"].values + else: + # Aggregate changes: compute mean change per period per group + periods = sorted(data[time].unique()) + + treated_data = data[data[treatment_group] == 1] + control_data = data[data[treatment_group] == 0] + + # Compute period means + treated_means = treated_data.groupby(time)[outcome].mean() + control_means = control_data.groupby(time)[outcome].mean() + + # Compute changes between consecutive periods + treated_changes = np.diff(treated_means.values) + control_changes = np.diff(control_means.values) + + return treated_changes.astype(float), control_changes.astype(float) + + +def equivalence_test_trends( + data: pd.DataFrame, + outcome: str, + time: str, + treatment_group: str, + unit: str = None, + pre_periods: list = None, + equivalence_margin: float = None +) -> dict: + """ + Perform equivalence testing (TOST) for parallel trends. + + Tests whether the difference in trends is practically equivalent to zero + using Two One-Sided Tests (TOST) procedure. + + Parameters + ---------- + data : pd.DataFrame + Panel data. + outcome : str + Name of outcome variable column. + time : str + Name of time period column. + treatment_group : str + Name of treatment group indicator column. + unit : str, optional + Name of unit identifier column. + pre_periods : list, optional + List of pre-treatment time periods. + equivalence_margin : float, optional + The margin for equivalence (delta). If None, uses 0.5 * pooled SD + of outcome changes as a default. + + Returns + ------- + dict + Dictionary containing: + - mean_difference: Difference in mean changes + - equivalence_margin: The margin used + - lower_p_value: P-value for lower bound test + - upper_p_value: P-value for upper bound test + - tost_p_value: Maximum of the two p-values + - equivalent: Boolean indicating equivalence at alpha=0.05 + """ + # Get pre-treatment periods + if pre_periods is None: + all_periods = sorted(data[time].unique()) + mid_point = len(all_periods) // 2 + pre_periods = all_periods[:mid_point] + + pre_data = data[data[time].isin(pre_periods)].copy() + + # Compute outcome changes + treated_changes, control_changes = _compute_outcome_changes( + pre_data, outcome, time, treatment_group, unit + ) + + if len(treated_changes) < 2 or len(control_changes) < 2: + return { + "mean_difference": np.nan, + "equivalence_margin": np.nan, + "lower_p_value": np.nan, + "upper_p_value": np.nan, + "tost_p_value": np.nan, + "equivalent": None, + "error": "Insufficient data", + } + + # Compute statistics + mean_diff = np.mean(treated_changes) - np.mean(control_changes) + se_diff = np.sqrt( + np.var(treated_changes, ddof=1) / len(treated_changes) + + np.var(control_changes, ddof=1) / len(control_changes) + ) + + # Set equivalence margin if not provided + if equivalence_margin is None: + pooled_changes = np.concatenate([treated_changes, control_changes]) + equivalence_margin = 0.5 * np.std(pooled_changes, ddof=1) + + # Degrees of freedom (Welch-Satterthwaite approximation) + var_t = np.var(treated_changes, ddof=1) + var_c = np.var(control_changes, ddof=1) + n_t = len(treated_changes) + n_c = len(control_changes) + + df = ((var_t/n_t + var_c/n_c)**2 / + ((var_t/n_t)**2/(n_t-1) + (var_c/n_c)**2/(n_c-1))) + + # TOST: Two one-sided tests + # Test 1: H0: diff <= -margin vs H1: diff > -margin + t_lower = (mean_diff - (-equivalence_margin)) / se_diff + p_lower = stats.t.sf(t_lower, df) + + # Test 2: H0: diff >= margin vs H1: diff < margin + t_upper = (mean_diff - equivalence_margin) / se_diff + p_upper = stats.t.cdf(t_upper, df) + + # TOST p-value is the maximum of the two + tost_p = max(p_lower, p_upper) + + return { + "mean_difference": mean_diff, + "se_difference": se_diff, + "equivalence_margin": equivalence_margin, + "lower_t_stat": t_lower, + "upper_t_stat": t_upper, + "lower_p_value": p_lower, + "upper_p_value": p_upper, + "tost_p_value": tost_p, + "degrees_of_freedom": df, + "equivalent": bool(tost_p < 0.05), + } diff --git a/tests/test_estimators.py b/tests/test_estimators.py index d3348823..f8b70e49 100644 --- a/tests/test_estimators.py +++ b/tests/test_estimators.py @@ -452,3 +452,239 @@ def test_covariates_with_fixed_effects(self, panel_data_with_fe): assert results is not None assert "size" in results.coefficients + + +class TestParallelTrendsRobust: + """Tests for robust parallel trends checking.""" + + @pytest.fixture + def parallel_trends_data(self): + """Create panel data where parallel trends holds.""" + np.random.seed(42) + n_units = 100 + n_periods = 6 # 3 pre, 3 post + + data = [] + for unit in range(n_units): + is_treated = unit < n_units // 2 + unit_effect = np.random.normal(0, 2) + + for period in range(n_periods): + # Common trend for both groups + time_effect = period * 1.5 + + y = 10.0 + unit_effect + time_effect + + # Treatment effect only in post period (period >= 3) + if is_treated and period >= 3: + y += 5.0 + + y += np.random.normal(0, 0.5) + + data.append({ + "unit": unit, + "period": period, + "treated": int(is_treated), + "outcome": y, + }) + + return pd.DataFrame(data) + + @pytest.fixture + def non_parallel_trends_data(self): + """Create panel data where parallel trends is violated.""" + np.random.seed(42) + n_units = 100 + n_periods = 6 + + data = [] + for unit in range(n_units): + is_treated = unit < n_units // 2 + unit_effect = np.random.normal(0, 2) + + for period in range(n_periods): + # Different trends for treated vs control + if is_treated: + time_effect = period * 3.0 # Steeper trend + else: + time_effect = period * 1.0 # Flatter trend + + y = 10.0 + unit_effect + time_effect + + # Treatment effect in post period + if is_treated and period >= 3: + y += 5.0 + + y += np.random.normal(0, 0.5) + + data.append({ + "unit": unit, + "period": period, + "treated": int(is_treated), + "outcome": y, + }) + + return pd.DataFrame(data) + + def test_wasserstein_parallel_trends_valid(self, parallel_trends_data): + """Test Wasserstein check when parallel trends holds.""" + from diff_diff.utils import check_parallel_trends_robust + + results = check_parallel_trends_robust( + parallel_trends_data, + outcome="outcome", + time="period", + treatment_group="treated", + unit="unit", + pre_periods=[0, 1, 2], + seed=42 + ) + + assert "wasserstein_distance" in results + assert "wasserstein_p_value" in results + assert "ks_statistic" in results + # When trends are parallel, p-value should be high + assert results["wasserstein_p_value"] > 0.05 + assert results["parallel_trends_plausible"] is True + + def test_wasserstein_parallel_trends_violated(self, non_parallel_trends_data): + """Test Wasserstein check when parallel trends is violated.""" + from diff_diff.utils import check_parallel_trends_robust + + results = check_parallel_trends_robust( + non_parallel_trends_data, + outcome="outcome", + time="period", + treatment_group="treated", + unit="unit", + pre_periods=[0, 1, 2], + seed=42 + ) + + # When trends are not parallel, should detect it + # Either low p-value or high normalized Wasserstein + assert results["wasserstein_distance"] > 0 + # The test should flag this as problematic + assert results["parallel_trends_plausible"] is False + + def test_wasserstein_returns_changes(self, parallel_trends_data): + """Test that changes arrays are returned.""" + from diff_diff.utils import check_parallel_trends_robust + + results = check_parallel_trends_robust( + parallel_trends_data, + outcome="outcome", + time="period", + treatment_group="treated", + unit="unit", + pre_periods=[0, 1, 2], + seed=42 + ) + + assert "treated_changes" in results + assert "control_changes" in results + assert len(results["treated_changes"]) > 0 + assert len(results["control_changes"]) > 0 + + def test_wasserstein_without_unit(self, parallel_trends_data): + """Test Wasserstein check without unit specification.""" + from diff_diff.utils import check_parallel_trends_robust + + results = check_parallel_trends_robust( + parallel_trends_data, + outcome="outcome", + time="period", + treatment_group="treated", + pre_periods=[0, 1, 2], + seed=42 + ) + + assert "wasserstein_distance" in results + assert not np.isnan(results["wasserstein_distance"]) + + def test_equivalence_test_parallel(self, parallel_trends_data): + """Test equivalence testing when trends are parallel.""" + from diff_diff.utils import equivalence_test_trends + + results = equivalence_test_trends( + parallel_trends_data, + outcome="outcome", + time="period", + treatment_group="treated", + unit="unit", + pre_periods=[0, 1, 2] + ) + + assert "tost_p_value" in results + assert "equivalent" in results + assert "equivalence_margin" in results + # When trends are parallel, should be equivalent + assert results["equivalent"] is True + + def test_equivalence_test_non_parallel(self, non_parallel_trends_data): + """Test equivalence testing when trends are not parallel.""" + from diff_diff.utils import equivalence_test_trends + + results = equivalence_test_trends( + non_parallel_trends_data, + outcome="outcome", + time="period", + treatment_group="treated", + unit="unit", + pre_periods=[0, 1, 2] + ) + + # When trends are not parallel, should not be equivalent + assert results["equivalent"] is False + + def test_equivalence_test_custom_margin(self, parallel_trends_data): + """Test equivalence testing with custom margin.""" + from diff_diff.utils import equivalence_test_trends + + results = equivalence_test_trends( + parallel_trends_data, + outcome="outcome", + time="period", + treatment_group="treated", + unit="unit", + pre_periods=[0, 1, 2], + equivalence_margin=0.1 # Very tight margin + ) + + assert results["equivalence_margin"] == 0.1 + + def test_ks_test_included(self, parallel_trends_data): + """Test that KS test results are included.""" + from diff_diff.utils import check_parallel_trends_robust + + results = check_parallel_trends_robust( + parallel_trends_data, + outcome="outcome", + time="period", + treatment_group="treated", + unit="unit", + pre_periods=[0, 1, 2], + seed=42 + ) + + assert "ks_statistic" in results + assert "ks_p_value" in results + assert 0 <= results["ks_statistic"] <= 1 + assert 0 <= results["ks_p_value"] <= 1 + + def test_variance_ratio(self, parallel_trends_data): + """Test that variance ratio is computed.""" + from diff_diff.utils import check_parallel_trends_robust + + results = check_parallel_trends_robust( + parallel_trends_data, + outcome="outcome", + time="period", + treatment_group="treated", + unit="unit", + pre_periods=[0, 1, 2], + seed=42 + ) + + assert "variance_ratio" in results + assert results["variance_ratio"] > 0