diff --git a/diff_diff/prep.py b/diff_diff/prep.py index 4406d827..99109fa7 100644 --- a/diff_diff/prep.py +++ b/diff_diff/prep.py @@ -249,21 +249,23 @@ def wide_to_long( # Get other columns to preserve (not id or value columns) other_cols = [c for c in data.columns if c != id_column and c not in value_columns] - # Build long format - records = [] - for _, row in data.iterrows(): - for time_val, value_col in zip(time_values, value_columns): - record = {id_column: row[id_column], time_name: time_val, value_name: row[value_col]} - # Preserve other columns - for col in other_cols: - record[col] = row[col] - records.append(record) + # Use pd.melt for better performance (vectorized) + long_df = pd.melt( + data, + id_vars=[id_column] + other_cols, + value_vars=value_columns, + var_name='_temp_var', + value_name=value_name + ) - long_df = pd.DataFrame(records) + # Map column names to time values + col_to_time = dict(zip(value_columns, time_values)) + long_df[time_name] = long_df['_temp_var'].map(col_to_time) + long_df = long_df.drop('_temp_var', axis=1) - # Reorder columns + # Reorder columns and sort cols = [id_column, time_name, value_name] + other_cols - return long_df[cols] + return long_df[cols].sort_values([id_column, time_name]).reset_index(drop=True) def balance_panel( @@ -347,18 +349,21 @@ def balance_panel( result = full_df.merge(data, on=[unit_column, time_column], how="left") if method == "fill": + # Identify columns to fill (exclude unit and time columns) + cols_to_fill = [c for c in result.columns if c not in [unit_column, time_column]] + if fill_value is not None: - # Fill all numeric columns with fill_value + # Fill specified columns with fill_value numeric_cols = result.select_dtypes(include=[np.number]).columns for col in numeric_cols: - if col not in [unit_column, time_column]: + if col in cols_to_fill: result[col] = result[col].fillna(fill_value) else: - # Forward fill within each unit + # Forward fill within each unit for non-key columns result = result.sort_values([unit_column, time_column]) - result = result.groupby(unit_column).ffill() + result[cols_to_fill] = result.groupby(unit_column)[cols_to_fill].ffill() # Backward fill any remaining NaN at start - result = result.groupby(unit_column).bfill() + result[cols_to_fill] = result.groupby(unit_column)[cols_to_fill].bfill() return result @@ -481,13 +486,25 @@ def validate_did_data( errors.append("No control observations found (treatment column is all 1).") # Check for each treatment-time combination - for t_val in [0, 1]: - for p_val in [0, 1] if len(time_vals) == 2 else time_vals[:2]: - count = len(data[(data[treatment] == t_val) & (data[time] == p_val)]) - if count == 0: + if len(time_vals) == 2: + # For 2-period DiD, check all four cells + for t_val in [0, 1]: + for p_val in time_vals: + count = len(data[(data[treatment] == t_val) & (data[time] == p_val)]) + if count == 0: + errors.append( + f"No observations for treatment={t_val}, time={p_val}. " + "DiD requires observations in all treatment-time cells." + ) + else: + # For multi-period, check that both treatment groups exist in multiple periods + for t_val in [0, 1]: + n_periods_with_obs = data[data[treatment] == t_val][time].nunique() + if n_periods_with_obs < 2: + group_name = "Treated" if t_val == 1 else "Control" errors.append( - f"No observations for treatment={t_val}, time={p_val}. " - "DiD requires observations in all treatment-time cells." + f"{group_name} group has observations in only {n_periods_with_obs} period(s). " + "DiD requires multiple periods per group." ) # Panel-specific validation @@ -571,24 +588,25 @@ def summarize_did_data( ("max", "max") ]).round(4) - # Add group labels - summary.index = summary.index.map( - lambda x: f"{'Treated' if x[0] == 1 else 'Control'} - " - f"{'Post' if x[1] == 1 else 'Pre'}" - if len(data[time].unique()) == 2 - else f"{'Treated' if x[0] == 1 else 'Control'} - Period {x[1]}" - ) - - # Calculate DiD components if binary time + # Calculate time values for labeling time_vals = sorted(data[time].unique()) + + # Add group labels based on sorted time values (not literal 0/1) if len(time_vals) == 2: - pre, post = time_vals[0], time_vals[1] + pre_val, post_val = time_vals[0], time_vals[1] + + def format_label(x): + treatment_label = 'Treated' if x[0] == 1 else 'Control' + time_label = 'Post' if x[1] == post_val else 'Pre' + return f"{treatment_label} - {time_label}" + + summary.index = summary.index.map(format_label) # Calculate means for each cell - treated_pre = data[(data[treatment] == 1) & (data[time] == pre)][outcome].mean() - treated_post = data[(data[treatment] == 1) & (data[time] == post)][outcome].mean() - control_pre = data[(data[treatment] == 0) & (data[time] == pre)][outcome].mean() - control_post = data[(data[treatment] == 0) & (data[time] == post)][outcome].mean() + treated_pre = data[(data[treatment] == 1) & (data[time] == pre_val)][outcome].mean() + treated_post = data[(data[treatment] == 1) & (data[time] == post_val)][outcome].mean() + control_pre = data[(data[treatment] == 0) & (data[time] == pre_val)][outcome].mean() + control_post = data[(data[treatment] == 0) & (data[time] == post_val)][outcome].mean() # Calculate DiD treated_diff = treated_post - treated_pre @@ -607,6 +625,10 @@ def summarize_did_data( index=["DiD Estimate"] ) summary = pd.concat([summary, did_row]) + else: + summary.index = summary.index.map( + lambda x: f"{'Treated' if x[0] == 1 else 'Control'} - Period {x[1]}" + ) return summary @@ -776,7 +798,11 @@ def create_event_time( df[new_column] = df[time_column] - df[treatment_time_column] # Handle never-treated (inf or NaN in treatment time) - never_treated = df[treatment_time_column].isna() | np.isinf(df[treatment_time_column]) + col = df[treatment_time_column] + if pd.api.types.is_numeric_dtype(col): + never_treated = col.isna() | np.isinf(col) + else: + never_treated = col.isna() df.loc[never_treated, new_column] = np.nan return df diff --git a/tests/test_prep.py b/tests/test_prep.py index e8d87b25..5cc1c615 100644 --- a/tests/test_prep.py +++ b/tests/test_prep.py @@ -229,6 +229,20 @@ def test_fill_with_value(self): missing_row = result[(result["unit"] == 2) & (result["period"] == 2)] assert missing_row["y"].values[0] == 0.0 + def test_fill_forward_backward(self): + """Test fill method with forward/backward fill.""" + df = pd.DataFrame({ + "unit": [1, 1, 1, 2, 2], + "period": [1, 2, 3, 1, 3], # Unit 2 missing period 2 + "y": [10.0, 11.0, 12.0, 20.0, 22.0] + }) + result = balance_panel(df, "unit", "period", method="fill", fill_value=None) + assert len(result) == 6 + # Check that unit 2, period 2 was filled + filled_row = result[(result["unit"] == 2) & (result["period"] == 2)] + assert len(filled_row) == 1 + assert filled_row["y"].values[0] == 20.0 # Forward filled from period 1 + def test_error_invalid_method(self): """Test error with invalid method.""" df = pd.DataFrame({"unit": [1], "period": [1], "y": [10]})