Skip to content

Commit 8f2660d

Browse files
Use TypeGuard in _can_skip_jacobian_combination
Co-authored-by: Pierre Quinton <pierre.quinton@epfl.ch>
1 parent 48cd70b commit 8f2660d

1 file changed

Lines changed: 3 additions & 3 deletions

File tree

src/torchjd/autojac/_jac_to_grad.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
from collections import deque
22
from collections.abc import Iterable
3-
from typing import cast
3+
from typing import cast, TypeGuard
44

55
import torch
66
from torch import Tensor, nn
@@ -76,13 +76,13 @@ def jac_to_grad(
7676
_free_jacs(tensors_)
7777

7878
if _can_skip_jacobian_combination(aggregator):
79-
gradients = _gramian_based(cast(GramianWeightedAggregator, aggregator), jacobians, tensors_)
79+
gradients = _gramian_based(aggregator, jacobians, tensors_)
8080
else:
8181
gradients = _jacobian_based(aggregator, jacobians, tensors_)
8282
accumulate_grads(tensors_, gradients)
8383

8484

85-
def _can_skip_jacobian_combination(aggregator: Aggregator) -> bool:
85+
def _can_skip_jacobian_combination(aggregator: Aggregator) -> TypeGuard[GramianWeightedAggregator]:
8686
return isinstance(aggregator, GramianWeightedAggregator) and not _has_forward_hook(aggregator)
8787

8888

0 commit comments

Comments
 (0)