This repository was archived by the owner on May 17, 2024. It is now read-only.
-
Notifications
You must be signed in to change notification settings - Fork 305
Expand file tree
/
Copy pathsnowflake.py
More file actions
207 lines (168 loc) · 6.99 KB
/
snowflake.py
File metadata and controls
207 lines (168 loc) · 6.99 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
import base64
from typing import Any, ClassVar, Union, List, Type, Optional
import logging
import attrs
from data_diff.abcs.database_types import (
Timestamp,
TimestampTZ,
Decimal,
Float,
Text,
FractionalType,
TemporalType,
DbPath,
Boolean,
Date,
Time,
)
from data_diff.databases.base import (
BaseDialect,
ConnectError,
Database,
import_helper,
CHECKSUM_MASK,
ThreadLocalInterpreter,
CHECKSUM_OFFSET,
)
@import_helper("snowflake")
def import_snowflake():
import snowflake.connector
from cryptography.hazmat.primitives import serialization
from cryptography.hazmat.backends import default_backend
return snowflake, serialization, default_backend
class Dialect(BaseDialect):
name = "Snowflake"
ROUNDS_ON_PREC_LOSS = False
TYPE_CLASSES = {
# Timestamps
"TIMESTAMP_NTZ": Timestamp,
"TIMESTAMP_LTZ": Timestamp,
"TIMESTAMP_TZ": TimestampTZ,
"DATE": Date,
"TIME": Time,
# Numbers
"NUMBER": Decimal,
"FLOAT": Float,
# Text
"TEXT": Text,
# Boolean
"BOOLEAN": Boolean,
}
def explain_as_text(self, query: str) -> str:
return f"EXPLAIN USING TEXT {query}"
def quote(self, s: str):
return f'"{s}"'
def to_string(self, s: str):
return f"cast({s} as string)"
def set_timezone_to_utc(self) -> str:
return "ALTER SESSION SET TIMEZONE = 'UTC'"
def optimizer_hints(self, hints: str) -> str:
raise NotImplementedError("Optimizer hints not yet implemented in snowflake")
def type_repr(self, t) -> str:
if isinstance(t, TimestampTZ):
return f"timestamp_tz({t.precision})"
return super().type_repr(t)
def md5_as_int(self, s: str) -> str:
return f"BITAND(md5_number_lower64({s}), {CHECKSUM_MASK}) - {CHECKSUM_OFFSET}"
def md5_as_hex(self, s: str) -> str:
return f"md5({s})"
def normalize_timestamp(self, value: str, coltype: TemporalType) -> str:
try:
is_date = coltype.is_date
is_time = coltype.is_time
except:
is_date = False
is_time = False
if isinstance(coltype, Date) or is_date:
return f"({value}::varchar)"
elif isinstance(coltype, Time) or is_time:
microseconds = f"TIMEDIFF(microsecond, cast('00:00:00' as time), {value})"
rounded = f"round({microseconds}, -6 + {coltype.precision})"
time_value = f"TIMEADD(microsecond, {rounded}, cast('00:00:00' as time))"
converted = f"TO_VARCHAR({time_value}, 'HH24:MI:SS.FF6')"
return converted
if coltype.rounds:
timestamp = f"to_timestamp(round(date_part(epoch_nanosecond, convert_timezone('UTC', {value})::timestamp(9))/1000000000, {coltype.precision}))"
else:
timestamp = f"cast(convert_timezone('UTC', {value}) as timestamp({coltype.precision}))"
return f"to_char({timestamp}, 'YYYY-MM-DD HH24:MI:SS.FF6')"
def normalize_number(self, value: str, coltype: FractionalType) -> str:
return self.to_string(f"cast({value} as decimal(38, {coltype.precision}))")
def normalize_boolean(self, value: str, _coltype: Boolean) -> str:
return self.to_string(f"{value}::int")
@attrs.define(frozen=False, init=False, kw_only=True)
class Snowflake(Database):
DIALECT_CLASS: ClassVar[Type[BaseDialect]] = Dialect
CONNECT_URI_HELP = "snowflake://<user>:<password>@<account>/<database>/<SCHEMA>?warehouse=<WAREHOUSE>"
CONNECT_URI_PARAMS = ["database", "schema"]
CONNECT_URI_KWPARAMS = ["warehouse"]
_conn: Any
def __init__(self, *, schema: str, key: Optional[str] = None, key_content: Optional[str] = None, **kw) -> None:
super().__init__()
snowflake, serialization, default_backend = import_snowflake()
logging.getLogger("snowflake.connector").setLevel(logging.WARNING)
# Ignore the error: snowflake.connector.network.RetryRequest: could not find io module state
# It's a known issue: https://github.com/snowflakedb/snowflake-connector-python/issues/145
logging.getLogger("snowflake.connector.network").disabled = True
assert '"' not in schema, "Schema name should not contain quotes!"
if key_content and key:
raise ConnectError("Only key value or key file path can be specified, not both")
key_bytes = None
if key:
with open(key, "rb") as f:
key_bytes = f.read()
if key_content:
key_bytes = base64.b64decode(key_content)
# If a private key is used, read it from the specified path and pass it as "private_key" to the connector.
if key_bytes:
if "password" in kw:
raise ConnectError("Cannot use password and key at the same time")
if kw.get("private_key_passphrase"):
encoded_passphrase = kw.get("private_key_passphrase").encode()
else:
encoded_passphrase = None
p_key = serialization.load_pem_private_key(
key_bytes,
password=encoded_passphrase,
backend=default_backend(),
)
kw["private_key"] = p_key.private_bytes(
encoding=serialization.Encoding.DER,
format=serialization.PrivateFormat.PKCS8,
encryption_algorithm=serialization.NoEncryption(),
)
self._conn = snowflake.connector.connect(schema=f'"{schema}"', **kw)
self.default_schema = schema
def close(self):
super().close()
self._conn.close()
def _query(self, sql_code: Union[str, ThreadLocalInterpreter]):
"Uses the standard SQL cursor interface"
return self._query_conn(self._conn, sql_code)
def select_table_schema(self, path: DbPath) -> str:
"""Provide SQL for selecting the table schema as (name, type, date_prec, num_prec)"""
database, schema, name = self._normalize_table_path(path)
info_schema_path = ["information_schema", "columns"]
if database:
info_schema_path.insert(0, database)
return (
"SELECT column_name, data_type, datetime_precision, numeric_precision, numeric_scale"
" , coalesce(collation_name, 'utf8') "
f"FROM {'.'.join(info_schema_path)} "
f"WHERE table_name = '{name}' AND table_schema = '{schema}'"
)
def _normalize_table_path(self, path: DbPath) -> DbPath:
if len(path) == 1:
return None, self.default_schema, path[0]
elif len(path) == 2:
return None, path[0], path[1]
elif len(path) == 3:
return path
raise ValueError(
f"{self.name}: Bad table path for {self}: '{'.'.join(path)}'. Expected format: table, schema.table, or database.schema.table"
)
@property
def is_autocommit(self) -> bool:
return True
def query_table_unique_columns(self, path: DbPath) -> List[str]:
return []