Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
48 changes: 47 additions & 1 deletion README.md
Original file line number Diff line number Diff line change
Expand Up @@ -235,6 +235,8 @@ results.r_squared

### Parallel Trends

**Simple slope-based test:**

```python
from diff_diff.utils import check_parallel_trends

Expand All @@ -248,7 +250,51 @@ trends = check_parallel_trends(
print(f"Treated trend: {trends['treated_trend']:.4f}")
print(f"Control trend: {trends['control_trend']:.4f}")
print(f"Difference p-value: {trends['p_value']:.4f}")
print(f"Parallel trends plausible: {trends['parallel_trends_plausible']}")
```

**Robust distributional test (Wasserstein distance):**

```python
from diff_diff.utils import check_parallel_trends_robust

results = check_parallel_trends_robust(
data,
outcome='outcome',
time='period',
treatment_group='treated',
unit='firm_id', # Unit identifier for panel data
pre_periods=[2018, 2019], # Pre-treatment periods
n_permutations=1000 # Permutations for p-value
)

print(f"Wasserstein distance: {results['wasserstein_distance']:.4f}")
print(f"Wasserstein p-value: {results['wasserstein_p_value']:.4f}")
print(f"KS test p-value: {results['ks_p_value']:.4f}")
print(f"Parallel trends plausible: {results['parallel_trends_plausible']}")
```

The Wasserstein (Earth Mover's) distance compares the full distribution of outcome changes, not just means. This is more robust to:
- Non-normal distributions
- Heterogeneous effects across units
- Outliers

**Equivalence testing (TOST):**

```python
from diff_diff.utils import equivalence_test_trends

results = equivalence_test_trends(
data,
outcome='outcome',
time='period',
treatment_group='treated',
unit='firm_id',
equivalence_margin=0.5 # Define "practically equivalent"
)

print(f"Mean difference: {results['mean_difference']:.4f}")
print(f"TOST p-value: {results['tost_p_value']:.4f}")
print(f"Trends equivalent: {results['equivalent']}")
```

## API Reference
Expand Down
331 changes: 331 additions & 0 deletions diff_diff/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -248,3 +248,334 @@ def compute_trend(group_data):
"p_value": p_value,
"parallel_trends_plausible": p_value > 0.05 if not np.isnan(p_value) else None,
}


def check_parallel_trends_robust(
data: pd.DataFrame,
outcome: str,
time: str,
treatment_group: str,
unit: str = None,
pre_periods: list = None,
n_permutations: int = 1000,
seed: int = None
) -> dict:
"""
Perform robust parallel trends testing using distributional comparisons.

Uses the Wasserstein (Earth Mover's) distance to compare the full
distribution of outcome changes between treated and control groups,
with permutation-based inference.

Parameters
----------
data : pd.DataFrame
Panel data with repeated observations over time.
outcome : str
Name of outcome variable column.
time : str
Name of time period column.
treatment_group : str
Name of treatment group indicator column (0/1).
unit : str, optional
Name of unit identifier column. If provided, computes unit-level
changes. Otherwise uses observation-level data.
pre_periods : list, optional
List of pre-treatment time periods. If None, uses first half of periods.
n_permutations : int, default=1000
Number of permutations for computing p-value.
seed : int, optional
Random seed for reproducibility.

Returns
-------
dict
Dictionary containing:
- wasserstein_distance: Wasserstein distance between group distributions
- wasserstein_p_value: Permutation-based p-value
- ks_statistic: Kolmogorov-Smirnov test statistic
- ks_p_value: KS test p-value
- mean_difference: Difference in mean changes
- variance_ratio: Ratio of variances in changes
- treated_changes: Array of outcome changes for treated
- control_changes: Array of outcome changes for control
- parallel_trends_plausible: Boolean assessment

Examples
--------
>>> results = check_parallel_trends_robust(
... data, outcome='sales', time='year',
... treatment_group='treated', unit='firm_id'
... )
>>> print(f"Wasserstein distance: {results['wasserstein_distance']:.4f}")
>>> print(f"P-value: {results['wasserstein_p_value']:.4f}")

Notes
-----
The Wasserstein distance (Earth Mover's Distance) measures the minimum
"cost" of transforming one distribution into another. Unlike simple
mean comparisons, it captures differences in the entire distribution
shape, making it more robust to non-normal data and heterogeneous effects.

A small Wasserstein distance and high p-value suggest the distributions
of pre-treatment changes are similar, supporting the parallel trends
assumption.
"""
if seed is not None:
np.random.seed(seed)

# Identify pre-treatment periods
if pre_periods is None:
all_periods = sorted(data[time].unique())
mid_point = len(all_periods) // 2
pre_periods = all_periods[:mid_point]

pre_data = data[data[time].isin(pre_periods)].copy()

# Compute outcome changes
treated_changes, control_changes = _compute_outcome_changes(
pre_data, outcome, time, treatment_group, unit
)

if len(treated_changes) < 2 or len(control_changes) < 2:
return {
"wasserstein_distance": np.nan,
"wasserstein_p_value": np.nan,
"ks_statistic": np.nan,
"ks_p_value": np.nan,
"mean_difference": np.nan,
"variance_ratio": np.nan,
"treated_changes": treated_changes,
"control_changes": control_changes,
"parallel_trends_plausible": None,
"error": "Insufficient data for comparison",
}

# Compute Wasserstein distance
wasserstein_dist = stats.wasserstein_distance(treated_changes, control_changes)

# Permutation test for Wasserstein distance
all_changes = np.concatenate([treated_changes, control_changes])
n_treated = len(treated_changes)
n_total = len(all_changes)

permuted_distances = np.zeros(n_permutations)
for i in range(n_permutations):
perm_idx = np.random.permutation(n_total)
perm_treated = all_changes[perm_idx[:n_treated]]
perm_control = all_changes[perm_idx[n_treated:]]
permuted_distances[i] = stats.wasserstein_distance(perm_treated, perm_control)

# P-value: proportion of permuted distances >= observed
wasserstein_p = np.mean(permuted_distances >= wasserstein_dist)

# Kolmogorov-Smirnov test
ks_stat, ks_p = stats.ks_2samp(treated_changes, control_changes)

# Additional summary statistics
mean_diff = np.mean(treated_changes) - np.mean(control_changes)
var_treated = np.var(treated_changes, ddof=1)
var_control = np.var(control_changes, ddof=1)
var_ratio = var_treated / var_control if var_control > 0 else np.nan

# Normalized Wasserstein (relative to pooled std)
pooled_std = np.std(all_changes, ddof=1)
wasserstein_normalized = wasserstein_dist / pooled_std if pooled_std > 0 else np.nan

# Assessment: parallel trends plausible if p-value > 0.05
# and normalized Wasserstein is small (< 0.2 as rule of thumb)
plausible = bool(
wasserstein_p > 0.05 and
(wasserstein_normalized < 0.2 if not np.isnan(wasserstein_normalized) else True)
)

return {
"wasserstein_distance": wasserstein_dist,
"wasserstein_normalized": wasserstein_normalized,
"wasserstein_p_value": wasserstein_p,
"ks_statistic": ks_stat,
"ks_p_value": ks_p,
"mean_difference": mean_diff,
"variance_ratio": var_ratio,
"n_treated": len(treated_changes),
"n_control": len(control_changes),
"treated_changes": treated_changes,
"control_changes": control_changes,
"parallel_trends_plausible": plausible,
}


def _compute_outcome_changes(
data: pd.DataFrame,
outcome: str,
time: str,
treatment_group: str,
unit: str = None
) -> tuple:
"""
Compute period-to-period outcome changes for treated and control groups.

Parameters
----------
data : pd.DataFrame
Panel data.
outcome : str
Outcome variable column.
time : str
Time period column.
treatment_group : str
Treatment group indicator column.
unit : str, optional
Unit identifier column.

Returns
-------
tuple
(treated_changes, control_changes) as numpy arrays.
"""
if unit is not None:
# Unit-level changes: compute change for each unit across periods
data_sorted = data.sort_values([unit, time])
data_sorted["_outcome_change"] = data_sorted.groupby(unit)[outcome].diff()

# Remove NaN from first period of each unit
changes_data = data_sorted.dropna(subset=["_outcome_change"])

treated_changes = changes_data[
changes_data[treatment_group] == 1
]["_outcome_change"].values

control_changes = changes_data[
changes_data[treatment_group] == 0
]["_outcome_change"].values
else:
# Aggregate changes: compute mean change per period per group
periods = sorted(data[time].unique())

treated_data = data[data[treatment_group] == 1]
control_data = data[data[treatment_group] == 0]

# Compute period means
treated_means = treated_data.groupby(time)[outcome].mean()
control_means = control_data.groupby(time)[outcome].mean()

# Compute changes between consecutive periods
treated_changes = np.diff(treated_means.values)
control_changes = np.diff(control_means.values)

return treated_changes.astype(float), control_changes.astype(float)


def equivalence_test_trends(
data: pd.DataFrame,
outcome: str,
time: str,
treatment_group: str,
unit: str = None,
pre_periods: list = None,
equivalence_margin: float = None
) -> dict:
"""
Perform equivalence testing (TOST) for parallel trends.

Tests whether the difference in trends is practically equivalent to zero
using Two One-Sided Tests (TOST) procedure.

Parameters
----------
data : pd.DataFrame
Panel data.
outcome : str
Name of outcome variable column.
time : str
Name of time period column.
treatment_group : str
Name of treatment group indicator column.
unit : str, optional
Name of unit identifier column.
pre_periods : list, optional
List of pre-treatment time periods.
equivalence_margin : float, optional
The margin for equivalence (delta). If None, uses 0.5 * pooled SD
of outcome changes as a default.

Returns
-------
dict
Dictionary containing:
- mean_difference: Difference in mean changes
- equivalence_margin: The margin used
- lower_p_value: P-value for lower bound test
- upper_p_value: P-value for upper bound test
- tost_p_value: Maximum of the two p-values
- equivalent: Boolean indicating equivalence at alpha=0.05
"""
# Get pre-treatment periods
if pre_periods is None:
all_periods = sorted(data[time].unique())
mid_point = len(all_periods) // 2
pre_periods = all_periods[:mid_point]

pre_data = data[data[time].isin(pre_periods)].copy()

# Compute outcome changes
treated_changes, control_changes = _compute_outcome_changes(
pre_data, outcome, time, treatment_group, unit
)

if len(treated_changes) < 2 or len(control_changes) < 2:
return {
"mean_difference": np.nan,
"equivalence_margin": np.nan,
"lower_p_value": np.nan,
"upper_p_value": np.nan,
"tost_p_value": np.nan,
"equivalent": None,
"error": "Insufficient data",
}

# Compute statistics
mean_diff = np.mean(treated_changes) - np.mean(control_changes)
se_diff = np.sqrt(
np.var(treated_changes, ddof=1) / len(treated_changes) +
np.var(control_changes, ddof=1) / len(control_changes)
)

# Set equivalence margin if not provided
if equivalence_margin is None:
pooled_changes = np.concatenate([treated_changes, control_changes])
equivalence_margin = 0.5 * np.std(pooled_changes, ddof=1)

# Degrees of freedom (Welch-Satterthwaite approximation)
var_t = np.var(treated_changes, ddof=1)
var_c = np.var(control_changes, ddof=1)
n_t = len(treated_changes)
n_c = len(control_changes)

df = ((var_t/n_t + var_c/n_c)**2 /
((var_t/n_t)**2/(n_t-1) + (var_c/n_c)**2/(n_c-1)))

# TOST: Two one-sided tests
# Test 1: H0: diff <= -margin vs H1: diff > -margin
t_lower = (mean_diff - (-equivalence_margin)) / se_diff
p_lower = stats.t.sf(t_lower, df)

# Test 2: H0: diff >= margin vs H1: diff < margin
t_upper = (mean_diff - equivalence_margin) / se_diff
p_upper = stats.t.cdf(t_upper, df)

# TOST p-value is the maximum of the two
tost_p = max(p_lower, p_upper)

return {
"mean_difference": mean_diff,
"se_difference": se_diff,
"equivalence_margin": equivalence_margin,
"lower_t_stat": t_lower,
"upper_t_stat": t_upper,
"lower_p_value": p_lower,
"upper_p_value": p_upper,
"tost_p_value": tost_p,
"degrees_of_freedom": df,
"equivalent": bool(tost_p < 0.05),
}
Loading