diff --git a/src/openfermion/contrib/representability/_multitensor.py b/src/openfermion/contrib/representability/_multitensor.py index e63b0805b..53f77e13d 100644 --- a/src/openfermion/contrib/representability/_multitensor.py +++ b/src/openfermion/contrib/representability/_multitensor.py @@ -23,7 +23,7 @@ def __iter__(self): class MultiTensor(object): - def __init__(self, tensors, dual_basis=DualBasis()): + def __init__(self, tensors, dual_basis=None): """ A collection of tensor objects with maps from name to tensor @@ -46,6 +46,8 @@ def __init__(self, tensors, dual_basis=DualBasis()): self.off_set_map = self.make_offset_dict(self.tensors) # An iterable object that provides access to the dual basis elements + if dual_basis is None: + dual_basis = DualBasis() self.dual_basis = dual_basis self.vec_dim = sum([vec.size for vec in self.tensors]) @@ -81,8 +83,7 @@ def add_dual_elements(self, dual_element): if not isinstance(dual_element, DualBasisElement): raise TypeError("dual_element variable needs to be a DualBasisElement type") - # we should extend TMap to add - self.dual_basis.elements.extend(dual_element) + self.dual_basis.elements.append(dual_element) def synthesize_dual_basis(self): """ @@ -93,7 +94,7 @@ def synthesize_dual_basis(self): :returns: sparse matrix """ - # go throught the dual basis list and synthesize each element + # go through the dual basis list and synthesize each element dual_row_indices = [] dual_col_indices = [] dual_data_values = [] diff --git a/src/openfermion/contrib/representability/_multitensor_test.py b/src/openfermion/contrib/representability/_multitensor_test.py index 620fa07e3..40fb398c1 100644 --- a/src/openfermion/contrib/representability/_multitensor_test.py +++ b/src/openfermion/contrib/representability/_multitensor_test.py @@ -94,6 +94,33 @@ def test_add_dualelement(): mt.add_dual_elements(dbe) assert len(mt.dual_basis) == 1 + dbe2 = DualBasisElement() + dbe2.add_element('b', (0, 1, 2), 5) + mt.add_dual_elements(dbe2) + assert len(mt.dual_basis) == 2 + + A, bias, scalar = mt.synthesize_dual_basis() + assert A.shape[0] == 2 + # Verify that the elements are correctly added as DualBasisElements + # and not as their internal tuples (which would cause synthesis to fail). + assert isinstance(mt.dual_basis[0], DualBasisElement) + assert isinstance(mt.dual_basis[1], DualBasisElement) + + +def test_multitensor_init_isolation(): + # Test that different MultiTensor instances don't share the same dual_basis. + a = np.random.random((2, 2)) + at = Tensor(tensor=a, name='a') + mt1 = MultiTensor([at]) + mt2 = MultiTensor([at]) + + dbe = DualBasisElement() + dbe.add_element('a', (0, 0), 1.0) + mt1.add_dual_elements(dbe) + + assert len(mt1.dual_basis) == 1 + assert len(mt2.dual_basis) == 0 + def test_synthesis_element(): a = np.random.random((5, 5))