From 05d967883ebd2a71ed64e9a955d86e67c4ccb961 Mon Sep 17 00:00:00 2001 From: Pierre Quinton Date: Sat, 29 Mar 2025 10:51:00 +0100 Subject: [PATCH 1/9] Make dependence on gramian explicit in AlignedMTL --- src/torchjd/aggregation/aligned_mtl.py | 8 +++----- 1 file changed, 3 insertions(+), 5 deletions(-) diff --git a/src/torchjd/aggregation/aligned_mtl.py b/src/torchjd/aggregation/aligned_mtl.py index 6f54a5adf..b707a3306 100644 --- a/src/torchjd/aggregation/aligned_mtl.py +++ b/src/torchjd/aggregation/aligned_mtl.py @@ -91,16 +91,14 @@ def __init__(self, weighting: _Weighting): def forward(self, matrix: Tensor) -> Tensor: w = self.weighting(matrix) - G = matrix.T - B = self._compute_balance_transformation(G) + M = matrix @ matrix.T + B = self._compute_balance_transformation(M) alpha = B @ w return alpha @staticmethod - def _compute_balance_transformation(G: Tensor) -> Tensor: - M = G.T @ G - + def _compute_balance_transformation(M: Tensor) -> Tensor: lambda_, V = torch.linalg.eigh(M, UPLO="U") # More modern equivalent to torch.symeig tol = torch.max(lambda_) * len(M) * torch.finfo().eps rank = sum(lambda_ > tol) From 1467f46ed649bda9698dafd7145c540c7ec2f236 Mon Sep 17 00:00:00 2001 From: Pierre Quinton Date: Sat, 29 Mar 2025 10:51:13 +0100 Subject: [PATCH 2/9] Make dependence on gramian explicit in CAGrad --- src/torchjd/aggregation/cagrad.py | 7 +++++-- 1 file changed, 5 insertions(+), 2 deletions(-) diff --git a/src/torchjd/aggregation/cagrad.py b/src/torchjd/aggregation/cagrad.py index c6c66788c..58a9385a4 100644 --- a/src/torchjd/aggregation/cagrad.py +++ b/src/torchjd/aggregation/cagrad.py @@ -73,12 +73,15 @@ def __init__(self, c: float, norm_eps: float): def forward(self, matrix: Tensor) -> Tensor: gramian = normalize(compute_gramian(matrix), self.norm_eps) + return self._compute_from_gramian(gramian) + + def _compute_from_gramian(self, gramian: Tensor) -> Tensor: U, S, _ = torch.svd(gramian) reduced_matrix = U @ S.sqrt().diag() reduced_array = reduced_matrix.cpu().detach().numpy().astype(np.float64) - dimension = matrix.shape[0] + dimension = gramian.shape[0] reduced_g_0 = reduced_array.T @ np.ones(dimension) / dimension sqrt_phi = self.c * np.linalg.norm(reduced_g_0, 2) @@ -97,6 +100,6 @@ def forward(self, matrix: Tensor) -> Tensor: # We are approximately on the pareto front weight_array = np.zeros(dimension) - weights = torch.from_numpy(weight_array).to(device=matrix.device, dtype=matrix.dtype) + weights = torch.from_numpy(weight_array).to(device=gramian.device, dtype=gramian.dtype) return weights From bc3a1f94855753edb5cb05d9fd04fc1ac8988606 Mon Sep 17 00:00:00 2001 From: Pierre Quinton Date: Sat, 29 Mar 2025 10:51:24 +0100 Subject: [PATCH 3/9] Make dependence on gramian explicit in IMTL-G --- src/torchjd/aggregation/imtl_g.py | 9 +++++++-- 1 file changed, 7 insertions(+), 2 deletions(-) diff --git a/src/torchjd/aggregation/imtl_g.py b/src/torchjd/aggregation/imtl_g.py index f4fb82322..1c8c5a7cd 100644 --- a/src/torchjd/aggregation/imtl_g.py +++ b/src/torchjd/aggregation/imtl_g.py @@ -38,8 +38,13 @@ class _IMTLGWeighting(_Weighting): """ def forward(self, matrix: Tensor) -> Tensor: - d = torch.linalg.norm(matrix, dim=1) - v = torch.linalg.pinv(matrix @ matrix.T) @ d + gramian = matrix @ matrix.T + return self._compute_from_gramian(gramian) + + @staticmethod + def _compute_from_gramian(gramian: Tensor) -> Tensor: + d = torch.sqrt(torch.diagonal(gramian)) + v = torch.linalg.pinv(gramian) @ d v_sum = v.sum() if v_sum.abs() < 1e-12: From b0ba0168bd473dfdc84e21780846bdb7ddb5b503 Mon Sep 17 00:00:00 2001 From: Pierre Quinton Date: Sat, 29 Mar 2025 10:51:36 +0100 Subject: [PATCH 4/9] Make dependence on gramian explicit in Krum --- src/torchjd/aggregation/krum.py | 16 ++++++++++++---- 1 file changed, 12 insertions(+), 4 deletions(-) diff --git a/src/torchjd/aggregation/krum.py b/src/torchjd/aggregation/krum.py index acbe54102..401eaf5b9 100644 --- a/src/torchjd/aggregation/krum.py +++ b/src/torchjd/aggregation/krum.py @@ -80,16 +80,24 @@ def __init__(self, n_byzantine: int, n_selected: int): def forward(self, matrix: Tensor) -> Tensor: self._check_matrix_shape(matrix) + gramian = matrix @ matrix.T + return self._compute_from_gramian(gramian) - distances = torch.cdist(matrix, matrix, compute_mode="donot_use_mm_for_euclid_dist") - n_closest = matrix.shape[0] - self.n_byzantine - 2 + def _compute_from_gramian(self, gramian: Tensor) -> Tensor: + gradient_norms_squared = torch.diagonal(gramian) + distances_squared = ( + gradient_norms_squared.unsqueeze(0) + gradient_norms_squared.unsqueeze(1) - 2 * gramian + ) + distances = torch.sqrt(distances_squared) + + n_closest = gramian.shape[0] - self.n_byzantine - 2 smallest_distances, _ = torch.topk(distances, k=n_closest + 1, largest=False) smallest_distances_excluding_self = smallest_distances[:, 1:] scores = smallest_distances_excluding_self.sum(dim=1) _, selected_indices = torch.topk(scores, k=self.n_selected, largest=False) - one_hot_selected_indices = F.one_hot(selected_indices, num_classes=matrix.shape[0]) - weights = one_hot_selected_indices.sum(dim=0).to(dtype=matrix.dtype) / self.n_selected + one_hot_selected_indices = F.one_hot(selected_indices, num_classes=gramian.shape[0]) + weights = one_hot_selected_indices.sum(dim=0).to(dtype=gramian.dtype) / self.n_selected return weights From 9696c0d6cdeb9883c7cb4d95ccbbdc39d494024c Mon Sep 17 00:00:00 2001 From: Pierre Quinton Date: Sat, 29 Mar 2025 10:51:45 +0100 Subject: [PATCH 5/9] Make dependence on gramian explicit in MGDA --- src/torchjd/aggregation/mgda.py | 14 +++++++------- 1 file changed, 7 insertions(+), 7 deletions(-) diff --git a/src/torchjd/aggregation/mgda.py b/src/torchjd/aggregation/mgda.py index f2548ce2f..abb21e8b1 100644 --- a/src/torchjd/aggregation/mgda.py +++ b/src/torchjd/aggregation/mgda.py @@ -56,15 +56,14 @@ def __init__(self, epsilon: float, max_iters: int): self.epsilon = epsilon self.max_iters = max_iters - def _frank_wolfe_solver(self, matrix: Tensor) -> Tensor: - gramian = compute_gramian(matrix) - device = matrix.device - dtype = matrix.dtype + def _frank_wolfe_solver(self, gramian: Tensor) -> Tensor: + device = gramian.device + dtype = gramian.dtype - alpha = torch.ones(matrix.shape[0], device=device, dtype=dtype) / matrix.shape[0] + alpha = torch.ones(gramian.shape[0], device=device, dtype=dtype) / gramian.shape[0] for i in range(self.max_iters): t = torch.argmin(gramian @ alpha) - e_t = torch.zeros(matrix.shape[0], device=device, dtype=dtype) + e_t = torch.zeros(gramian.shape[0], device=device, dtype=dtype) e_t[t] = 1.0 a = alpha @ (gramian @ e_t) b = alpha @ (gramian @ alpha) @@ -81,5 +80,6 @@ def _frank_wolfe_solver(self, matrix: Tensor) -> Tensor: return alpha def forward(self, matrix: Tensor) -> Tensor: - weights = self._frank_wolfe_solver(matrix) + gramian = compute_gramian(matrix) + weights = self._frank_wolfe_solver(gramian) return weights From e44d73873db1d439f3706daa75be5c6275f73595 Mon Sep 17 00:00:00 2001 From: Pierre Quinton Date: Sat, 29 Mar 2025 10:51:53 +0100 Subject: [PATCH 6/9] Make dependence on gramian explicit in PCGrad --- src/torchjd/aggregation/pcgrad.py | 9 ++++++--- 1 file changed, 6 insertions(+), 3 deletions(-) diff --git a/src/torchjd/aggregation/pcgrad.py b/src/torchjd/aggregation/pcgrad.py index c110aa98e..f340c7cb2 100644 --- a/src/torchjd/aggregation/pcgrad.py +++ b/src/torchjd/aggregation/pcgrad.py @@ -41,11 +41,14 @@ class _PCGradWeighting(_Weighting): def forward(self, matrix: Tensor) -> Tensor: # Pre-compute the inner products - inner_products = matrix @ matrix.T + gramian = matrix @ matrix.T + return self._compute_from_gramian(gramian) + @staticmethod + def _compute_from_gramian(inner_products: Tensor) -> Tensor: # Move all computations on cpu to avoid moving memory between cpu and gpu at each iteration - device = matrix.device - dtype = matrix.dtype + device = inner_products.device + dtype = inner_products.dtype cpu = torch.device("cpu") inner_products = inner_products.to(device=cpu) From 21a120067ace9428ded01fafbc8956d29f2ef694 Mon Sep 17 00:00:00 2001 From: Pierre Quinton Date: Fri, 25 Apr 2025 14:13:24 +0200 Subject: [PATCH 7/9] Use compute_gramian and rename inner_products to gramian --- src/torchjd/aggregation/aligned_mtl.py | 3 ++- src/torchjd/aggregation/imtl_g.py | 3 ++- src/torchjd/aggregation/krum.py | 3 ++- src/torchjd/aggregation/pcgrad.py | 17 +++++++++-------- 4 files changed, 15 insertions(+), 11 deletions(-) diff --git a/src/torchjd/aggregation/aligned_mtl.py b/src/torchjd/aggregation/aligned_mtl.py index b707a3306..e8df27fe9 100644 --- a/src/torchjd/aggregation/aligned_mtl.py +++ b/src/torchjd/aggregation/aligned_mtl.py @@ -28,6 +28,7 @@ import torch from torch import Tensor +from ._gramian_utils import compute_gramian from ._pref_vector_utils import pref_vector_to_str_suffix, pref_vector_to_weighting from .bases import _WeightedAggregator, _Weighting from .mean import _MeanWeighting @@ -91,7 +92,7 @@ def __init__(self, weighting: _Weighting): def forward(self, matrix: Tensor) -> Tensor: w = self.weighting(matrix) - M = matrix @ matrix.T + M = compute_gramian(matrix) B = self._compute_balance_transformation(M) alpha = B @ w diff --git a/src/torchjd/aggregation/imtl_g.py b/src/torchjd/aggregation/imtl_g.py index 1c8c5a7cd..b807df6d0 100644 --- a/src/torchjd/aggregation/imtl_g.py +++ b/src/torchjd/aggregation/imtl_g.py @@ -1,6 +1,7 @@ import torch from torch import Tensor +from ._gramian_utils import compute_gramian from .bases import _WeightedAggregator, _Weighting @@ -38,7 +39,7 @@ class _IMTLGWeighting(_Weighting): """ def forward(self, matrix: Tensor) -> Tensor: - gramian = matrix @ matrix.T + gramian = compute_gramian(matrix) return self._compute_from_gramian(gramian) @staticmethod diff --git a/src/torchjd/aggregation/krum.py b/src/torchjd/aggregation/krum.py index 401eaf5b9..8e39045e0 100644 --- a/src/torchjd/aggregation/krum.py +++ b/src/torchjd/aggregation/krum.py @@ -2,6 +2,7 @@ from torch import Tensor from torch.nn import functional as F +from ._gramian_utils import compute_gramian from .bases import _WeightedAggregator, _Weighting @@ -80,7 +81,7 @@ def __init__(self, n_byzantine: int, n_selected: int): def forward(self, matrix: Tensor) -> Tensor: self._check_matrix_shape(matrix) - gramian = matrix @ matrix.T + gramian = compute_gramian(matrix) return self._compute_from_gramian(gramian) def _compute_from_gramian(self, gramian: Tensor) -> Tensor: diff --git a/src/torchjd/aggregation/pcgrad.py b/src/torchjd/aggregation/pcgrad.py index f340c7cb2..1b448fe32 100644 --- a/src/torchjd/aggregation/pcgrad.py +++ b/src/torchjd/aggregation/pcgrad.py @@ -1,6 +1,7 @@ import torch from torch import Tensor +from ._gramian_utils import compute_gramian from .bases import _WeightedAggregator, _Weighting @@ -41,18 +42,18 @@ class _PCGradWeighting(_Weighting): def forward(self, matrix: Tensor) -> Tensor: # Pre-compute the inner products - gramian = matrix @ matrix.T + gramian = compute_gramian(matrix) return self._compute_from_gramian(gramian) @staticmethod - def _compute_from_gramian(inner_products: Tensor) -> Tensor: + def _compute_from_gramian(gramian: Tensor) -> Tensor: # Move all computations on cpu to avoid moving memory between cpu and gpu at each iteration - device = inner_products.device - dtype = inner_products.dtype + device = gramian.device + dtype = gramian.dtype cpu = torch.device("cpu") - inner_products = inner_products.to(device=cpu) + gramian = gramian.to(device=cpu) - dimension = inner_products.shape[0] + dimension = gramian.shape[0] weights = torch.zeros(dimension, device=cpu, dtype=dtype) for i in range(dimension): @@ -65,10 +66,10 @@ def _compute_from_gramian(inner_products: Tensor) -> Tensor: continue # Compute the inner product between g_i^{PC} and g_j - inner_product = inner_products[j] @ current_weights + inner_product = gramian[j] @ current_weights if inner_product < 0.0: - current_weights[j] -= inner_product / (inner_products[j, j]) + current_weights[j] -= inner_product / (gramian[j, j]) weights = weights + current_weights From 6f8f3c8edef05ff4e2e548d7cd53f9ad08cb85d7 Mon Sep 17 00:00:00 2001 From: Pierre Quinton Date: Sat, 26 Apr 2025 09:17:25 +0200 Subject: [PATCH 8/9] Move normalization of gramian into _from_gramian --- src/torchjd/aggregation/cagrad.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/src/torchjd/aggregation/cagrad.py b/src/torchjd/aggregation/cagrad.py index 58a9385a4..93075fc0c 100644 --- a/src/torchjd/aggregation/cagrad.py +++ b/src/torchjd/aggregation/cagrad.py @@ -72,11 +72,11 @@ def __init__(self, c: float, norm_eps: float): self.norm_eps = norm_eps def forward(self, matrix: Tensor) -> Tensor: - gramian = normalize(compute_gramian(matrix), self.norm_eps) + gramian = compute_gramian(matrix) return self._compute_from_gramian(gramian) def _compute_from_gramian(self, gramian: Tensor) -> Tensor: - U, S, _ = torch.svd(gramian) + U, S, _ = torch.svd(normalize(gramian, self.norm_eps)) reduced_matrix = U @ S.sqrt().diag() reduced_array = reduced_matrix.cpu().detach().numpy().astype(np.float64) From 19a2a4dd19f85671ba80c2c89357c6c4d4146c99 Mon Sep 17 00:00:00 2001 From: Pierre Quinton Date: Sat, 26 Apr 2025 09:20:32 +0200 Subject: [PATCH 9/9] Rename Frank-Wolfe solver into _from_gramian --- src/torchjd/aggregation/mgda.py | 9 +++++++-- 1 file changed, 7 insertions(+), 2 deletions(-) diff --git a/src/torchjd/aggregation/mgda.py b/src/torchjd/aggregation/mgda.py index abb21e8b1..2a2f05e18 100644 --- a/src/torchjd/aggregation/mgda.py +++ b/src/torchjd/aggregation/mgda.py @@ -56,7 +56,12 @@ def __init__(self, epsilon: float, max_iters: int): self.epsilon = epsilon self.max_iters = max_iters - def _frank_wolfe_solver(self, gramian: Tensor) -> Tensor: + def _compute_from_gramian(self, gramian: Tensor) -> Tensor: + """ + This is the Frank-Wolfe solver in Algorithm 2 of `Multi-Task Learning as Multi-Objective + Optimization + `_. + """ device = gramian.device dtype = gramian.dtype @@ -81,5 +86,5 @@ def _frank_wolfe_solver(self, gramian: Tensor) -> Tensor: def forward(self, matrix: Tensor) -> Tensor: gramian = compute_gramian(matrix) - weights = self._frank_wolfe_solver(gramian) + weights = self._compute_from_gramian(gramian) return weights