From 8c072e959ab66da970bbeb2deb0fe967867d29c7 Mon Sep 17 00:00:00 2001 From: Sergey Vasilyev Date: Mon, 25 Sep 2023 16:58:22 +0200 Subject: [PATCH 1/6] Remove unused mixin for regexs --- data_diff/abcs/mixins.py | 6 ------ data_diff/databases/base.py | 13 ++----------- data_diff/databases/duckdb.py | 6 ------ data_diff/databases/mysql.py | 6 ------ data_diff/queries/ast_classes.py | 9 --------- 5 files changed, 2 insertions(+), 38 deletions(-) diff --git a/data_diff/abcs/mixins.py b/data_diff/abcs/mixins.py index 17f06064..d40c2f60 100644 --- a/data_diff/abcs/mixins.py +++ b/data_diff/abcs/mixins.py @@ -134,12 +134,6 @@ def list_tables(self, table_schema: str, like: Compilable = None) -> Compilable: """ -class AbstractMixin_Regex(AbstractMixin): - @abstractmethod - def test_regex(self, string: Compilable, pattern: Compilable) -> Compilable: - """Tests whether the regex pattern matches the string. Returns a bool expression.""" - - class AbstractMixin_RandomSample(AbstractMixin): @abstractmethod def random_sample_n(self, tbl: str, size: int) -> str: diff --git a/data_diff/databases/base.py b/data_diff/databases/base.py index dc43d8d7..bb892f50 100644 --- a/data_diff/databases/base.py +++ b/data_diff/databases/base.py @@ -27,7 +27,7 @@ Join, \ Param, \ Random, \ - Root, TableAlias, TableOp, TablePath, TestRegex, \ + Root, TableAlias, TableOp, TablePath, \ TimeTravel, TruncateTable, UnaryOp, WhenThen, _ResolveColumn from data_diff.abcs.database_types import ( AbstractDatabase, @@ -52,7 +52,7 @@ Boolean, JSON, ) -from data_diff.abcs.mixins import AbstractMixin_Regex, AbstractMixin_TimeTravel, Compilable +from data_diff.abcs.mixins import AbstractMixin_TimeTravel, Compilable from data_diff.abcs.mixins import ( AbstractMixin_Schema, AbstractMixin_RandomSample, @@ -225,8 +225,6 @@ def render_compilable(self, c: Compiler, elem: Compilable) -> str: return self.render_checksum(c, elem) elif isinstance(elem, Concat): return self.render_concat(c, elem) - elif isinstance(elem, TestRegex): - return self.render_testregex(c, elem) elif isinstance(elem, Func): return self.render_func(c, elem) elif isinstance(elem, WhenThen): @@ -372,13 +370,6 @@ def render_concat(self, c: Compiler, elem: Concat) -> str: def render_alias(self, c: Compiler, elem: Alias) -> str: return f"{self.compile(c, elem.expr)} AS {self.quote(elem.name)}" - def render_testregex(self, c: Compiler, elem: TestRegex) -> str: - # TODO: move this method to that mixin! raise here instead, unconditionally. - if not isinstance(self, AbstractMixin_Regex): - raise NotImplementedError(f"No regex implementation for database '{c.dialect}'") - regex = self.test_regex(elem.string, elem.pattern) - return self.compile(c, regex) - def render_count(self, c: Compiler, elem: Count) -> str: expr = self.compile(c, elem.expr) if elem.expr else "*" if elem.distinct: diff --git a/data_diff/databases/duckdb.py b/data_diff/databases/duckdb.py index f7fdaadd..03234edb 100644 --- a/data_diff/databases/duckdb.py +++ b/data_diff/databases/duckdb.py @@ -20,7 +20,6 @@ AbstractMixin_MD5, AbstractMixin_NormalizeValue, AbstractMixin_RandomSample, - AbstractMixin_Regex, ) from data_diff.databases.base import ( Database, @@ -70,11 +69,6 @@ def random_sample_ratio_approx(self, tbl: AbstractTable, ratio: float) -> Abstra return code("SELECT * FROM ({tbl}) USING SAMPLE {percent}%;", tbl=tbl, percent=int(100 * ratio)) -class Mixin_Regex(AbstractMixin_Regex): - def test_regex(self, string: Compilable, pattern: Compilable) -> Compilable: - return Func("regexp_matches", [string, pattern]) - - class Dialect(BaseDialect, Mixin_Schema, Mixin_MD5, Mixin_NormalizeValue, AbstractMixin_MD5, AbstractMixin_NormalizeValue): name = "DuckDB" ROUNDS_ON_PREC_LOSS = False diff --git a/data_diff/databases/mysql.py b/data_diff/databases/mysql.py index 910ff78d..d6dcba9e 100644 --- a/data_diff/databases/mysql.py +++ b/data_diff/databases/mysql.py @@ -14,7 +14,6 @@ from data_diff.abcs.mixins import ( AbstractMixin_MD5, AbstractMixin_NormalizeValue, - AbstractMixin_Regex, ) from data_diff.databases.base import ( Mixin_OptimizerHints, @@ -61,11 +60,6 @@ def normalize_uuid(self, value: str, coltype: ColType_UUID) -> str: return f"TRIM(CAST({value} AS char))" -class Mixin_Regex(AbstractMixin_Regex): - def test_regex(self, string: Compilable, pattern: Compilable) -> Compilable: - return BinBoolOp("REGEXP", [string, pattern]) - - class Dialect(BaseDialect, Mixin_Schema, Mixin_OptimizerHints, Mixin_MD5, Mixin_NormalizeValue, AbstractMixin_MD5, AbstractMixin_NormalizeValue): name = "MySQL" ROUNDS_ON_PREC_LOSS = True diff --git a/data_diff/queries/ast_classes.py b/data_diff/queries/ast_classes.py index 0013fef7..93edfebf 100644 --- a/data_diff/queries/ast_classes.py +++ b/data_diff/queries/ast_classes.py @@ -218,9 +218,6 @@ def is_distinct_from(self, other): def like(self, other): return BinBoolOp("LIKE", [self, other]) - def test_regex(self, other): - return TestRegex(self, other) - def sum(self): return Func("SUM", [self]) @@ -231,12 +228,6 @@ def min(self): return Func("MIN", [self]) -@dataclass -class TestRegex(ExprNode, LazyOps): - string: Expr - pattern: Expr - - @dataclass(eq=False) class Func(ExprNode, LazyOps): name: str From e072b35c1fa3eb24f350bc3c9ef958178aeb106b Mon Sep 17 00:00:00 2001 From: Sergey Vasilyev Date: Mon, 25 Sep 2023 17:10:39 +0200 Subject: [PATCH 2/6] Remove unneeded "bound" nodes & tables --- data_diff/bound_exprs.py | 94 ------------------------------------- data_diff/databases/base.py | 10 ---- tests/test_database.py | 7 ++- 3 files changed, 6 insertions(+), 105 deletions(-) delete mode 100644 data_diff/bound_exprs.py diff --git a/data_diff/bound_exprs.py b/data_diff/bound_exprs.py deleted file mode 100644 index 4b53846d..00000000 --- a/data_diff/bound_exprs.py +++ /dev/null @@ -1,94 +0,0 @@ -"""Expressions bound to a specific database""" - -import inspect -from functools import wraps -from typing import Union, TYPE_CHECKING - -from runtype import dataclass -from typing_extensions import Self - -from data_diff.abcs.database_types import AbstractDatabase -from data_diff.queries.ast_classes import ExprNode, TablePath, Compilable -from data_diff.queries.api import table -from data_diff.schema import create_schema - - -@dataclass -class BoundNode(ExprNode): - database: AbstractDatabase - node: Compilable - - def __getattr__(self, attr): - value = getattr(self.node, attr) - if inspect.ismethod(value): - - @wraps(value) - def bound_method(*args, **kw): - return BoundNode(self.database, value(*args, **kw)) - - return bound_method - return value - - def query(self, res_type=list): - return self.database.query(self.node, res_type=res_type) - - @property - def type(self): - return self.node.type - - -def bind_node(node, database): - return BoundNode(database, node) - - -ExprNode.bind = bind_node - - -@dataclass -class BoundTable(BoundNode): # ITable - database: AbstractDatabase - node: TablePath - - def with_schema(self, schema) -> Self: - table_path = self.node.replace(schema=schema) - return self.replace(node=table_path) - - def query_schema(self, *, columns=None, where=None, case_sensitive=True) -> Self: - table_path = self.node - - if table_path.schema: - return self - - raw_schema = self.database.query_table_schema(table_path.path) - schema = self.database._process_table_schema(table_path.path, raw_schema, columns, where) - schema = create_schema(self.database, table_path, schema, case_sensitive) - return self.with_schema(schema) - - @property - def schema(self): - return self.node.schema - - -def bound_table(database: AbstractDatabase, table_path: Union[TablePath, str, tuple], **kw): - return BoundTable(database, table(table_path, **kw)) - - -# Database.table = bound_table - -# def test(): -# from data_diff import connect -# from data_diff.queries.api import table -# d = connect("mysql://erez:qweqwe123@localhost/erez") -# t = table(('Rating',)) - -# b = BoundTable(d, t) -# b2 = b.with_schema() - -# breakpoint() - -# test() - -if TYPE_CHECKING: - - class BoundTable(BoundTable, TablePath): - pass diff --git a/data_diff/databases/base.py b/data_diff/databases/base.py index bb892f50..9107bbf8 100644 --- a/data_diff/databases/base.py +++ b/data_diff/databases/base.py @@ -59,7 +59,6 @@ AbstractMixin_NormalizeValue, AbstractMixin_OptimizerHints, ) -from data_diff.bound_exprs import BoundNode, bound_table logger = logging.getLogger("database") cv_params = contextvars.ContextVar("params") @@ -275,8 +274,6 @@ def render_compilable(self, c: Compiler, elem: Compilable) -> str: return self.render_inserttotable(c, elem) elif isinstance(elem, Code): return self.render_code(c, elem) - elif isinstance(elem, BoundNode): - return self.render_boundnode(c, elem) elif isinstance(elem, _ResolveColumn): return self.render__resolvecolumn(c, elem) @@ -424,10 +421,6 @@ def render_tableop(self, parent_c: Compiler, elem: TableOp) -> str: table_expr = f"({table_expr})" return table_expr - def render_boundnode(self, c: Compiler, elem: BoundNode) -> str: - assert self is elem.database.dialect - return self.compile(c, elem.node) - def render__resolvecolumn(self, c: Compiler, elem: _ResolveColumn) -> str: return self.compile(c, elem._get_resolved()) @@ -970,9 +963,6 @@ def close(self): def list_tables(self, tables_like, schema=None): return self.query(self.dialect.list_tables(schema or self.default_schema, tables_like)) - def table(self, *path, **kw): - return bound_table(self, path, **kw) - class ThreadedDatabase(Database): """Access the database through singleton threads. diff --git a/tests/test_database.py b/tests/test_database.py index 4f4c8ce1..f5998609 100644 --- a/tests/test_database.py +++ b/tests/test_database.py @@ -8,6 +8,7 @@ from data_diff import databases as dbs from data_diff.queries.api import table, current_timestamp from data_diff.queries.extras import NormalizeAsString +from data_diff.schema import create_schema from tests.common import TEST_MYSQL_CONN_STRING, test_each_database_in_list, get_conn, str_to_checksum, random_table_suffix from data_diff.abcs.database_types import TimestampTZ @@ -123,7 +124,11 @@ def test_correct_timezone(self): db.query(table(name).insert_row(1, now, now)) db.query(db.dialect.set_timezone_to_utc()) - t = db.table(name).query_schema() + t = table(name) + raw_schema = db.query_table_schema(t.path) + schema = db._process_table_schema(t.path, raw_schema) + schema = create_schema(self.database, t, schema, case_sensitive=True) + t = t.replace(schema=schema) t.schema["created_at"] = t.schema["created_at"].replace(precision=t.schema["created_at"].precision) tbl = table(name, schema=t.schema) From 6db9f663d0f215e1f545d3a073ca98dacd173ff3 Mon Sep 17 00:00:00 2001 From: Sergey Vasilyev Date: Mon, 25 Sep 2023 17:27:44 +0200 Subject: [PATCH 3/6] Move the compiler closer to dialects & databases that it uses & is used in --- data_diff/databases/base.py | 64 +++++++++++++++++++++++++++++++- data_diff/queries/ast_classes.py | 1 - data_diff/queries/compiler.py | 57 ---------------------------- tests/test_query.py | 2 +- 4 files changed, 63 insertions(+), 61 deletions(-) delete mode 100644 data_diff/queries/compiler.py diff --git a/data_diff/databases/base.py b/data_diff/databases/base.py index 9107bbf8..4094d2f8 100644 --- a/data_diff/databases/base.py +++ b/data_diff/databases/base.py @@ -1,4 +1,5 @@ import functools +from dataclasses import field from datetime import datetime import math import sys @@ -15,10 +16,10 @@ from runtype import dataclass from typing_extensions import Self -from data_diff.queries.compiler import CompileError +from data_diff.abcs.compiler import AbstractCompiler from data_diff.queries.extras import ApplyFuncAndNormalizeAsString, Checksum, NormalizeAsString from data_diff.utils import ArithString, is_uuid, join_iter, safezip -from data_diff.queries.api import Expr, Compiler, table, Select, SKIP, Explain, Code, this +from data_diff.queries.api import Expr, table, Select, SKIP, Explain, Code, this from data_diff.queries.ast_classes import Alias, BinOp, CaseWhen, Cast, Column, Commit, Concat, ConstantTable, Count, \ CreateTable, Cte, \ CurrentTimestamp, DropTable, Func, \ @@ -64,6 +65,65 @@ cv_params = contextvars.ContextVar("params") +class CompileError(Exception): + pass + + +# TODO: LATER: Resolve the circular imports of databases-compiler-dialects: +# A database uses a compiler to render the SQL query. +# The compiler delegates to a dialect. +# The dialect renders the SQL. +# AS IS: The dialect requires the db to normalize table paths — leading to the back-dependency. +# TO BE: All the tables paths must be pre-normalized before SQL rendering. +# Also: c.database.is_autocommit in render_commit(). +# After this, the Compiler can cease referring Database/Dialect at all, +# and be used only as a CompilingContext (a counter/data-bearing class). +# As a result, it becomes low-level util, and the circular dependency auto-resolves. +# Meanwhile, the easy fix is to simply move the Compiler here. +@dataclass +class Compiler(AbstractCompiler): + """ + Compiler bears the context for a single compilation. + + There can be multiple compilation per app run. + There can be multiple compilers in one compilation (with varying contexts). + """ + + # Database is needed to normalize tables. Dialect is needed for recursive compilations. + # In theory, it is many-to-many relations: e.g. a generic ODBC driver with multiple dialects. + # In practice, we currently bind the dialects to the specific database classes. + database: AbstractDatabase + + in_select: bool = False # Compilation runtime flag + in_join: bool = False # Compilation runtime flag + + _table_context: List = field(default_factory=list) # List[ITable] + _subqueries: Dict[str, Any] = field(default_factory=dict) # XXX not thread-safe + root: bool = True + + _counter: List = field(default_factory=lambda: [0]) + + @property + def dialect(self) -> AbstractDialect: + return self.database.dialect + + # TODO: DEPRECATED: Remove once the dialect is used directly in all places. + def compile(self, elem, params=None) -> str: + return self.dialect.compile(self, elem, params) + + def new_unique_name(self, prefix="tmp"): + self._counter[0] += 1 + return f"{prefix}{self._counter[0]}" + + def new_unique_table_name(self, prefix="tmp") -> DbPath: + self._counter[0] += 1 + table_name = f"{prefix}{self._counter[0]}_{'%x'%random.randrange(2**32)}" + return self.database.dialect.parse_table_name(table_name) + + def add_table_context(self, *tables: Sequence, **kw) -> Self: + return self.replace(_table_context=self._table_context + list(tables), **kw) + + def parse_table_name(t): return tuple(t.split(".")) diff --git a/data_diff/queries/ast_classes.py b/data_diff/queries/ast_classes.py index 93edfebf..54069e00 100644 --- a/data_diff/queries/ast_classes.py +++ b/data_diff/queries/ast_classes.py @@ -10,7 +10,6 @@ from data_diff.abcs.database_types import AbstractTable from data_diff.schema import Schema -from data_diff.queries.compiler import Compiler from data_diff.queries.base import SKIP, args_as_tuple, SqeletonError from data_diff.abcs.database_types import DbPath diff --git a/data_diff/queries/compiler.py b/data_diff/queries/compiler.py deleted file mode 100644 index 224ad636..00000000 --- a/data_diff/queries/compiler.py +++ /dev/null @@ -1,57 +0,0 @@ -import random -from dataclasses import field -from typing import Any, Dict, Sequence, List - -from runtype import dataclass -from typing_extensions import Self - -from data_diff.abcs.database_types import AbstractDatabase, AbstractDialect, DbPath -from data_diff.abcs.compiler import AbstractCompiler - - -class CompileError(Exception): - pass - - -@dataclass -class Compiler(AbstractCompiler): - """ - Compiler bears the context for a single compilation. - - There can be multiple compilation per app run. - There can be multiple compilers in one compilation (with varying contexts). - """ - - # Database is needed to normalize tables. Dialect is needed for recursive compilations. - # In theory, it is many-to-many relations: e.g. a generic ODBC driver with multiple dialects. - # In practice, we currently bind the dialects to the specific database classes. - database: AbstractDatabase - - in_select: bool = False # Compilation runtime flag - in_join: bool = False # Compilation runtime flag - - _table_context: List = field(default_factory=list) # List[ITable] - _subqueries: Dict[str, Any] = field(default_factory=dict) # XXX not thread-safe - root: bool = True - - _counter: List = field(default_factory=lambda: [0]) - - @property - def dialect(self) -> AbstractDialect: - return self.database.dialect - - # TODO: DEPRECATED: Remove once the dialect is used directly in all places. - def compile(self, elem, params=None) -> str: - return self.dialect.compile(self, elem, params) - - def new_unique_name(self, prefix="tmp"): - self._counter[0] += 1 - return f"{prefix}{self._counter[0]}" - - def new_unique_table_name(self, prefix="tmp") -> DbPath: - self._counter[0] += 1 - table_name = f"{prefix}{self._counter[0]}_{'%x'%random.randrange(2**32)}" - return self.database.dialect.parse_table_name(table_name) - - def add_table_context(self, *tables: Sequence, **kw) -> Self: - return self.replace(_table_context=self._table_context + list(tables), **kw) diff --git a/tests/test_query.py b/tests/test_query.py index cc11b533..bd731cfb 100644 --- a/tests/test_query.py +++ b/tests/test_query.py @@ -4,7 +4,7 @@ from data_diff.abcs.database_types import AbstractDatabase, AbstractDialect from data_diff.utils import CaseInsensitiveDict, CaseSensitiveDict -from data_diff.queries.compiler import Compiler, CompileError +from data_diff.databases.base import Compiler, CompileError from data_diff.queries.api import outerjoin, cte, when, coalesce from data_diff.queries.ast_classes import Random from data_diff.queries.api import code, this, table From f213b26ad8c7fbc930ff631ba1e440453f6c71e4 Mon Sep 17 00:00:00 2001 From: Sergey Vasilyev Date: Mon, 25 Sep 2023 17:37:13 +0200 Subject: [PATCH 4/6] Squash abstract database into simply base database --- data_diff/abcs/database_types.py | 70 -------------------------------- data_diff/databases/base.py | 47 +++++++++++++++++++-- data_diff/schema.py | 5 ++- 3 files changed, 47 insertions(+), 75 deletions(-) diff --git a/data_diff/abcs/database_types.py b/data_diff/abcs/database_types.py index a679db67..9ea4c5d7 100644 --- a/data_diff/abcs/database_types.py +++ b/data_diff/abcs/database_types.py @@ -268,76 +268,6 @@ def to_comparable(self, value: str, coltype: ColType) -> str: """Ensure that the expression is comparable in ``IS DISTINCT FROM``.""" -from typing import TypeVar, Generic - -T_Dialect = TypeVar("T_Dialect", bound=AbstractDialect) - - -class AbstractDatabase(Generic[T_Dialect]): - @property - @abstractmethod - def dialect(self) -> T_Dialect: - "The dialect of the database. Used internally by Database, and also available publicly." - - @classmethod - @abstractmethod - def load_mixins(cls, *abstract_mixins) -> type: - "Extend the dialect with a list of mixins that implement the given abstract mixins." - - @property - @abstractmethod - def CONNECT_URI_HELP(self) -> str: - "Example URI to show the user in help and error messages" - - @property - @abstractmethod - def CONNECT_URI_PARAMS(self) -> List[str]: - "List of parameters given in the path of the URI" - - @abstractmethod - def _query(self, sql_code: str) -> list: - "Send query to database and return result" - - @abstractmethod - def query_table_schema(self, path: DbPath) -> Dict[str, tuple]: - """Query the table for its schema for table in 'path', and return {column: tuple} - where the tuple is (table_name, col_name, type_repr, datetime_precision?, numeric_precision?, numeric_scale?) - - Note: This method exists instead of select_table_schema(), just because not all databases support - accessing the schema using a SQL query. - """ - - @abstractmethod - def select_table_unique_columns(self, path: DbPath) -> str: - "Provide SQL for selecting the names of unique columns in the table" - - @abstractmethod - def query_table_unique_columns(self, path: DbPath) -> List[str]: - """Query the table for its unique columns for table in 'path', and return {column}""" - - @abstractmethod - def _process_table_schema( - self, path: DbPath, raw_schema: Dict[str, tuple], filter_columns: Sequence[str], where: str = None - ): - """Process the result of query_table_schema(). - - Done in a separate step, to minimize the amount of processed columns. - Needed because processing each column may: - * throw errors and warnings - * query the database to sample values - - """ - - @abstractmethod - def close(self): - "Close connection(s) to the database instance. Querying will stop functioning." - - @property - @abstractmethod - def is_autocommit(self) -> bool: - "Return whether the database autocommits changes. When false, COMMIT statements are skipped." - - class AbstractTable(ABC): @abstractmethod def select(self, *exprs, distinct=False, **named_exprs) -> "AbstractTable": diff --git a/data_diff/databases/base.py b/data_diff/databases/base.py index 4094d2f8..d093fba2 100644 --- a/data_diff/databases/base.py +++ b/data_diff/databases/base.py @@ -1,3 +1,4 @@ +import abc import functools from dataclasses import field from datetime import datetime @@ -31,7 +32,6 @@ Root, TableAlias, TableOp, TablePath, \ TimeTravel, TruncateTable, UnaryOp, WhenThen, _ResolveColumn from data_diff.abcs.database_types import ( - AbstractDatabase, Array, Struct, AbstractDialect, @@ -92,7 +92,7 @@ class Compiler(AbstractCompiler): # Database is needed to normalize tables. Dialect is needed for recursive compilations. # In theory, it is many-to-many relations: e.g. a generic ODBC driver with multiple dialects. # In practice, we currently bind the dialects to the specific database classes. - database: AbstractDatabase + database: "Database" in_select: bool = False # Compilation runtime flag in_join: bool = False # Compilation runtime flag @@ -789,7 +789,7 @@ def __getitem__(self, i): return self.rows[i] -class Database(AbstractDatabase[T]): +class Database(abc.ABC): """Base abstract class for databases. Used for providing connection code and implementation specific SQL utilities. @@ -898,6 +898,12 @@ def select_table_schema(self, path: DbPath) -> str: ) def query_table_schema(self, path: DbPath) -> Dict[str, tuple]: + """Query the table for its schema for table in 'path', and return {column: tuple} + where the tuple is (table_name, col_name, type_repr, datetime_precision?, numeric_precision?, numeric_scale?) + + Note: This method exists instead of select_table_schema(), just because not all databases support + accessing the schema using a SQL query. + """ rows = self.query(self.select_table_schema(path), list) if not rows: raise RuntimeError(f"{self.name}: Table '{'.'.join(path)}' does not exist, or has no columns") @@ -907,6 +913,7 @@ def query_table_schema(self, path: DbPath) -> Dict[str, tuple]: return d def select_table_unique_columns(self, path: DbPath) -> str: + "Provide SQL for selecting the names of unique columns in the table" schema, name = self._normalize_table_path(path) return ( @@ -916,6 +923,7 @@ def select_table_unique_columns(self, path: DbPath) -> str: ) def query_table_unique_columns(self, path: DbPath) -> List[str]: + """Query the table for its unique columns for table in 'path', and return {column}""" if not self.SUPPORTS_UNIQUE_CONSTAINT: raise NotImplementedError("This database doesn't support 'unique' constraints") res = self.query(self.select_table_unique_columns(path), List[str]) @@ -924,6 +932,14 @@ def query_table_unique_columns(self, path: DbPath) -> List[str]: def _process_table_schema( self, path: DbPath, raw_schema: Dict[str, tuple], filter_columns: Sequence[str] = None, where: str = None ): + """Process the result of query_table_schema(). + + Done in a separate step, to minimize the amount of processed columns. + Needed because processing each column may: + * throw errors and warnings + * query the database to sample values + + """ if filter_columns is None: filtered_schema = raw_schema else: @@ -1017,12 +1033,37 @@ def _query_conn(self, conn, sql_code: Union[str, ThreadLocalInterpreter]) -> Que return apply_query(callback, sql_code) def close(self): + "Close connection(s) to the database instance. Querying will stop functioning." self.is_closed = True return super().close() def list_tables(self, tables_like, schema=None): return self.query(self.dialect.list_tables(schema or self.default_schema, tables_like)) + @property + @abstractmethod + def dialect(self) -> BaseDialect: + "The dialect of the database. Used internally by Database, and also available publicly." + + @property + @abstractmethod + def CONNECT_URI_HELP(self) -> str: + "Example URI to show the user in help and error messages" + + @property + @abstractmethod + def CONNECT_URI_PARAMS(self) -> List[str]: + "List of parameters given in the path of the URI" + + @abstractmethod + def _query(self, sql_code: str) -> list: + "Send query to database and return result" + + @property + @abstractmethod + def is_autocommit(self) -> bool: + "Return whether the database autocommits changes. When false, COMMIT statements are skipped." + class ThreadedDatabase(Database): """Access the database through singleton threads. diff --git a/data_diff/schema.py b/data_diff/schema.py index 847bbf23..ae0b3935 100644 --- a/data_diff/schema.py +++ b/data_diff/schema.py @@ -1,14 +1,15 @@ import logging +from data_diff import Database from data_diff.utils import CaseAwareMapping, CaseInsensitiveDict, CaseSensitiveDict -from data_diff.abcs.database_types import AbstractDatabase, DbPath +from data_diff.abcs.database_types import DbPath logger = logging.getLogger("schema") Schema = CaseAwareMapping -def create_schema(db: AbstractDatabase, table_path: DbPath, schema: dict, case_sensitive: bool) -> CaseAwareMapping: +def create_schema(db: Database, table_path: DbPath, schema: dict, case_sensitive: bool) -> CaseAwareMapping: logger.debug(f"[{db.name}] Schema = {schema}") if case_sensitive: From c1bac7f3c36a8952d93192f2646d32460d858a5e Mon Sep 17 00:00:00 2001 From: Sergey Vasilyev Date: Mon, 25 Sep 2023 17:37:46 +0200 Subject: [PATCH 5/6] Squash abstract dialect into simply base dialect --- data_diff/abcs/database_types.py | 94 -------------------------------- data_diff/databases/base.py | 42 ++++++++++++-- 2 files changed, 38 insertions(+), 98 deletions(-) diff --git a/data_diff/abcs/database_types.py b/data_diff/abcs/database_types.py index 9ea4c5d7..6fd19ca5 100644 --- a/data_diff/abcs/database_types.py +++ b/data_diff/abcs/database_types.py @@ -174,100 +174,6 @@ class UnknownColType(ColType): supported = False -class AbstractDialect(ABC): - """Dialect-dependent query expressions""" - - @abstractmethod - def compile(self, compiler: AbstractCompiler, elem, params=None) -> str: - raise NotImplementedError - - @abstractmethod - def parse_table_name(self, name: str) -> DbPath: - "Parse the given table name into a DbPath" - - @property - @abstractmethod - def name(self) -> str: - "Name of the dialect" - - @classmethod - @abstractmethod - def load_mixins(cls, *abstract_mixins) -> Self: - "Load a list of mixins that implement the given abstract mixins" - - @property - @abstractmethod - def ROUNDS_ON_PREC_LOSS(self) -> bool: - "True if db rounds real values when losing precision, False if it truncates." - - @abstractmethod - def quote(self, s: str): - "Quote SQL name" - - @abstractmethod - def concat(self, items: List[str]) -> str: - "Provide SQL for concatenating a bunch of columns into a string" - - @abstractmethod - def is_distinct_from(self, a: str, b: str) -> str: - "Provide SQL for a comparison where NULL = NULL is true" - - @abstractmethod - def to_string(self, s: str) -> str: - # TODO rewrite using cast_to(x, str) - "Provide SQL for casting a column to string" - - @abstractmethod - def random(self) -> str: - "Provide SQL for generating a random number betweein 0..1" - - @abstractmethod - def current_timestamp(self) -> str: - "Provide SQL for returning the current timestamp, aka now" - - @abstractmethod - def current_database(self) -> str: - "Provide SQL for returning the current default database." - - @abstractmethod - def current_schema(self) -> str: - "Provide SQL for returning the current default schema." - - @abstractmethod - def offset_limit( - self, offset: Optional[int] = None, limit: Optional[int] = None, has_order_by: Optional[bool] = None - ) -> str: - "Provide SQL fragment for limit and offset inside a select" - - @abstractmethod - def explain_as_text(self, query: str) -> str: - "Provide SQL for explaining a query, returned as table(varchar)" - - @abstractmethod - def timestamp_value(self, t: datetime) -> str: - "Provide SQL for the given timestamp value" - - @abstractmethod - def set_timezone_to_utc(self) -> str: - "Provide SQL for setting the session timezone to UTC" - - @abstractmethod - def parse_type( - self, - table_path: DbPath, - col_name: str, - type_repr: str, - datetime_precision: int = None, - numeric_precision: int = None, - numeric_scale: int = None, - ) -> ColType: - "Parse type info as returned by the database" - - @abstractmethod - def to_comparable(self, value: str, coltype: ColType) -> str: - """Ensure that the expression is comparable in ``IS DISTINCT FROM``.""" - - class AbstractTable(ABC): @abstractmethod def select(self, *exprs, distinct=False, **named_exprs) -> "AbstractTable": diff --git a/data_diff/databases/base.py b/data_diff/databases/base.py index d093fba2..64535094 100644 --- a/data_diff/databases/base.py +++ b/data_diff/databases/base.py @@ -34,7 +34,6 @@ from data_diff.abcs.database_types import ( Array, Struct, - AbstractDialect, AbstractTable, ColType, Integer, @@ -104,7 +103,7 @@ class Compiler(AbstractCompiler): _counter: List = field(default_factory=lambda: [0]) @property - def dialect(self) -> AbstractDialect: + def dialect(self) -> "Dialect": return self.database.dialect # TODO: DEPRECATED: Remove once the dialect is used directly in all places. @@ -221,7 +220,7 @@ def optimizer_hints(self, hints: str) -> str: return f"/*+ {hints} */ " -class BaseDialect(AbstractDialect): +class BaseDialect(abc.ABC): SUPPORTS_PRIMARY_KEY = False SUPPORTS_INDEXES = False TYPE_CLASSES: Dict[str, type] = {} @@ -230,6 +229,7 @@ class BaseDialect(AbstractDialect): PLACEHOLDER_TABLE = None # Used for Oracle def parse_table_name(self, name: str) -> DbPath: + "Parse the given table name into a DbPath" return parse_table_name(name) def compile(self, compiler: Compiler, elem, params=None) -> str: @@ -638,12 +638,14 @@ def render_inserttotable(self, c: Compiler, elem: InsertToTable) -> str: def offset_limit( self, offset: Optional[int] = None, limit: Optional[int] = None, has_order_by: Optional[bool] = None ) -> str: + "Provide SQL fragment for limit and offset inside a select" if offset: raise NotImplementedError("No support for OFFSET in query") return f"LIMIT {limit}" def concat(self, items: List[str]) -> str: + "Provide SQL for concatenating a bunch of columns into a string" assert len(items) > 1 joined_exprs = ", ".join(items) return f"concat({joined_exprs})" @@ -653,24 +655,31 @@ def to_comparable(self, value: str, coltype: ColType) -> str: return value def is_distinct_from(self, a: str, b: str) -> str: + "Provide SQL for a comparison where NULL = NULL is true" return f"{a} is distinct from {b}" def timestamp_value(self, t: DbTime) -> str: + "Provide SQL for the given timestamp value" return f"'{t.isoformat()}'" def random(self) -> str: + "Provide SQL for generating a random number betweein 0..1" return "random()" def current_timestamp(self) -> str: + "Provide SQL for returning the current timestamp, aka now" return "current_timestamp()" def current_database(self) -> str: + "Provide SQL for returning the current default database." return "current_database()" def current_schema(self) -> str: + "Provide SQL for returning the current default schema." return "current_schema()" def explain_as_text(self, query: str) -> str: + "Provide SQL for explaining a query, returned as table(varchar)" return f"EXPLAIN {query}" def _constant_value(self, v): @@ -719,7 +728,7 @@ def parse_type( numeric_precision: int = None, numeric_scale: int = None, ) -> ColType: - """ """ + "Parse type info as returned by the database" cls = self._parse_type_repr(type_repr) if cls is None: @@ -762,6 +771,7 @@ def _convert_db_precision_to_digits(self, p: int) -> int: @classmethod def load_mixins(cls, *abstract_mixins) -> Self: + "Load a list of mixins that implement the given abstract mixins" mixins = {m for m in cls.MIXINS if issubclass(m, abstract_mixins)} class _DialectWithMixins(cls, *mixins, *abstract_mixins): @@ -771,6 +781,30 @@ class _DialectWithMixins(cls, *mixins, *abstract_mixins): return _DialectWithMixins() + @property + @abstractmethod + def name(self) -> str: + "Name of the dialect" + + @property + @abstractmethod + def ROUNDS_ON_PREC_LOSS(self) -> bool: + "True if db rounds real values when losing precision, False if it truncates." + + @abstractmethod + def quote(self, s: str): + "Quote SQL name" + + @abstractmethod + def to_string(self, s: str) -> str: + # TODO rewrite using cast_to(x, str) + "Provide SQL for casting a column to string" + + @abstractmethod + def set_timezone_to_utc(self) -> str: + "Provide SQL for setting the session timezone to UTC" + + T = TypeVar("T", bound=BaseDialect) From 7aec378e89bfb60a6bfed25db6758f3afe0cac43 Mon Sep 17 00:00:00 2001 From: Sergey Vasilyev Date: Mon, 25 Sep 2023 17:42:30 +0200 Subject: [PATCH 6/6] Squash abstract table into already existent ITable --- data_diff/abcs/database_types.py | 92 +------------------------------- data_diff/abcs/mixins.py | 2 +- data_diff/databases/base.py | 7 ++- data_diff/databases/duckdb.py | 7 ++- data_diff/queries/ast_classes.py | 53 +++++++++++++++--- 5 files changed, 55 insertions(+), 106 deletions(-) diff --git a/data_diff/abcs/database_types.py b/data_diff/abcs/database_types.py index 6fd19ca5..43764b39 100644 --- a/data_diff/abcs/database_types.py +++ b/data_diff/abcs/database_types.py @@ -1,12 +1,10 @@ import decimal from abc import ABC, abstractmethod -from typing import Sequence, Optional, Tuple, Union, Dict, List +from typing import Tuple, Union from datetime import datetime from runtype import dataclass -from typing_extensions import Self -from data_diff.abcs.compiler import AbstractCompiler from data_diff.utils import ArithAlphanumeric, ArithUUID, Unknown @@ -172,91 +170,3 @@ class UnknownColType(ColType): text: str supported = False - - -class AbstractTable(ABC): - @abstractmethod - def select(self, *exprs, distinct=False, **named_exprs) -> "AbstractTable": - """Choose new columns, based on the old ones. (aka Projection) - - Parameters: - exprs: List of expressions to constitute the columns of the new table. - If not provided, returns all columns in source table (i.e. ``select *``) - distinct: 'select' or 'select distinct' - named_exprs: More expressions to constitute the columns of the new table, aliased to keyword name. - - """ - # XXX distinct=SKIP - - @abstractmethod - def where(self, *exprs) -> "AbstractTable": - """Filter the rows, based on the given predicates. (aka Selection)""" - - @abstractmethod - def order_by(self, *exprs) -> "AbstractTable": - """Order the rows lexicographically, according to the given expressions.""" - - @abstractmethod - def limit(self, limit: int) -> "AbstractTable": - """Stop yielding rows after the given limit. i.e. take the first 'n=limit' rows""" - - @abstractmethod - def join(self, target) -> "AbstractTable": - """Join the current table with the target table, returning a new table containing both side-by-side. - - When joining, it's recommended to use explicit tables names, instead of `this`, in order to avoid potential name collisions. - - Example: - :: - - person = table('person') - city = table('city') - - name_and_city = ( - person - .join(city) - .on(person['city_id'] == city['id']) - .select(person['id'], city['name']) - ) - """ - - @abstractmethod - def group_by(self, *keys): - """Behaves like in SQL, except for a small change in syntax: - - A call to `.agg()` must follow every call to `.group_by()`. - - Example: - :: - - # SELECT a, sum(b) FROM tmp GROUP BY 1 - table('tmp').group_by(this.a).agg(this.b.sum()) - - # SELECT a, sum(b) FROM a GROUP BY 1 HAVING (b > 10) - (table('tmp') - .group_by(this.a) - .agg(this.b.sum()) - .having(this.b > 10) - ) - - """ - - @abstractmethod - def count(self) -> int: - """SELECT count() FROM self""" - - @abstractmethod - def union(self, other: "ITable"): - """SELECT * FROM self UNION other""" - - @abstractmethod - def union_all(self, other: "ITable"): - """SELECT * FROM self UNION ALL other""" - - @abstractmethod - def minus(self, other: "ITable"): - """SELECT * FROM self EXCEPT other""" - - @abstractmethod - def intersect(self, other: "ITable"): - """SELECT * FROM self INTERSECT other""" diff --git a/data_diff/abcs/mixins.py b/data_diff/abcs/mixins.py index d40c2f60..9a30f41e 100644 --- a/data_diff/abcs/mixins.py +++ b/data_diff/abcs/mixins.py @@ -146,7 +146,7 @@ def random_sample_ratio_approx(self, tbl: str, ratio: float) -> str: i.e. the actual mount of rows returned may vary by standard deviation. """ - # def random_sample_ratio(self, table: AbstractTable, ratio: float): + # def random_sample_ratio(self, table: ITable, ratio: float): # """Take a random sample of the size determined by the ratio (0..1), where 0 means no rows, and 1 means all rows # """ diff --git a/data_diff/databases/base.py b/data_diff/databases/base.py index 64535094..d55c59a6 100644 --- a/data_diff/databases/base.py +++ b/data_diff/databases/base.py @@ -25,7 +25,7 @@ CreateTable, Cte, \ CurrentTimestamp, DropTable, Func, \ GroupBy, \ - In, InsertToTable, IsDistinctFrom, \ + ITable, In, InsertToTable, IsDistinctFrom, \ Join, \ Param, \ Random, \ @@ -34,7 +34,6 @@ from data_diff.abcs.database_types import ( Array, Struct, - AbstractTable, ColType, Integer, Decimal, @@ -207,11 +206,11 @@ def list_tables(self, table_schema: str, like: Compilable = None) -> Compilable: class Mixin_RandomSample(AbstractMixin_RandomSample): - def random_sample_n(self, tbl: AbstractTable, size: int) -> AbstractTable: + def random_sample_n(self, tbl: ITable, size: int) -> ITable: # TODO use a more efficient algorithm, when the table count is known return tbl.order_by(Random()).limit(size) - def random_sample_ratio_approx(self, tbl: AbstractTable, ratio: float) -> AbstractTable: + def random_sample_ratio_approx(self, tbl: ITable, ratio: float) -> ITable: return tbl.where(Random() < ratio) diff --git a/data_diff/databases/duckdb.py b/data_diff/databases/duckdb.py index 03234edb..ba6afd63 100644 --- a/data_diff/databases/duckdb.py +++ b/data_diff/databases/duckdb.py @@ -14,7 +14,6 @@ Text, FractionalType, Boolean, - AbstractTable, ) from data_diff.abcs.mixins import ( AbstractMixin_MD5, @@ -30,7 +29,7 @@ TIMESTAMP_PRECISION_POS, ) from data_diff.databases.base import MD5_HEXDIGITS, CHECKSUM_HEXDIGITS, Mixin_Schema -from data_diff.queries.ast_classes import Func, Compilable +from data_diff.queries.ast_classes import Func, Compilable, ITable from data_diff.queries.api import code @@ -62,10 +61,10 @@ def normalize_boolean(self, value: str, _coltype: Boolean) -> str: class Mixin_RandomSample(AbstractMixin_RandomSample): - def random_sample_n(self, tbl: AbstractTable, size: int) -> AbstractTable: + def random_sample_n(self, tbl: ITable, size: int) -> ITable: return code("SELECT * FROM ({tbl}) USING SAMPLE {size};", tbl=tbl, size=size) - def random_sample_ratio_approx(self, tbl: AbstractTable, ratio: float) -> AbstractTable: + def random_sample_ratio_approx(self, tbl: ITable, ratio: float) -> ITable: return code("SELECT * FROM ({tbl}) USING SAMPLE {percent}%;", tbl=tbl, percent=int(100 * ratio)) diff --git a/data_diff/queries/ast_classes.py b/data_diff/queries/ast_classes.py index 54069e00..56efdb20 100644 --- a/data_diff/queries/ast_classes.py +++ b/data_diff/queries/ast_classes.py @@ -7,7 +7,6 @@ from data_diff.utils import ArithString from data_diff.abcs.compiler import Compilable -from data_diff.abcs.database_types import AbstractTable from data_diff.schema import Schema from data_diff.queries.base import SKIP, args_as_tuple, SqeletonError @@ -81,12 +80,20 @@ def _drop_skips_dict(exprs_dict): return {k: v for k, v in exprs_dict.items() if v is not SKIP} -class ITable(AbstractTable): +class ITable: source_table: Any schema: Schema = None def select(self, *exprs, distinct=SKIP, optimizer_hints=SKIP, **named_exprs) -> "ITable": - """Create a new table with the specified fields""" + """Choose new columns, based on the old ones. (aka Projection) + + Parameters: + exprs: List of expressions to constitute the columns of the new table. + If not provided, returns all columns in source table (i.e. ``select *``) + distinct: 'select' or 'select distinct' + named_exprs: More expressions to constitute the columns of the new table, aliased to keyword name. + + """ exprs = args_as_tuple(exprs) exprs = _drop_skips(exprs) named_exprs = _drop_skips_dict(named_exprs) @@ -95,6 +102,7 @@ def select(self, *exprs, distinct=SKIP, optimizer_hints=SKIP, **named_exprs) -> return Select.make(self, columns=exprs, distinct=distinct, optimizer_hints=optimizer_hints) def where(self, *exprs): + """Filter the rows, based on the given predicates. (aka Selection)""" exprs = args_as_tuple(exprs) exprs = _drop_skips(exprs) if not exprs: @@ -104,6 +112,7 @@ def where(self, *exprs): return Select.make(self, where_exprs=exprs) def order_by(self, *exprs): + """Order the rows lexicographically, according to the given expressions.""" exprs = _drop_skips(exprs) if not exprs: return self @@ -112,19 +121,50 @@ def order_by(self, *exprs): return Select.make(self, order_by_exprs=exprs) def limit(self, limit: int): + """Stop yielding rows after the given limit. i.e. take the first 'n=limit' rows""" if limit is SKIP: return self return Select.make(self, limit_expr=limit) def join(self, target: "ITable"): - """Join this table with the target table.""" + """Join the current table with the target table, returning a new table containing both side-by-side. + + When joining, it's recommended to use explicit tables names, instead of `this`, in order to avoid potential name collisions. + + Example: + :: + + person = table('person') + city = table('city') + + name_and_city = ( + person + .join(city) + .on(person['city_id'] == city['id']) + .select(person['id'], city['name']) + ) + """ return Join([self, target]) def group_by(self, *keys) -> "GroupBy": - """Group according to the given keys. + """Behaves like in SQL, except for a small change in syntax: + + A call to `.agg()` must follow every call to `.group_by()`. + + Example: + :: + + # SELECT a, sum(b) FROM tmp GROUP BY 1 + table('tmp').group_by(this.a).agg(this.b.sum()) + + # SELECT a, sum(b) FROM a GROUP BY 1 HAVING (b > 10) + (table('tmp') + .group_by(this.a) + .agg(this.b.sum()) + .having(this.b > 10) + ) - Must be followed by a call to :ref:``GroupBy.agg()`` """ keys = _drop_skips(keys) resolve_names(self.source_table, keys) @@ -145,6 +185,7 @@ def __getitem__(self, column): return self._get_column(column) def count(self): + """SELECT count() FROM self""" return Select(self, [Count()]) def union(self, other: "ITable"):