Skip to content

Commit aa16592

Browse files
Use Pydantic FHIR types in SDS client.
1 parent 370d2f5 commit aa16592

3 files changed

Lines changed: 54 additions & 41 deletions

File tree

gateway-api/src/fhir/resources.py

Lines changed: 25 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -82,7 +82,7 @@ def _validate_resource_type(cls, value: str) -> str:
8282
return value
8383

8484

85-
type BundleType = Literal["document", "transaction"]
85+
type BundleType = Literal["document", "transaction", "searchset"]
8686

8787

8888
class Bundle(Resource, resource_type="Bundle"):
@@ -117,6 +117,30 @@ def empty(cls, bundle_type: BundleType) -> "Bundle":
117117
return cls.create(type=bundle_type, entry=None)
118118

119119

120+
class Device(Resource, resource_type="Device"):
121+
"""A FHIR R4 Device resource."""
122+
123+
class ASIDIdentifier(
124+
Identifier, expected_system="https://fhir.nhs.uk/Id/nhsSpineASID"
125+
):
126+
"""A FHIR R4 ASID Identifier."""
127+
128+
class PartyKeyIdentifier(
129+
Identifier, expected_system="https://fhir.nhs.uk/Id/nhsMhsPartyKey"
130+
):
131+
"""A FHIR R4 Party Key Identifier."""
132+
133+
identifier: Annotated[
134+
list[ASIDIdentifier | PartyKeyIdentifier], Field(frozen=True, min_length=1)
135+
]
136+
137+
138+
class Endpoint(Resource, resource_type="Endpoint"):
139+
"""A FHIR R4 Endpoint resource."""
140+
141+
address: str | None = Field(None, frozen=True)
142+
143+
120144
class OperationOutcome(Resource, resource_type="OperationOutcome"):
121145
"""A FHIR R4 OperationOutcome resource."""
122146

gateway-api/src/gateway_api/get_structured_record/request.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,8 @@
22
from typing import TYPE_CHECKING, ClassVar
33

44
from fhir import OperationOutcome, Parameters
5+
6+
# TODO: may be able to remove the use of the FHIR type entirely.
57
from fhir.operation_outcome import OperationOutcomeIssue
68
from flask.wrappers import Request, Response
79
from werkzeug.exceptions import BadRequest

gateway-api/src/gateway_api/sds/client.py

Lines changed: 27 additions & 40 deletions
Original file line numberDiff line numberDiff line change
@@ -10,8 +10,9 @@
1010

1111
import os
1212
from enum import StrEnum
13-
from typing import Any, cast
13+
from typing import Any
1414

15+
from fhir.resources import Bundle, Device, Endpoint, Resource
1516
from stubs import SdsFhirApiStub
1617

1718
from gateway_api.get_structured_record import ACCESS_RECORD_STRUCTURED_INTERACTION_ID
@@ -29,11 +30,6 @@
2930
sds = SdsFhirApiStub()
3031
get = sds.get # type: ignore
3132

32-
# Recursive JSON-like structure typing used for parsed FHIR bodies.
33-
type ResultStructureDict = dict[str, ResultStructure]
34-
type ResultList = list[ResultStructureDict]
35-
type ResultStructure = str | ResultStructureDict | list["ResultStructure"]
36-
3733

3834
class SdsResourceType(StrEnum):
3935
"""SDS FHIR resource types."""
@@ -134,12 +130,14 @@ def get_org_details(
134130
querytype=SdsResourceType.DEVICE,
135131
)
136132

137-
device = self._extract_first_entry(device_bundle)
133+
device = self._extract_first_resource(device_bundle, Device)
138134

139-
# TODO: Post-steel-thread handle case where no device is found for ODS code
135+
if not device:
136+
empty_search_results = SdsSearchResults(asid=None, endpoint=None)
137+
return empty_search_results
140138

141-
asid = self._extract_identifier(device, self.ASID_SYSTEM)
142-
party_key = self._extract_identifier(device, self.PARTYKEY_SYSTEM)
139+
asid = self._extract_device_identifier(device, self.ASID_SYSTEM)
140+
party_key = self._extract_device_identifier(device, self.PARTYKEY_SYSTEM)
143141

144142
# Step 2: Get Endpoint to obtain endpoint URL
145143
endpoint_url: str | None = None
@@ -154,11 +152,9 @@ def get_org_details(
154152
timeout=timeout,
155153
querytype=SdsResourceType.ENDPOINT,
156154
)
157-
endpoint = self._extract_first_entry(endpoint_bundle)
158-
if endpoint:
159-
address = endpoint.get("address")
160-
if address:
161-
endpoint_url = str(address).strip()
155+
endpoint = self._extract_first_resource(endpoint_bundle, Endpoint)
156+
if endpoint and endpoint.address:
157+
endpoint_url = str(endpoint.address).strip()
162158

163159
return SdsSearchResults(asid=asid, endpoint=endpoint_url)
164160

@@ -182,7 +178,7 @@ def _query_sds(
182178
correlation_id: str | None = None,
183179
timeout: int | None = 10,
184180
querytype: SdsResourceType = SdsResourceType.DEVICE,
185-
) -> ResultStructureDict:
181+
) -> Bundle:
186182
"""
187183
Query SDS /Device or /Endpoint endpoint.
188184
"""
@@ -206,38 +202,29 @@ def _query_sds(
206202

207203
# TODO: Post-steel-thread we probably want a raise_for_status() here
208204

209-
body = response.json()
210-
return cast("ResultStructureDict", body)
205+
bundle = Bundle.model_validate(response.json())
206+
return bundle
211207

212208
@staticmethod
213-
def _extract_first_entry(
214-
bundle: ResultStructureDict,
215-
) -> ResultStructureDict: # TODO: Post-steel-thread this may return a None as well
216-
"""
217-
Extract the first resource from a Bundle.
218-
"""
219-
entries = cast("ResultList", bundle.get("entry", []))
220-
209+
def _extract_first_resource[T: Resource](
210+
bundle: Bundle, resource: type[T]
211+
) -> T | None:
221212
# TODO: Post-steel-thread handle case where bundle contains no entries
222213

223214
# TODO: more carefully consider business logic for handling multiple
224215
# entries in beta
225-
if not entries:
226-
return {}
227-
first_entry = entries[0]
228-
return cast("ResultStructureDict", first_entry.get("resource", {}))
229-
230-
def _extract_identifier(
231-
self, device: ResultStructureDict, system: str
232-
) -> str | None:
216+
resources = bundle.find_resources(resource)
217+
if not resources:
218+
return None
219+
first_entry = resources[0]
220+
return first_entry
221+
222+
def _extract_device_identifier(self, device: Device, system: str) -> str | None:
233223
"""
234224
Extract an identifier value from a Device resource for a given system.
235225
"""
236-
identifiers = cast("ResultList", device.get("identifier", []))
237-
238-
for identifier in identifiers:
239-
id_system = str(identifier.get("system", ""))
240-
if id_system == system:
241-
return cast("str", identifier.get("value", ""))
226+
for identifier in device.identifier:
227+
if identifier.system == system:
228+
return identifier.value or ""
242229

243230
return None

0 commit comments

Comments
 (0)