Skip to content

Commit d6b559e

Browse files
sdpythonCopilot
andauthored
Update _doc/examples/ml/plot_template_data.py
Co-authored-by: Copilot <175728472+Copilot@users.noreply.github.com>
1 parent 0de836b commit d6b559e

1 file changed

Lines changed: 12 additions & 7 deletions

File tree

_doc/examples/ml/plot_template_data.py

Lines changed: 12 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -94,19 +94,24 @@ def split_train_test(table, cible):
9494

9595
def make_pipeline(table, cible):
9696
vars = [c for c in table.columns if c != cible]
97-
num_cols = ["Capacité de l’établissement par formation"]
97+
# Candidate numeric feature; include it only if it exists in the table to avoid KeyError.
98+
numeric_feature = "Capacité de l’établissement par formation"
99+
num_cols = [numeric_feature] if numeric_feature in table.columns else []
98100
cat_cols = [c for c in vars if c not in num_cols]
99101

102+
transformers = []
103+
if num_cols:
104+
transformers.append(("num", StandardScaler(), num_cols))
105+
if cat_cols:
106+
transformers.append(
107+
("cats", OneHotEncoder(handle_unknown="ignore"), cat_cols)
108+
)
109+
100110
model = Pipeline(
101111
[
102112
(
103113
"preprocessing",
104-
ColumnTransformer(
105-
[
106-
("num", StandardScaler(), num_cols),
107-
("cats", OneHotEncoder(handle_unknown="ignore"), cat_cols),
108-
]
109-
),
114+
ColumnTransformer(transformers),
110115
),
111116
("regressor", HistGradientBoostingRegressor()),
112117
]

0 commit comments

Comments
 (0)