@@ -397,11 +397,12 @@ def flatten_once(
397397 # pylint: disable=protected-access
398398 bb ._i = max (binst .i for binst in self .bloq_instances ) + 1
399399
400- soq_map : List [Tuple [SoquetT , SoquetT ]] = []
400+ # soq_map: List[Tuple[SoquetT, SoquetT]] = []
401+ flat_soq_map : Dict [Soquet , Soquet ] = {}
401402 new_out_soqs : Tuple [SoquetT , ...]
402403 did_work = False
403404 for binst , in_soqs , old_out_soqs in self .iter_bloqsoqs ():
404- in_soqs = _map_soqs (in_soqs , soq_map ) # update `in_soqs` from old to new.
405+ in_soqs = _map_flat_soqs (in_soqs , flat_soq_map ) # update `in_soqs` from old to new.
405406 if pred (binst ):
406407 try :
407408 new_out_soqs = bb .add_from (binst .bloq , ** in_soqs )
@@ -416,12 +417,13 @@ def flatten_once(
416417 # pylint: disable=protected-access
417418 new_out_soqs = tuple (soq for _ , soq in bb ._add_binst (binst , in_soqs = in_soqs ))
418419
419- soq_map .extend (zip (old_out_soqs , new_out_soqs ))
420+ _update_flat_soq_map (zip (old_out_soqs , new_out_soqs ), flat_soq_map )
421+ # soq_map.extend(zip(old_out_soqs, new_out_soqs))
420422
421423 if not did_work :
422424 raise DidNotFlattenAnythingError ()
423425
424- fsoqs = _map_soqs (self .final_soqs (), soq_map )
426+ fsoqs = _map_flat_soqs (self .final_soqs (), flat_soq_map )
425427 return bb .finalize (** fsoqs )
426428
427429 def flatten (
@@ -865,6 +867,40 @@ def _map_soqs(soqs: SoquetT) -> SoquetT:
865867 return {name : _map_soqs (soqs ) for name , soqs in soqs .items ()}
866868
867869
870+ def _map_flat_soqs (
871+ soqs : Dict [str , SoquetT ], flat_soq_map : Dict [Soquet , Soquet ]
872+ ) -> Dict [str , SoquetT ]:
873+
874+ # use vectorize to use the flat mapping.
875+ def _map_soq (soq : Soquet ) -> Soquet :
876+ # Helper function to map an individual soquet.
877+ return flat_soq_map .get (soq , soq )
878+
879+ # Use `vectorize` to call `_map_soq` on each element of the array.
880+ vmap = np .vectorize (_map_soq , otypes = [object ])
881+
882+ def _map_soqs (soqs : SoquetT ) -> SoquetT :
883+ if isinstance (soqs , Soquet ):
884+ return _map_soq (soqs )
885+ return vmap (soqs )
886+
887+ return {name : _map_soqs (soqs ) for name , soqs in soqs .items ()}
888+
889+
890+ def _update_flat_soq_map (soq_map : Iterable [Tuple [SoquetT , SoquetT ]], flat_soq_map ):
891+ for old_soqs , new_soqs in soq_map :
892+ if isinstance (old_soqs , Soquet ):
893+ assert isinstance (new_soqs , Soquet ), new_soqs
894+ flat_soq_map [old_soqs ] = new_soqs
895+ continue
896+
897+ assert isinstance (old_soqs , np .ndarray ), old_soqs
898+ assert isinstance (new_soqs , np .ndarray ), new_soqs
899+ assert old_soqs .shape == new_soqs .shape , (old_soqs .shape , new_soqs .shape )
900+ for o , n in zip (old_soqs .reshape (- 1 ), new_soqs .reshape (- 1 )):
901+ flat_soq_map [o ] = n
902+
903+
868904class BloqBuilder :
869905 """A builder class for composing bloqs.
870906
0 commit comments