This repository was archived by the owner on May 17, 2024. It is now read-only.
-
Notifications
You must be signed in to change notification settings - Fork 305
Expand file tree
/
Copy pathtest_database.py
More file actions
165 lines (125 loc) · 5.36 KB
/
test_database.py
File metadata and controls
165 lines (125 loc) · 5.36 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
import unittest
from datetime import datetime
from typing import Callable, List, Tuple
import pytz
from data_diff.sqeleton import connect
from data_diff.sqeleton import databases as dbs
from data_diff.sqeleton.queries import table, current_timestamp, NormalizeAsString
from tests.common import TEST_MYSQL_CONN_STRING, test_each_database_in_list, get_conn, str_to_checksum, random_table_suffix
from data_diff.sqeleton.abcs.database_types import TimestampTZ
TEST_DATABASES = {
dbs.MySQL,
dbs.PostgreSQL,
dbs.Oracle,
dbs.Redshift,
dbs.Snowflake,
dbs.DuckDB,
dbs.BigQuery,
dbs.Presto,
dbs.Trino,
dbs.Vertica,
dbs.MsSQL,
}
test_each_database: Callable = test_each_database_in_list(TEST_DATABASES)
class TestDatabase(unittest.TestCase):
def setUp(self):
self.mysql = connect(TEST_MYSQL_CONN_STRING)
def test_connect_to_db(self):
self.assertEqual(1, self.mysql.query("SELECT 1", int))
class TestMD5(unittest.TestCase):
def test_md5_as_int(self):
class MD5Dialect(dbs.mysql.Dialect, dbs.mysql.Mixin_MD5):
pass
self.mysql = connect(TEST_MYSQL_CONN_STRING)
self.mysql.dialect = MD5Dialect()
str = "hello world"
query_fragment = self.mysql.dialect.md5_as_int("'{0}'".format(str))
query = f"SELECT {query_fragment}"
self.assertEqual(str_to_checksum(str), self.mysql.query(query, int))
class TestConnect(unittest.TestCase):
def test_bad_uris(self):
self.assertRaises(ValueError, connect, "p")
self.assertRaises(ValueError, connect, "postgresql:///bla/foo")
self.assertRaises(ValueError, connect, "snowflake://user:pass@foo/bar/TEST1")
self.assertRaises(ValueError, connect, "snowflake://user:pass@foo/bar/TEST1?warehouse=ha&schema=dup")
@test_each_database
class TestSchema(unittest.TestCase):
def test_table_list(self):
name = "tbl_" + random_table_suffix()
db = get_conn(self.db_cls)
tbl = table(db.parse_table_name(name), schema={"id": int})
q = db.dialect.list_tables(db.default_schema, name)
assert not db.query(q)
db.query(tbl.create())
self.assertEqual(db.query(q, List[str]), [name])
db.query(tbl.drop())
assert not db.query(q)
def test_type_mapping(self):
name = "tbl_" + random_table_suffix()
db = get_conn(self.db_cls)
tbl = table(
db.parse_table_name(name),
schema={
"int": int,
"float": float,
"datetime": datetime,
"str": str,
"bool": bool,
},
)
q = db.dialect.list_tables(db.default_schema, name)
assert not db.query(q)
db.query(tbl.create())
self.assertEqual(db.query(q, List[str]), [name])
db.query(tbl.drop())
assert not db.query(q)
@test_each_database
class TestQueries(unittest.TestCase):
def test_current_timestamp(self):
db = get_conn(self.db_cls)
res = db.query(current_timestamp(), datetime)
assert isinstance(res, datetime), (res, type(res))
def test_correct_timezone(self):
if self.db_cls in [dbs.MsSQL]:
self.skipTest("No support for session tz.")
name = "tbl_" + random_table_suffix()
db = get_conn(self.db_cls)
tbl = table(name, schema={"id": int, "created_at": TimestampTZ(9), "updated_at": TimestampTZ(9)})
db.query(tbl.create())
tz = pytz.timezone("Europe/Berlin")
now = datetime.now(tz)
if isinstance(db, dbs.Presto):
ms = now.microsecond // 1000 * 1000 # Presto max precision is 3
now = now.replace(microsecond=ms)
db.query(table(name).insert_row(1, now, now))
db.query(db.dialect.set_timezone_to_utc())
t = db.table(name).query_schema()
t.schema["created_at"] = t.schema["created_at"].replace(precision=t.schema["created_at"].precision)
tbl = table(name, schema=t.schema)
results = db.query(tbl.select(NormalizeAsString(tbl[c]) for c in ["created_at", "updated_at"]), List[Tuple])
created_at = results[0][1]
updated_at = results[0][1]
utc = now.astimezone(pytz.UTC)
expected = utc.__format__("%Y-%m-%d %H:%M:%S.%f")
self.assertEqual(created_at, expected)
self.assertEqual(updated_at, expected)
db.query(tbl.drop())
@test_each_database
class TestThreePartIds(unittest.TestCase):
def test_three_part_support(self):
if self.db_cls not in [dbs.PostgreSQL, dbs.Redshift, dbs.Snowflake, dbs.DuckDB, dbs.MsSQL]:
self.skipTest("Limited support for 3 part ids")
table_name = "tbl_" + random_table_suffix()
db = get_conn(self.db_cls)
db_res = db.query(f"SELECT {db.dialect.current_database()}")
schema_res = db.query(f"SELECT {db.dialect.current_schema()}")
db_name = db_res.rows[0][0]
schema_name = schema_res.rows[0][0]
table_one_part = table((table_name,), schema={"id": int})
table_two_part = table((schema_name, table_name), schema={"id": int})
table_three_part = table((db_name, schema_name, table_name), schema={"id": int})
for part in (table_one_part, table_two_part, table_three_part):
db.query(part.create())
d = db.query_table_schema(part.path)
assert len(d) == 1
db.query(part.drop())