Skip to content

Commit 43eaee6

Browse files
committed
Add Databricks/Spark SQL adapter
- Add DatabricksAdapter with connection URL support (databricks://) - Add Databricks to symmetric aggregation using xxhash64 hash function - Add 11 integration tests (skipped in CI, requires real Databricks workspace) - Update SemanticLayer to recognize databricks:// URLs - Add databricks-sql-connector to optional dependencies - Tests at parity with other databases: basic metrics, dimensions, joins, filters, ORDER BY, LIMIT, symmetric aggregates, 3-way joins - No CI job (requires real Databricks credentials)
1 parent 04b0fdf commit 43eaee6

7 files changed

Lines changed: 768 additions & 2 deletions

File tree

pyproject.toml

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -52,6 +52,10 @@ clickhouse = [
5252
"clickhouse-connect>=0.6.0",
5353
"pyarrow>=14.0.0", # For Arrow support
5454
]
55+
databricks = [
56+
"databricks-sql-connector>=2.0.0",
57+
"pyarrow>=14.0.0", # For Arrow support
58+
]
5559

5660
[build-system]
5761
requires = ["hatchling"]

sidemantic/core/semantic_layer.py

Lines changed: 7 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -34,6 +34,7 @@ def __init__(
3434
- bigquery://project_id/dataset_id
3535
- snowflake://user:password@account/database/schema
3636
- clickhouse://user:password@host:port/database
37+
- databricks://token@server-hostname/http-path
3738
dialect: SQL dialect for query generation (optional, inferred from adapter)
3839
auto_register: Set as current layer for auto-registration (default: True)
3940
use_preaggregations: Enable automatic pre-aggregation routing (default: False)
@@ -76,10 +77,15 @@ def __init__(
7677

7778
self.adapter = ClickHouseAdapter.from_url(connection)
7879
self.dialect = dialect or "clickhouse"
80+
elif connection.startswith("databricks://"):
81+
from sidemantic.db.databricks import DatabricksAdapter
82+
83+
self.adapter = DatabricksAdapter.from_url(connection)
84+
self.dialect = dialect or "databricks"
7985
else:
8086
raise ValueError(
8187
f"Unsupported connection URL: {connection}. "
82-
"Supported: duckdb:///, postgres://, bigquery://, snowflake://, clickhouse://, or BaseDatabaseAdapter instance"
88+
"Supported: duckdb:///, postgres://, bigquery://, snowflake://, clickhouse://, databricks://, or BaseDatabaseAdapter instance"
8389
)
8490
else:
8591
raise TypeError(f"connection must be a string URL or BaseDatabaseAdapter instance, got {type(connection)}")

sidemantic/core/symmetric_aggregate.py

Lines changed: 7 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -26,7 +26,7 @@ def build_symmetric_aggregate_sql(
2626
primary_key: The primary key field to use for deduplication
2727
agg_type: Type of aggregation (sum, avg, count, count_distinct)
2828
model_alias: Optional table/CTE alias to prefix columns
29-
dialect: SQL dialect (duckdb, bigquery, postgres, snowflake, clickhouse)
29+
dialect: SQL dialect (duckdb, bigquery, postgres, snowflake, clickhouse, databricks)
3030
3131
Returns:
3232
SQL expression using symmetric aggregates
@@ -68,6 +68,12 @@ def hash_func(col):
6868
def hash_func(col):
6969
return f"halfMD5(CAST({col} AS String))"
7070

71+
multiplier = "1048576" # 2^20 as literal
72+
elif dialect == "databricks":
73+
# Databricks/Spark SQL xxhash64 returns bigint
74+
def hash_func(col):
75+
return f"xxhash64(CAST({col} AS STRING))"
76+
7177
multiplier = "1048576" # 2^20 as literal
7278
else: # duckdb
7379

sidemantic/db/__init__.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -27,4 +27,8 @@ def __getattr__(name):
2727
from sidemantic.db.clickhouse import ClickHouseAdapter
2828

2929
return ClickHouseAdapter
30+
if name == "DatabricksAdapter":
31+
from sidemantic.db.databricks import DatabricksAdapter
32+
33+
return DatabricksAdapter
3034
raise AttributeError(f"module '{__name__}' has no attribute '{name}'")

sidemantic/db/databricks.py

Lines changed: 214 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,214 @@
1+
"""Databricks/Spark SQL database adapter."""
2+
3+
from typing import Any
4+
from urllib.parse import parse_qs, unquote, urlparse
5+
6+
from sidemantic.db.base import BaseDatabaseAdapter
7+
8+
9+
class DatabricksResult:
10+
"""Wrapper for Databricks cursor to match DuckDB result API."""
11+
12+
def __init__(self, cursor):
13+
"""Initialize Databricks result wrapper.
14+
15+
Args:
16+
cursor: Databricks cursor object
17+
"""
18+
self.cursor = cursor
19+
self._description = cursor.description
20+
21+
def fetchone(self) -> tuple | None:
22+
"""Fetch one row from the result."""
23+
return self.cursor.fetchone()
24+
25+
def fetchall(self) -> list[tuple]:
26+
"""Fetch all remaining rows."""
27+
return self.cursor.fetchall()
28+
29+
def fetch_record_batch(self) -> Any:
30+
"""Convert result to PyArrow RecordBatchReader."""
31+
import pyarrow as pa
32+
33+
# Databricks cursor may support Arrow format directly
34+
# For now, convert from standard result
35+
rows = self.cursor.fetchall()
36+
if not rows:
37+
# Empty result
38+
schema = pa.schema([(desc[0], pa.string()) for desc in self._description])
39+
return pa.RecordBatchReader.from_batches(schema, [])
40+
41+
# Build Arrow table from rows
42+
columns = {desc[0]: [row[i] for row in rows] for i, desc in enumerate(self._description)}
43+
table = pa.table(columns)
44+
return pa.RecordBatchReader.from_batches(table.schema, table.to_batches())
45+
46+
@property
47+
def description(self):
48+
"""Get column descriptions."""
49+
return self._description
50+
51+
52+
class DatabricksAdapter(BaseDatabaseAdapter):
53+
"""Databricks/Spark SQL database adapter.
54+
55+
Example:
56+
>>> adapter = DatabricksAdapter(
57+
... server_hostname="your-workspace.cloud.databricks.com",
58+
... http_path="/sql/1.0/warehouses/abc123",
59+
... access_token="dapi..."
60+
... )
61+
>>> result = adapter.execute("SELECT * FROM table")
62+
"""
63+
64+
def __init__(
65+
self,
66+
server_hostname: str,
67+
http_path: str,
68+
access_token: str | None = None,
69+
catalog: str | None = None,
70+
schema: str | None = None,
71+
**kwargs,
72+
):
73+
"""Initialize Databricks adapter.
74+
75+
Args:
76+
server_hostname: Databricks workspace hostname
77+
http_path: SQL warehouse HTTP path
78+
access_token: Personal access token or service principal token
79+
catalog: Unity Catalog name (optional)
80+
schema: Schema/database name (optional)
81+
**kwargs: Additional arguments passed to databricks.sql.connect
82+
"""
83+
try:
84+
from databricks import sql
85+
except ImportError as e:
86+
raise ImportError(
87+
"Databricks support requires databricks-sql-connector. "
88+
"Install with: pip install sidemantic[databricks] or pip install databricks-sql-connector"
89+
) from e
90+
91+
# Build connection params
92+
conn_params = {
93+
"server_hostname": server_hostname,
94+
"http_path": http_path,
95+
}
96+
97+
if access_token:
98+
conn_params["access_token"] = access_token
99+
100+
if catalog:
101+
conn_params["catalog"] = catalog
102+
103+
if schema:
104+
conn_params["schema"] = schema
105+
106+
# Merge with additional kwargs
107+
conn_params.update(kwargs)
108+
109+
self.conn = sql.connect(**conn_params)
110+
self.catalog = catalog
111+
self.schema = schema
112+
113+
def execute(self, sql: str) -> DatabricksResult:
114+
"""Execute SQL query."""
115+
cursor = self.conn.cursor()
116+
cursor.execute(sql)
117+
return DatabricksResult(cursor)
118+
119+
def executemany(self, sql: str, params: list) -> DatabricksResult:
120+
"""Execute SQL with multiple parameter sets."""
121+
cursor = self.conn.cursor()
122+
cursor.executemany(sql, params)
123+
return DatabricksResult(cursor)
124+
125+
def fetchone(self, result: DatabricksResult) -> tuple | None:
126+
"""Fetch one row from result."""
127+
return result.fetchone()
128+
129+
def fetch_record_batch(self, result: DatabricksResult) -> Any:
130+
"""Fetch result as PyArrow RecordBatchReader."""
131+
return result.fetch_record_batch()
132+
133+
def get_tables(self) -> list[dict]:
134+
"""List all tables in the catalog/schema."""
135+
if self.schema:
136+
sql = f"SHOW TABLES IN {self.schema}"
137+
elif self.catalog:
138+
sql = f"SHOW TABLES IN {self.catalog}"
139+
else:
140+
sql = "SHOW TABLES"
141+
142+
result = self.execute(sql)
143+
rows = result.fetchall()
144+
return [{"table_name": row[1], "schema": row[0]} for row in rows]
145+
146+
def get_columns(self, table_name: str, schema: str | None = None) -> list[dict]:
147+
"""Get column information for a table."""
148+
schema = schema or self.schema
149+
table_ref = f"{schema}.{table_name}" if schema else table_name
150+
151+
sql = f"DESCRIBE {table_ref}"
152+
result = self.execute(sql)
153+
rows = result.fetchall()
154+
return [{"column_name": row[0], "data_type": row[1]} for row in rows]
155+
156+
def close(self) -> None:
157+
"""Close the Databricks connection."""
158+
self.conn.close()
159+
160+
@property
161+
def dialect(self) -> str:
162+
"""Return SQL dialect."""
163+
return "databricks"
164+
165+
@property
166+
def raw_connection(self) -> Any:
167+
"""Return raw Databricks connection."""
168+
return self.conn
169+
170+
@classmethod
171+
def from_url(cls, url: str) -> "DatabricksAdapter":
172+
"""Create adapter from connection URL.
173+
174+
URL format: databricks://token@server-hostname/http-path?catalog=x&schema=y
175+
Example: databricks://dapi123@my-workspace.cloud.databricks.com/sql/1.0/warehouses/abc?catalog=main&schema=default
176+
177+
Args:
178+
url: Connection URL
179+
180+
Returns:
181+
DatabricksAdapter instance
182+
"""
183+
if not url.startswith("databricks://"):
184+
raise ValueError(f"Invalid Databricks URL: {url}")
185+
186+
parsed = urlparse(url)
187+
188+
# Parse hostname
189+
server_hostname = parsed.hostname
190+
if not server_hostname:
191+
raise ValueError("Databricks URL must include server hostname")
192+
193+
# Parse path as http_path (everything after hostname)
194+
http_path = parsed.path or ""
195+
196+
# Parse token from username (password is ignored)
197+
access_token = unquote(parsed.username) if parsed.username else None
198+
199+
# Parse query parameters for catalog and schema
200+
params = {}
201+
if parsed.query:
202+
params = {k: v[0] if len(v) == 1 else v for k, v in parse_qs(parsed.query).items()}
203+
204+
catalog = params.pop("catalog", None)
205+
schema = params.pop("schema", None)
206+
207+
return cls(
208+
server_hostname=server_hostname,
209+
http_path=http_path,
210+
access_token=access_token,
211+
catalog=catalog,
212+
schema=schema,
213+
**params,
214+
)

0 commit comments

Comments
 (0)