-
Notifications
You must be signed in to change notification settings - Fork 109
Expand file tree
/
Copy pathsql_server_adapter.py
More file actions
181 lines (159 loc) · 6.05 KB
/
sql_server_adapter.py
File metadata and controls
181 lines (159 loc) · 6.05 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
from typing import List, Optional
import agate
from dbt.adapters.base.relation import BaseRelation
from dbt.adapters.cache import _make_ref_key_dict
from dbt.adapters.sql import SQLAdapter
from dbt.adapters.sql.impl import CREATE_SCHEMA_MACRO_NAME
from dbt.events.functions import fire_event
from dbt.events.types import SchemaCreation
from dbt.adapters.sqlserver.sql_server_column import SQLServerColumn
from dbt.adapters.sqlserver.sql_server_configs import SQLServerConfigs
from dbt.adapters.sqlserver.sql_server_connection_manager import (
SQLServerConnectionManager,
)
class SQLServerAdapter(SQLAdapter):
ConnectionManager = SQLServerConnectionManager
Column = SQLServerColumn
AdapterSpecificConfigs = SQLServerConfigs
def create_schema(self, relation: BaseRelation) -> None:
relation = relation.without_identifier()
fire_event(SchemaCreation(relation=_make_ref_key_dict(relation)))
macro_name = CREATE_SCHEMA_MACRO_NAME
kwargs = {
"relation": relation,
}
if self.config.credentials.schema_authorization:
kwargs[
"schema_authorization"
] = self.config.credentials.schema_authorization
macro_name = "sqlserver__create_schema_with_authorization"
self.execute_macro(macro_name, kwargs=kwargs)
self.commit_if_has_connection()
@classmethod
def date_function(cls):
return "getdate()"
@classmethod
def convert_text_type(cls, agate_table, col_idx):
column = agate_table.columns[col_idx]
# see https://github.com/fishtown-analytics/dbt/pull/2255
lens = [len(d.encode("utf-8")) for d in column.values_without_nulls()]
max_len = max(lens) if lens else 64
length = max_len if max_len > 16 else 16
return "varchar({})".format(length)
@classmethod
def convert_datetime_type(cls, agate_table, col_idx):
return "datetime"
@classmethod
def convert_boolean_type(cls, agate_table, col_idx):
return "bit"
@classmethod
def convert_number_type(cls, agate_table, col_idx):
decimals = agate_table.aggregate(agate.MaxPrecision(col_idx))
return "float" if decimals else "int"
@classmethod
def convert_time_type(cls, agate_table, col_idx):
return "datetime"
# Methods used in adapter tests
def timestamp_add_sql(
self, add_to: str, number: int = 1, interval: str = "hour"
) -> str:
# note: 'interval' is not supported for T-SQL
# for backwards compatibility, we're compelled to set some sort of
# default. A lot of searching has lead me to believe that the
# '+ interval' syntax used in postgres/redshift is relatively common
# and might even be the SQL standard's intention.
return f"DATEADD({interval},{number},{add_to})"
def string_add_sql(
self,
add_to: str,
value: str,
location="append",
) -> str:
"""
`+` is T-SQL's string concatenation operator
"""
if location == "append":
return f"{add_to} + '{value}'"
elif location == "prepend":
return f"'{value}' + {add_to}"
else:
raise ValueError(f'Got an unexpected location value of "{location}"')
def get_rows_different_sql(
self,
relation_a: BaseRelation,
relation_b: BaseRelation,
column_names: Optional[List[str]] = None,
except_operator: str = "EXCEPT",
) -> str:
"""
note: using is not supported on Synapse so COLUMNS_EQUAL_SQL is adjsuted
Generate SQL for a query that returns a single row with a two
columns: the number of rows that are different between the two
relations and the number of mismatched rows.
"""
# This method only really exists for test reasons.
names: List[str]
if column_names is None:
columns = self.get_columns_in_relation(relation_a)
names = sorted((self.quote(c.name) for c in columns))
else:
names = sorted((self.quote(n) for n in column_names))
columns_csv = ", ".join(names)
sql = COLUMNS_EQUAL_SQL.format(
columns=columns_csv,
relation_a=str(relation_a),
relation_b=str(relation_b),
except_op=except_operator,
)
return sql
def valid_incremental_strategies(self):
"""The set of standard builtin strategies which this adapter supports out-of-the-box.
Not used to validate custom strategies defined by end users.
"""
return ["append", "delete+insert", "merge", "insert_overwrite"]
# This is for use in the test suite
def run_sql_for_tests(self, sql, fetch, conn):
cursor = conn.handle.cursor()
try:
cursor.execute(sql)
if not fetch:
conn.handle.commit()
if fetch == "one":
return cursor.fetchone()
elif fetch == "all":
return cursor.fetchall()
else:
return
except BaseException:
if conn.handle and not getattr(conn.handle, "closed", True):
conn.handle.rollback()
raise
finally:
conn.transaction_open = False
COLUMNS_EQUAL_SQL = """
with diff_count as (
SELECT
1 as id,
COUNT(*) as num_missing FROM (
(SELECT {columns} FROM {relation_a} {except_op}
SELECT {columns} FROM {relation_b})
UNION ALL
(SELECT {columns} FROM {relation_b} {except_op}
SELECT {columns} FROM {relation_a})
) as a
), table_a as (
SELECT COUNT(*) as num_rows FROM {relation_a}
), table_b as (
SELECT COUNT(*) as num_rows FROM {relation_b}
), row_count_diff as (
select
1 as id,
table_a.num_rows - table_b.num_rows as difference
from table_a, table_b
)
select
row_count_diff.difference as row_count_difference,
diff_count.num_missing as num_mismatched
from row_count_diff
join diff_count on row_count_diff.id = diff_count.id
""".strip()