Skip to content
Merged
Show file tree
Hide file tree
Changes from 6 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
8 changes: 3 additions & 5 deletions src/torchjd/aggregation/aligned_mtl.py
Comment thread
ValerianRey marked this conversation as resolved.
Original file line number Diff line number Diff line change
Expand Up @@ -91,16 +91,14 @@ def __init__(self, weighting: _Weighting):
def forward(self, matrix: Tensor) -> Tensor:
w = self.weighting(matrix)

G = matrix.T
B = self._compute_balance_transformation(G)
M = matrix @ matrix.T
Comment thread
PierreQuinton marked this conversation as resolved.
Outdated
B = self._compute_balance_transformation(M)
alpha = B @ w

return alpha

@staticmethod
def _compute_balance_transformation(G: Tensor) -> Tensor:
M = G.T @ G

def _compute_balance_transformation(M: Tensor) -> Tensor:
lambda_, V = torch.linalg.eigh(M, UPLO="U") # More modern equivalent to torch.symeig
tol = torch.max(lambda_) * len(M) * torch.finfo().eps
rank = sum(lambda_ > tol)
Expand Down
7 changes: 5 additions & 2 deletions src/torchjd/aggregation/cagrad.py
Original file line number Diff line number Diff line change
Expand Up @@ -73,12 +73,15 @@ def __init__(self, c: float, norm_eps: float):

def forward(self, matrix: Tensor) -> Tensor:
gramian = normalize(compute_gramian(matrix), self.norm_eps)
return self._compute_from_gramian(gramian)
Comment thread
PierreQuinton marked this conversation as resolved.

def _compute_from_gramian(self, gramian: Tensor) -> Tensor:
U, S, _ = torch.svd(gramian)

reduced_matrix = U @ S.sqrt().diag()
reduced_array = reduced_matrix.cpu().detach().numpy().astype(np.float64)

dimension = matrix.shape[0]
dimension = gramian.shape[0]
reduced_g_0 = reduced_array.T @ np.ones(dimension) / dimension
sqrt_phi = self.c * np.linalg.norm(reduced_g_0, 2)

Expand All @@ -97,6 +100,6 @@ def forward(self, matrix: Tensor) -> Tensor:
# We are approximately on the pareto front
weight_array = np.zeros(dimension)

weights = torch.from_numpy(weight_array).to(device=matrix.device, dtype=matrix.dtype)
weights = torch.from_numpy(weight_array).to(device=gramian.device, dtype=gramian.dtype)

return weights
9 changes: 7 additions & 2 deletions src/torchjd/aggregation/imtl_g.py
Original file line number Diff line number Diff line change
Expand Up @@ -38,8 +38,13 @@ class _IMTLGWeighting(_Weighting):
"""

def forward(self, matrix: Tensor) -> Tensor:
d = torch.linalg.norm(matrix, dim=1)
v = torch.linalg.pinv(matrix @ matrix.T) @ d
gramian = matrix @ matrix.T
Comment thread
PierreQuinton marked this conversation as resolved.
Outdated
return self._compute_from_gramian(gramian)

@staticmethod
def _compute_from_gramian(gramian: Tensor) -> Tensor:
d = torch.sqrt(torch.diagonal(gramian))
v = torch.linalg.pinv(gramian) @ d
v_sum = v.sum()

if v_sum.abs() < 1e-12:
Expand Down
16 changes: 12 additions & 4 deletions src/torchjd/aggregation/krum.py
Comment thread
ValerianRey marked this conversation as resolved.
Original file line number Diff line number Diff line change
Expand Up @@ -80,16 +80,24 @@ def __init__(self, n_byzantine: int, n_selected: int):

def forward(self, matrix: Tensor) -> Tensor:
self._check_matrix_shape(matrix)
gramian = matrix @ matrix.T
Comment thread
PierreQuinton marked this conversation as resolved.
Outdated
return self._compute_from_gramian(gramian)

distances = torch.cdist(matrix, matrix, compute_mode="donot_use_mm_for_euclid_dist")
n_closest = matrix.shape[0] - self.n_byzantine - 2
def _compute_from_gramian(self, gramian: Tensor) -> Tensor:
gradient_norms_squared = torch.diagonal(gramian)
distances_squared = (
gradient_norms_squared.unsqueeze(0) + gradient_norms_squared.unsqueeze(1) - 2 * gramian
)
distances = torch.sqrt(distances_squared)

n_closest = gramian.shape[0] - self.n_byzantine - 2
smallest_distances, _ = torch.topk(distances, k=n_closest + 1, largest=False)
smallest_distances_excluding_self = smallest_distances[:, 1:]
scores = smallest_distances_excluding_self.sum(dim=1)

_, selected_indices = torch.topk(scores, k=self.n_selected, largest=False)
one_hot_selected_indices = F.one_hot(selected_indices, num_classes=matrix.shape[0])
weights = one_hot_selected_indices.sum(dim=0).to(dtype=matrix.dtype) / self.n_selected
one_hot_selected_indices = F.one_hot(selected_indices, num_classes=gramian.shape[0])
weights = one_hot_selected_indices.sum(dim=0).to(dtype=gramian.dtype) / self.n_selected

return weights

Expand Down
14 changes: 7 additions & 7 deletions src/torchjd/aggregation/mgda.py
Original file line number Diff line number Diff line change
Expand Up @@ -56,15 +56,14 @@ def __init__(self, epsilon: float, max_iters: int):
self.epsilon = epsilon
self.max_iters = max_iters

def _frank_wolfe_solver(self, matrix: Tensor) -> Tensor:
gramian = compute_gramian(matrix)
device = matrix.device
dtype = matrix.dtype
def _frank_wolfe_solver(self, gramian: Tensor) -> Tensor:
device = gramian.device
dtype = gramian.dtype

alpha = torch.ones(matrix.shape[0], device=device, dtype=dtype) / matrix.shape[0]
alpha = torch.ones(gramian.shape[0], device=device, dtype=dtype) / gramian.shape[0]
for i in range(self.max_iters):
t = torch.argmin(gramian @ alpha)
e_t = torch.zeros(matrix.shape[0], device=device, dtype=dtype)
e_t = torch.zeros(gramian.shape[0], device=device, dtype=dtype)
e_t[t] = 1.0
a = alpha @ (gramian @ e_t)
b = alpha @ (gramian @ alpha)
Expand All @@ -81,5 +80,6 @@ def _frank_wolfe_solver(self, matrix: Tensor) -> Tensor:
return alpha

def forward(self, matrix: Tensor) -> Tensor:
weights = self._frank_wolfe_solver(matrix)
gramian = compute_gramian(matrix)
weights = self._frank_wolfe_solver(gramian)
Comment thread
PierreQuinton marked this conversation as resolved.
Outdated
return weights
9 changes: 6 additions & 3 deletions src/torchjd/aggregation/pcgrad.py
Original file line number Diff line number Diff line change
Expand Up @@ -41,11 +41,14 @@ class _PCGradWeighting(_Weighting):

def forward(self, matrix: Tensor) -> Tensor:
# Pre-compute the inner products
inner_products = matrix @ matrix.T
gramian = matrix @ matrix.T
Comment thread
PierreQuinton marked this conversation as resolved.
Outdated
return self._compute_from_gramian(gramian)

@staticmethod
def _compute_from_gramian(inner_products: Tensor) -> Tensor:
Comment thread
PierreQuinton marked this conversation as resolved.
Outdated
# Move all computations on cpu to avoid moving memory between cpu and gpu at each iteration
device = matrix.device
dtype = matrix.dtype
device = inner_products.device
dtype = inner_products.dtype
cpu = torch.device("cpu")
inner_products = inner_products.to(device=cpu)

Expand Down