Skip to content

Commit 0c5a80f

Browse files
committed
Fix MCP server tools to return plain dictionaries
The MCP tools (list_models, get_models, run_query, create_chart) were returning Pydantic model instances which included internal fields like model_fields and model_config. This caused confusing output when using the MCP server. Changed all tools to return plain dictionaries instead of Pydantic models. Updated tests to use dictionary access instead of attribute access.
1 parent b792691 commit 0c5a80f

2 files changed

Lines changed: 72 additions & 70 deletions

File tree

sidemantic/mcp_server.py

Lines changed: 35 additions & 33 deletions
Original file line numberDiff line numberDiff line change
@@ -91,7 +91,7 @@ def get_layer() -> SemanticLayer:
9191

9292

9393
@mcp.tool()
94-
def list_models() -> list[ModelInfo]:
94+
def list_models() -> list[dict[str, Any]]:
9595
"""List all available models in the semantic layer.
9696
9797
Models are the core building blocks of the semantic layer. Each model represents
@@ -111,20 +111,20 @@ def list_models() -> list[ModelInfo]:
111111
models = []
112112
for model_name, model in layer.graph.models.items():
113113
models.append(
114-
ModelInfo(
115-
name=model_name,
116-
table=model.table,
117-
dimensions=[d.name for d in model.dimensions],
118-
metrics=[m.name for m in model.metrics],
119-
relationships=len(model.relationships),
120-
)
114+
{
115+
"name": model_name,
116+
"table": model.table,
117+
"dimensions": [d.name for d in model.dimensions],
118+
"metrics": [m.name for m in model.metrics],
119+
"relationships": len(model.relationships),
120+
}
121121
)
122122

123123
return models
124124

125125

126126
@mcp.tool()
127-
def get_models(model_names: list[str]) -> list[ModelDetail]:
127+
def get_models(model_names: list[str]) -> list[dict[str, Any]]:
128128
"""Get detailed information about one or more models.
129129
130130
Returns comprehensive details about models including:
@@ -209,17 +209,19 @@ def get_models(model_names: list[str]) -> list[ModelDetail]:
209209
}
210210
)
211211

212-
details.append(
213-
ModelDetail(
214-
name=model_name,
215-
table=model.table,
216-
dimensions=dims,
217-
metrics=metrics,
218-
relationships=rels,
219-
source_format=getattr(model, "_source_format", None),
220-
source_file=getattr(model, "_source_file", None),
221-
)
222-
)
212+
detail = {
213+
"name": model_name,
214+
"table": model.table,
215+
"dimensions": dims,
216+
"metrics": metrics,
217+
"relationships": rels,
218+
}
219+
if source_format := getattr(model, "_source_format", None):
220+
detail["source_format"] = source_format
221+
if source_file := getattr(model, "_source_file", None):
222+
detail["source_file"] = source_file
223+
224+
details.append(detail)
223225

224226
return details
225227

@@ -231,7 +233,7 @@ def run_query(
231233
where: str | None = None,
232234
order_by: list[str] | None = None,
233235
limit: int | None = None,
234-
) -> QueryResult:
236+
) -> dict[str, Any]:
235237
"""Run a query against the semantic layer.
236238
237239
Sidemantic automatically generates SQL from semantic references and handles joins between models.
@@ -306,11 +308,11 @@ def run_query(
306308
columns = [desc[0] for desc in result.description]
307309
row_dicts = [dict(zip(columns, row)) for row in rows]
308310

309-
return QueryResult(
310-
sql=sql,
311-
rows=row_dicts,
312-
row_count=len(row_dicts),
313-
)
311+
return {
312+
"sql": sql,
313+
"rows": row_dicts,
314+
"row_count": len(row_dicts),
315+
}
314316

315317

316318
@mcp.tool()
@@ -324,7 +326,7 @@ def create_chart(
324326
title: str | None = None,
325327
width: int = 600,
326328
height: int = 400,
327-
) -> ChartResult:
329+
) -> dict[str, Any]:
328330
"""Generate a beautiful chart from a semantic layer query.
329331
330332
This tool combines query execution with intelligent chart generation, producing
@@ -431,12 +433,12 @@ def create_chart(
431433
vega_spec = chart_to_vega(chart)
432434
png_base64 = chart_to_base64_png(chart)
433435

434-
return ChartResult(
435-
sql=sql,
436-
vega_spec=vega_spec,
437-
png_base64=png_base64,
438-
row_count=len(row_dicts),
439-
)
436+
return {
437+
"sql": sql,
438+
"vega_spec": vega_spec,
439+
"png_base64": png_base64,
440+
"row_count": len(row_dicts),
441+
}
440442

441443

442444
def _generate_chart_title(dimensions: list[str], metrics: list[str]) -> str:

tests/test_mcp_server.py

Lines changed: 37 additions & 37 deletions
Original file line numberDiff line numberDiff line change
@@ -75,15 +75,15 @@ def test_list_models(demo_layer):
7575
models = list_models()
7676

7777
assert len(models) == 1
78-
assert models[0].name == "orders"
79-
assert models[0].table == "orders_table"
80-
assert len(models[0].dimensions) == 3
81-
assert len(models[0].metrics) == 2
82-
assert "order_id" in models[0].dimensions
83-
assert "customer_name" in models[0].dimensions
84-
assert "order_date" in models[0].dimensions
85-
assert "total_revenue" in models[0].metrics
86-
assert "order_count" in models[0].metrics
78+
assert models[0]["name"] == "orders"
79+
assert models[0]["table"] == "orders_table"
80+
assert len(models[0]["dimensions"]) == 3
81+
assert len(models[0]["metrics"]) == 2
82+
assert "order_id" in models[0]["dimensions"]
83+
assert "customer_name" in models[0]["dimensions"]
84+
assert "order_date" in models[0]["dimensions"]
85+
assert "total_revenue" in models[0]["metrics"]
86+
assert "order_count" in models[0]["metrics"]
8787

8888

8989
def test_get_models(demo_layer):
@@ -92,24 +92,24 @@ def test_get_models(demo_layer):
9292

9393
assert len(models) == 1
9494
model = models[0]
95-
assert model.name == "orders"
96-
assert model.table == "orders_table"
95+
assert model["name"] == "orders"
96+
assert model["table"] == "orders_table"
9797

9898
# Check dimensions
99-
assert len(model.dimensions) == 3
100-
dim_names = [d["name"] for d in model.dimensions]
99+
assert len(model["dimensions"]) == 3
100+
dim_names = [d["name"] for d in model["dimensions"]]
101101
assert "order_id" in dim_names
102102
assert "customer_name" in dim_names
103103
assert "order_date" in dim_names
104104

105105
# Check metrics
106-
assert len(model.metrics) == 2
107-
metric_names = [m["name"] for m in model.metrics]
106+
assert len(model["metrics"]) == 2
107+
metric_names = [m["name"] for m in model["metrics"]]
108108
assert "total_revenue" in metric_names
109109
assert "order_count" in metric_names
110110

111111
# Check metric details
112-
revenue_metric = next(m for m in model.metrics if m["name"] == "total_revenue")
112+
revenue_metric = next(m for m in model["metrics"] if m["name"] == "total_revenue")
113113
assert revenue_metric["agg"] == "sum"
114114
assert revenue_metric["sql"] == "amount"
115115

@@ -124,7 +124,7 @@ def test_get_models_multiple(demo_layer):
124124
"""Test getting multiple models (only one exists)."""
125125
models = get_models(["orders", "nonexistent"])
126126
assert len(models) == 1
127-
assert models[0].name == "orders"
127+
assert models[0]["name"] == "orders"
128128

129129

130130
def test_run_query_basic(demo_layer):
@@ -134,13 +134,13 @@ def test_run_query_basic(demo_layer):
134134
metrics=["orders.total_revenue"],
135135
)
136136

137-
assert result.sql is not None
138-
assert "SELECT" in result.sql.upper()
139-
assert "customer_name" in result.sql
140-
assert "SUM" in result.sql.upper()
137+
assert result["sql"] is not None
138+
assert "SELECT" in result["sql"].upper()
139+
assert "customer_name" in result["sql"]
140+
assert "SUM" in result["sql"].upper()
141141
# Should have 2 rows (Alice and Bob)
142-
assert result.row_count == 2
143-
assert len(result.rows) == 2
142+
assert result["row_count"] == 2
143+
assert len(result["rows"]) == 2
144144

145145

146146
def test_run_query_with_filter(demo_layer):
@@ -151,9 +151,9 @@ def test_run_query_with_filter(demo_layer):
151151
where="orders.customer_name = 'Alice'",
152152
)
153153

154-
assert result.sql is not None
155-
assert "WHERE" in result.sql.upper()
156-
assert "Alice" in result.sql
154+
assert result["sql"] is not None
155+
assert "WHERE" in result["sql"].upper()
156+
assert "Alice" in result["sql"]
157157

158158

159159
def test_run_query_with_order_by(demo_layer):
@@ -164,8 +164,8 @@ def test_run_query_with_order_by(demo_layer):
164164
order_by=["orders.total_revenue desc"],
165165
)
166166

167-
assert result.sql is not None
168-
assert "ORDER BY" in result.sql.upper()
167+
assert result["sql"] is not None
168+
assert "ORDER BY" in result["sql"].upper()
169169

170170

171171
def test_run_query_with_limit(demo_layer):
@@ -176,9 +176,9 @@ def test_run_query_with_limit(demo_layer):
176176
limit=10,
177177
)
178178

179-
assert result.sql is not None
180-
assert "LIMIT" in result.sql.upper()
181-
assert "10" in result.sql
179+
assert result["sql"] is not None
180+
assert "LIMIT" in result["sql"].upper()
181+
assert "10" in result["sql"]
182182

183183

184184
def test_run_query_dimensions_only(demo_layer):
@@ -187,9 +187,9 @@ def test_run_query_dimensions_only(demo_layer):
187187
dimensions=["orders.customer_name", "orders.order_date"],
188188
)
189189

190-
assert result.sql is not None
191-
assert "customer_name" in result.sql
192-
assert "order_date" in result.sql
190+
assert result["sql"] is not None
191+
assert "customer_name" in result["sql"]
192+
assert "order_date" in result["sql"]
193193

194194

195195
def test_run_query_metrics_only(demo_layer):
@@ -198,6 +198,6 @@ def test_run_query_metrics_only(demo_layer):
198198
metrics=["orders.total_revenue", "orders.order_count"],
199199
)
200200

201-
assert result.sql is not None
202-
assert "SUM" in result.sql.upper()
203-
assert "COUNT" in result.sql.upper()
201+
assert result["sql"] is not None
202+
assert "SUM" in result["sql"].upper()
203+
assert "COUNT" in result["sql"].upper()

0 commit comments

Comments
 (0)