1616
1717import numpy as np
1818
19- from cirq import protocols , value
19+ from cirq import _compat , protocols , value
2020from cirq .ops import raw_types
2121
2222if TYPE_CHECKING :
@@ -40,6 +40,7 @@ def __init__(
4040 key : Union [str , 'cirq.MeasurementKey' ] = '' ,
4141 invert_mask : Tuple [bool , ...] = (),
4242 qid_shape : Tuple [int , ...] = None ,
43+ confusion_map : Optional [Dict [Tuple [int , ...], np .ndarray ]] = None ,
4344 ) -> None :
4445 """Inits MeasurementGate.
4546
@@ -52,10 +53,15 @@ def __init__(
5253 Qubits with indices past the end of the mask are not flipped.
5354 qid_shape: Specifies the dimension of each qid the measurement
5455 applies to. The default is 2 for every qubit.
56+ confusion_map: A map of qubit index sets (using indices in the
57+ operation generated from this gate) to the 2D confusion matrix
58+ for those qubits. Indices not included use the identity.
59+ Applied before invert_mask if both are provided.
5560
5661 Raises:
57- ValueError: If the length of invert_mask is greater than num_qubits.
58- or if the length of qid_shape doesn't equal num_qubits.
62+ ValueError: If invert_mask or confusion_map have indices
63+ greater than the available qubit indices, or if the length of
64+ qid_shape doesn't equal num_qubits.
5965 """
6066 if qid_shape is None :
6167 if num_qubits is None :
@@ -74,6 +80,9 @@ def __init__(
7480 self ._invert_mask = invert_mask or ()
7581 if self .invert_mask is not None and len (self .invert_mask ) > self .num_qubits ():
7682 raise ValueError ('len(invert_mask) > num_qubits' )
83+ self ._confusion_map = confusion_map or {}
84+ if any (x >= self .num_qubits () for idx in self ._confusion_map for x in idx ):
85+ raise ValueError ('Confusion matrices have index out of bounds.' )
7786
7887 @property
7988 def key (self ) -> str :
@@ -87,6 +96,10 @@ def mkey(self) -> 'cirq.MeasurementKey':
8796 def invert_mask (self ) -> Tuple [bool , ...]:
8897 return self ._invert_mask
8998
99+ @property
100+ def confusion_map (self ) -> Dict [Tuple [int , ...], np .ndarray ]:
101+ return self ._confusion_map
102+
90103 def _qid_shape_ (self ) -> Tuple [int , ...]:
91104 return self ._qid_shape
92105
@@ -98,7 +111,11 @@ def with_key(self, key: Union[str, 'cirq.MeasurementKey']) -> 'MeasurementGate':
98111 if key == self .key :
99112 return self
100113 return MeasurementGate (
101- self .num_qubits (), key = key , invert_mask = self .invert_mask , qid_shape = self ._qid_shape
114+ self .num_qubits (),
115+ key = key ,
116+ invert_mask = self .invert_mask ,
117+ qid_shape = self ._qid_shape ,
118+ confusion_map = self .confusion_map ,
102119 )
103120
104121 def _with_key_path_ (self , path : Tuple [str , ...]):
@@ -116,14 +133,22 @@ def _with_measurement_key_mapping_(self, key_map: Dict[str, str]):
116133 return self .with_key (protocols .with_measurement_key_mapping (self .mkey , key_map ))
117134
118135 def with_bits_flipped (self , * bit_positions : int ) -> 'MeasurementGate' :
119- """Toggles whether or not the measurement inverts various outputs."""
136+ """Toggles whether or not the measurement inverts various outputs.
137+
138+ This only affects the invert_mask, which is applied after confusion
139+ matrices if any are defined.
140+ """
120141 old_mask = self .invert_mask or ()
121142 n = max (len (old_mask ) - 1 , * bit_positions ) + 1
122143 new_mask = [k < len (old_mask ) and old_mask [k ] for k in range (n )]
123144 for b in bit_positions :
124145 new_mask [b ] = not new_mask [b ]
125146 return MeasurementGate (
126- self .num_qubits (), key = self .key , invert_mask = tuple (new_mask ), qid_shape = self ._qid_shape
147+ self .num_qubits (),
148+ key = self .key ,
149+ invert_mask = tuple (new_mask ),
150+ qid_shape = self ._qid_shape ,
151+ confusion_map = self .confusion_map ,
127152 )
128153
129154 def full_invert_mask (self ) -> Tuple [bool , ...]:
@@ -166,12 +191,17 @@ def _circuit_diagram_info_(
166191 self , args : 'cirq.CircuitDiagramInfoArgs'
167192 ) -> 'cirq.CircuitDiagramInfo' :
168193 symbols = ['M' ] * self .num_qubits ()
169-
170- # Show which output bits are negated.
171- if self .invert_mask :
172- for i , b in enumerate (self .invert_mask ):
173- if b :
174- symbols [i ] = '!M'
194+ flipped_indices = {i for i , x in enumerate (self .full_invert_mask ()) if x }
195+ confused_indices = {x for idxs in self .confusion_map for x in idxs }
196+
197+ # Show which output bits are negated and/or confused.
198+ for i in range (self .num_qubits ()):
199+ prefix = ''
200+ if i in flipped_indices :
201+ prefix += '!'
202+ if i in confused_indices :
203+ prefix += '?'
204+ symbols [i ] = prefix + symbols [i ]
175205
176206 # Mention the measurement key.
177207 label_map = args .label_map or {}
@@ -184,7 +214,7 @@ def _circuit_diagram_info_(
184214 return protocols .CircuitDiagramInfo (symbols )
185215
186216 def _qasm_ (self , args : 'cirq.QasmArgs' , qubits : Tuple ['cirq.Qid' , ...]) -> Optional [str ]:
187- if not all (d == 2 for d in self ._qid_shape ):
217+ if self . confusion_map or not all (d == 2 for d in self ._qid_shape ):
188218 return NotImplemented
189219 args .validate_version ('2.0' )
190220 invert_mask = self .invert_mask
@@ -202,7 +232,7 @@ def _qasm_(self, args: 'cirq.QasmArgs', qubits: Tuple['cirq.Qid', ...]) -> Optio
202232 def _quil_ (
203233 self , qubits : Tuple ['cirq.Qid' , ...], formatter : 'cirq.QuilFormatter'
204234 ) -> Optional [str ]:
205- if not all (d == 2 for d in self ._qid_shape ):
235+ if self . confusion_map or not all (d == 2 for d in self ._qid_shape ):
206236 return NotImplemented
207237 invert_mask = self .invert_mask
208238 if len (invert_mask ) < len (qubits ):
@@ -222,28 +252,39 @@ def _op_repr_(self, qubits: Sequence['cirq.Qid']) -> str:
222252 args .append (f'key={ self .mkey !r} ' )
223253 if self .invert_mask :
224254 args .append (f'invert_mask={ self .invert_mask !r} ' )
255+ if self .confusion_map :
256+ proper_map_str = ', ' .join (
257+ f"{ k !r} : { _compat .proper_repr (v )} " for k , v in self .confusion_map .items ()
258+ )
259+ args .append (f'confusion_map={{{ proper_map_str } }}' )
225260 arg_list = ', ' .join (args )
226261 return f'cirq.measure({ arg_list } )'
227262
228263 def __repr__ (self ):
229- qid_shape_arg = ''
264+ args = [ f' { self . num_qubits ()!r } ' , f' { self . mkey !r } ' , f' { self . invert_mask } ' ]
230265 if any (d != 2 for d in self ._qid_shape ):
231- qid_shape_arg = f', { self ._qid_shape !r} '
232- return (
233- f'cirq.MeasurementGate('
234- f'{ self .num_qubits ()!r} , '
235- f'{ self .mkey !r} , '
236- f'{ self .invert_mask } '
237- f'{ qid_shape_arg } )'
238- )
266+ args .append (f'qid_shape={ self ._qid_shape !r} ' )
267+ if self .confusion_map :
268+ proper_map_str = ', ' .join (
269+ f"{ k !r} : { _compat .proper_repr (v )} " for k , v in self .confusion_map .items ()
270+ )
271+ args .append (f'confusion_map={{{ proper_map_str } }}' )
272+ return f'cirq.MeasurementGate({ ", " .join (args )} )'
239273
240274 def _value_equality_values_ (self ) -> Any :
241- return self .key , self .invert_mask , self ._qid_shape
275+ hashable_cmap = frozenset (
276+ (idxs , tuple (v for _ , v in np .ndenumerate (cmap )))
277+ for idxs , cmap in self ._confusion_map .items ()
278+ )
279+ return self .key , self .invert_mask , self ._qid_shape , hashable_cmap
242280
243281 def _json_dict_ (self ) -> Dict [str , Any ]:
244- other = {}
282+ other : Dict [ str , Any ] = {}
245283 if not all (d == 2 for d in self ._qid_shape ):
246284 other ['qid_shape' ] = self ._qid_shape
285+ if self .confusion_map :
286+ json_cmap = [(k , v .tolist ()) for k , v in self .confusion_map .items ()]
287+ other ['confusion_map' ] = json_cmap
247288 return {
248289 'num_qubits' : len (self ._qid_shape ),
249290 'key' : self .key ,
@@ -252,12 +293,15 @@ def _json_dict_(self) -> Dict[str, Any]:
252293 }
253294
254295 @classmethod
255- def _from_json_dict_ (cls , num_qubits , key , invert_mask , qid_shape = None , ** kwargs ):
296+ def _from_json_dict_ (
297+ cls , num_qubits , key , invert_mask , qid_shape = None , confusion_map = None , ** kwargs
298+ ):
256299 return cls (
257300 num_qubits = num_qubits ,
258301 key = value .MeasurementKey .parse_serialized (key ),
259302 invert_mask = tuple (invert_mask ),
260303 qid_shape = None if qid_shape is None else tuple (qid_shape ),
304+ confusion_map = {tuple (k ): np .array (v ) for k , v in confusion_map or []},
261305 )
262306
263307 def _has_stabilizer_effect_ (self ) -> Optional [bool ]:
@@ -268,7 +312,7 @@ def _act_on_(self, sim_state: 'cirq.SimulationStateBase', qubits: Sequence['cirq
268312
269313 if not isinstance (sim_state , SimulationState ):
270314 return NotImplemented
271- sim_state .measure (qubits , self .key , self .full_invert_mask ())
315+ sim_state .measure (qubits , self .key , self .full_invert_mask (), self . confusion_map )
272316 return True
273317
274318
0 commit comments