Skip to content

Commit 0a8da47

Browse files
committed
Add init_sql support to DuckDB connection
1 parent 9c95340 commit 0a8da47

6 files changed

Lines changed: 233 additions & 7 deletions

File tree

sidemantic/cli.py

Lines changed: 6 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -5,7 +5,7 @@
55
import typer
66

77
from sidemantic import SemanticLayer, __version__, load_from_directory
8-
from sidemantic.config import SidemanticConfig, build_connection_string, find_config, load_config
8+
from sidemantic.config import SidemanticConfig, build_connection_string, find_config, get_init_sql, load_config
99

1010

1111
def version_callback(value: bool):
@@ -348,6 +348,7 @@ def query(
348348
try:
349349
# Build connection string from args or config
350350
connection_str = None
351+
init_sql = None
351352
if connection:
352353
# Explicit --connection arg provided
353354
connection_str = connection
@@ -357,6 +358,7 @@ def query(
357358
elif _loaded_config and _loaded_config.connection:
358359
# Use connection from config
359360
connection_str = build_connection_string(_loaded_config)
361+
init_sql = get_init_sql(_loaded_config)
360362
else:
361363
# Try to find database file in data/
362364
data_dir = models / "data"
@@ -369,7 +371,9 @@ def query(
369371
preagg_db = _loaded_config.preagg_database if _loaded_config else None
370372
preagg_sch = _loaded_config.preagg_schema if _loaded_config else None
371373
if connection_str:
372-
layer = SemanticLayer(connection=connection_str, preagg_database=preagg_db, preagg_schema=preagg_sch)
374+
layer = SemanticLayer(
375+
connection=connection_str, preagg_database=preagg_db, preagg_schema=preagg_sch, init_sql=init_sql
376+
)
373377
else:
374378
layer = SemanticLayer(preagg_database=preagg_db, preagg_schema=preagg_sch)
375379
load_from_directory(layer, str(models))

sidemantic/config.py

Lines changed: 19 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -11,6 +11,10 @@ class DuckDBConnection(BaseModel):
1111

1212
type: Literal["duckdb"] = "duckdb"
1313
path: str = Field(..., description="Path to DuckDB database file or :memory:")
14+
init_sql: list[str] | None = Field(
15+
default=None,
16+
description="SQL statements to run after connecting (e.g., loading extensions, attaching catalogs)",
17+
)
1418

1519

1620
class PostgreSQLConnection(BaseModel):
@@ -186,7 +190,7 @@ def resolve_paths(self, base_dir: Path | None = None) -> "SidemanticConfig":
186190
db_p = Path(connection.path)
187191
if not db_p.is_absolute():
188192
db_p = (base / db_p).resolve()
189-
connection = DuckDBConnection(type="duckdb", path=str(db_p))
193+
connection = DuckDBConnection(type="duckdb", path=str(db_p), init_sql=connection.init_sql)
190194

191195
return SidemanticConfig(
192196
models_dir=str(models_path),
@@ -263,6 +267,20 @@ def find_config(start_dir: Path | None = None) -> Path | None:
263267
return None
264268

265269

270+
def get_init_sql(config: SidemanticConfig) -> list[str] | None:
271+
"""Get init_sql statements from config, if any.
272+
273+
Args:
274+
config: Sidemantic configuration
275+
276+
Returns:
277+
List of SQL statements to run after connecting, or None
278+
"""
279+
if config.connection and isinstance(config.connection, DuckDBConnection):
280+
return config.connection.init_sql
281+
return None
282+
283+
266284
def build_connection_string(config: SidemanticConfig) -> str:
267285
"""Build database connection string from config.
268286

sidemantic/core/semantic_layer.py

Lines changed: 8 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -24,6 +24,7 @@ def __init__(
2424
use_preaggregations: bool = False,
2525
preagg_database: str | None = None,
2626
preagg_schema: str | None = None,
27+
init_sql: list[str] | None = None,
2728
):
2829
"""Initialize semantic layer.
2930
@@ -45,6 +46,8 @@ def __init__(
4546
use_preaggregations: Enable automatic pre-aggregation routing (default: False)
4647
preagg_database: Optional database name for pre-aggregation tables
4748
preagg_schema: Optional schema name for pre-aggregation tables
49+
init_sql: SQL statements to run after connecting (DuckDB only, e.g.,
50+
loading extensions, attaching catalogs, creating secrets)
4851
"""
4952
from sidemantic.db.base import BaseDatabaseAdapter
5053

@@ -66,10 +69,14 @@ def __init__(
6669

6770
self.adapter = MotherDuckAdapter.from_url(connection)
6871
self.dialect = dialect or "duckdb"
72+
# Run init_sql after MotherDuck connection
73+
if init_sql:
74+
for stmt in init_sql:
75+
self.adapter.execute(stmt)
6976
elif connection.startswith("duckdb://"):
7077
from sidemantic.db.duckdb import DuckDBAdapter
7178

72-
self.adapter = DuckDBAdapter.from_url(connection)
79+
self.adapter = DuckDBAdapter.from_url(connection, init_sql=init_sql)
7380
self.dialect = dialect or "duckdb"
7481
elif connection.startswith(("postgres://", "postgresql://")):
7582
from sidemantic.db.postgres import PostgreSQLAdapter

sidemantic/db/duckdb.py

Lines changed: 24 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -14,11 +14,21 @@ class DuckDBAdapter(BaseDatabaseAdapter):
1414
Wraps DuckDB connection to provide unified adapter interface.
1515
"""
1616

17-
def __init__(self, path: str = ":memory:", read_only: bool = False, config: dict[str, Any] | None = None):
17+
def __init__(
18+
self,
19+
path: str = ":memory:",
20+
read_only: bool = False,
21+
config: dict[str, Any] | None = None,
22+
init_sql: list[str] | None = None,
23+
):
1824
"""Initialize DuckDB adapter.
1925
2026
Args:
2127
path: Database file path or ":memory:" for in-memory database
28+
read_only: Open database in read-only mode
29+
config: DuckDB configuration options
30+
init_sql: SQL statements to run immediately after connecting
31+
(e.g., loading extensions, attaching catalogs, creating secrets)
2232
"""
2333
if not read_only and config is None:
2434
self.conn = duckdb.connect(path)
@@ -27,6 +37,10 @@ def __init__(self, path: str = ":memory:", read_only: bool = False, config: dict
2737
else:
2838
self.conn = duckdb.connect(path, read_only=read_only, config=config)
2939

40+
if init_sql:
41+
for stmt in init_sql:
42+
self.conn.execute(stmt)
43+
3044
def execute(self, sql: str) -> Any:
3145
"""Execute SQL and return DuckDB relation."""
3246
return self.conn.execute(sql)
@@ -88,11 +102,12 @@ def raw_connection(self) -> Any:
88102
return self.conn
89103

90104
@classmethod
91-
def from_url(cls, url: str) -> "DuckDBAdapter":
105+
def from_url(cls, url: str, init_sql: list[str] | None = None) -> "DuckDBAdapter":
92106
"""Create adapter from connection URL.
93107
94108
Args:
95109
url: Connection URL (e.g., "duckdb:///:memory:" or "duckdb:///path/to/db.duckdb")
110+
init_sql: SQL statements to run after connecting (overrides any init_sql in URL params)
96111
97112
Returns:
98113
DuckDBAdapter instance
@@ -117,6 +132,7 @@ def from_url(cls, url: str) -> "DuckDBAdapter":
117132
query = parse_qs(parsed.query)
118133
read_only = False
119134
config: dict[str, Any] = {}
135+
url_init_sql: list[str] | None = None
120136

121137
def parse_value(value: str) -> Any:
122138
lowered = value.lower()
@@ -132,10 +148,15 @@ def parse_value(value: str) -> Any:
132148
for key, values in query.items():
133149
if not values:
134150
continue
151+
if key == "init_sql":
152+
url_init_sql = values
153+
continue
135154
value = values[-1]
136155
if key == "read_only":
137156
read_only = bool(parse_value(value))
138157
else:
139158
config[key] = parse_value(value)
140159

141-
return cls(db_path, read_only=read_only, config=config or None)
160+
# Parameter init_sql overrides any init_sql found in URL query params
161+
effective_init_sql = init_sql if init_sql is not None else url_init_sql
162+
return cls(db_path, read_only=read_only, config=config or None, init_sql=effective_init_sql)

tests/db/test_duckdb_adapter.py

Lines changed: 80 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -189,3 +189,83 @@ def test_duckdb_valid_schema_names_accepted(schema):
189189
# Should not raise
190190
columns = adapter.get_columns("orders", schema=schema)
191191
assert len(columns) >= 1
192+
193+
194+
def test_duckdb_adapter_init_sql():
195+
"""Test that init_sql statements run after connecting."""
196+
adapter = DuckDBAdapter(
197+
":memory:",
198+
init_sql=["CREATE TABLE setup_test (id INT, name VARCHAR)"],
199+
)
200+
result = adapter.execute("SELECT * FROM setup_test")
201+
assert result.fetchall() == []
202+
203+
# Verify columns were created
204+
columns = adapter.get_columns("setup_test")
205+
col_names = {c["column_name"] for c in columns}
206+
assert col_names == {"id", "name"}
207+
208+
209+
def test_duckdb_adapter_init_sql_multiple_statements():
210+
"""Test multiple init_sql statements execute in order."""
211+
adapter = DuckDBAdapter(
212+
":memory:",
213+
init_sql=[
214+
"CREATE TABLE t1 (x INT)",
215+
"INSERT INTO t1 VALUES (42)",
216+
"CREATE TABLE t2 AS SELECT x * 2 AS doubled FROM t1",
217+
],
218+
)
219+
result = adapter.execute("SELECT doubled FROM t2")
220+
assert result.fetchone()[0] == 84
221+
222+
223+
def test_duckdb_adapter_init_sql_none():
224+
"""Test that init_sql=None is fine (no-op)."""
225+
adapter = DuckDBAdapter(":memory:", init_sql=None)
226+
result = adapter.execute("SELECT 1")
227+
assert result.fetchone()[0] == 1
228+
229+
230+
def test_duckdb_adapter_from_url_with_init_sql():
231+
"""Test from_url passes init_sql through."""
232+
adapter = DuckDBAdapter.from_url(
233+
"duckdb:///:memory:",
234+
init_sql=["CREATE TABLE url_test (val INT)"],
235+
)
236+
result = adapter.execute("SELECT COUNT(*) FROM url_test")
237+
assert result.fetchone()[0] == 0
238+
239+
240+
def test_duckdb_adapter_from_url_init_sql_in_query_params():
241+
"""Test init_sql can be passed via URL query parameters."""
242+
adapter = DuckDBAdapter.from_url("duckdb:///:memory:?init_sql=CREATE+TABLE+qs_test+(id+INT)")
243+
result = adapter.execute("SELECT COUNT(*) FROM qs_test")
244+
assert result.fetchone()[0] == 0
245+
246+
247+
def test_duckdb_adapter_from_url_param_overrides_query():
248+
"""Test that explicit init_sql parameter overrides URL query params."""
249+
adapter = DuckDBAdapter.from_url(
250+
"duckdb:///:memory:?init_sql=CREATE+TABLE+url_table+(id+INT)",
251+
init_sql=["CREATE TABLE param_table (id INT)"],
252+
)
253+
# param_table should exist (from explicit param)
254+
result = adapter.execute("SELECT COUNT(*) FROM param_table")
255+
assert result.fetchone()[0] == 0
256+
257+
# url_table should NOT exist (overridden)
258+
with pytest.raises(Exception):
259+
adapter.execute("SELECT COUNT(*) FROM url_table")
260+
261+
262+
def test_duckdb_semantic_layer_init_sql():
263+
"""Test init_sql flows through SemanticLayer constructor."""
264+
from sidemantic import SemanticLayer
265+
266+
layer = SemanticLayer(
267+
connection="duckdb:///:memory:",
268+
init_sql=["CREATE TABLE layer_test (id INT)"],
269+
)
270+
result = layer.adapter.execute("SELECT COUNT(*) FROM layer_test")
271+
assert result.fetchone()[0] == 0
Lines changed: 96 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,96 @@
1+
"""Tests for init_sql config parsing and plumbing."""
2+
3+
from sidemantic.config import (
4+
DuckDBConnection,
5+
SidemanticConfig,
6+
build_connection_string,
7+
get_init_sql,
8+
load_config,
9+
)
10+
11+
12+
def test_duckdb_connection_init_sql_field():
13+
"""Test DuckDBConnection accepts init_sql."""
14+
conn = DuckDBConnection(path=":memory:", init_sql=["LOAD httpfs", "LOAD iceberg"])
15+
assert conn.init_sql == ["LOAD httpfs", "LOAD iceberg"]
16+
17+
18+
def test_duckdb_connection_init_sql_none_by_default():
19+
"""Test DuckDBConnection init_sql defaults to None."""
20+
conn = DuckDBConnection(path=":memory:")
21+
assert conn.init_sql is None
22+
23+
24+
def test_get_init_sql_returns_statements():
25+
"""Test get_init_sql extracts init_sql from config."""
26+
config = SidemanticConfig(
27+
connection=DuckDBConnection(
28+
path=":memory:",
29+
init_sql=["INSTALL httpfs", "LOAD httpfs"],
30+
)
31+
)
32+
result = get_init_sql(config)
33+
assert result == ["INSTALL httpfs", "LOAD httpfs"]
34+
35+
36+
def test_get_init_sql_returns_none_when_no_init_sql():
37+
"""Test get_init_sql returns None when no init_sql configured."""
38+
config = SidemanticConfig(connection=DuckDBConnection(path=":memory:"))
39+
assert get_init_sql(config) is None
40+
41+
42+
def test_get_init_sql_returns_none_for_non_duckdb():
43+
"""Test get_init_sql returns None for non-DuckDB connections."""
44+
config = SidemanticConfig(connection=None)
45+
assert get_init_sql(config) is None
46+
47+
48+
def test_load_config_with_init_sql(tmp_path):
49+
"""Test loading YAML config with init_sql."""
50+
config_path = tmp_path / "sidemantic.yaml"
51+
config_path.write_text(
52+
"""
53+
models_dir: .
54+
connection:
55+
type: duckdb
56+
path: ":memory:"
57+
init_sql:
58+
- "INSTALL httpfs"
59+
- "LOAD httpfs"
60+
- "ATTACH 's3://bucket/db.duckdb' AS remote"
61+
"""
62+
)
63+
config = load_config(config_path)
64+
assert config.connection is not None
65+
init_sql = get_init_sql(config)
66+
assert init_sql == [
67+
"INSTALL httpfs",
68+
"LOAD httpfs",
69+
"ATTACH 's3://bucket/db.duckdb' AS remote",
70+
]
71+
72+
73+
def test_build_connection_string_ignores_init_sql():
74+
"""Test that build_connection_string produces URL without init_sql."""
75+
config = SidemanticConfig(
76+
connection=DuckDBConnection(
77+
path=":memory:",
78+
init_sql=["LOAD httpfs"],
79+
)
80+
)
81+
url = build_connection_string(config)
82+
assert url == "duckdb:///:memory:"
83+
assert "init_sql" not in url
84+
85+
86+
def test_resolve_paths_preserves_init_sql(tmp_path):
87+
"""Test that resolve_paths keeps init_sql when resolving DuckDB paths."""
88+
config = SidemanticConfig(
89+
connection=DuckDBConnection(
90+
path="data/warehouse.db",
91+
init_sql=["LOAD httpfs"],
92+
)
93+
)
94+
resolved = config.resolve_paths(tmp_path)
95+
assert isinstance(resolved.connection, DuckDBConnection)
96+
assert resolved.connection.init_sql == ["LOAD httpfs"]

0 commit comments

Comments
 (0)