diff --git a/data_diff/databases/__init__.py b/data_diff/databases/__init__.py index 9b9a81ea..0be4e0c0 100644 --- a/data_diff/databases/__init__.py +++ b/data_diff/databases/__init__.py @@ -12,5 +12,6 @@ from .clickhouse import Clickhouse from .vertica import Vertica from .duckdb import DuckDB +from .mssql import MsSql from ._connect import connect diff --git a/data_diff/databases/_connect.py b/data_diff/databases/_connect.py index 6ca94246..d4293436 100644 --- a/data_diff/databases/_connect.py +++ b/data_diff/databases/_connect.py @@ -14,6 +14,7 @@ from .clickhouse import Clickhouse from .vertica import Vertica from .duckdb import DuckDB +from .mssql import MsSql DATABASE_BY_SCHEME = { @@ -29,6 +30,7 @@ "trino": Trino, "clickhouse": Clickhouse, "vertica": Vertica, + "mssql": MsSql, } diff --git a/data_diff/databases/mssql.py b/data_diff/databases/mssql.py new file mode 100644 index 00000000..61dc5307 --- /dev/null +++ b/data_diff/databases/mssql.py @@ -0,0 +1,10 @@ +from data_diff.sqeleton.databases import mssql +from .base import DatadiffDialect + + +class Dialect(mssql.Dialect, mssql.Mixin_MD5, mssql.Mixin_NormalizeValue, DatadiffDialect): + pass + + +class MsSql(mssql.MsSQL): + dialect = Dialect() diff --git a/data_diff/joindiff_tables.py b/data_diff/joindiff_tables.py index 2184276a..0220364b 100644 --- a/data_diff/joindiff_tables.py +++ b/data_diff/joindiff_tables.py @@ -10,7 +10,7 @@ from runtype import dataclass -from data_diff.sqeleton.databases import Database, MySQL, BigQuery, Presto, Oracle, Snowflake, DbPath +from data_diff.sqeleton.databases import Database, MsSQL, MySQL, BigQuery, Presto, Oracle, Snowflake, DbPath from data_diff.sqeleton.abcs import NumericType from data_diff.sqeleton.queries import ( table, @@ -25,9 +25,10 @@ leftjoin, rightjoin, this, + when, Compiler, ) -from data_diff.sqeleton.queries.ast_classes import Concat, Count, Expr, Random, TablePath, Code, ITable +from data_diff.sqeleton.queries.ast_classes import Concat, Count, Expr, Func, Random, TablePath, Code, ITable from data_diff.sqeleton.queries.extras import NormalizeAsString from .info_tree import InfoTree @@ -82,6 +83,12 @@ def _outerjoin(db: Database, a: ITable, b: ITable, keys1: List[str], keys2: List is_exclusive_a = and_(b[k] == None for k in keys2) is_exclusive_b = and_(a[k] == None for k in keys1) + + if isinstance(db, MsSQL): + # There is no "IS NULL" or "ISNULL()" as expressions, only as conditions. + is_exclusive_a = when(is_exclusive_a).then(1).else_(0) + is_exclusive_b = when(is_exclusive_b).then(1).else_(0) + if isinstance(db, Oracle): is_exclusive_a = bool_to_int(is_exclusive_a) is_exclusive_b = bool_to_int(is_exclusive_b) @@ -342,7 +349,7 @@ def _count_diff_per_column(self, db, diff_rows, cols, is_diff_cols): self.stats["diff_counts"] = diff_counts def _sample_and_count_exclusive(self, db, diff_rows, a_cols, b_cols): - if isinstance(db, Oracle): + if isinstance(db, (Oracle, MsSQL)): exclusive_rows_query = diff_rows.where((this.is_exclusive_a == 1) | (this.is_exclusive_b == 1)) else: exclusive_rows_query = diff_rows.where(this.is_exclusive_a | this.is_exclusive_b) diff --git a/data_diff/sqeleton/abcs/database_types.py b/data_diff/sqeleton/abcs/database_types.py index f82e681b..9bde030b 100644 --- a/data_diff/sqeleton/abcs/database_types.py +++ b/data_diff/sqeleton/abcs/database_types.py @@ -216,7 +216,17 @@ def current_timestamp(self) -> str: "Provide SQL for returning the current timestamp, aka now" @abstractmethod - def offset_limit(self, offset: Optional[int] = None, limit: Optional[int] = None): + def current_database(self) -> str: + "Provide SQL for returning the current default database." + + @abstractmethod + def current_schema(self) -> str: + "Provide SQL for returning the current default schema." + + @abstractmethod + def offset_limit( + self, offset: Optional[int] = None, limit: Optional[int] = None, has_order_by: Optional[bool] = None + ) -> str: "Provide SQL fragment for limit and offset inside a select" @abstractmethod diff --git a/data_diff/sqeleton/databases/__init__.py b/data_diff/sqeleton/databases/__init__.py index 44a7e1c8..5b6e3397 100644 --- a/data_diff/sqeleton/databases/__init__.py +++ b/data_diff/sqeleton/databases/__init__.py @@ -14,5 +14,6 @@ from .clickhouse import Clickhouse from .vertica import Vertica from .duckdb import DuckDB +from .mssql import MsSQL connect = Connect() diff --git a/data_diff/sqeleton/databases/_connect.py b/data_diff/sqeleton/databases/_connect.py index 2d2314fa..c6638d98 100644 --- a/data_diff/sqeleton/databases/_connect.py +++ b/data_diff/sqeleton/databases/_connect.py @@ -21,6 +21,7 @@ from .clickhouse import Clickhouse from .vertica import Vertica from .duckdb import DuckDB +from .mssql import MsSQL @dataclass @@ -86,6 +87,7 @@ def match_path(self, dsn): "trino": Trino, "clickhouse": Clickhouse, "vertica": Vertica, + "mssql": MsSQL, } diff --git a/data_diff/sqeleton/databases/base.py b/data_diff/sqeleton/databases/base.py index 73c69424..78bfe2bf 100644 --- a/data_diff/sqeleton/databases/base.py +++ b/data_diff/sqeleton/databases/base.py @@ -155,7 +155,9 @@ class BaseDialect(AbstractDialect): PLACEHOLDER_TABLE = None # Used for Oracle - def offset_limit(self, offset: Optional[int] = None, limit: Optional[int] = None): + def offset_limit( + self, offset: Optional[int] = None, limit: Optional[int] = None, has_order_by: Optional[bool] = None + ) -> str: if offset: raise NotImplementedError("No support for OFFSET in query") @@ -182,6 +184,12 @@ def random(self) -> str: def current_timestamp(self) -> str: return "current_timestamp()" + def current_database(self) -> str: + return "current_database()" + + def current_schema(self) -> str: + return "current_schema()" + def explain_as_text(self, query: str) -> str: return f"EXPLAIN {query}" @@ -518,7 +526,10 @@ def _query_cursor(self, c, sql_code: str) -> QueryResult: c.execute(sql_code) if sql_code.lower().startswith(("select", "explain", "show")): columns = [col[0] for col in c.description] - return QueryResult(c.fetchall(), columns) + + fetched = c.fetchall() + result = QueryResult(fetched, columns) + return result except Exception as _e: # logger.exception(e) # logger.error(f'Caused by SQL: {sql_code}') @@ -590,7 +601,8 @@ def is_autocommit(self) -> bool: return False -CHECKSUM_HEXDIGITS = 15 # Must be 15 or lower, otherwise SUM() overflows +# TODO FYI mssql md5_as_int currently requires this to be reduced +CHECKSUM_HEXDIGITS = 14 # Must be 15 or lower, otherwise SUM() overflows MD5_HEXDIGITS = 32 _CHECKSUM_BITSIZE = CHECKSUM_HEXDIGITS << 2 diff --git a/data_diff/sqeleton/databases/mssql.py b/data_diff/sqeleton/databases/mssql.py index 8d394e3c..5ea3cab5 100644 --- a/data_diff/sqeleton/databases/mssql.py +++ b/data_diff/sqeleton/databases/mssql.py @@ -1,25 +1,214 @@ -# class MsSQL(ThreadedDatabase): -# "AKA sql-server" +from typing import Optional +from data_diff.sqeleton.abcs.mixins import AbstractMixin_MD5, AbstractMixin_NormalizeValue +from .base import ( + CHECKSUM_HEXDIGITS, + Mixin_OptimizerHints, + Mixin_RandomSample, + QueryError, + ThreadedDatabase, + import_helper, + ConnectError, + BaseDialect, +) +from .base import Mixin_Schema +from ..abcs.database_types import ( + JSON, + Timestamp, + TimestampTZ, + DbPath, + Float, + Decimal, + Integer, + TemporalType, + Native_UUID, + Text, + FractionalType, + Boolean, +) -# def __init__(self, host, port, user, password, *, database, thread_count, **kw): -# args = dict(server=host, port=port, database=database, user=user, password=password, **kw) -# self._args = {k: v for k, v in args.items() if v is not None} -# super().__init__(thread_count=thread_count) +@import_helper("mssql") +def import_mssql(): + import pyodbc -# def create_connection(self): -# mssql = import_mssql() -# try: -# return mssql.connect(**self._args) -# except mssql.Error as e: -# raise ConnectError(*e.args) from e + return pyodbc -# def quote(self, s: str): -# return f"[{s}]" -# def md5_as_int(self, s: str) -> str: -# return f"CONVERT(decimal(38,0), CONVERT(bigint, HashBytes('MD5', {s}), 2))" -# # return f"CONVERT(bigint, (CHECKSUM({s})))" +class Mixin_NormalizeValue(AbstractMixin_NormalizeValue): + def normalize_timestamp(self, value: str, coltype: TemporalType) -> str: + if coltype.precision > 0: + formatted_value = ( + f"FORMAT({value}, 'yyyy-MM-dd HH:mm:ss') + '.' + " + f"SUBSTRING(FORMAT({value}, 'fffffff'), 1, {coltype.precision})" + ) + else: + formatted_value = f"FORMAT({value}, 'yyyy-MM-dd HH:mm:ss')" -# def to_string(self, s: str): -# return f"CONVERT(varchar, {s})" + return formatted_value + + def normalize_number(self, value: str, coltype: FractionalType) -> str: + if coltype.precision == 0: + return f"CAST(FLOOR({value}) AS VARCHAR)" + + return f"FORMAT({value}, 'N{coltype.precision}')" + + +class Mixin_MD5(AbstractMixin_MD5): + def md5_as_int(self, s: str) -> str: + return f"convert(bigint, convert(varbinary, '0x' + RIGHT(CONVERT(NVARCHAR(32), HashBytes('MD5', {s}), 2), {CHECKSUM_HEXDIGITS}), 1))" + + +class Dialect(BaseDialect, Mixin_Schema, Mixin_OptimizerHints): + name = "MsSQL" + ROUNDS_ON_PREC_LOSS = True + SUPPORTS_PRIMARY_KEY = True + SUPPORTS_INDEXES = True + TYPE_CLASSES = { + # Timestamps + "datetimeoffset": TimestampTZ, + "datetime": Timestamp, + "datetime2": Timestamp, + "smalldatetime": Timestamp, + "date": Timestamp, + # Numbers + "float": Float, + "real": Float, + "decimal": Decimal, + "money": Decimal, + "smallmoney": Decimal, + # int + "int": Integer, + "bigint": Integer, + "tinyint": Integer, + "smallint": Integer, + # Text + "varchar": Text, + "char": Text, + "text": Text, + "ntext": Text, + "nvarchar": Text, + "nchar": Text, + "binary": Text, + "varbinary": Text, + # UUID + "uniqueidentifier": Native_UUID, + # Bool + "bit": Boolean, + # JSON + "json": JSON, + } + + MIXINS = {Mixin_Schema, Mixin_NormalizeValue, Mixin_RandomSample} + + def quote(self, s: str): + return f"[{s}]" + + def set_timezone_to_utc(self) -> str: + raise NotImplementedError("MsSQL does not support a session timezone setting.") + + def current_timestamp(self) -> str: + return "GETDATE()" + + def current_database(self) -> str: + return "DB_NAME()" + + def current_schema(self) -> str: + return """default_schema_name + FROM sys.database_principals + WHERE name = CURRENT_USER""" + + def to_string(self, s: str): + return f"CONVERT(varchar, {s})" + + def type_repr(self, t) -> str: + try: + return {bool: "bit"}[t] + except KeyError: + return super().type_repr(t) + + def random(self) -> str: + return "rand()" + + def is_distinct_from(self, a: str, b: str) -> str: + # IS (NOT) DISTINCT FROM is available only since SQLServer 2022. + # See: https://stackoverflow.com/a/18684859/857383 + return f"(({a}<>{b} OR {a} IS NULL OR {b} IS NULL) AND NOT({a} IS NULL AND {b} IS NULL))" + + def offset_limit( + self, offset: Optional[int] = None, limit: Optional[int] = None, has_order_by: Optional[bool] = None + ) -> str: + if offset: + raise NotImplementedError("No support for OFFSET in query") + + result = "" + if not has_order_by: + result += "ORDER BY 1" + + result += f" OFFSET 0 ROWS FETCH NEXT {limit} ROWS ONLY" + return result + + def constant_values(self, rows) -> str: + values = ", ".join("(%s)" % ", ".join(self._constant_value(v) for v in row) for row in rows) + return f"VALUES {values}" + + +class MsSQL(ThreadedDatabase): + dialect = Dialect() + # + CONNECT_URI_HELP = "mssql://:@//" + CONNECT_URI_PARAMS = ["database", "schema"] + + def __init__(self, host, port, user, password, *, database, thread_count, **kw): + args = dict(server=host, port=port, database=database, user=user, password=password, **kw) + self._args = {k: v for k, v in args.items() if v is not None} + self._args["driver"] = "{ODBC Driver 18 for SQL Server}" + + # TODO temp dev debug + self._args["TrustServerCertificate"] = "yes" + + try: + self.default_database = self._args["database"] + self.default_schema = self._args["schema"] + except KeyError: + raise ValueError("Specify a default database and schema.") + + super().__init__(thread_count=thread_count) + + def create_connection(self): + self._mssql = import_mssql() + try: + connection = self._mssql.connect(**self._args) + return connection + except self._mssql.Error as error: + raise ConnectError(*error.args) from error + + def select_table_schema(self, path: DbPath) -> str: + """Provide SQL for selecting the table schema as (name, type, date_prec, num_prec)""" + database, schema, name = self._normalize_table_path(path) + info_schema_path = ["information_schema", "columns"] + if database: + info_schema_path.insert(0, self.dialect.quote(database)) + + return ( + "SELECT column_name, data_type, datetime_precision, numeric_precision, numeric_scale " + f"FROM {'.'.join(info_schema_path)} " + f"WHERE table_name = '{name}' AND table_schema = '{schema}'" + ) + + def _normalize_table_path(self, path: DbPath) -> DbPath: + if len(path) == 1: + return self.default_database, self.default_schema, path[0] + elif len(path) == 2: + return self.default_database, path[0], path[1] + elif len(path) == 3: + return path + + raise ValueError( + f"{self.name}: Bad table path for {self}: '{'.'.join(path)}'. Expected format: table, schema.table, or database.schema.table" + ) + + def _query_cursor(self, c, sql_code: str): + try: + return super()._query_cursor(c, sql_code) + except self._mssql.DatabaseError as e: + raise QueryError(e) diff --git a/data_diff/sqeleton/databases/oracle.py b/data_diff/sqeleton/databases/oracle.py index 557759b3..c26bcf42 100644 --- a/data_diff/sqeleton/databases/oracle.py +++ b/data_diff/sqeleton/databases/oracle.py @@ -104,7 +104,9 @@ def quote(self, s: str): def to_string(self, s: str): return f"cast({s} as varchar(1024))" - def offset_limit(self, offset: Optional[int] = None, limit: Optional[int] = None): + def offset_limit( + self, offset: Optional[int] = None, limit: Optional[int] = None, has_order_by: Optional[bool] = None + ) -> str: if offset: raise NotImplementedError("No support for OFFSET in query") diff --git a/data_diff/sqeleton/queries/ast_classes.py b/data_diff/sqeleton/queries/ast_classes.py index 7975c8fa..aba86c70 100644 --- a/data_diff/sqeleton/queries/ast_classes.py +++ b/data_diff/sqeleton/queries/ast_classes.py @@ -730,7 +730,8 @@ def compile(self, parent_c: Compiler) -> str: select += " ORDER BY " + ", ".join(map(c.compile, self.order_by_exprs)) if self.limit_expr is not None: - select += " " + c.dialect.offset_limit(0, self.limit_expr) + has_order_by = bool(self.order_by_exprs) + select += " " + c.dialect.offset_limit(0, self.limit_expr, has_order_by=has_order_by) if parent_c.in_select: select = f"({select}) {c.new_unique_name()}" diff --git a/poetry.lock b/poetry.lock index 0c728f8c..ef73d2fc 100644 --- a/poetry.lock +++ b/poetry.lock @@ -1711,6 +1711,50 @@ dev = ["coverage[toml] (==5.0.4)", "cryptography (>=3.4.0)", "pre-commit", "pyte docs = ["sphinx (>=4.5.0,<5.0.0)", "sphinx-rtd-theme", "zope.interface"] tests = ["coverage[toml] (==5.0.4)", "pytest (>=6.0.0,<7.0.0)"] +[[package]] +name = "pyodbc" +version = "4.0.39" +description = "DB API Module for ODBC" +optional = false +python-versions = ">=2.7, !=3.0.*, !=3.1.*, !=3.2.*, !=3.3.*, !=3.4.*, !=3.5.*" +files = [ + {file = "pyodbc-4.0.39-cp310-cp310-macosx_10_9_x86_64.whl", hash = "sha256:74af348dbaee4885998858daf50c8964e767629ecf6c195868b016367b0bb861"}, + {file = "pyodbc-4.0.39-cp310-cp310-macosx_11_0_arm64.whl", hash = "sha256:0f5901b57eaef0761f4cf02bca8e7c63f589fd0fd723a79f6ccf1ea1275372e5"}, + {file = "pyodbc-4.0.39-cp310-cp310-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:e0db69478d00fcd8d0b9bdde8aca0b0eada341fd6ed8c2da84b594b928c84106"}, + {file = "pyodbc-4.0.39-cp310-cp310-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:c5faf2870e9d434c6a85c6adc1cdff55c0e376273baf480f06d9848025405688"}, + {file = "pyodbc-4.0.39-cp310-cp310-win32.whl", hash = "sha256:62bb6d7d0d25dc75d1445e539f946461c9c5a3643ae14676b240f71794ea004f"}, + {file = "pyodbc-4.0.39-cp310-cp310-win_amd64.whl", hash = "sha256:8eb5547282dc73a7784ce7b99584f68687dd85543538ca6f70cffaa6310676e7"}, + {file = "pyodbc-4.0.39-cp311-cp311-macosx_10_9_x86_64.whl", hash = "sha256:530c1ac37ead782803b44fb1934ba4c68ed4a6969f7475cb8bc04ae1da14486e"}, + {file = "pyodbc-4.0.39-cp311-cp311-macosx_11_0_arm64.whl", hash = "sha256:1f7fb65191926308f09ce75ae7ccecf89310232ee50cdea74edf17ee04a9b068"}, + {file = "pyodbc-4.0.39-cp311-cp311-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:5ec009180fcd7c8197f45d083e6670623d8dfe198a457ca2a50ebb1bafe4107f"}, + {file = "pyodbc-4.0.39-cp311-cp311-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:400e911d54980098c6badadecc82385fc0d6a9057db525d63d2652317df43efe"}, + {file = "pyodbc-4.0.39-cp311-cp311-win32.whl", hash = "sha256:f792677b88e1dde12dab46de8647620fc8171742c02780d51744f7b1b2135dbc"}, + {file = "pyodbc-4.0.39-cp311-cp311-win_amd64.whl", hash = "sha256:3d9d70e1635d35ba3aee3df216ec8e35f2824909f43331c0112b17f460a93923"}, + {file = "pyodbc-4.0.39-cp36-cp36m-macosx_10_9_x86_64.whl", hash = "sha256:c1a59096f1784d0cda3d0b8f393849f05515c46a10016edb6da1b1960d039800"}, + {file = "pyodbc-4.0.39-cp36-cp36m-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:1b3467157661615d5c30893efa1069b55c9ffa434097fc3ae3739e740d83d2ec"}, + {file = "pyodbc-4.0.39-cp36-cp36m-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:af027a60e84274ea08fad1c75991d37a5f1f6e8bcd30f6bda20db99f0cdfbc7d"}, + {file = "pyodbc-4.0.39-cp36-cp36m-win32.whl", hash = "sha256:64c1de1263281de7b5ce585b0352746ab1a483453017a8589f838a79cbe3d6d9"}, + {file = "pyodbc-4.0.39-cp36-cp36m-win_amd64.whl", hash = "sha256:27d1b3c3159673b44c97c878f9d8056901d45f747ce2e0b4d5d99f0fb6949dc7"}, + {file = "pyodbc-4.0.39-cp37-cp37m-macosx_10_9_x86_64.whl", hash = "sha256:efccc11dff6fba684a74ae1030c92ff8b82429d7f00e0a50aa2ac6f56621cd9f"}, + {file = "pyodbc-4.0.39-cp37-cp37m-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:ea08e9379c08663d7260e2b8a6c451f56d36c17291af735191089f8e29ad9578"}, + {file = "pyodbc-4.0.39-cp37-cp37m-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:b36fe804d367d01ad81077fa524a36e667aabc3945e32564c7ef9595b28edfa9"}, + {file = "pyodbc-4.0.39-cp37-cp37m-win32.whl", hash = "sha256:72d364e52f6ca2417881a23834b3a36733c09e0dcd4760f49a6b864218d98d92"}, + {file = "pyodbc-4.0.39-cp37-cp37m-win_amd64.whl", hash = "sha256:39f6c56022c764309aa7552c0eb2c58fbb5902ab5d2010d42b021c0b205aa609"}, + {file = "pyodbc-4.0.39-cp38-cp38-macosx_10_9_x86_64.whl", hash = "sha256:ebcb900fcaf19ca2bc38632218c5d48c666fcc19fe38b08cde001917f4581456"}, + {file = "pyodbc-4.0.39-cp38-cp38-macosx_11_0_arm64.whl", hash = "sha256:a3e133621ac2dad22d0870a8521c7e82d4270e24ce02451d64e7eb6a40ad0941"}, + {file = "pyodbc-4.0.39-cp38-cp38-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:05a0912e852ebddaffa8f235b0f3974475021dd8eb604eb46ea67af06efe1239"}, + {file = "pyodbc-4.0.39-cp38-cp38-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:6353044b99c763aeec7ca1760b4340298504d8ee544fdcab3c380a2abec15b78"}, + {file = "pyodbc-4.0.39-cp38-cp38-win32.whl", hash = "sha256:a591a1cf3c251a9c7c1642cfb3774119bf3512f3be56151247238f8a7b22b336"}, + {file = "pyodbc-4.0.39-cp38-cp38-win_amd64.whl", hash = "sha256:8553eaef9f8ec333bbddff6eadf0d322dda34b37f4bab19f0658eb532037840c"}, + {file = "pyodbc-4.0.39-cp39-cp39-macosx_10_9_x86_64.whl", hash = "sha256:9253e746c5c94bf61e3e9adb08fb7688d413cb68c06ebb287ec233387534760a"}, + {file = "pyodbc-4.0.39-cp39-cp39-macosx_11_0_arm64.whl", hash = "sha256:a6f4067f46aaa78e77e8a15ade81eb21fb344563d245fb2d9a0aaa553c367cbd"}, + {file = "pyodbc-4.0.39-cp39-cp39-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:cdf5a27e6587d1762f7f0e35d6f0309f09019bf3e19ca9177a4b765121f3f106"}, + {file = "pyodbc-4.0.39-cp39-cp39-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:fe4ee87b88867867f582dd0c1236cd982508db359a6cbb5e91623ceb6c83e60a"}, + {file = "pyodbc-4.0.39-cp39-cp39-win32.whl", hash = "sha256:42649ed57d09c04aa197bdd4fe0aa9ca319790b7aa86d0b0784cc70e78c426e5"}, + {file = "pyodbc-4.0.39-cp39-cp39-win_amd64.whl", hash = "sha256:305c7d6337e2d4c8350677cc641b343fc0197b7b9bc167815c66b64545c67a53"}, + {file = "pyodbc-4.0.39.tar.gz", hash = "sha256:e528bb70dd6d6299ee429868925df0866e3e919c772b9eff79c8e17920d8f116"}, +] + [[package]] name = "pyopenssl" version = "22.0.0" @@ -2415,4 +2459,4 @@ vertica = ["vertica-python"] [metadata] lock-version = "2.0" python-versions = "^3.7.2" -content-hash = "8e437f479e7e82cdf74aba56e0d4dbc0a69d03de11cd38bef98e3bc9b6346020" +content-hash = "a35a0f4127cafe848ddda5ee5274f5710493bd8789320612061d322d6b57d49c" diff --git a/pyproject.toml b/pyproject.toml index 03ba9a17..591d2445 100755 --- a/pyproject.toml +++ b/pyproject.toml @@ -45,6 +45,7 @@ preql = {version="^0.2.19", optional=true} vertica-python = {version="*", optional=true} urllib3 = "<2" oracledb = {version = "*", extras = ["oracle"]} +pyodbc = "^4.0.39" [tool.poetry.dev-dependencies] parameterized = "*" diff --git a/tests/common.py b/tests/common.py index 7fa68d24..baeec717 100644 --- a/tests/common.py +++ b/tests/common.py @@ -35,6 +35,7 @@ # vertica uri provided for docker - "vertica://vertica:Password1@localhost:5433/vertica" TEST_VERTICA_CONN_STRING: str = os.environ.get("DATADIFF_VERTICA_URI") TEST_DUCKDB_CONN_STRING: str = "duckdb://main:@:memory:" +TEST_MSSQL_CONN_STRING: str = os.environ.get("DATADIFF_MSSQL_URI") DEFAULT_N_SAMPLES = 50 @@ -80,6 +81,7 @@ def get_git_revision_short_hash() -> str: db.Clickhouse: TEST_CLICKHOUSE_CONN_STRING, db.Vertica: TEST_VERTICA_CONN_STRING, db.DuckDB: TEST_DUCKDB_CONN_STRING, + db.MsSql: TEST_MSSQL_CONN_STRING, } _database_instances = {} diff --git a/tests/sqeleton/common.py b/tests/sqeleton/common.py index 03625da7..4d2dc3d2 100644 --- a/tests/sqeleton/common.py +++ b/tests/sqeleton/common.py @@ -30,6 +30,7 @@ TEST_DUCKDB_CONN_STRING, N_THREADS, TEST_ACROSS_ALL_DBS, + TEST_MSSQL_CONN_STRING, ) @@ -65,6 +66,7 @@ def get_git_revision_short_hash() -> str: db.Clickhouse: TEST_CLICKHOUSE_CONN_STRING, db.Vertica: TEST_VERTICA_CONN_STRING, db.DuckDB: TEST_DUCKDB_CONN_STRING, + db.MsSQL: TEST_MSSQL_CONN_STRING, } _database_instances = {} diff --git a/tests/sqeleton/test_database.py b/tests/sqeleton/test_database.py index 948cfd12..5faa9abf 100644 --- a/tests/sqeleton/test_database.py +++ b/tests/sqeleton/test_database.py @@ -22,6 +22,7 @@ dbs.Presto, dbs.Trino, dbs.Vertica, + dbs.MsSQL, } test_each_database: Callable = test_each_database_in_list(TEST_DATABASES) @@ -104,6 +105,8 @@ def test_current_timestamp(self): assert isinstance(res, datetime), (res, type(res)) def test_correct_timezone(self): + if self.db_cls in [dbs.MsSQL]: + self.skipTest("No support for session tz.") name = "tbl_" + random_table_suffix() db = get_conn(self.db_cls) tbl = table(name, schema={"id": int, "created_at": TimestampTZ(9), "updated_at": TimestampTZ(9)}) @@ -142,13 +145,13 @@ def test_correct_timezone(self): @test_each_database class TestThreePartIds(unittest.TestCase): def test_three_part_support(self): - if self.db_cls not in [dbs.PostgreSQL, dbs.Redshift, dbs.Snowflake, dbs.DuckDB]: + if self.db_cls not in [dbs.PostgreSQL, dbs.Redshift, dbs.Snowflake, dbs.DuckDB, dbs.MsSQL]: self.skipTest("Limited support for 3 part ids") table_name = "tbl_" + random_table_suffix() db = get_conn(self.db_cls) - db_res = db.query("SELECT CURRENT_DATABASE()") - schema_res = db.query("SELECT CURRENT_SCHEMA()") + db_res = db.query(f"SELECT {db.dialect.current_database()}") + schema_res = db.query(f"SELECT {db.dialect.current_schema()}") db_name = db_res.rows[0][0] schema_name = schema_res.rows[0][0] diff --git a/tests/sqeleton/test_query.py b/tests/sqeleton/test_query.py index efc41c02..cfa6ada8 100644 --- a/tests/sqeleton/test_query.py +++ b/tests/sqeleton/test_query.py @@ -41,7 +41,15 @@ def random(self) -> str: def current_timestamp(self) -> str: return "now()" - def offset_limit(self, offset: Optional[int] = None, limit: Optional[int] = None): + def current_database(self) -> str: + return "current_database()" + + def current_schema(self) -> str: + return "current_schema()" + + def offset_limit( + self, offset: Optional[int] = None, limit: Optional[int] = None, has_order_by: Optional[bool] = None + ) -> str: x = offset and f"OFFSET {offset}", limit and f"LIMIT {limit}" return " ".join(filter(None, x)) diff --git a/tests/test_database_types.py b/tests/test_database_types.py index ad67bea6..4a9eaf4b 100644 --- a/tests/test_database_types.py +++ b/tests/test_database_types.py @@ -351,6 +351,15 @@ def init_conns(): "boolean", ], }, + db.MsSql: { + "int": ["INT", "BIGINT"], + "datetime": ["datetime2(6)"], + "float": ["DECIMAL(6, 2)", "FLOAT", "REAL"], + "uuid": ["VARCHAR(100)", "CHAR(100)", "UNIQUEIDENTIFIER"], + "boolean": [ + "BIT", + ], + }, } @@ -615,6 +624,9 @@ def _insert_to_table(conn, table_path, values, coltype): ) for i, sample in values ] + # mssql represents with int + elif isinstance(conn, db.MsSql) and coltype in ("BIT"): + values = [(i, int(sample)) for i, sample in values] insert_rows_in_batches(conn, tbl, values, columns=["id", "col"]) conn.query(commit)