Skip to content
This repository was archived by the owner on May 17, 2024. It is now read-only.
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions data_diff/databases/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,5 +12,6 @@
from .clickhouse import Clickhouse
from .vertica import Vertica
from .duckdb import DuckDB
from .mssql import MsSql

from ._connect import connect
2 changes: 2 additions & 0 deletions data_diff/databases/_connect.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@
from .clickhouse import Clickhouse
from .vertica import Vertica
from .duckdb import DuckDB
from .mssql import MsSql


DATABASE_BY_SCHEME = {
Expand All @@ -29,6 +30,7 @@
"trino": Trino,
"clickhouse": Clickhouse,
"vertica": Vertica,
"mssql": MsSql,
}


Expand Down
10 changes: 10 additions & 0 deletions data_diff/databases/mssql.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,10 @@
from data_diff.sqeleton.databases import mssql
from .base import DatadiffDialect


class Dialect(mssql.Dialect, mssql.Mixin_MD5, mssql.Mixin_NormalizeValue, DatadiffDialect):
pass


class MsSql(mssql.MsSQL):
dialect = Dialect()
13 changes: 10 additions & 3 deletions data_diff/joindiff_tables.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@

from runtype import dataclass

from data_diff.sqeleton.databases import Database, MySQL, BigQuery, Presto, Oracle, Snowflake, DbPath
from data_diff.sqeleton.databases import Database, MsSQL, MySQL, BigQuery, Presto, Oracle, Snowflake, DbPath
from data_diff.sqeleton.abcs import NumericType
from data_diff.sqeleton.queries import (
table,
Expand All @@ -25,9 +25,10 @@
leftjoin,
rightjoin,
this,
when,
Compiler,
)
from data_diff.sqeleton.queries.ast_classes import Concat, Count, Expr, Random, TablePath, Code, ITable
from data_diff.sqeleton.queries.ast_classes import Concat, Count, Expr, Func, Random, TablePath, Code, ITable
from data_diff.sqeleton.queries.extras import NormalizeAsString

from .info_tree import InfoTree
Expand Down Expand Up @@ -82,6 +83,12 @@ def _outerjoin(db: Database, a: ITable, b: ITable, keys1: List[str], keys2: List

is_exclusive_a = and_(b[k] == None for k in keys2)
is_exclusive_b = and_(a[k] == None for k in keys1)

if isinstance(db, MsSQL):
# There is no "IS NULL" or "ISNULL()" as expressions, only as conditions.
is_exclusive_a = when(is_exclusive_a).then(1).else_(0)
is_exclusive_b = when(is_exclusive_b).then(1).else_(0)

if isinstance(db, Oracle):
is_exclusive_a = bool_to_int(is_exclusive_a)
is_exclusive_b = bool_to_int(is_exclusive_b)
Expand Down Expand Up @@ -342,7 +349,7 @@ def _count_diff_per_column(self, db, diff_rows, cols, is_diff_cols):
self.stats["diff_counts"] = diff_counts

def _sample_and_count_exclusive(self, db, diff_rows, a_cols, b_cols):
if isinstance(db, Oracle):
if isinstance(db, (Oracle, MsSQL)):
exclusive_rows_query = diff_rows.where((this.is_exclusive_a == 1) | (this.is_exclusive_b == 1))
else:
exclusive_rows_query = diff_rows.where(this.is_exclusive_a | this.is_exclusive_b)
Expand Down
12 changes: 11 additions & 1 deletion data_diff/sqeleton/abcs/database_types.py
Original file line number Diff line number Diff line change
Expand Up @@ -216,7 +216,17 @@ def current_timestamp(self) -> str:
"Provide SQL for returning the current timestamp, aka now"

@abstractmethod
def offset_limit(self, offset: Optional[int] = None, limit: Optional[int] = None):
def current_database(self) -> str:
"Provide SQL for returning the current default database."

@abstractmethod
def current_schema(self) -> str:
"Provide SQL for returning the current default schema."

@abstractmethod
def offset_limit(
self, offset: Optional[int] = None, limit: Optional[int] = None, has_order_by: Optional[bool] = None
) -> str:
"Provide SQL fragment for limit and offset inside a select"

@abstractmethod
Expand Down
1 change: 1 addition & 0 deletions data_diff/sqeleton/databases/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,5 +14,6 @@
from .clickhouse import Clickhouse
from .vertica import Vertica
from .duckdb import DuckDB
from .mssql import MsSQL

connect = Connect()
2 changes: 2 additions & 0 deletions data_diff/sqeleton/databases/_connect.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@
from .clickhouse import Clickhouse
from .vertica import Vertica
from .duckdb import DuckDB
from .mssql import MsSQL


@dataclass
Expand Down Expand Up @@ -86,6 +87,7 @@ def match_path(self, dsn):
"trino": Trino,
"clickhouse": Clickhouse,
"vertica": Vertica,
"mssql": MsSQL,
}


Expand Down
18 changes: 15 additions & 3 deletions data_diff/sqeleton/databases/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -155,7 +155,9 @@ class BaseDialect(AbstractDialect):

PLACEHOLDER_TABLE = None # Used for Oracle

def offset_limit(self, offset: Optional[int] = None, limit: Optional[int] = None):
def offset_limit(
self, offset: Optional[int] = None, limit: Optional[int] = None, has_order_by: Optional[bool] = None
) -> str:
if offset:
raise NotImplementedError("No support for OFFSET in query")

Expand All @@ -182,6 +184,12 @@ def random(self) -> str:
def current_timestamp(self) -> str:
return "current_timestamp()"

def current_database(self) -> str:
return "current_database()"

def current_schema(self) -> str:
return "current_schema()"

def explain_as_text(self, query: str) -> str:
return f"EXPLAIN {query}"

Expand Down Expand Up @@ -518,7 +526,10 @@ def _query_cursor(self, c, sql_code: str) -> QueryResult:
c.execute(sql_code)
if sql_code.lower().startswith(("select", "explain", "show")):
columns = [col[0] for col in c.description]
return QueryResult(c.fetchall(), columns)

fetched = c.fetchall()
result = QueryResult(fetched, columns)
return result
except Exception as _e:
# logger.exception(e)
# logger.error(f'Caused by SQL: {sql_code}')
Expand Down Expand Up @@ -590,7 +601,8 @@ def is_autocommit(self) -> bool:
return False


CHECKSUM_HEXDIGITS = 15 # Must be 15 or lower, otherwise SUM() overflows
# TODO FYI mssql md5_as_int currently requires this to be reduced
CHECKSUM_HEXDIGITS = 14 # Must be 15 or lower, otherwise SUM() overflows
MD5_HEXDIGITS = 32

_CHECKSUM_BITSIZE = CHECKSUM_HEXDIGITS << 2
Expand Down
227 changes: 208 additions & 19 deletions data_diff/sqeleton/databases/mssql.py
Original file line number Diff line number Diff line change
@@ -1,25 +1,214 @@
# class MsSQL(ThreadedDatabase):
# "AKA sql-server"
from typing import Optional
from data_diff.sqeleton.abcs.mixins import AbstractMixin_MD5, AbstractMixin_NormalizeValue
from .base import (
CHECKSUM_HEXDIGITS,
Mixin_OptimizerHints,
Mixin_RandomSample,
QueryError,
ThreadedDatabase,
import_helper,
ConnectError,
BaseDialect,
)
from .base import Mixin_Schema
from ..abcs.database_types import (
JSON,
Timestamp,
TimestampTZ,
DbPath,
Float,
Decimal,
Integer,
TemporalType,
Native_UUID,
Text,
FractionalType,
Boolean,
)

# def __init__(self, host, port, user, password, *, database, thread_count, **kw):
# args = dict(server=host, port=port, database=database, user=user, password=password, **kw)
# self._args = {k: v for k, v in args.items() if v is not None}

# super().__init__(thread_count=thread_count)
@import_helper("mssql")
def import_mssql():
import pyodbc

# def create_connection(self):
# mssql = import_mssql()
# try:
# return mssql.connect(**self._args)
# except mssql.Error as e:
# raise ConnectError(*e.args) from e
return pyodbc

# def quote(self, s: str):
# return f"[{s}]"

# def md5_as_int(self, s: str) -> str:
# return f"CONVERT(decimal(38,0), CONVERT(bigint, HashBytes('MD5', {s}), 2))"
# # return f"CONVERT(bigint, (CHECKSUM({s})))"
class Mixin_NormalizeValue(AbstractMixin_NormalizeValue):
def normalize_timestamp(self, value: str, coltype: TemporalType) -> str:
if coltype.precision > 0:
formatted_value = (
f"FORMAT({value}, 'yyyy-MM-dd HH:mm:ss') + '.' + "
f"SUBSTRING(FORMAT({value}, 'fffffff'), 1, {coltype.precision})"
)
else:
formatted_value = f"FORMAT({value}, 'yyyy-MM-dd HH:mm:ss')"

# def to_string(self, s: str):
# return f"CONVERT(varchar, {s})"
return formatted_value

def normalize_number(self, value: str, coltype: FractionalType) -> str:
if coltype.precision == 0:
return f"CAST(FLOOR({value}) AS VARCHAR)"

return f"FORMAT({value}, 'N{coltype.precision}')"


class Mixin_MD5(AbstractMixin_MD5):
def md5_as_int(self, s: str) -> str:
return f"convert(bigint, convert(varbinary, '0x' + RIGHT(CONVERT(NVARCHAR(32), HashBytes('MD5', {s}), 2), {CHECKSUM_HEXDIGITS}), 1))"


class Dialect(BaseDialect, Mixin_Schema, Mixin_OptimizerHints):
name = "MsSQL"
ROUNDS_ON_PREC_LOSS = True
SUPPORTS_PRIMARY_KEY = True
SUPPORTS_INDEXES = True
TYPE_CLASSES = {
# Timestamps
"datetimeoffset": TimestampTZ,
"datetime": Timestamp,
"datetime2": Timestamp,
"smalldatetime": Timestamp,
"date": Timestamp,
# Numbers
"float": Float,
"real": Float,
"decimal": Decimal,
"money": Decimal,
"smallmoney": Decimal,
# int
"int": Integer,
"bigint": Integer,
"tinyint": Integer,
"smallint": Integer,
# Text
"varchar": Text,
"char": Text,
"text": Text,
"ntext": Text,
"nvarchar": Text,
"nchar": Text,
"binary": Text,
"varbinary": Text,
# UUID
"uniqueidentifier": Native_UUID,
# Bool
"bit": Boolean,
# JSON
"json": JSON,
}

MIXINS = {Mixin_Schema, Mixin_NormalizeValue, Mixin_RandomSample}

def quote(self, s: str):
return f"[{s}]"

def set_timezone_to_utc(self) -> str:
raise NotImplementedError("MsSQL does not support a session timezone setting.")

def current_timestamp(self) -> str:
return "GETDATE()"

def current_database(self) -> str:
return "DB_NAME()"

def current_schema(self) -> str:
return """default_schema_name
FROM sys.database_principals
WHERE name = CURRENT_USER"""

def to_string(self, s: str):
return f"CONVERT(varchar, {s})"

def type_repr(self, t) -> str:
try:
return {bool: "bit"}[t]
except KeyError:
return super().type_repr(t)

def random(self) -> str:
return "rand()"

def is_distinct_from(self, a: str, b: str) -> str:
# IS (NOT) DISTINCT FROM is available only since SQLServer 2022.
# See: https://stackoverflow.com/a/18684859/857383
return f"(({a}<>{b} OR {a} IS NULL OR {b} IS NULL) AND NOT({a} IS NULL AND {b} IS NULL))"

def offset_limit(
self, offset: Optional[int] = None, limit: Optional[int] = None, has_order_by: Optional[bool] = None
) -> str:
if offset:
raise NotImplementedError("No support for OFFSET in query")

result = ""
if not has_order_by:
result += "ORDER BY 1"

result += f" OFFSET 0 ROWS FETCH NEXT {limit} ROWS ONLY"
return result

def constant_values(self, rows) -> str:
values = ", ".join("(%s)" % ", ".join(self._constant_value(v) for v in row) for row in rows)
return f"VALUES {values}"


class MsSQL(ThreadedDatabase):
dialect = Dialect()
#
CONNECT_URI_HELP = "mssql://<user>:<password>@<host>/<database>/<schema>"
CONNECT_URI_PARAMS = ["database", "schema"]

def __init__(self, host, port, user, password, *, database, thread_count, **kw):
args = dict(server=host, port=port, database=database, user=user, password=password, **kw)
self._args = {k: v for k, v in args.items() if v is not None}
self._args["driver"] = "{ODBC Driver 18 for SQL Server}"

# TODO temp dev debug
self._args["TrustServerCertificate"] = "yes"

try:
self.default_database = self._args["database"]
self.default_schema = self._args["schema"]
except KeyError:
raise ValueError("Specify a default database and schema.")

super().__init__(thread_count=thread_count)

def create_connection(self):
self._mssql = import_mssql()
try:
connection = self._mssql.connect(**self._args)
return connection
except self._mssql.Error as error:
raise ConnectError(*error.args) from error

def select_table_schema(self, path: DbPath) -> str:
"""Provide SQL for selecting the table schema as (name, type, date_prec, num_prec)"""
database, schema, name = self._normalize_table_path(path)
info_schema_path = ["information_schema", "columns"]
if database:
info_schema_path.insert(0, self.dialect.quote(database))

return (
"SELECT column_name, data_type, datetime_precision, numeric_precision, numeric_scale "
f"FROM {'.'.join(info_schema_path)} "
f"WHERE table_name = '{name}' AND table_schema = '{schema}'"
)

def _normalize_table_path(self, path: DbPath) -> DbPath:
if len(path) == 1:
return self.default_database, self.default_schema, path[0]
elif len(path) == 2:
return self.default_database, path[0], path[1]
elif len(path) == 3:
return path

raise ValueError(
f"{self.name}: Bad table path for {self}: '{'.'.join(path)}'. Expected format: table, schema.table, or database.schema.table"
)

def _query_cursor(self, c, sql_code: str):
try:
return super()._query_cursor(c, sql_code)
except self._mssql.DatabaseError as e:
raise QueryError(e)
Loading