Skip to content

Commit 3043876

Browse files
committed
tidy
1 parent 172e6c3 commit 3043876

4 files changed

Lines changed: 90 additions & 98 deletions

File tree

recordprocessor/src/redis_cacher.py

Lines changed: 0 additions & 25 deletions
This file was deleted.

recordprocessor/src/redis_disease_mapping.py

Lines changed: 29 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1,14 +1,39 @@
11
"Upload the content from a config file in S3 to ElastiCache (Redis)"
2-
from redis_cacher import RedisCacher
32

3+
import json
4+
import redis
45
from constants import DISEASE_MAPPING_FILE_KEY
56

67

8+
class RedisCacher():
9+
""" RedisCacher abstraction class to decouple application code
10+
from direct use of Redis client."""
11+
12+
def __init__(self, redis_host, redis_port, logger):
13+
try:
14+
# Attempt to connect to Redis
15+
self.redis_client = redis.StrictRedis(redis_host, redis_port, decode_responses=True)
16+
# Check the connection with a PING command
17+
if self.redis_client.ping():
18+
logger.info("Successfully connected to Redis.")
19+
else:
20+
logger.error("Failed to connect to Redis.")
21+
except Exception as e:
22+
logger.exception(f"Connection to Redis failed: {e}")
23+
24+
def get(self, key: str) -> dict:
25+
"""Gets the value from Redis cache for the given key."""
26+
value = self.redis_client.get(key)
27+
if value is not None:
28+
return json.loads(value)
29+
return {}
30+
31+
732
class DiseaseMapping:
833
"""Class to handle disease mapping operations."""
934
# redis_cache instance is found in clients.py
1035
def __init__(self, redis_cache: RedisCacher):
11-
mapping = redis_cache.get_cache(DISEASE_MAPPING_FILE_KEY)
36+
mapping = redis_cache.get(DISEASE_MAPPING_FILE_KEY)
1237
self.vaccines = mapping["vaccine"]
1338
self.diseases = mapping["disease"]
1439
self.load_vaccines_into_diseases()
@@ -43,14 +68,14 @@ def load_vaccines_into_diseases(self):
4368
if vaccine not in self.diseases[disease]["vaccines"]:
4469
self.diseases[disease]["vaccines"].append(vaccine)
4570

46-
def get_diseases(self, vaccine: str) -> list:
71+
def get_diseases_from_vaccine(self, vaccine: str) -> list:
4772
"""Returns a list of diseases for the given vaccine."""
4873
vaccine = self.vaccines.get(vaccine, {})
4974
if not vaccine:
5075
return []
5176
return vaccine.get("diseases", [])
5277

53-
def get_vaccines(self, disease: str) -> list:
78+
def get_vaccines_from_disease(self, disease: str) -> list:
5479
"""Returns a list of vaccines for the given disease."""
5580
disease = self.disease_map.get(disease, {})
5681
if not disease:

recordprocessor/tests/test_redis_cacher.py

Lines changed: 0 additions & 67 deletions
This file was deleted.

recordprocessor/tests/test_redis_disease_mapping.py

Lines changed: 61 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,10 +1,69 @@
11
import unittest
22
from unittest.mock import patch, MagicMock
3-
from src.redis_disease_mapping import DiseaseMapping
4-
from src.redis_cacher import RedisCacher
3+
from src.redis_disease_mapping import DiseaseMapping, RedisCacher
54
from src.constants import DISEASE_MAPPING_FILE_KEY
65

76

7+
class TestRedisCacher(unittest.TestCase):
8+
def setUp(self):
9+
10+
self.strict_redis_patcher = patch("redis.StrictRedis")
11+
self.strict_redis = self.strict_redis_patcher.start()
12+
self.mock_redis_client = MagicMock()
13+
self.strict_redis.return_value = self.mock_redis_client
14+
15+
self.logger_info_patcher = patch("logging.Logger.info")
16+
self.mock_logger_info = self.logger_info_patcher.start()
17+
18+
self.logger_exception_patcher = patch("logging.Logger.exception")
19+
self.mock_logger_exception = self.logger_exception_patcher.start()
20+
21+
self.logger_warning_patcher = patch("logging.Logger.warning")
22+
self.mock_logger_warning = self.logger_warning_patcher.start()
23+
24+
self.logger_error_patcher = patch("logging.Logger.error")
25+
self.mock_logger_error = self.logger_error_patcher.start()
26+
27+
def tearDown(self):
28+
self.strict_redis_patcher.stop()
29+
self.logger_info_patcher.stop()
30+
self.logger_exception_patcher.stop()
31+
self.logger_warning_patcher.stop()
32+
self.logger_error_patcher.stop()
33+
34+
def test_successful_connection(self):
35+
self.mock_redis_client.ping.return_value = True
36+
cacher = RedisCacher("localhost", 6379)
37+
self.assertTrue(hasattr(cacher, "redis_client"))
38+
self.mock_logger_info.assert_called_once()
39+
self.mock_redis_client.ping.assert_called_once()
40+
print("test_successful_connection...Done")
41+
42+
def test_failed_connection(self):
43+
self.mock_redis_client.ping.return_value = False
44+
cacher = RedisCacher("localhost", 6379)
45+
self.assertTrue(hasattr(cacher, "redis_client"))
46+
self.mock_redis_client.ping.assert_called_once()
47+
print("test_failed_connection...Done")
48+
49+
def test_get_cache_returns_dict(self):
50+
self.mock_redis_client.ping.return_value = True
51+
self.mock_redis_client.get.return_value = '{"foo": "bar"}'
52+
cacher = RedisCacher("localhost", 6379)
53+
result = cacher.get_cache("some_key")
54+
self.assertEqual(result, {"foo": "bar"})
55+
self.mock_redis_client.get.assert_called_once_with("some_key")
56+
print("test_get_cache_returns_dict...Done")
57+
58+
def test_get_cache_returns_empty_dict(self):
59+
self.mock_redis_client.ping.return_value = True
60+
self.mock_redis_client.get.return_value = None
61+
cacher = RedisCacher("localhost", 6379)
62+
result = cacher.get_cache("missing_key")
63+
self.assertEqual(result, {})
64+
self.mock_redis_client.get.assert_called_once_with("missing_key")
65+
66+
867
class TestDiseaseMapping(unittest.TestCase):
968

1069
basic_cache = {

0 commit comments

Comments
 (0)