1- from contextlib import AbstractContextManager
21from datetime import date , datetime , time
32from decimal import Decimal
43from functools import wraps
54from typing import (
65 Any ,
76 Callable ,
8- ContextManager ,
97 Dict ,
108 Iterator ,
119 List ,
2321from sqlalchemy import Column
2422from sqlalchemy .dialects import mysql
2523from sqlalchemy .engine import Dialect , default
26- from sqlalchemy .orm .query import Query
27- from sqlalchemy .sql import Delete , Insert , Select , Update
2824
2925from pydataapi .exceptions import DataAPIError , MultipleResultsFound , NoResultFound
3026
5753DATE_TYPE_HINT : str = "DATE"
5854
5955
60- def generate_sql (query : Union [Query , Insert , Update , Delete , Select ]) -> str :
61- if hasattr (query , "statement" ):
62- sql : str = query .statement .compile (** QUERY_STATEMENT_COMPILE_PARAMS )
63- else :
64- sql = query .compile (** QUERY_STATEMENT_COMPILE_PARAMS )
65- return str (sql )
66-
67-
68- def wrap_process_result_value_function (
69- process_result_value : Callable [..., Any ], dialect : default .DefaultDialect
70- ) -> Callable [..., Any ]:
71- @wraps (process_result_value )
72- def wrapped (value : Any ) -> Callable [..., Any ]:
73- return process_result_value (value , dialect )
74-
75- return wrapped
76-
77-
78- def get_process_result_value_function (
79- table_name : str ,
80- column_name : str ,
81- query : Union [Select , Query ],
82- dialect : default .DefaultDialect ,
83- ) -> Callable [..., Any ]:
84- process_result_value : Optional [Callable [..., Any ]] = None
85- if isinstance (query , Select ): # pragma: no cover
86- for column in query .columns :
87- if column .name == column_name :
88- process_result_value = getattr (
89- column .type , "process_result_value" , None
90- )
91- break
92- elif isinstance (query , Query ): # pragma: no cover
93- for column_description in query .column_descriptions :
94- type_ = column_description ["type" ]
95- if type_ .__tablename__ == table_name :
96- column = getattr (type_ , column_name , None )
97- if column :
98- expression = getattr (column , "expression" , None )
99- if (
100- isinstance (expression , Column )
101- and expression .name == column_name
102- ):
103- process_result_value = getattr (
104- expression .type , "process_result_value" , None
105- )
106- break
107- if process_result_value :
108- return wrap_process_result_value_function (process_result_value , dialect )
109- return lambda v : v
110-
111-
112- def create_process_result_value_function_list (
113- column_metadata : List [Dict [str , Any ]],
114- query : Union [Select , Query ],
115- dialect : default .DefaultDialect ,
116- ) -> List [Callable [..., Any ]]:
117- return [
118- get_process_result_value_function (cm ["tableName" ], cm ["name" ], query , dialect )
119- for cm in column_metadata
120- ]
121-
122-
12356def convert_array_value (value : Union [List [Any ], Tuple [Any , ...]]) -> Dict [str , Any ]:
12457 first_value : Any = value [0 ]
12558 if isinstance (first_value , (list , tuple )):
@@ -312,27 +245,12 @@ def __getitem__(self, i: Union[int, slice]) -> Union["Record", List["Record"]]:
312245 def __len__ (self ) -> int :
313246 return len (self ._rows )
314247
315- def __init__ (
316- self ,
317- response : Dict [Any , Any ],
318- process_result_value_function_list : Optional [List [Callable [..., Any ]]] = None ,
319- ) -> None :
248+ def __init__ (self , response : Dict [Any , Any ],) -> None :
320249 self ._response = response
321- if process_result_value_function_list :
322- self ._rows : Sequence [List [Any ]] = [
323- [
324- process_result_value (_get_value_from_row (column ))
325- for column , process_result_value in zip (
326- row , process_result_value_function_list
327- )
328- ]
329- for row in response .get ("records" , [])
330- ]
331- else :
332- self ._rows = [
333- [_get_value_from_row (column ) for column in row ]
334- for row in response .get ("records" , [])
335- ]
250+ self ._rows = [
251+ [_get_value_from_row (column ) for column in row ]
252+ for row in response .get ("records" , [])
253+ ]
336254 self ._column_metadata : List [Dict [str , Any ]] = response .get ("columnMetadata" , [])
337255 self ._headers : Optional [List [str ]] = None
338256 self ._index : int = - 1
@@ -418,12 +336,6 @@ def convert_parameter_sets(cls, v: Any) -> Any:
418336 return [create_sql_parameters (parameter ) for parameter in v ]
419337 return v # pragma: no cover
420338
421- @validator ("sql" , pre = True )
422- def validate_sql (cls , v : Any ) -> Any :
423- if isinstance (v , str ):
424- return v
425- return generate_sql (v )
426-
427339 def build (self ) -> Dict [str , Any ]:
428340 return self .dict (exclude_unset = True , by_alias = True )
429341
@@ -547,7 +459,7 @@ def rollback(self, transaction_id: Optional[str] = None) -> str:
547459
548460 def execute (
549461 self ,
550- query : Union [ Query , Insert , Update , Delete , Select , str ] ,
462+ query : str ,
551463 parameters : Optional [Dict [str , Any ]] = None ,
552464 transaction_id : Optional [str ] = None ,
553465 continue_after_timeout : bool = True ,
@@ -568,18 +480,11 @@ def execute(
568480 includeResultMetadata = True , ** options .build ()
569481 )
570482
571- if isinstance (query , (Query , Select )):
572- process_result_value_function_list = create_process_result_value_function_list (
573- response .get ("columnMetadata" , []),
574- query ,
575- QUERY_STATEMENT_COMPILE_PARAMS ["dialect" ],
576- )
577- return Result (response , process_result_value_function_list )
578483 return Result (response )
579484
580485 def batch_execute (
581486 self ,
582- query : Union [ Query , Insert , Update , Delete , Select , str ] ,
487+ query : str ,
583488 parameter_sets : Optional [List [Dict [str , Any ]]],
584489 transaction_id : Optional [str ] = None ,
585490 database : Optional [str ] = None ,
0 commit comments