11import abc
22import functools
3- from dataclasses import field
3+ import random
44from datetime import datetime
55import math
66import sys
1414import decimal
1515import contextvars
1616
17- from runtype import dataclass
17+ import attrs
1818from typing_extensions import Self
1919
2020from data_diff .abcs .compiler import AbstractCompiler
@@ -90,12 +90,7 @@ class CompileError(Exception):
9090 pass
9191
9292
93- # TODO: remove once switched to attrs, where ForwardRef[]/strings are resolved.
94- class _RuntypeHackToFixCicularRefrencedDatabase :
95- dialect : "BaseDialect"
96-
97-
98- @dataclass
93+ @attrs .define (frozen = True )
9994class Compiler (AbstractCompiler ):
10095 """
10196 Compiler bears the context for a single compilation.
@@ -107,16 +102,16 @@ class Compiler(AbstractCompiler):
107102 # Database is needed to normalize tables. Dialect is needed for recursive compilations.
108103 # In theory, it is many-to-many relations: e.g. a generic ODBC driver with multiple dialects.
109104 # In practice, we currently bind the dialects to the specific database classes.
110- database : _RuntypeHackToFixCicularRefrencedDatabase
105+ database : "Database"
111106
112107 in_select : bool = False # Compilation runtime flag
113108 in_join : bool = False # Compilation runtime flag
114109
115- _table_context : List = field (default_factory = list ) # List[ITable]
116- _subqueries : Dict [str , Any ] = field (default_factory = dict ) # XXX not thread-safe
110+ _table_context : List = attrs . field (factory = list ) # List[ITable]
111+ _subqueries : Dict [str , Any ] = attrs . field (factory = dict ) # XXX not thread-safe
117112 root : bool = True
118113
119- _counter : List = field (default_factory = lambda : [0 ])
114+ _counter : List = attrs . field (factory = lambda : [0 ])
120115
121116 @property
122117 def dialect (self ) -> "BaseDialect" :
@@ -136,7 +131,7 @@ def new_unique_table_name(self, prefix="tmp") -> DbPath:
136131 return self .database .dialect .parse_table_name (table_name )
137132
138133 def add_table_context (self , * tables : Sequence , ** kw ) -> Self :
139- return self . replace ( _table_context = self ._table_context + list (tables ), ** kw )
134+ return attrs . evolve ( self , table_context = self ._table_context + list (tables ), ** kw )
140135
141136
142137def parse_table_name (t ):
@@ -272,7 +267,7 @@ def _compile(self, compiler: Compiler, elem) -> str:
272267 if elem is None :
273268 return "NULL"
274269 elif isinstance (elem , Compilable ):
275- return self .render_compilable (compiler . replace ( root = False ), elem )
270+ return self .render_compilable (attrs . evolve ( compiler , root = False ), elem )
276271 elif isinstance (elem , str ):
277272 return f"'{ elem } '"
278273 elif isinstance (elem , (int , float )):
@@ -382,7 +377,7 @@ def render_column(self, c: Compiler, elem: Column) -> str:
382377 return self .quote (elem .name )
383378
384379 def render_cte (self , parent_c : Compiler , elem : Cte ) -> str :
385- c : Compiler = parent_c . replace ( _table_context = [], in_select = False )
380+ c : Compiler = attrs . evolve ( parent_c , table_context = [], in_select = False )
386381 compiled = self .compile (c , elem .source_table )
387382
388383 name = elem .name or parent_c .new_unique_name ()
@@ -495,7 +490,7 @@ def render_tablealias(self, c: Compiler, elem: TableAlias) -> str:
495490 return f"{ self .compile (c , elem .source_table )} { self .quote (elem .name )} "
496491
497492 def render_tableop (self , parent_c : Compiler , elem : TableOp ) -> str :
498- c : Compiler = parent_c . replace ( in_select = False )
493+ c : Compiler = attrs . evolve ( parent_c , in_select = False )
499494 table_expr = f"{ self .compile (c , elem .table1 )} { elem .op } { self .compile (c , elem .table2 )} "
500495 if parent_c .in_select :
501496 table_expr = f"({ table_expr } ) { c .new_unique_name ()} "
@@ -507,7 +502,7 @@ def render__resolvecolumn(self, c: Compiler, elem: _ResolveColumn) -> str:
507502 return self .compile (c , elem ._get_resolved ())
508503
509504 def render_select (self , parent_c : Compiler , elem : Select ) -> str :
510- c : Compiler = parent_c . replace ( in_select = True ) # .add_table_context(self.table)
505+ c : Compiler = attrs . evolve ( parent_c , in_select = True ) # .add_table_context(self.table)
511506 compile_fn = functools .partial (self .compile , c )
512507
513508 columns = ", " .join (map (compile_fn , elem .columns )) if elem .columns else "*"
@@ -545,7 +540,8 @@ def render_select(self, parent_c: Compiler, elem: Select) -> str:
545540
546541 def render_join (self , parent_c : Compiler , elem : Join ) -> str :
547542 tables = [
548- t if isinstance (t , TableAlias ) else TableAlias (t , parent_c .new_unique_name ()) for t in elem .source_tables
543+ t if isinstance (t , TableAlias ) else TableAlias (source_table = t , name = parent_c .new_unique_name ())
544+ for t in elem .source_tables
549545 ]
550546 c = parent_c .add_table_context (* tables , in_join = True , in_select = False )
551547 op = " JOIN " if elem .op is None else f" { elem .op } JOIN "
@@ -578,7 +574,8 @@ def render_groupby(self, c: Compiler, elem: GroupBy) -> str:
578574 if isinstance (elem .table , Select ) and elem .table .columns is None and elem .table .group_by_exprs is None :
579575 return self .compile (
580576 c ,
581- elem .table .replace (
577+ attrs .evolve (
578+ elem .table ,
582579 columns = columns ,
583580 group_by_exprs = [Code (k ) for k in keys ],
584581 having_exprs = elem .having_exprs ,
@@ -590,7 +587,7 @@ def render_groupby(self, c: Compiler, elem: GroupBy) -> str:
590587 having_str = (
591588 " HAVING " + " AND " .join (map (compile_fn , elem .having_exprs )) if elem .having_exprs is not None else ""
592589 )
593- select = f"SELECT { columns_str } FROM { self .compile (c . replace ( in_select = True ), elem .table )} GROUP BY { keys_str } { having_str } "
590+ select = f"SELECT { columns_str } FROM { self .compile (attrs . evolve ( c , in_select = True ), elem .table )} GROUP BY { keys_str } { having_str } "
594591
595592 if c .in_select :
596593 select = f"({ select } ) { c .new_unique_name ()} "
@@ -827,7 +824,7 @@ def set_timezone_to_utc(self) -> str:
827824T = TypeVar ("T" , bound = BaseDialect )
828825
829826
830- @dataclass
827+ @attrs . define ( frozen = True )
831828class QueryResult :
832829 rows : list
833830 columns : Optional [list ] = None
@@ -842,7 +839,7 @@ def __getitem__(self, i):
842839 return self .rows [i ]
843840
844841
845- class Database (abc .ABC , _RuntypeHackToFixCicularRefrencedDatabase ):
842+ class Database (abc .ABC ):
846843 """Base abstract class for databases.
847844
848845 Used for providing connection code and implementation specific SQL utilities.
0 commit comments