Skip to content

Commit 2728816

Browse files
Fix zero num_features (#374)
* use None internally to store 0 num_features * use super set_params * Update src/squlearn/encoding_circuit/layered_encoding_circuit.py Co-authored-by: Florian Wieland <114916947+ProfessorNova@users.noreply.github.com> --------- Co-authored-by: Florian Wieland <114916947+ProfessorNova@users.noreply.github.com>
1 parent 4ca30bd commit 2728816

3 files changed

Lines changed: 6 additions & 29 deletions

File tree

src/squlearn/encoding_circuit/circuit_library/random_encoding_circuit.py

Lines changed: 1 addition & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -402,17 +402,7 @@ def set_params(self, **params):
402402
Args:
403403
params: Hyper-parameters and their values, e.g. ``num_qubits=2``.
404404
"""
405-
valid_params = self.get_params()
406-
for key, value in params.items():
407-
if key not in valid_params:
408-
raise ValueError(
409-
f"Invalid parameter {key!r}. "
410-
f"Valid parameters are {sorted(valid_params)!r}."
411-
)
412-
try:
413-
setattr(self, key, value)
414-
except:
415-
setattr(self, "_" + key, value)
405+
super().set_params(**params)
416406

417407
# Reset the random configuration
418408
self._is_config_available = False

src/squlearn/encoding_circuit/encoding_circuit_base.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -22,7 +22,7 @@ class EncodingCircuitBase(ABC):
2222

2323
def __init__(self, num_qubits: int, num_features: int = None) -> None:
2424
self._num_qubits = num_qubits
25-
self._num_features = num_features
25+
self._num_features = num_features if num_features != 0 else None
2626

2727
if num_features is not None:
2828
warnings.warn(
@@ -213,6 +213,8 @@ def set_params(self, **params) -> EncodingCircuitBase:
213213
f"Invalid parameter {key!r}. "
214214
f"Valid parameters are {sorted(valid_params)!r}."
215215
)
216+
if key == "num_features" and value == 0:
217+
value = None
216218
try:
217219
setattr(self, key, value)
218220
except:

src/squlearn/encoding_circuit/layered_encoding_circuit.py

Lines changed: 2 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -2158,23 +2158,8 @@ def get_params(self, deep: bool = True) -> dict:
21582158

21592159
return params
21602160

2161-
def set_params(self, **params) -> None:
2162-
if "encoding_circuit_str" in params:
2163-
self._encoding_circuit_str = params["encoding_circuit_str"]
2164-
2165-
valid_params = self.get_params()
2166-
for key, value in params.items():
2167-
if key not in valid_params:
2168-
raise ValueError(
2169-
f"Invalid parameter {key!r}. "
2170-
f"Valid parameters are {sorted(valid_params)!r}."
2171-
)
2172-
2173-
if "num_features" in params:
2174-
self._num_features = params["num_features"]
2175-
2176-
if "num_qubits" in params:
2177-
self._num_qubits = params["num_qubits"]
2161+
def set_params(self, **params):
2162+
super().set_params(**params)
21782163

21792164
dict_layered_pqc = {}
21802165
for key in params.keys():

0 commit comments

Comments
 (0)