This repository was archived by the owner on May 17, 2024. It is now read-only.
-
Notifications
You must be signed in to change notification settings - Fork 305
Expand file tree
/
Copy pathhashdiff_tables.py
More file actions
224 lines (180 loc) · 9.28 KB
/
hashdiff_tables.py
File metadata and controls
224 lines (180 loc) · 9.28 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
import os
from dataclasses import field
from numbers import Number
import logging
from collections import defaultdict
from typing import Iterator
from operator import attrgetter
from runtype import dataclass
from data_diff.sqeleton.abcs import ColType_UUID, NumericType, PrecisionType, StringType, Boolean, JSON
from data_diff.info_tree import InfoTree
from data_diff.utils import safezip, diffs_are_equiv_jsons
from data_diff.thread_utils import ThreadedYielder
from data_diff.table_segment import TableSegment
from data_diff.diff_tables import TableDiffer
BENCHMARK = os.environ.get("BENCHMARK", False)
DEFAULT_BISECTION_THRESHOLD = 1024 * 16
DEFAULT_BISECTION_FACTOR = 32
logger = logging.getLogger("hashdiff_tables")
def diff_sets(a: list, b: list, json_cols: dict = None) -> Iterator:
sa = set(a)
sb = set(b)
# The first item is always the key (see TableDiffer.relevant_columns)
# TODO update when we add compound keys to hashdiff
d = defaultdict(list)
for row in a:
if row not in sb:
d[row[0]].append(("-", row))
for row in b:
if row not in sa:
d[row[0]].append(("+", row))
warned_diff_cols = set()
for _k, v in sorted(d.items(), key=lambda i: i[0]):
if json_cols:
parsed_match, overriden_diff_cols = diffs_are_equiv_jsons(v, json_cols)
if parsed_match:
to_warn = overriden_diff_cols - warned_diff_cols
for w in to_warn:
logger.warning(
f"Equivalent JSON objects with different string representations detected "
f"in column '{w}'. These cases are NOT reported as differences."
)
warned_diff_cols.add(w)
continue
yield from v
@dataclass
class HashDiffer(TableDiffer):
"""Finds the diff between two SQL tables
The algorithm uses hashing to quickly check if the tables are different, and then applies a
bisection search recursively to find the differences efficiently.
Works best for comparing tables that are mostly the same, with minor discrepancies.
Parameters:
bisection_factor (int): Into how many segments to bisect per iteration.
bisection_threshold (Number): When should we stop bisecting and compare locally (in row count).
threaded (bool): Enable/disable threaded diffing. Needed to take advantage of database threads.
max_threadpool_size (int): Maximum size of each threadpool. ``None`` means auto.
Only relevant when `threaded` is ``True``.
There may be many pools, so number of actual threads can be a lot higher.
"""
bisection_factor: int = DEFAULT_BISECTION_FACTOR
bisection_threshold: Number = DEFAULT_BISECTION_THRESHOLD # Accepts inf for tests
stats: dict = field(default_factory=dict)
def __post_init__(self):
# Validate options
if self.bisection_factor >= self.bisection_threshold:
raise ValueError("Incorrect param values (bisection factor must be lower than threshold)")
if self.bisection_factor < 2:
raise ValueError("Must have at least two segments per iteration (i.e. bisection_factor >= 2)")
def _validate_and_adjust_columns(self, table1, table2):
for c1, c2 in safezip(table1.relevant_columns, table2.relevant_columns):
if c1 not in table1._schema:
raise ValueError(f"Column '{c1}' not found in schema for table {table1}")
if c2 not in table2._schema:
raise ValueError(f"Column '{c2}' not found in schema for table {table2}")
# Update schemas to minimal mutual precision
col1 = table1._schema[c1]
col2 = table2._schema[c2]
if isinstance(col1, PrecisionType):
if not isinstance(col2, PrecisionType):
raise TypeError(f"Incompatible types for column '{c1}': {col1} <-> {col2}")
lowest = min(col1, col2, key=attrgetter("precision"))
if col1.precision != col2.precision:
logger.warning(f"Using reduced precision {lowest} for column '{c1}'. Types={col1}, {col2}")
table1._schema[c1] = col1.replace(precision=lowest.precision, rounds=lowest.rounds)
table2._schema[c2] = col2.replace(precision=lowest.precision, rounds=lowest.rounds)
elif isinstance(col1, (NumericType, Boolean)):
if not isinstance(col2, (NumericType, Boolean)):
raise TypeError(f"Incompatible types for column '{c1}': {col1} <-> {col2}")
lowest = min(col1, col2, key=attrgetter("precision"))
if col1.precision != col2.precision:
logger.warning(f"Using reduced precision {lowest} for column '{c1}'. Types={col1}, {col2}")
if lowest.precision != col1.precision:
table1._schema[c1] = col1.replace(precision=lowest.precision)
if lowest.precision != col2.precision:
table2._schema[c2] = col2.replace(precision=lowest.precision)
elif isinstance(col1, ColType_UUID):
if not isinstance(col2, ColType_UUID):
raise TypeError(f"Incompatible types for column '{c1}': {col1} <-> {col2}")
elif isinstance(col1, StringType):
if not isinstance(col2, StringType):
raise TypeError(f"Incompatible types for column '{c1}': {col1} <-> {col2}")
for t in [table1, table2]:
for c in t.relevant_columns:
ctype = t._schema[c]
if not ctype.supported:
logger.warning(
f"[{t.database.name}] Column '{c}' of type '{ctype}' has no compatibility handling. "
"If encoding/formatting differs between databases, it may result in false positives."
)
def _diff_segments(
self,
ti: ThreadedYielder,
table1: TableSegment,
table2: TableSegment,
info_tree: InfoTree,
max_rows: int,
level=0,
segment_index=None,
segment_count=None,
):
logger.info(
". " * level + f"Diffing segment {segment_index}/{segment_count}, "
f"key-range: {table1.min_key}..{table2.max_key}, "
f"size <= {max_rows}"
)
# When benchmarking, we want the ability to skip checksumming. This
# allows us to download all rows for comparison in performance. By
# default, data-diff will checksum the section first (when it's below
# the threshold) and _then_ download it.
if BENCHMARK:
if max_rows < self.bisection_threshold:
return self._bisect_and_diff_segments(ti, table1, table2, info_tree, level=level, max_rows=max_rows)
(count1, checksum1), (count2, checksum2) = self._threaded_call("count_and_checksum", [table1, table2])
assert not info_tree.info.rowcounts
info_tree.info.rowcounts = {1: count1, 2: count2}
if count1 == 0 and count2 == 0:
logger.debug(
"Uneven distribution of keys detected in segment %s..%s (big gaps in the key column). "
"For better performance, we recommend to increase the bisection-threshold.",
table1.min_key,
table1.max_key,
)
assert checksum1 is None and checksum2 is None
info_tree.info.is_diff = False
return
if checksum1 == checksum2:
info_tree.info.is_diff = False
return
info_tree.info.is_diff = True
return self._bisect_and_diff_segments(ti, table1, table2, info_tree, level=level, max_rows=max(count1, count2))
def _bisect_and_diff_segments(
self,
ti: ThreadedYielder,
table1: TableSegment,
table2: TableSegment,
info_tree: InfoTree,
level=0,
max_rows=None,
):
assert table1.is_bounded and table2.is_bounded
max_space_size = max(table1.approximate_size(), table2.approximate_size())
if max_rows is None:
# We can be sure that row_count <= max_rows iff the table key is unique
max_rows = max_space_size
info_tree.info.max_rows = max_rows
# If count is below the threshold, just download and compare the columns locally
# This saves time, as bisection speed is limited by ping and query performance.
if max_rows < self.bisection_threshold or max_space_size < self.bisection_factor * 2:
rows1, rows2 = self._threaded_call("get_values", [table1, table2])
json_cols = {
i: colname
for i, colname in enumerate(table1.extra_columns)
if isinstance(table1._schema[colname], JSON)
}
diff = list(diff_sets(rows1, rows2, json_cols))
info_tree.info.set_diff(diff)
info_tree.info.rowcounts = {1: len(rows1), 2: len(rows2)}
logger.info(". " * level + f"Diff found {len(diff)} different rows.")
self.stats["rows_downloaded"] = self.stats.get("rows_downloaded", 0) + max(len(rows1), len(rows2))
return diff
return super()._bisect_and_diff_segments(ti, table1, table2, info_tree, level, max_rows)