2828
2929MAX_RECORDS : int = 1000
3030
31- DIALECT : Dialect = mysql .dialect (paramstyle = ' named' )
31+ DIALECT : Dialect = mysql .dialect (paramstyle = " named" )
3232
3333QUERY_STATEMENT_COMPILE_PARAMS = {
34- ' dialect' : mysql .dialect (paramstyle = ' named' ),
35- ' compile_kwargs' : {"literal_binds" : True },
34+ " dialect" : mysql .dialect (paramstyle = " named" ),
35+ " compile_kwargs" : {"literal_binds" : True },
3636}
3737
38- BOOLEAN_VALUE : str = ' booleanValue'
39- STRING_VALUE : str = ' stringValue'
40- LONG_VALUE : str = ' longValue'
41- DOUBLE_VALUE : str = ' doubleValue'
42- BLOB_VALUE : str = ' blobValue'
43- IS_NULL : str = ' isNull'
44- ARRAY_VALUE : str = ' arrayValue'
45- ARRAY_VALUES : str = ' arrayValues'
46- BOOLEAN_VALUES : str = ' booleanValues'
47- STRING_VALUES : str = ' stringValues'
48- LONG_VALUES : str = ' longValues'
49- DOUBLE_VALUES : str = ' doubleValues'
50- BLOB_VALUES : str = ' blobValues'
51-
52- DECIMAL_TYPE_HINT : str = ' DECIMAL'
53- TIMESTAMP_TYPE_HINT : str = ' TIMESTAMP'
54- TIME_TYPE_HINT : str = ' TIME'
55- DATE_TYPE_HINT : str = ' DATE'
38+ BOOLEAN_VALUE : str = " booleanValue"
39+ STRING_VALUE : str = " stringValue"
40+ LONG_VALUE : str = " longValue"
41+ DOUBLE_VALUE : str = " doubleValue"
42+ BLOB_VALUE : str = " blobValue"
43+ IS_NULL : str = " isNull"
44+ ARRAY_VALUE : str = " arrayValue"
45+ ARRAY_VALUES : str = " arrayValues"
46+ BOOLEAN_VALUES : str = " booleanValues"
47+ STRING_VALUES : str = " stringValues"
48+ LONG_VALUES : str = " longValues"
49+ DOUBLE_VALUES : str = " doubleValues"
50+ BLOB_VALUES : str = " blobValues"
51+
52+ DECIMAL_TYPE_HINT : str = " DECIMAL"
53+ TIMESTAMP_TYPE_HINT : str = " TIMESTAMP"
54+ TIME_TYPE_HINT : str = " TIME"
55+ DATE_TYPE_HINT : str = " DATE"
5656
5757
5858def generate_sql (query : Union [Query , Insert , Update , Delete , Select ]) -> str :
59- if hasattr (query , ' statement' ):
59+ if hasattr (query , " statement" ):
6060 sql : str = query .statement .compile (** QUERY_STATEMENT_COMPILE_PARAMS )
6161 else :
6262 sql = query .compile (** QUERY_STATEMENT_COMPILE_PARAMS )
@@ -84,22 +84,22 @@ def get_process_result_value_function(
8484 for column in query .columns :
8585 if column .name == column_name :
8686 process_result_value = getattr (
87- column .type , ' process_result_value' , None
87+ column .type , " process_result_value" , None
8888 )
8989 break
9090 elif isinstance (query , Query ): # pragma: no cover
9191 for column_description in query .column_descriptions :
92- type_ = column_description [' type' ]
92+ type_ = column_description [" type" ]
9393 if type_ .__tablename__ == table_name :
9494 column = getattr (type_ , column_name , None )
9595 if column :
96- expression = getattr (column , ' expression' , None )
96+ expression = getattr (column , " expression" , None )
9797 if (
9898 isinstance (expression , Column )
9999 and expression .name == column_name
100100 ):
101101 process_result_value = getattr (
102- expression .type , ' process_result_value' , None
102+ expression .type , " process_result_value" , None
103103 )
104104 break
105105 if process_result_value :
@@ -113,7 +113,7 @@ def create_process_result_value_function_list(
113113 dialect : default .DefaultDialect ,
114114) -> List [Callable ]:
115115 return [
116- get_process_result_value_function (cm [' tableName' ], cm [' name' ], query , dialect )
116+ get_process_result_value_function (cm [" tableName" ], cm [" name" ], query , dialect )
117117 for cm in column_metadata
118118 ]
119119
@@ -142,7 +142,7 @@ def convert_array_value(value: Union[List, Tuple]) -> Dict[str, Any]:
142142 values_key = BLOB_VALUES
143143 if values_key :
144144 return {ARRAY_VALUE : {values_key : list (value )}}
145- raise Exception (f' unsupported array type { type (value [0 ])} ]: { value } ' )
145+ raise Exception (f" unsupported array type { type (value [0 ])} ]: { value } " )
146146
147147
148148def create_sql_parameter (key : str , value : Any ) -> Dict [str , Any ]:
@@ -170,20 +170,20 @@ def create_sql_parameter(key: str, value: Any) -> Dict[str, Any]:
170170 converted_value = {STRING_VALUE : str (value )}
171171 type_hint = DECIMAL_TYPE_HINT
172172 elif isinstance (value , datetime ):
173- converted_value = {STRING_VALUE : value .strftime (' %Y-%m-%d %H:%M:%S.%f' )}
173+ converted_value = {STRING_VALUE : value .strftime (" %Y-%m-%d %H:%M:%S.%f" )}
174174 type_hint = TIMESTAMP_TYPE_HINT
175175 elif isinstance (value , time ):
176- converted_value = {STRING_VALUE : value .strftime (' %H:%M:%S.%f' )}
176+ converted_value = {STRING_VALUE : value .strftime (" %H:%M:%S.%f" )}
177177 type_hint = TIME_TYPE_HINT
178178 elif isinstance (value , date ):
179- converted_value = {STRING_VALUE : value .strftime (' %Y-%m-%d' )}
179+ converted_value = {STRING_VALUE : value .strftime (" %Y-%m-%d" )}
180180 type_hint = DATE_TYPE_HINT
181181 else :
182182 # TODO: support structValue
183183 converted_value = {STRING_VALUE : str (value )}
184184 if type_hint :
185- return {' name' : key , ' value' : converted_value , ' typeHint' : type_hint }
186- return {' name' : key , ' value' : converted_value }
185+ return {" name" : key , " value" : converted_value , " typeHint" : type_hint }
186+ return {" name" : key , " value" : converted_value }
187187
188188
189189def create_sql_parameters (
@@ -209,13 +209,13 @@ def _get_value_from_row(row: Dict[str, Any]) -> Any:
209209 return value
210210
211211
212- T = TypeVar ('T' )
212+ T = TypeVar ("T" )
213213
214214
215215class GeneratedFields :
216216 def __repr__ (self ) -> str :
217- values : str = ', ' .join (str (f ) for f in self .generated_fields )
218- return f' <{ self .__class__ .__name__ } ({ values } )>'
217+ values : str = ", " .join (str (f ) for f in self .generated_fields )
218+ return f" <{ self .__class__ .__name__ } ({ values } )>"
219219
220220 def __init__ (self , generated_fields : List [Dict [str , Any ]]):
221221 self ._generated_fields_raw : List [Dict [str , Any ]] = generated_fields
@@ -247,8 +247,8 @@ def __eq__(self, other: Any) -> bool:
247247
248248class Record (Sequence , Iterator ):
249249 def __repr__ (self ) -> str :
250- values : str = ', ' .join (f' { k } ={ str (v )} ' for k , v in self .dict ().items ())
251- return f' <{ self .__class__ .__name__ } ({ values } )>'
250+ values : str = ", " .join (f" { k } ={ str (v )} " for k , v in self .dict ().items ())
251+ return f" <{ self .__class__ .__name__ } ({ values } )>"
252252
253253 def __next__ (self ) -> Any :
254254 self ._index += 1
@@ -300,7 +300,7 @@ def __next__(self) -> Any:
300300
301301 def __getitem__ ( # type: ignore
302302 self , i : Union [int , slice ]
303- ) -> Union [' Record' , List [' Record' ]]:
303+ ) -> Union [" Record" , List [" Record" ]]:
304304 if isinstance (i , slice ):
305305 return [Record (r , self .headers ) for r in self ._rows [i ]]
306306 return Record (self ._rows [i ], self .headers ) # type: ignore
@@ -322,26 +322,26 @@ def __init__(
322322 row , process_result_value_function_list
323323 )
324324 ]
325- for row in response .get (' records' , []) # type: ignore
325+ for row in response .get (" records" , []) # type: ignore
326326 ]
327327 else :
328328 self ._rows = [
329329 [_get_value_from_row (column ) for column in row ]
330- for row in response .get (' records' , []) # type: ignore
330+ for row in response .get (" records" , []) # type: ignore
331331 ]
332- self ._column_metadata : List [Dict [str , Any ]] = response .get (' columnMetadata' , [])
332+ self ._column_metadata : List [Dict [str , Any ]] = response .get (" columnMetadata" , [])
333333 self ._headers : Optional [List [str ]] = None
334334 self ._index : int = - 1
335- super ().__init__ (response .get (' generatedFields' , []))
335+ super ().__init__ (response .get (" generatedFields" , []))
336336
337337 @property
338338 def number_of_records_updated (self ) -> int :
339- return self ._response .get (' numberOfRecordsUpdated' , 0 )
339+ return self ._response .get (" numberOfRecordsUpdated" , 0 )
340340
341341 @property
342342 def headers (self ) -> List [str ]:
343343 if self ._headers is None :
344- self ._headers = [meta [' label' ] for meta in self ._column_metadata ]
344+ self ._headers = [meta [" label" ] for meta in self ._column_metadata ]
345345 return self ._headers
346346
347347 def first (self ) -> Optional [Record ]:
@@ -373,13 +373,13 @@ def all(self) -> List[Record]:
373373class UpdateResults (Sequence [GeneratedFields ]):
374374 def __getitem__ ( # type: ignore
375375 self , i : Union [int , slice ]
376- ) -> Union [' GeneratedFields' , List [' GeneratedFields' ]]:
376+ ) -> Union [" GeneratedFields" , List [" GeneratedFields" ]]:
377377 if isinstance (i , slice ):
378378 return [
379- GeneratedFields (r [' generatedFields' ]) for r in self ._update_results [i ]
379+ GeneratedFields (r [" generatedFields" ]) for r in self ._update_results [i ]
380380 ]
381381 return GeneratedFields (
382- self ._update_results [i ][' generatedFields' ]
382+ self ._update_results [i ][" generatedFields" ]
383383 ) # type: ignore
384384
385385 def __len__ (self ) -> int :
@@ -394,7 +394,7 @@ class Options(BaseModel):
394394 secretArn : str
395395 sql : Optional [str ]
396396 database : Optional [str ]
397- schema_ : Optional [str ] = Field (None , alias = ' schema' )
397+ schema_ : Optional [str ] = Field (None , alias = " schema" )
398398 transactionId : Optional [str ]
399399 continueAfterTimeout : Optional [bool ]
400400 parameters : Optional [List [Dict [str , Any ]]]
@@ -404,19 +404,19 @@ class Options(BaseModel):
404404 def validate_all (cls , values : Dict [str , Any ]) -> Dict [str , Any ]:
405405 return {k : v for k , v in values .items () if v is not None }
406406
407- @validator (' parameters' , pre = True )
407+ @validator (" parameters" , pre = True )
408408 def convert_parameters (cls , v : Any ) -> Any :
409409 if isinstance (v , Dict ):
410410 return create_sql_parameters (v )
411411 return v
412412
413- @validator (' parameterSets' , pre = True )
413+ @validator (" parameterSets" , pre = True )
414414 def convert_parameter_sets (cls , v : Any ) -> Any :
415415 if isinstance (v , (list , tuple )): # pragma: no cover
416416 return [create_sql_parameters (parameter ) for parameter in v ]
417417 return v # pragma: no cover
418418
419- @validator (' sql' , pre = True )
419+ @validator (" sql" , pre = True )
420420 def validate_sql (cls , v : Any ) -> Any :
421421 if isinstance (v , str ):
422422 return v
@@ -430,10 +430,10 @@ def find_arn_by_resource_name(
430430 resource_name : str , boto3_client : Optional [boto3 .session .Session .client ]
431431) -> str :
432432 if not boto3_client :
433- boto3_client = boto3 .client (' rds' )
433+ boto3_client = boto3 .client (" rds" )
434434 return boto3_client .describe_db_clusters (DBClusterIdentifier = resource_name )[
435- ' DBClusters'
436- ][0 ][' DBClusterArn' ]
435+ " DBClusters"
436+ ][0 ][" DBClusterArn" ]
437437
438438
439439class DataAPI (AbstractContextManager ):
@@ -452,27 +452,28 @@ def __init__(
452452 if resource_name :
453453 if resource_arn :
454454 raise DataAPIError (
455- f' resource_name should be set without resource_arn. resource_arn: { resource_arn } ,'
456- f' resource_name: { resource_name } '
455+ f" resource_name should be set without resource_arn. resource_arn: { resource_arn } ,"
456+ f" resource_name: { resource_name } "
457457 )
458458 resource_arn = find_arn_by_resource_name (resource_name , rds_client )
459459 if not resource_arn :
460- raise DataAPIError (' Not Found resource_arn.' )
460+ raise DataAPIError (" Not Found resource_arn." )
461461 self .resource_arn : str = resource_arn
462462 self .secret_arn : str = secret_arn
463463 self .database : Optional [str ] = database
464464
465465 client_kwargs = {}
466- region_name = resource_arn .split (':' )[3 ]
467- if region_name :
468- client_kwargs ['region_name' ] = region_name
466+ region_name = resource_arn .split (":" )[3 ]
467+ client_kwargs ["region_name" ] = region_name
469468
470469 self ._transaction_id : Optional [str ] = transaction_id
471- self ._client : boto3 .session .Session .client = client or boto3 .client ('rds-data' , ** client_kwargs )
470+ self ._client : boto3 .session .Session .client = client or boto3 .client (
471+ "rds-data" , ** client_kwargs
472+ )
472473 self ._transaction_status : Optional [str ] = None
473474 self .rollback_exception : Optional [Type [Exception ]] = rollback_exception
474475
475- def __enter__ (self ) -> ' DataAPI' :
476+ def __enter__ (self ) -> " DataAPI" :
476477 self .begin ()
477478 return self
478479
@@ -512,9 +513,9 @@ def begin(
512513 )
513514
514515 response : Dict [str , str ] = self .client .begin_transaction (** options .build ())
515- self ._transaction_id = response [' transactionId' ]
516+ self ._transaction_id = response [" transactionId" ]
516517
517- return response [' transactionId' ]
518+ return response [" transactionId" ]
518519
519520 def commit (self , transaction_id : Optional [str ] = None ) -> str :
520521
@@ -525,7 +526,7 @@ def commit(self, transaction_id: Optional[str] = None) -> str:
525526 )
526527
527528 response : Dict [str , str ] = self .client .commit_transaction (** options .build ())
528- self ._transaction_status = response [' transactionStatus' ]
529+ self ._transaction_status = response [" transactionStatus" ]
529530
530531 return self ._transaction_status
531532
@@ -538,7 +539,7 @@ def rollback(self, transaction_id: Optional[str] = None) -> str:
538539 )
539540
540541 response : Dict [str , str ] = self .client .rollback_transaction (** options .build ())
541- self ._transaction_status = response [' transactionStatus' ]
542+ self ._transaction_status = response [" transactionStatus" ]
542543
543544 return self ._transaction_status
544545
@@ -567,9 +568,9 @@ def execute(
567568
568569 if isinstance (query , (Query , Select )):
569570 process_result_value_function_list = create_process_result_value_function_list (
570- response .get (' columnMetadata' , []),
571+ response .get (" columnMetadata" , []),
571572 query ,
572- QUERY_STATEMENT_COMPILE_PARAMS [' dialect' ],
573+ QUERY_STATEMENT_COMPILE_PARAMS [" dialect" ],
573574 )
574575 return Result (response , process_result_value_function_list )
575576 return Result (response )
0 commit comments