Skip to content

Commit 98dc17b

Browse files
committed
refactor: Optimize Mock PDS service initialization in lambda_handler
- Removed the global service retrieval function and directly initialized the MockPdsService with a Redis client and rate limiter. - Updated the lambda_handler to use the initialized service directly, improving performance and readability. - Adjusted unit tests to accommodate the new service initialization approach, ensuring proper mocking and error handling.
1 parent 28cac67 commit 98dc17b

2 files changed

Lines changed: 34 additions & 34 deletions

File tree

lambdas/mock_pds/src/lambda_handler.py

Lines changed: 15 additions & 22 deletions
Original file line numberDiff line numberDiff line change
@@ -9,32 +9,25 @@
99
logger = logging.getLogger()
1010
logger.setLevel(logging.INFO)
1111

12-
_mock_pds_service: MockPdsService | None = None
13-
14-
15-
def get_mock_pds_service() -> MockPdsService:
16-
global _mock_pds_service
17-
if _mock_pds_service is None:
18-
redis_client = redis.Redis(
19-
host=os.environ["REDIS_HOST"],
20-
port=int(os.getenv("REDIS_PORT", "6379")),
21-
decode_responses=True,
22-
)
23-
rate_limiter = FixedWindowRateLimiter(
24-
redis_client=redis_client,
25-
key_prefix="mock-pds",
26-
average_limit=int(os.getenv("MOCK_PDS_AVERAGE_LIMIT", "125")),
27-
average_window_seconds=int(os.getenv("MOCK_PDS_AVERAGE_WINDOW_SECONDS", "60")),
28-
spike_limit=int(os.getenv("MOCK_PDS_SPIKE_LIMIT", "450")),
29-
spike_window_seconds=int(os.getenv("MOCK_PDS_SPIKE_WINDOW_SECONDS", "1")),
30-
)
31-
_mock_pds_service = MockPdsService(rate_limiter, os.getenv("MOCK_PDS_GP_ODS_CODE", "Y12345"))
32-
return _mock_pds_service
12+
_redis_client = redis.Redis(
13+
host=os.environ["REDIS_HOST"],
14+
port=int(os.getenv("REDIS_PORT", "6379")),
15+
decode_responses=True,
16+
)
17+
_rate_limiter = FixedWindowRateLimiter(
18+
redis_client=_redis_client,
19+
key_prefix="mock-pds",
20+
average_limit=int(os.getenv("MOCK_PDS_AVERAGE_LIMIT", "125")),
21+
average_window_seconds=int(os.getenv("MOCK_PDS_AVERAGE_WINDOW_SECONDS", "60")),
22+
spike_limit=int(os.getenv("MOCK_PDS_SPIKE_LIMIT", "450")),
23+
spike_window_seconds=int(os.getenv("MOCK_PDS_SPIKE_WINDOW_SECONDS", "1")),
24+
)
25+
_mock_pds_service = MockPdsService(_rate_limiter, os.getenv("MOCK_PDS_GP_ODS_CODE", "Y12345"))
3326

3427

3528
def lambda_handler(event, context):
3629
try:
37-
return get_mock_pds_service().handle(event)
30+
return _mock_pds_service.handle(event)
3831
except Exception:
3932
logger.exception("Mock PDS failed to handle request")
4033
return {

lambdas/mock_pds/tests/test_mock_pds_service.py

Lines changed: 19 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -1,8 +1,13 @@
1+
import importlib
12
import json
3+
import os
24
import unittest
35
from unittest.mock import Mock, patch
46

5-
from lambda_handler import get_mock_pds_service, lambda_handler
7+
os.environ.setdefault("REDIS_HOST", "test-redis-host")
8+
os.environ.setdefault("REDIS_PORT", "6379")
9+
10+
import lambda_handler as lambda_handler_module
611
from mock_pds_service import RATE_LIMIT_MESSAGE, MockPdsService
712
from rate_limiter import FixedWindowRateLimiter, RateLimitDecision
813

@@ -47,7 +52,7 @@ def test_rejects_non_get_requests(self):
4752

4853
class TestLambdaHandler(unittest.TestCase):
4954
def tearDown(self):
50-
get_mock_pds_service.__globals__["_mock_pds_service"] = None
55+
importlib.reload(lambda_handler_module)
5156

5257
@patch.dict(
5358
"os.environ",
@@ -60,23 +65,25 @@ def tearDown(self):
6065
},
6166
clear=False,
6267
)
63-
@patch("lambda_handler.redis.Redis")
64-
def test_lambda_handler_uses_cached_service(self, mock_redis):
68+
@patch("mock_pds_service.MockPdsService")
69+
@patch("redis.Redis")
70+
def test_lambda_handler_uses_cached_service(self, mock_redis, mock_pds_cls):
6571
mock_service = Mock()
6672
mock_service.handle.return_value = {"statusCode": 200}
73+
mock_pds_cls.return_value = mock_service
6774

68-
with patch("lambda_handler.MockPdsService", return_value=mock_service):
69-
first_response = lambda_handler(_event(nhs_number="123"), None)
70-
second_response = lambda_handler(_event(nhs_number="456"), None)
75+
importlib.reload(lambda_handler_module)
76+
first_response = lambda_handler_module.lambda_handler(_event(nhs_number="123"), None)
77+
second_response = lambda_handler_module.lambda_handler(_event(nhs_number="456"), None)
7178

7279
self.assertEqual(first_response, {"statusCode": 200})
7380
self.assertEqual(second_response, {"statusCode": 200})
7481
mock_redis.assert_called_once_with(host="mock-redis", port=6379, decode_responses=True)
7582

76-
@patch("lambda_handler.get_mock_pds_service")
77-
def test_lambda_handler_returns_500_on_unhandled_error(self, mock_get_service):
78-
mock_get_service.return_value.handle.side_effect = RuntimeError("boom")
79-
80-
response = lambda_handler(_event(nhs_number="123"), None)
83+
def test_lambda_handler_returns_500_on_unhandled_error(self):
84+
mock_svc = Mock()
85+
mock_svc.handle.side_effect = RuntimeError("boom")
86+
with patch.object(lambda_handler_module, "_mock_pds_service", mock_svc):
87+
response = lambda_handler_module.lambda_handler(_event(nhs_number="123"), None)
8188

8289
self.assertEqual(response["statusCode"], 500)

0 commit comments

Comments
 (0)