Skip to content

Commit 1d3b4d8

Browse files
committed
DP-2022 Refactor refresh token logic and fix inactive session handling
1 parent e5483d1 commit 1d3b4d8

3 files changed

Lines changed: 28 additions & 21 deletions

File tree

okdata/sdk/auth/auth.py

Lines changed: 18 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -2,7 +2,10 @@
22
import logging
33

44
from okdata.sdk.auth.credentials.client_credentials import ClientCredentialsProvider
5-
from okdata.sdk.auth.credentials.common import TokenProviderNotInitialized
5+
from okdata.sdk.auth.credentials.common import (
6+
TokenProviderNotInitialized,
7+
TokenRefreshError,
8+
)
69
from okdata.sdk.auth.credentials.password_grant import TokenServiceProvider
710
from okdata.sdk.auth.util import is_token_expired
811
from okdata.sdk.exceptions import ApiAuthenticateError
@@ -55,16 +58,6 @@ def access_token(self):
5558
self.refresh_access_token()
5659
return self._access_token
5760

58-
# read only
59-
@property
60-
def refresh_token(self):
61-
if not self.token_provider:
62-
return None
63-
# If expired, relog
64-
if is_token_expired(self._refresh_token):
65-
self.token_provider.new_token()
66-
return self._refresh_token
67-
6861
def login(self, force=False):
6962
if not self.token_provider:
7063
return
@@ -77,21 +70,26 @@ def login(self, force=False):
7770
if self._access_token and not is_token_expired(self._access_token):
7871
log.info("Token not expired, skipping")
7972
return
80-
tokens = self.token_provider.new_token()
81-
if "access_token" not in tokens:
82-
raise ApiAuthenticateError
83-
self._access_token = tokens["access_token"]
84-
self._refresh_token = tokens["refresh_token"]
85-
self.file_cache.write_credentials(credentials=self)
73+
self.refresh_access_token()
8674

8775
def refresh_access_token(self):
8876
if not self.token_provider:
8977
return
9078

91-
if is_token_expired(self._refresh_token):
79+
tokens = None
80+
81+
if self._refresh_token and not is_token_expired(self._refresh_token):
82+
try:
83+
tokens = self.token_provider.refresh_token(self._refresh_token)
84+
except TokenRefreshError as e:
85+
log.warn(f"Error refreshing token: {e}")
86+
87+
if not tokens:
9288
tokens = self.token_provider.new_token()
93-
else:
94-
tokens = self.token_provider.refresh_token(self.refresh_token)
89+
if "access_token" not in tokens:
90+
raise ApiAuthenticateError
91+
self._refresh_token = tokens["refresh_token"]
92+
9593
self._access_token = tokens["access_token"]
9694
self.file_cache.write_credentials(credentials=self)
9795

okdata/sdk/auth/credentials/client_credentials.py

Lines changed: 6 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,9 +1,11 @@
11
from typing import Optional
2+
from keycloak.exceptions import KeycloakGetError # type: ignore
23
from keycloak.keycloak_openid import KeycloakOpenID # type: ignore
34

45
from okdata.sdk.auth.credentials.common import (
56
TokenProvider,
67
TokenProviderNotInitialized,
8+
TokenRefreshError,
79
)
810
from okdata.sdk.config import Config
911

@@ -39,7 +41,10 @@ def __post_init__(self):
3941
)
4042

4143
def refresh_token(self, refresh_token):
42-
return self.client.refresh_token(refresh_token=refresh_token)
44+
try:
45+
return self.client.refresh_token(refresh_token=refresh_token)
46+
except KeycloakGetError as e:
47+
raise TokenRefreshError(str(e))
4348

4449
def new_token(self):
4550
return self.client.token(grant_type=["client_credentials"])

okdata/sdk/auth/credentials/common.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -5,6 +5,10 @@ class TokenProviderNotInitialized(Exception):
55
pass
66

77

8+
class TokenRefreshError(Exception):
9+
pass
10+
11+
812
class TokenProvider:
913
config: Config
1014

0 commit comments

Comments
 (0)