Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
30 changes: 24 additions & 6 deletions diff_diff/estimators.py
Original file line number Diff line number Diff line change
Expand Up @@ -169,6 +169,10 @@ def fit(
# Validate inputs
self._validate_data(data, outcome, treatment, time, covariates)

# Validate binary variables BEFORE any transformations
validate_binary(data[treatment].values, "treatment")
validate_binary(data[time].values, "time")

# Validate fixed effects and absorb columns
if fixed_effects:
for fe in fixed_effects:
Expand All @@ -186,6 +190,9 @@ def fit(

if absorb:
# Apply within-transformation for each absorbed variable
# Only demean outcome and covariates, NOT treatment/time indicators
# Treatment is typically time-invariant (within unit), and time is
# unit-invariant, so demeaning them would create multicollinearity
vars_to_demean = [outcome] + (covariates or [])
for ab_var in absorb:
n_absorbed_effects += working_data[ab_var].nunique() - 1
Expand All @@ -194,15 +201,11 @@ def fit(
working_data[var] = working_data[var] - group_means
absorbed_vars.append(ab_var)

# Extract variables
# Extract variables (may be demeaned if absorb was used)
y = working_data[outcome].values.astype(float)
d = working_data[treatment].values.astype(float)
t = working_data[time].values.astype(float)

# Validate binary variables
validate_binary(d, "treatment")
validate_binary(t, "time")

# Create interaction term
dt = d * t

Expand All @@ -220,7 +223,8 @@ def fit(
if fixed_effects:
for fe in fixed_effects:
# Create dummies, drop first category to avoid multicollinearity
dummies = pd.get_dummies(data[fe], prefix=fe, drop_first=True)
# Use working_data to be consistent with absorbed FE if both are used
dummies = pd.get_dummies(working_data[fe], prefix=fe, drop_first=True)
for col in dummies.columns:
X = np.column_stack([X, dummies[col].values.astype(float)])
var_names.append(col)
Expand Down Expand Up @@ -298,7 +302,21 @@ def _fit_ols(self, X: np.ndarray, y: np.ndarray) -> tuple:
-------
tuple
(coefficients, residuals, fitted_values, r_squared)

Raises
------
ValueError
If design matrix is rank-deficient (perfect multicollinearity).
"""
# Check for rank deficiency (perfect multicollinearity)
rank = np.linalg.matrix_rank(X)
if rank < X.shape[1]:
raise ValueError(
f"Design matrix is rank-deficient (rank {rank} < {X.shape[1]} columns). "
"This indicates perfect multicollinearity. Check your fixed effects "
"and covariates for linear dependencies."
)

# Solve normal equations: β = (X'X)^(-1) X'y
coefficients = np.linalg.lstsq(X, y, rcond=None)[0]

Expand Down
94 changes: 74 additions & 20 deletions diff_diff/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -218,13 +218,18 @@ def compute_trend(group_data):
mean_t = np.mean(time_norm)
mean_y = np.mean(outcome_values)

slope = np.sum((time_norm - mean_t) * (outcome_values - mean_y)) / np.sum((time_norm - mean_t) ** 2)
# Check for zero variance in time (all same time period)
time_var = np.sum((time_norm - mean_t) ** 2)
if time_var == 0:
return np.nan, np.nan

slope = np.sum((time_norm - mean_t) * (outcome_values - mean_y)) / time_var

# Compute standard error of slope
y_hat = mean_y + slope * (time_norm - mean_t)
residuals = outcome_values - y_hat
mse = np.sum(residuals ** 2) / (n - 2)
se_slope = np.sqrt(mse / np.sum((time_norm - mean_t) ** 2))
se_slope = np.sqrt(mse / time_var)

return slope, se_slope

Expand Down Expand Up @@ -258,7 +263,8 @@ def check_parallel_trends_robust(
unit: str = None,
pre_periods: list = None,
n_permutations: int = 1000,
seed: int = None
seed: int = None,
wasserstein_threshold: float = 0.2
) -> dict:
"""
Perform robust parallel trends testing using distributional comparisons.
Expand Down Expand Up @@ -286,6 +292,9 @@ def check_parallel_trends_robust(
Number of permutations for computing p-value.
seed : int, optional
Random seed for reproducibility.
wasserstein_threshold : float, default=0.2
Threshold for normalized Wasserstein distance. Values below this
threshold (combined with p > 0.05) suggest parallel trends are plausible.

Returns
-------
Expand Down Expand Up @@ -321,8 +330,8 @@ def check_parallel_trends_robust(
of pre-treatment changes are similar, supporting the parallel trends
assumption.
"""
if seed is not None:
np.random.seed(seed)
# Use local RNG to avoid affecting global random state
rng = np.random.default_rng(seed)

# Identify pre-treatment periods
if pre_periods is None:
Expand Down Expand Up @@ -361,7 +370,7 @@ def check_parallel_trends_robust(

permuted_distances = np.zeros(n_permutations)
for i in range(n_permutations):
perm_idx = np.random.permutation(n_total)
perm_idx = rng.permutation(n_total)
perm_treated = all_changes[perm_idx[:n_treated]]
perm_control = all_changes[perm_idx[n_treated:]]
permuted_distances[i] = stats.wasserstein_distance(perm_treated, perm_control)
Expand All @@ -383,10 +392,10 @@ def check_parallel_trends_robust(
wasserstein_normalized = wasserstein_dist / pooled_std if pooled_std > 0 else np.nan

# Assessment: parallel trends plausible if p-value > 0.05
# and normalized Wasserstein is small (< 0.2 as rule of thumb)
# and normalized Wasserstein is small (below threshold)
plausible = bool(
wasserstein_p > 0.05 and
(wasserstein_normalized < 0.2 if not np.isnan(wasserstein_normalized) else True)
(wasserstein_normalized < wasserstein_threshold if not np.isnan(wasserstein_normalized) else True)
)

return {
Expand Down Expand Up @@ -523,37 +532,82 @@ def equivalence_test_trends(
pre_data, outcome, time, treatment_group, unit
)

# Need at least 2 observations per group to compute variance
# and at least 3 total for meaningful df calculation
if len(treated_changes) < 2 or len(control_changes) < 2:
return {
"mean_difference": np.nan,
"se_difference": np.nan,
"equivalence_margin": np.nan,
"lower_t_stat": np.nan,
"upper_t_stat": np.nan,
"lower_p_value": np.nan,
"upper_p_value": np.nan,
"tost_p_value": np.nan,
"degrees_of_freedom": np.nan,
"equivalent": None,
"error": "Insufficient data",
"error": "Insufficient data (need at least 2 observations per group)",
}

# Compute statistics
var_t = np.var(treated_changes, ddof=1)
var_c = np.var(control_changes, ddof=1)
n_t = len(treated_changes)
n_c = len(control_changes)

mean_diff = np.mean(treated_changes) - np.mean(control_changes)
se_diff = np.sqrt(
np.var(treated_changes, ddof=1) / len(treated_changes) +
np.var(control_changes, ddof=1) / len(control_changes)
)

# Handle zero variance case
if var_t == 0 and var_c == 0:
return {
"mean_difference": mean_diff,
"se_difference": 0.0,
"equivalence_margin": np.nan,
"lower_t_stat": np.nan,
"upper_t_stat": np.nan,
"lower_p_value": np.nan,
"upper_p_value": np.nan,
"tost_p_value": np.nan,
"degrees_of_freedom": np.nan,
"equivalent": None,
"error": "Zero variance in both groups - cannot perform t-test",
}

se_diff = np.sqrt(var_t / n_t + var_c / n_c)

# Handle zero SE case (cannot divide by zero in t-stat calculation)
if se_diff == 0:
return {
"mean_difference": mean_diff,
"se_difference": 0.0,
"equivalence_margin": np.nan,
"lower_t_stat": np.nan,
"upper_t_stat": np.nan,
"lower_p_value": np.nan,
"upper_p_value": np.nan,
"tost_p_value": np.nan,
"degrees_of_freedom": np.nan,
"equivalent": None,
"error": "Zero standard error - cannot perform t-test",
}

# Set equivalence margin if not provided
if equivalence_margin is None:
pooled_changes = np.concatenate([treated_changes, control_changes])
equivalence_margin = 0.5 * np.std(pooled_changes, ddof=1)

# Degrees of freedom (Welch-Satterthwaite approximation)
var_t = np.var(treated_changes, ddof=1)
var_c = np.var(control_changes, ddof=1)
n_t = len(treated_changes)
n_c = len(control_changes)

df = ((var_t/n_t + var_c/n_c)**2 /
((var_t/n_t)**2/(n_t-1) + (var_c/n_c)**2/(n_c-1)))
# Guard against division by zero when one group has zero variance
numerator = (var_t/n_t + var_c/n_c)**2
denom_t = (var_t/n_t)**2/(n_t-1) if var_t > 0 else 0
denom_c = (var_c/n_c)**2/(n_c-1) if var_c > 0 else 0
denominator = denom_t + denom_c

if denominator == 0:
# Fall back to minimum of n_t-1 and n_c-1 when one variance is zero
df = min(n_t - 1, n_c - 1)
else:
df = numerator / denominator

# TOST: Two one-sided tests
# Test 1: H0: diff <= -margin vs H1: diff > -margin
Expand Down
Loading