11from __future__ import annotations
22
3- import os
3+ from dataclasses import replace
44import uuid
55from pathlib import Path
66
77import typer
88
99from sqlcompare .config import get_default_schema
1010from sqlcompare .db import DBConnection
11- from sqlcompare .helpers import create_table_from_select , detect_input , ensure_schema
11+ from sqlcompare .helpers import (
12+ detect_input ,
13+ materialize_sql_inputs ,
14+ resolve_connection ,
15+ resolve_materialized_tables ,
16+ )
1217from sqlcompare .log import log
1318from sqlcompare .stats .checks import get_check_map
1419from sqlcompare .stats .models import ColumnPair , StatsContext
1520from sqlcompare .stats .report import render_report
16-
17-
18- def _resolve_connection (connection : str | None ) -> str :
19- if connection :
20- return connection
21- default_conn = os .getenv ("SQLCOMPARE_CONN_DEFAULT" ) or os .getenv ("DTK_CONN_DEFAULT" )
22- if not default_conn :
23- raise ValueError (
24- "No connection specified. Use --connection or set SQLCOMPARE_CONN_DEFAULT."
25- )
26- return default_conn
21+ from sqlcompare .utils .concurrency import run_ordered
2722
2823
2924def _resolve_checks (checks : str | None ) -> list [str ]:
@@ -101,6 +96,29 @@ def _build_context(db: DBConnection, previous_name: str, current_name: str) -> S
10196 )
10297
10398
99+ def _run_selected_checks (
100+ context : StatsContext ,
101+ selected_check_names : list [str ],
102+ check_map : dict [str , object ],
103+ connection_id : str | None ,
104+ parallel_safe : bool ,
105+ ) -> list :
106+ def run_check (name : str ):
107+ definition = check_map [name ]
108+ if not parallel_safe :
109+ return definition .runner (context , definition )
110+
111+ with DBConnection (connection_id ) as db :
112+ return definition .runner (replace (context , db = db ), definition )
113+
114+ return run_ordered (
115+ selected_check_names ,
116+ run_check ,
117+ enabled = parallel_safe ,
118+ max_workers = 4 ,
119+ )
120+
121+
104122def compare_table_stats (
105123 table1 : str ,
106124 table2 : str ,
@@ -125,42 +143,63 @@ def compare_table_stats(
125143 table1_name = Path (spec_prev .value ).stem
126144 table2_name = Path (spec_new .value ).stem
127145 elif spec_prev .kind == "sql" or spec_new .kind == "sql" :
128- connection_id = _resolve_connection (connection )
146+ connection_id = resolve_connection (connection , error_cls = ValueError )
129147 schema = get_default_schema ()
130- schema_prefix = f"{ schema } ." if schema else ""
131148 suffix = uuid .uuid4 ().hex [:8 ]
132- table1_name = (
133- spec_prev .value
134- if spec_prev .kind == "table"
135- else f"{ schema_prefix } sqlcompare_stats_{ suffix } _previous"
136- )
137- table2_name = (
138- spec_new .value
139- if spec_new .kind == "table"
140- else f"{ schema_prefix } sqlcompare_stats_{ suffix } _new"
149+ table1_name , table2_name = resolve_materialized_tables (
150+ spec_prev ,
151+ spec_new ,
152+ schema = schema ,
153+ prefix = "sqlcompare_stats" ,
154+ suffix = suffix ,
141155 )
142156 else :
143157 table1_name = spec_prev .value
144158 table2_name = spec_new .value
145159
146- with DBConnection (connection_id ) as db :
160+ parallel_safe = not (
161+ spec_prev .kind == "file" and spec_new .kind == "file" and connection is None
162+ )
163+
164+ def prepare_inputs (db : DBConnection ) -> None :
147165 if spec_prev .kind == "file" and spec_new .kind == "file" :
148166 if connection is None and connection_id == "duckdb:///:memory:" :
149167 db .create_table_from_file (table1_name , spec_prev .value )
150168 db .create_table_from_file (table2_name , spec_new .value )
151- if spec_prev .kind == "sql" or spec_new .kind == "sql" :
152- schema = get_default_schema ()
153- ensure_schema (db , schema )
154- if spec_prev .kind == "sql" :
155- create_table_from_select (db , table1_name , spec_prev .value )
156- if spec_new .kind == "sql" :
157- create_table_from_select (db , table2_name , spec_new .value )
158-
159- context = _build_context (db , table1_name , table2_name )
160- check_map = get_check_map ()
161- results = [
162- check_map [name ].runner (context , check_map [name ])
163- for name in selected_check_names
164- ]
169+ materialize_sql_inputs (
170+ db ,
171+ previous_spec = spec_prev ,
172+ current_spec = spec_new ,
173+ previous_table = table1_name ,
174+ current_table = table2_name ,
175+ schema = get_default_schema (),
176+ )
177+
178+ if not parallel_safe :
179+ with DBConnection (connection_id ) as db :
180+ prepare_inputs (db )
181+ context = _build_context (db , table1_name , table2_name )
182+ check_map = get_check_map ()
183+ results = _run_selected_checks (
184+ context = context ,
185+ selected_check_names = selected_check_names ,
186+ check_map = check_map ,
187+ connection_id = connection_id ,
188+ parallel_safe = False ,
189+ )
190+ else :
191+ with DBConnection (connection_id ) as db :
192+ prepare_inputs (db )
193+
194+ with DBConnection (connection_id ) as db :
195+ context = _build_context (db , table1_name , table2_name )
196+ check_map = get_check_map ()
197+ results = _run_selected_checks (
198+ context = context ,
199+ selected_check_names = selected_check_names ,
200+ check_map = check_map ,
201+ connection_id = connection_id ,
202+ parallel_safe = True ,
203+ )
165204
166205 log .info (render_report (context , selected_check_names , results ))
0 commit comments