1010
1111import os
1212from enum import StrEnum
13- from typing import Any , cast
13+ from typing import Any
1414
15+ from fhir .resources import Bundle , Device , Endpoint , Resource
1516from stubs import SdsFhirApiStub
1617
1718from gateway_api .get_structured_record import ACCESS_RECORD_STRUCTURED_INTERACTION_ID
2930 sds = SdsFhirApiStub ()
3031 get = sds .get # type: ignore
3132
32- # Recursive JSON-like structure typing used for parsed FHIR bodies.
33- type ResultStructureDict = dict [str , ResultStructure ]
34- type ResultList = list [ResultStructureDict ]
35- type ResultStructure = str | ResultStructureDict | list ["ResultStructure" ]
36-
3733
3834class SdsResourceType (StrEnum ):
3935 """SDS FHIR resource types."""
@@ -134,12 +130,14 @@ def get_org_details(
134130 querytype = SdsResourceType .DEVICE ,
135131 )
136132
137- device = self ._extract_first_entry (device_bundle )
133+ device = self ._extract_first_resource (device_bundle , Device )
138134
139- # TODO: Post-steel-thread handle case where no device is found for ODS code
135+ if not device :
136+ empty_search_results = SdsSearchResults (asid = None , endpoint = None )
137+ return empty_search_results
140138
141- asid = self ._extract_identifier (device , self .ASID_SYSTEM )
142- party_key = self ._extract_identifier (device , self .PARTYKEY_SYSTEM )
139+ asid = self ._extract_device_identifier (device , self .ASID_SYSTEM )
140+ party_key = self ._extract_device_identifier (device , self .PARTYKEY_SYSTEM )
143141
144142 # Step 2: Get Endpoint to obtain endpoint URL
145143 endpoint_url : str | None = None
@@ -154,11 +152,9 @@ def get_org_details(
154152 timeout = timeout ,
155153 querytype = SdsResourceType .ENDPOINT ,
156154 )
157- endpoint = self ._extract_first_entry (endpoint_bundle )
158- if endpoint :
159- address = endpoint .get ("address" )
160- if address :
161- endpoint_url = str (address ).strip ()
155+ endpoint = self ._extract_first_resource (endpoint_bundle , Endpoint )
156+ if endpoint and endpoint .address :
157+ endpoint_url = str (endpoint .address ).strip ()
162158
163159 return SdsSearchResults (asid = asid , endpoint = endpoint_url )
164160
@@ -182,7 +178,7 @@ def _query_sds(
182178 correlation_id : str | None = None ,
183179 timeout : int | None = 10 ,
184180 querytype : SdsResourceType = SdsResourceType .DEVICE ,
185- ) -> ResultStructureDict :
181+ ) -> Bundle :
186182 """
187183 Query SDS /Device or /Endpoint endpoint.
188184 """
@@ -206,38 +202,29 @@ def _query_sds(
206202
207203 # TODO: Post-steel-thread we probably want a raise_for_status() here
208204
209- body = response .json ()
210- return cast ( "ResultStructureDict" , body )
205+ bundle = Bundle . model_validate ( response .json () )
206+ return bundle
211207
212208 @staticmethod
213- def _extract_first_entry (
214- bundle : ResultStructureDict ,
215- ) -> ResultStructureDict : # TODO: Post-steel-thread this may return a None as well
216- """
217- Extract the first resource from a Bundle.
218- """
219- entries = cast ("ResultList" , bundle .get ("entry" , []))
220-
209+ def _extract_first_resource [T : Resource ](
210+ bundle : Bundle , resource : type [T ]
211+ ) -> T | None :
221212 # TODO: Post-steel-thread handle case where bundle contains no entries
222213
223214 # TODO: more carefully consider business logic for handling multiple
224215 # entries in beta
225- if not entries :
226- return {}
227- first_entry = entries [0 ]
228- return cast ("ResultStructureDict" , first_entry .get ("resource" , {}))
229-
230- def _extract_identifier (
231- self , device : ResultStructureDict , system : str
232- ) -> str | None :
216+ resources = bundle .find_resources (resource )
217+ if not resources :
218+ return None
219+ first_entry = resources [0 ]
220+ return first_entry
221+
222+ def _extract_device_identifier (self , device : Device , system : str ) -> str | None :
233223 """
234224 Extract an identifier value from a Device resource for a given system.
235225 """
236- identifiers = cast ("ResultList" , device .get ("identifier" , []))
237-
238- for identifier in identifiers :
239- id_system = str (identifier .get ("system" , "" ))
240- if id_system == system :
241- return cast ("str" , identifier .get ("value" , "" ))
226+ for identifier in device .identifier :
227+ if identifier .system == system :
228+ return identifier .value or ""
242229
243230 return None
0 commit comments