Skip to content

Commit bf0fa2d

Browse files
authored
Merge pull request #76 from oslokommune/DP-2022_fix_token_refresh
DP-2022 Fix token refresh on inactive session
2 parents 179cab7 + 33e64c8 commit bf0fa2d

5 files changed

Lines changed: 83 additions & 42 deletions

File tree

CHANGELOG.md

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -6,6 +6,9 @@
66
read errors, and redirects). The `retry` parameter now only controls the
77
maximum number of retries to perform on bad HTTP status codes.
88

9+
* Fix refresh of Keycloak access token when refresh token is invalid, e.g.
10+
due to inactive session because Keycloak server restarted.
11+
912
## 0.7.0
1013

1114
* `Dataset.update_dataset` now supports partial metadata updates when the

okdata/sdk/auth/auth.py

Lines changed: 18 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -2,7 +2,10 @@
22
import logging
33

44
from okdata.sdk.auth.credentials.client_credentials import ClientCredentialsProvider
5-
from okdata.sdk.auth.credentials.common import TokenProviderNotInitialized
5+
from okdata.sdk.auth.credentials.common import (
6+
TokenProviderNotInitialized,
7+
TokenRefreshError,
8+
)
69
from okdata.sdk.auth.credentials.password_grant import TokenServiceProvider
710
from okdata.sdk.auth.util import is_token_expired
811
from okdata.sdk.exceptions import ApiAuthenticateError
@@ -55,16 +58,6 @@ def access_token(self):
5558
self.refresh_access_token()
5659
return self._access_token
5760

58-
# read only
59-
@property
60-
def refresh_token(self):
61-
if not self.token_provider:
62-
return None
63-
# If expired, relog
64-
if is_token_expired(self._refresh_token):
65-
self.token_provider.new_token()
66-
return self._refresh_token
67-
6861
def login(self, force=False):
6962
if not self.token_provider:
7063
return
@@ -77,21 +70,26 @@ def login(self, force=False):
7770
if self._access_token and not is_token_expired(self._access_token):
7871
log.info("Token not expired, skipping")
7972
return
80-
tokens = self.token_provider.new_token()
81-
if "access_token" not in tokens:
82-
raise ApiAuthenticateError
83-
self._access_token = tokens["access_token"]
84-
self._refresh_token = tokens["refresh_token"]
85-
self.file_cache.write_credentials(credentials=self)
73+
self.refresh_access_token()
8674

8775
def refresh_access_token(self):
8876
if not self.token_provider:
8977
return
9078

91-
if is_token_expired(self._refresh_token):
79+
tokens = None
80+
81+
if self._refresh_token and not is_token_expired(self._refresh_token):
82+
try:
83+
tokens = self.token_provider.refresh_token(self._refresh_token)
84+
except TokenRefreshError as e:
85+
log.warn(f"Error refreshing token: {e}")
86+
87+
if not tokens:
9288
tokens = self.token_provider.new_token()
93-
else:
94-
tokens = self.token_provider.refresh_token(self.refresh_token)
89+
if "access_token" not in tokens:
90+
raise ApiAuthenticateError
91+
self._refresh_token = tokens["refresh_token"]
92+
9593
self._access_token = tokens["access_token"]
9694
self.file_cache.write_credentials(credentials=self)
9795

okdata/sdk/auth/credentials/client_credentials.py

Lines changed: 6 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,9 +1,11 @@
11
from typing import Optional
2+
from keycloak.exceptions import KeycloakGetError # type: ignore
23
from keycloak.keycloak_openid import KeycloakOpenID # type: ignore
34

45
from okdata.sdk.auth.credentials.common import (
56
TokenProvider,
67
TokenProviderNotInitialized,
8+
TokenRefreshError,
79
)
810
from okdata.sdk.config import Config
911

@@ -39,7 +41,10 @@ def __post_init__(self):
3941
)
4042

4143
def refresh_token(self, refresh_token):
42-
return self.client.refresh_token(refresh_token=refresh_token)
44+
try:
45+
return self.client.refresh_token(refresh_token=refresh_token)
46+
except KeycloakGetError as e:
47+
raise TokenRefreshError(str(e))
4348

4449
def new_token(self):
4550
return self.client.token(grant_type=["client_credentials"])

okdata/sdk/auth/credentials/common.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -5,6 +5,10 @@ class TokenProviderNotInitialized(Exception):
55
pass
66

77

8+
class TokenRefreshError(Exception):
9+
pass
10+
11+
812
class TokenProvider:
913
config: Config
1014

tests/auth/auth_test.py

Lines changed: 52 additions & 21 deletions
Original file line numberDiff line numberDiff line change
@@ -20,8 +20,8 @@
2020
logging.basicConfig(level=logging.INFO)
2121

2222

23-
config = Config()
24-
token_endpoint = "https://login-test.oslo.kommune.no/auth/realms/api-catalog/protocol/openid-connect/token"
23+
config = Config(env="prod")
24+
token_endpoint = "https://login.oslo.kommune.no/auth/realms/api-catalog/protocol/openid-connect/token"
2525

2626

2727
@pytest.fixture(scope="function")
@@ -39,11 +39,11 @@ def test_authenticate_cache_disabled(self, requests_mock, mock_home_dir):
3939

4040
response = json.dumps(client_credentials_response)
4141
matcher = re.compile(token_endpoint)
42-
requests_mock.register_uri("POST", matcher, text=response, status_code=204)
42+
requests_mock.register_uri("POST", matcher, text=response, status_code=200)
4343

4444
auth.login()
45-
assert auth.access_token == client_credentials_response["access_token"]
46-
assert auth.refresh_token == client_credentials_response["refresh_token"]
45+
assert auth._access_token == client_credentials_response["access_token"]
46+
assert auth._refresh_token == client_credentials_response["refresh_token"]
4747

4848
def test_authenticat_no_cache(self, requests_mock, mock_home_dir):
4949

@@ -54,11 +54,11 @@ def test_authenticat_no_cache(self, requests_mock, mock_home_dir):
5454

5555
response = json.dumps(client_credentials_response)
5656
matcher = re.compile(token_endpoint)
57-
requests_mock.register_uri("POST", matcher, text=response, status_code=204)
57+
requests_mock.register_uri("POST", matcher, text=response, status_code=200)
5858

5959
auth.login()
60-
assert auth.access_token == client_credentials_response["access_token"]
61-
assert auth.refresh_token == client_credentials_response["refresh_token"]
60+
assert auth._access_token == client_credentials_response["access_token"]
61+
assert auth._refresh_token == client_credentials_response["refresh_token"]
6262

6363
def test_authenticate_cached_credentials(self, mock_home_dir):
6464
client_credentials_provider = ClientCredentialsProvider(config)
@@ -73,8 +73,8 @@ def test_authenticate_cached_credentials(self, mock_home_dir):
7373

7474
auth.file_cache.write_credentials(json.dumps(cached_credentials))
7575
auth.login()
76-
assert auth.access_token == cached_credentials["access_token"]
77-
assert auth.refresh_token == cached_credentials["refresh_token"]
76+
assert auth._access_token == cached_credentials["access_token"]
77+
assert auth._refresh_token == cached_credentials["refresh_token"]
7878

7979
def test_authenticate_refresh_credentials(self, requests_mock, mock_home_dir):
8080

@@ -93,11 +93,11 @@ def test_authenticate_refresh_credentials(self, requests_mock, mock_home_dir):
9393

9494
response = json.dumps(client_credentials_response)
9595
matcher = re.compile(token_endpoint)
96-
requests_mock.register_uri("POST", matcher, text=response, status_code=204)
96+
requests_mock.register_uri("POST", matcher, text=response, status_code=200)
9797

9898
auth.login()
99-
assert auth.access_token == cached_credentials["access_token"]
100-
assert auth.refresh_token == cached_credentials["refresh_token"]
99+
assert auth._access_token == cached_credentials["access_token"]
100+
assert auth._refresh_token == cached_credentials["refresh_token"]
101101

102102
def test_authenticate_expired_tokens(self, requests_mock, mock_home_dir):
103103
client_credentials_provider = ClientCredentialsProvider(config)
@@ -115,13 +115,13 @@ def test_authenticate_expired_tokens(self, requests_mock, mock_home_dir):
115115

116116
response = json.dumps(client_credentials_response)
117117
matcher = re.compile(token_endpoint)
118-
requests_mock.register_uri("POST", matcher, text=response, status_code=204)
118+
requests_mock.register_uri("POST", matcher, text=response, status_code=200)
119119

120120
auth.login()
121121
print(from_cache_not_expired_token)
122122
print(from_cache_expired_token)
123-
assert auth.access_token == client_credentials_response["access_token"]
124-
assert auth.refresh_token == client_credentials_response["access_token"]
123+
assert auth._access_token == client_credentials_response["access_token"]
124+
assert auth._refresh_token == client_credentials_response["access_token"]
125125

126126
def test_authenticate_expired_access_token(self, requests_mock, mock_home_dir):
127127
client_credentials_provider = ClientCredentialsProvider(config)
@@ -139,11 +139,11 @@ def test_authenticate_expired_access_token(self, requests_mock, mock_home_dir):
139139

140140
response = json.dumps(client_credentials_response)
141141
matcher = re.compile(token_endpoint)
142-
requests_mock.register_uri("POST", matcher, text=response, status_code=204)
142+
requests_mock.register_uri("POST", matcher, text=response, status_code=200)
143143

144144
auth.login()
145-
assert auth.access_token == from_cache_not_expired_token
146-
assert auth.refresh_token == cached_credentials["refresh_token"]
145+
assert auth._access_token == from_cache_not_expired_token
146+
assert auth._refresh_token == cached_credentials["refresh_token"]
147147

148148
def test_authenticate_fail(self, requests_mock, mock_home_dir):
149149
client_credentials_provider = ClientCredentialsProvider(
@@ -152,12 +152,43 @@ def test_authenticate_fail(self, requests_mock, mock_home_dir):
152152
auth = Authenticate(config=config, token_provider=client_credentials_provider)
153153

154154
response = json.dumps(
155-
{"error": "authenitcation error", "error_description": "No such client"}
155+
{"error": "authentication error", "error_description": "No such client"}
156156
)
157157
matcher = re.compile(token_endpoint)
158-
requests_mock.register_uri("POST", matcher, text=response, status_code=204)
158+
requests_mock.register_uri("POST", matcher, text=response, status_code=200)
159159

160160
try:
161161
auth.login()
162162
except ApiAuthenticateError:
163163
assert True
164+
165+
def test_refresh_inactive_session(self, requests_mock, mock_home_dir):
166+
client_credentials_provider = ClientCredentialsProvider(config)
167+
auth = Authenticate(config=config, token_provider=client_credentials_provider)
168+
169+
auth.file_cache.credentials_cache_enabled = True
170+
171+
cached_credentials = {
172+
"provider": "TokenServiceProvider",
173+
"access_token": from_cache_expired_token,
174+
"refresh_token": from_cache_not_expired_token,
175+
}
176+
177+
auth.file_cache.write_credentials(json.dumps(cached_credentials))
178+
179+
error_msg = {
180+
"error": "invalid_grant",
181+
"error_description": "Session not active",
182+
}
183+
refresh_response = {"text": json.dumps(error_msg), "status_code": 400}
184+
login_response = {
185+
"text": json.dumps(client_credentials_response),
186+
"status_code": 200,
187+
}
188+
matcher = re.compile(token_endpoint)
189+
requests_mock.register_uri("POST", matcher, [refresh_response, login_response])
190+
191+
auth.login()
192+
193+
assert auth._access_token == from_cache_not_expired_token
194+
assert auth._refresh_token == cached_credentials["refresh_token"]

0 commit comments

Comments
 (0)