Skip to content

Commit 099ec9d

Browse files
committed
Quote reserved-word identifiers via sqlglot.to_identifier
_quote_identifier was only quoting names with dots/special chars but passing reserved words like 'order' through unquoted, producing invalid SQL. Now delegates to sqlglot.to_identifier for all simple names so reserved words get quoted automatically. Uses lru_cache to avoid perf regression from repeated sqlglot calls.
1 parent 72a340d commit 099ec9d

1 file changed

Lines changed: 13 additions & 4 deletions

File tree

sidemantic/sql/generator.py

Lines changed: 13 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,7 @@
22

33
import logging
44
import threading
5+
from functools import lru_cache
56

67
import sqlglot
78
from sqlglot import exp, select
@@ -12,6 +13,15 @@
1213
from sidemantic.core.symmetric_aggregate import build_symmetric_aggregate_sql
1314
from sidemantic.sql.aggregation_detection import sql_has_aggregate
1415

16+
17+
@lru_cache(maxsize=4096)
18+
def _quote_identifier_cached(name: str, dialect: str, is_simple: bool) -> str:
19+
"""Cached identifier quoting, shared across all SQLGenerator instances."""
20+
if is_simple:
21+
return sqlglot.to_identifier(name).sql(dialect=dialect)
22+
return sqlglot.to_identifier(name, quoted=True).sql(dialect=dialect)
23+
24+
1525
_dialect_cache: dict[str, Dialect] = {}
1626
_tls = threading.local()
1727

@@ -232,11 +242,10 @@ def _quote_identifier(self, name: str) -> str:
232242
"""Quote a SQL identifier for the current dialect.
233243
234244
Delegates to sqlglot which handles reserved words (e.g., 'order')
235-
and special characters automatically.
245+
and special characters automatically. Results are cached since the
246+
same identifiers are used many times during query generation.
236247
"""
237-
if self._is_simple_identifier(name):
238-
return name
239-
return sqlglot.to_identifier(name, quoted=True).sql(dialect=self.dialect)
248+
return _quote_identifier_cached(name, self.dialect, self._is_simple_identifier(name))
240249

241250
def _cte_name(self, model_name: str) -> str:
242251
"""Get the CTE identifier name for a model."""

0 commit comments

Comments
 (0)