-
Notifications
You must be signed in to change notification settings - Fork 15
Expand file tree
/
Copy pathkrum.py
More file actions
116 lines (93 loc) · 4.38 KB
/
krum.py
File metadata and controls
116 lines (93 loc) · 4.38 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
import torch
from torch import Tensor
from torch.nn import functional as F
from .bases import _WeightedAggregator, _Weighting
class Krum(_WeightedAggregator):
"""
:class:`~torchjd.aggregation.bases.Aggregator` for adversarial federated learning, as defined
in `Machine Learning with Adversaries: Byzantine Tolerant Gradient Descent
<https://proceedings.neurips.cc/paper/2017/file/f4b9ec30ad9f68f89b29639786cb62ef-Paper.pdf>`_.
:param n_byzantine: The number of rows of the input matrix that can come from an adversarial
source.
:param n_selected: The number of selected rows in the context of Multi-Krum. Defaults to 1.
.. admonition::
Example
Use Multi-Krum to aggregate a matrix with 1 adversarial row.
>>> from torch import tensor
>>> from torchjd.aggregation import Krum
>>>
>>> A = Krum(n_byzantine=1, n_selected=4)
>>> J = tensor([
... [1., 1., 1.],
... [1., 0., 1.],
... [75., -666., 23], # adversarial row
... [1., 2., 3.],
... [2., 0., 1.],
... ])
>>>
>>> A(J)
tensor([1.2500, 0.7500, 1.5000])
"""
def __init__(self, n_byzantine: int, n_selected: int = 1):
super().__init__(weighting=_KrumWeighting(n_byzantine=n_byzantine, n_selected=n_selected))
def __repr__(self) -> str:
return (
f"{self.__class__.__name__}(n_byzantine={self.weighting.n_byzantine}, n_selected="
f"{self.weighting.n_selected})"
)
def __str__(self) -> str:
return f"Krum{self.weighting.n_byzantine}-{self.weighting.n_selected}"
class _KrumWeighting(_Weighting):
"""
:class:`~torchjd.aggregation.bases._Weighting` that extracts weights using the
(Multi-)Krum aggregation rule, as defined in `Machine Learning with Adversaries: Byzantine
Tolerant Gradient Descent
<https://proceedings.neurips.cc/paper/2017/file/f4b9ec30ad9f68f89b29639786cb62ef-Paper.pdf>`_.
:param n_byzantine: The number of rows of the input matrix that can come from an adversarial
source.
:param n_selected: The number of selected rows in the context of Multi-Krum. Defaults to 1.
"""
def __init__(self, n_byzantine: int, n_selected: int):
super().__init__()
if n_byzantine < 0:
raise ValueError(
"Parameter `n_byzantine` should be a non-negative integer. Found `n_byzantine = "
f"{n_byzantine}`."
)
if n_selected < 1:
raise ValueError(
"Parameter `n_selected` should be a positive integer. Found `n_selected = "
f"{n_selected}`."
)
self.n_byzantine = n_byzantine
self.n_selected = n_selected
def forward(self, matrix: Tensor) -> Tensor:
self._check_matrix_shape(matrix)
gramian = matrix @ matrix.T
return self._compute_from_gramian(gramian)
def _compute_from_gramian(self, gramian: Tensor) -> Tensor:
gradient_norms_squared = torch.diagonal(gramian)
distances_squared = (
gradient_norms_squared.unsqueeze(0) + gradient_norms_squared.unsqueeze(1) - 2 * gramian
)
distances = torch.sqrt(distances_squared)
n_closest = gramian.shape[0] - self.n_byzantine - 2
smallest_distances, _ = torch.topk(distances, k=n_closest + 1, largest=False)
smallest_distances_excluding_self = smallest_distances[:, 1:]
scores = smallest_distances_excluding_self.sum(dim=1)
_, selected_indices = torch.topk(scores, k=self.n_selected, largest=False)
one_hot_selected_indices = F.one_hot(selected_indices, num_classes=gramian.shape[0])
weights = one_hot_selected_indices.sum(dim=0).to(dtype=gramian.dtype) / self.n_selected
return weights
def _check_matrix_shape(self, matrix: Tensor) -> None:
min_rows = self.n_byzantine + 3
if matrix.shape[0] < min_rows:
raise ValueError(
f"Parameter `matrix` should have at least {min_rows} rows (n_byzantine + 3). Found "
f"`matrix` with {matrix.shape[0]} rows."
)
if matrix.shape[0] < self.n_selected:
raise ValueError(
f"Parameter `matrix` should have at least {self.n_selected} rows (n_selected). "
f"Found `matrix` with {matrix.shape[0]} rows."
)