Skip to content

Commit 586b3f6

Browse files
authored
Change a list of if statements into match (#7204)
* Change a list of if statements into match - Match is supported as of python 3.10. - This changes a chain of if statements in arg_func_langs to a match statement. * Fix coverage
1 parent dbde5ab commit 586b3f6

2 files changed

Lines changed: 64 additions & 56 deletions

File tree

cirq-google/cirq_google/serialization/arg_func_langs.py

Lines changed: 58 additions & 56 deletions
Original file line numberDiff line numberDiff line change
@@ -224,26 +224,26 @@ def float_arg_from_proto(
224224
ValueError: If the float arg proto is invalid.
225225
"""
226226
which = arg_proto.WhichOneof('arg')
227-
if which == 'float_value':
228-
result = float(arg_proto.float_value)
229-
if round(result) == result:
230-
result = int(result)
231-
return result
232-
elif which == 'symbol':
233-
return sympy.Symbol(arg_proto.symbol)
234-
elif which == 'func':
235-
func = _arg_func_from_proto(arg_proto.func, required_arg_name=required_arg_name)
236-
if func is None and required_arg_name is not None:
237-
raise ValueError(
238-
f'Arg {arg_proto.func} could not be processed for {required_arg_name}.'
239-
)
240-
return cast(FLOAT_ARG_LIKE, func)
241-
elif which is None:
242-
if required_arg_name is not None:
243-
raise ValueError(f'Arg {required_arg_name} is missing.')
244-
return None
245-
else:
246-
raise ValueError(f'unrecognized argument type ({which}).')
227+
match which:
228+
case 'float_value':
229+
result = float(arg_proto.float_value)
230+
if round(result) == result:
231+
result = int(result)
232+
return result
233+
case 'symbol':
234+
return sympy.Symbol(arg_proto.symbol)
235+
case 'func':
236+
func = _arg_func_from_proto(arg_proto.func, required_arg_name=required_arg_name)
237+
if func is None and required_arg_name is not None:
238+
raise ValueError( # pragma: nocover
239+
f'Arg {arg_proto.func} could not be processed for {required_arg_name}.'
240+
)
241+
return cast(FLOAT_ARG_LIKE, func)
242+
case None:
243+
if required_arg_name is not None:
244+
raise ValueError(f'Arg {required_arg_name} is missing.')
245+
return None
246+
raise ValueError(f'unrecognized argument type ({which}).')
247247

248248

249249
def arg_from_proto(
@@ -268,42 +268,44 @@ def arg_from_proto(
268268
"""
269269

270270
which = arg_proto.WhichOneof('arg')
271-
if which == 'arg_value':
272-
arg_value = arg_proto.arg_value
273-
which_val = arg_value.WhichOneof('arg_value')
274-
if which_val == 'float_value' or which_val == 'double_value':
275-
if which_val == 'double_value':
276-
result = float(arg_value.double_value)
277-
else:
278-
result = float(arg_value.float_value)
279-
if math.ceil(result) == math.floor(result):
280-
result = int(result)
281-
return result
282-
if which_val == 'bool_value':
283-
return bool(arg_value.bool_value)
284-
if which_val == 'bool_values':
285-
return list(arg_value.bool_values.values)
286-
if which_val == 'string_value':
287-
return str(arg_value.string_value)
288-
if which_val == 'int64_values':
289-
return [int(v) for v in arg_value.int64_values.values]
290-
if which_val == 'double_values':
291-
return [float(v) for v in arg_value.double_values.values]
292-
if which_val == 'string_values':
293-
return [str(v) for v in arg_value.string_values.values]
294-
if which_val == 'value_with_unit':
295-
return tunits.Value.from_proto(arg_value.value_with_unit)
296-
if which_val == 'bytes_value':
297-
return bytes(arg_value.bytes_value)
298-
raise ValueError(f'Unrecognized value type: {which_val!r}')
299-
300-
if which == 'symbol':
301-
return sympy.Symbol(arg_proto.symbol)
302-
303-
if which == 'func':
304-
func = _arg_func_from_proto(arg_proto.func, required_arg_name=required_arg_name)
305-
if func is not None:
306-
return func
271+
match which:
272+
case 'arg_value':
273+
arg_value = arg_proto.arg_value
274+
which_val = arg_value.WhichOneof('arg_value')
275+
match which_val:
276+
case 'float_value':
277+
result = float(arg_value.float_value)
278+
if math.ceil(result) == math.floor(result):
279+
return int(result)
280+
return result
281+
case 'double_value':
282+
result = float(arg_value.double_value)
283+
if math.ceil(result) == math.floor(result):
284+
return int(result)
285+
return result
286+
case 'bool_value':
287+
return bool(arg_value.bool_value)
288+
case 'bool_values':
289+
return list(arg_value.bool_values.values)
290+
case 'string_value':
291+
return str(arg_value.string_value)
292+
case 'int64_values':
293+
return [int(v) for v in arg_value.int64_values.values]
294+
case 'double_values':
295+
return [float(v) for v in arg_value.double_values.values]
296+
case 'string_values':
297+
return [str(v) for v in arg_value.string_values.values]
298+
case 'value_with_unit':
299+
return tunits.Value.from_proto(arg_value.value_with_unit)
300+
case 'bytes_value':
301+
return bytes(arg_value.bytes_value)
302+
raise ValueError(f'Unrecognized value type: {which_val!r}') # pragma: nocover
303+
case 'symbol':
304+
return sympy.Symbol(arg_proto.symbol)
305+
case 'func':
306+
func = _arg_func_from_proto(arg_proto.func, required_arg_name=required_arg_name)
307+
if func is not None:
308+
return func
307309

308310
if required_arg_name is not None:
309311
raise ValueError(

cirq-google/cirq_google/serialization/arg_func_langs_test.py

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -55,6 +55,7 @@ def _json_format_kwargs() -> Dict[str, bool]:
5555
'value,proto',
5656
[
5757
(1.0, {'arg_value': {'float_value': 1.0}}),
58+
(1.5, {'arg_value': {'float_value': 1.5}}),
5859
(1, {'arg_value': {'float_value': 1.0}}),
5960
(b'abcdef', {'arg_value': {'bytes_value': base64.b64encode(b'abcdef').decode("ascii")}}),
6061
('abc', {'arg_value': {'string_value': 'abc'}}),
@@ -111,6 +112,10 @@ def test_double_value():
111112
msg.arg_value.double_value = 1.0
112113
parsed = arg_from_proto(msg)
113114
assert parsed == 1
115+
msg = v2.program_pb2.Arg()
116+
msg.arg_value.double_value = 1.5
117+
parsed = arg_from_proto(msg)
118+
assert parsed == 1.5
114119

115120

116121
def test_serialize_sympy_constants():
@@ -177,6 +182,7 @@ def test_missing_required_arg():
177182
with pytest.raises(ValueError, match='unrecognized argument type'):
178183
_ = arg_from_proto(v2.program_pb2.Arg(), required_arg_name='blah')
179184
assert arg_from_proto(v2.program_pb2.Arg()) is None
185+
assert float_arg_from_proto(v2.program_pb2.FloatArg()) is None
180186

181187

182188
def test_invalid_float_arg():

0 commit comments

Comments
 (0)