We read every piece of feedback, and take your input very seriously.
To see all available qualifiers, see our documentation.
_AggregateMatrices
1 parent fe1686b commit 3f78d6dCopy full SHA for 3f78d6d
1 file changed
src/torchjd/autojac/_transform/aggregate.py
@@ -32,7 +32,7 @@ def check_keys(self, input_keys: set[Tensor]) -> set[Tensor]:
32
33
class _AggregateMatrices(Transform[JacobianMatrices, GradientVectors]):
34
def __init__(self, aggregator: Aggregator, key_order: OrderedSet[Tensor]):
35
- self.key_order = OrderedSet(key_order)
+ self.key_order = key_order
36
self.aggregator = aggregator
37
38
def __call__(self, jacobian_matrices: JacobianMatrices) -> GradientVectors:
0 commit comments