Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
102 changes: 64 additions & 38 deletions diff_diff/prep.py
Original file line number Diff line number Diff line change
Expand Up @@ -249,21 +249,23 @@ def wide_to_long(
# Get other columns to preserve (not id or value columns)
other_cols = [c for c in data.columns if c != id_column and c not in value_columns]

# Build long format
records = []
for _, row in data.iterrows():
for time_val, value_col in zip(time_values, value_columns):
record = {id_column: row[id_column], time_name: time_val, value_name: row[value_col]}
# Preserve other columns
for col in other_cols:
record[col] = row[col]
records.append(record)
# Use pd.melt for better performance (vectorized)
long_df = pd.melt(
data,
id_vars=[id_column] + other_cols,
value_vars=value_columns,
var_name='_temp_var',
value_name=value_name
)

long_df = pd.DataFrame(records)
# Map column names to time values
col_to_time = dict(zip(value_columns, time_values))
long_df[time_name] = long_df['_temp_var'].map(col_to_time)
long_df = long_df.drop('_temp_var', axis=1)

# Reorder columns
# Reorder columns and sort
cols = [id_column, time_name, value_name] + other_cols
return long_df[cols]
return long_df[cols].sort_values([id_column, time_name]).reset_index(drop=True)


def balance_panel(
Expand Down Expand Up @@ -347,18 +349,21 @@ def balance_panel(
result = full_df.merge(data, on=[unit_column, time_column], how="left")

if method == "fill":
# Identify columns to fill (exclude unit and time columns)
cols_to_fill = [c for c in result.columns if c not in [unit_column, time_column]]

if fill_value is not None:
# Fill all numeric columns with fill_value
# Fill specified columns with fill_value
numeric_cols = result.select_dtypes(include=[np.number]).columns
for col in numeric_cols:
if col not in [unit_column, time_column]:
if col in cols_to_fill:
result[col] = result[col].fillna(fill_value)
else:
# Forward fill within each unit
# Forward fill within each unit for non-key columns
result = result.sort_values([unit_column, time_column])
result = result.groupby(unit_column).ffill()
result[cols_to_fill] = result.groupby(unit_column)[cols_to_fill].ffill()
# Backward fill any remaining NaN at start
result = result.groupby(unit_column).bfill()
result[cols_to_fill] = result.groupby(unit_column)[cols_to_fill].bfill()

return result

Expand Down Expand Up @@ -481,13 +486,25 @@ def validate_did_data(
errors.append("No control observations found (treatment column is all 1).")

# Check for each treatment-time combination
for t_val in [0, 1]:
for p_val in [0, 1] if len(time_vals) == 2 else time_vals[:2]:
count = len(data[(data[treatment] == t_val) & (data[time] == p_val)])
if count == 0:
if len(time_vals) == 2:
# For 2-period DiD, check all four cells
for t_val in [0, 1]:
for p_val in time_vals:
count = len(data[(data[treatment] == t_val) & (data[time] == p_val)])
if count == 0:
errors.append(
f"No observations for treatment={t_val}, time={p_val}. "
"DiD requires observations in all treatment-time cells."
)
else:
# For multi-period, check that both treatment groups exist in multiple periods
for t_val in [0, 1]:
n_periods_with_obs = data[data[treatment] == t_val][time].nunique()
if n_periods_with_obs < 2:
group_name = "Treated" if t_val == 1 else "Control"
errors.append(
f"No observations for treatment={t_val}, time={p_val}. "
"DiD requires observations in all treatment-time cells."
f"{group_name} group has observations in only {n_periods_with_obs} period(s). "
"DiD requires multiple periods per group."
)

# Panel-specific validation
Expand Down Expand Up @@ -571,24 +588,25 @@ def summarize_did_data(
("max", "max")
]).round(4)

# Add group labels
summary.index = summary.index.map(
lambda x: f"{'Treated' if x[0] == 1 else 'Control'} - "
f"{'Post' if x[1] == 1 else 'Pre'}"
if len(data[time].unique()) == 2
else f"{'Treated' if x[0] == 1 else 'Control'} - Period {x[1]}"
)

# Calculate DiD components if binary time
# Calculate time values for labeling
time_vals = sorted(data[time].unique())

# Add group labels based on sorted time values (not literal 0/1)
if len(time_vals) == 2:
pre, post = time_vals[0], time_vals[1]
pre_val, post_val = time_vals[0], time_vals[1]

def format_label(x):
treatment_label = 'Treated' if x[0] == 1 else 'Control'
time_label = 'Post' if x[1] == post_val else 'Pre'
return f"{treatment_label} - {time_label}"

summary.index = summary.index.map(format_label)

# Calculate means for each cell
treated_pre = data[(data[treatment] == 1) & (data[time] == pre)][outcome].mean()
treated_post = data[(data[treatment] == 1) & (data[time] == post)][outcome].mean()
control_pre = data[(data[treatment] == 0) & (data[time] == pre)][outcome].mean()
control_post = data[(data[treatment] == 0) & (data[time] == post)][outcome].mean()
treated_pre = data[(data[treatment] == 1) & (data[time] == pre_val)][outcome].mean()
treated_post = data[(data[treatment] == 1) & (data[time] == post_val)][outcome].mean()
control_pre = data[(data[treatment] == 0) & (data[time] == pre_val)][outcome].mean()
control_post = data[(data[treatment] == 0) & (data[time] == post_val)][outcome].mean()

# Calculate DiD
treated_diff = treated_post - treated_pre
Expand All @@ -607,6 +625,10 @@ def summarize_did_data(
index=["DiD Estimate"]
)
summary = pd.concat([summary, did_row])
else:
summary.index = summary.index.map(
lambda x: f"{'Treated' if x[0] == 1 else 'Control'} - Period {x[1]}"
)

return summary

Expand Down Expand Up @@ -776,7 +798,11 @@ def create_event_time(
df[new_column] = df[time_column] - df[treatment_time_column]

# Handle never-treated (inf or NaN in treatment time)
never_treated = df[treatment_time_column].isna() | np.isinf(df[treatment_time_column])
col = df[treatment_time_column]
if pd.api.types.is_numeric_dtype(col):
never_treated = col.isna() | np.isinf(col)
else:
never_treated = col.isna()
df.loc[never_treated, new_column] = np.nan

return df
Expand Down
14 changes: 14 additions & 0 deletions tests/test_prep.py
Original file line number Diff line number Diff line change
Expand Up @@ -229,6 +229,20 @@ def test_fill_with_value(self):
missing_row = result[(result["unit"] == 2) & (result["period"] == 2)]
assert missing_row["y"].values[0] == 0.0

def test_fill_forward_backward(self):
"""Test fill method with forward/backward fill."""
df = pd.DataFrame({
"unit": [1, 1, 1, 2, 2],
"period": [1, 2, 3, 1, 3], # Unit 2 missing period 2
"y": [10.0, 11.0, 12.0, 20.0, 22.0]
})
result = balance_panel(df, "unit", "period", method="fill", fill_value=None)
assert len(result) == 6
# Check that unit 2, period 2 was filled
filled_row = result[(result["unit"] == 2) & (result["period"] == 2)]
assert len(filled_row) == 1
assert filled_row["y"].values[0] == 20.0 # Forward filled from period 1

def test_error_invalid_method(self):
"""Test error with invalid method."""
df = pd.DataFrame({"unit": [1], "period": [1], "y": [10]})
Expand Down