Skip to content
Merged
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
9 changes: 6 additions & 3 deletions src/torchjd/aggregation/pcgrad.py
Original file line number Diff line number Diff line change
Expand Up @@ -41,11 +41,14 @@ class _PCGradWeighting(_Weighting):

def forward(self, matrix: Tensor) -> Tensor:
# Pre-compute the inner products
inner_products = matrix @ matrix.T
gramian = matrix @ matrix.T
Comment thread
PierreQuinton marked this conversation as resolved.
Outdated
return self._compute_from_gramian(gramian)

@staticmethod
def _compute_from_gramian(inner_products: Tensor) -> Tensor:
Comment thread
PierreQuinton marked this conversation as resolved.
Outdated
# Move all computations on cpu to avoid moving memory between cpu and gpu at each iteration
device = matrix.device
dtype = matrix.dtype
device = inner_products.device
dtype = inner_products.dtype
cpu = torch.device("cpu")
inner_products = inner_products.to(device=cpu)

Expand Down