Skip to content

Commit 991c4d3

Browse files
radimclaude
andcommitted
test: add unit tests for library and CLI
Co-Authored-By: Claude <noreply@anthropic.com>
1 parent 6f9df64 commit 991c4d3

5 files changed

Lines changed: 382 additions & 0 deletions

File tree

cluster_test.go

Lines changed: 137 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,137 @@
1+
package qshape
2+
3+
import "testing"
4+
5+
func TestGroupAggregatesCalls(t *testing.T) {
6+
in := []Query{
7+
{Raw: "SELECT id FROM users WHERE id = 1", Calls: 100},
8+
{Raw: "SELECT id FROM users WHERE id = 99", Calls: 200},
9+
}
10+
out, err := Group(in)
11+
if err != nil {
12+
t.Fatal(err)
13+
}
14+
if len(out) != 1 {
15+
t.Fatalf("expected 1 cluster, got %d", len(out))
16+
}
17+
if out[0].TotalCalls != 300 {
18+
t.Errorf("TotalCalls = %d, want 300", out[0].TotalCalls)
19+
}
20+
if len(out[0].Members) != 2 {
21+
t.Errorf("Members len = %d, want 2", len(out[0].Members))
22+
}
23+
}
24+
25+
func TestGroupAggregatesTiming(t *testing.T) {
26+
in := []Query{
27+
{Raw: "SELECT id FROM users WHERE id = 1", Calls: 100, TotalExecTimeMs: 250.0, Rows: 100},
28+
{Raw: "SELECT id FROM users WHERE id = 99", Calls: 400, TotalExecTimeMs: 750.0, Rows: 400},
29+
}
30+
out, err := Group(in)
31+
if err != nil {
32+
t.Fatal(err)
33+
}
34+
if len(out) != 1 {
35+
t.Fatalf("expected 1 cluster, got %d", len(out))
36+
}
37+
if out[0].TotalExecTimeMs != 1000.0 {
38+
t.Errorf("TotalExecTimeMs = %v, want 1000.0", out[0].TotalExecTimeMs)
39+
}
40+
if out[0].Rows != 500 {
41+
t.Errorf("Rows = %d, want 500", out[0].Rows)
42+
}
43+
wantMean := 1000.0 / 500.0
44+
if out[0].MeanExecTimeMs != wantMean {
45+
t.Errorf("MeanExecTimeMs = %v, want %v", out[0].MeanExecTimeMs, wantMean)
46+
}
47+
}
48+
49+
func TestGroupSortsByTimingWhenPresent(t *testing.T) {
50+
in := []Query{
51+
{Raw: "SELECT id FROM users", Calls: 1000, TotalExecTimeMs: 50.0},
52+
{Raw: "SELECT name FROM users", Calls: 10, TotalExecTimeMs: 5000.0},
53+
}
54+
out, err := Group(in)
55+
if err != nil {
56+
t.Fatal(err)
57+
}
58+
if len(out) != 2 {
59+
t.Fatalf("expected 2 clusters, got %d", len(out))
60+
}
61+
if out[0].TotalExecTimeMs < out[1].TotalExecTimeMs {
62+
t.Errorf("expected sort by TotalExecTimeMs desc, got %v then %v",
63+
out[0].TotalExecTimeMs, out[1].TotalExecTimeMs)
64+
}
65+
}
66+
67+
func TestGroupOrdering(t *testing.T) {
68+
in := []Query{
69+
{Raw: "SELECT name FROM users", Calls: 10},
70+
{Raw: "SELECT id FROM users", Calls: 500},
71+
}
72+
out, err := Group(in)
73+
if err != nil {
74+
t.Fatal(err)
75+
}
76+
if len(out) != 2 {
77+
t.Fatalf("expected 2 clusters, got %d", len(out))
78+
}
79+
if out[0].TotalCalls < out[1].TotalCalls {
80+
t.Errorf("clusters not sorted by TotalCalls desc: %d then %d",
81+
out[0].TotalCalls, out[1].TotalCalls)
82+
}
83+
}
84+
85+
// Documents MVP behavior: alias stripping is deferred, so ORM variants
86+
// that differ only in aliases currently produce more than one cluster.
87+
// When Phase 0.x lands alias-strip, tighten this to `== 1`.
88+
func TestGroupORMVariantsCurrentBehavior(t *testing.T) {
89+
in := []Query{
90+
{Raw: "SELECT id, name FROM users WHERE id = $1", Calls: 1},
91+
{Raw: "SELECT u.id, u.name FROM users u WHERE u.id = $1", Calls: 1},
92+
{Raw: "SELECT id, name FROM users WHERE id = $1 LIMIT $2", Calls: 1},
93+
}
94+
out, err := Group(in)
95+
if err != nil {
96+
t.Fatal(err)
97+
}
98+
if len(out) < 2 {
99+
t.Errorf("expected >= 2 clusters in MVP (alias strip deferred), got %d", len(out))
100+
}
101+
total := int64(0)
102+
for _, c := range out {
103+
total += c.TotalCalls
104+
}
105+
if total != 3 {
106+
t.Errorf("total calls across clusters = %d, want 3", total)
107+
}
108+
}
109+
110+
func TestGroupUnparseable(t *testing.T) {
111+
in := []Query{
112+
{Raw: "SELECT FROM WHERE", Calls: 5},
113+
}
114+
out, err := Group(in)
115+
if err != nil {
116+
t.Fatal(err)
117+
}
118+
if len(out) != 1 {
119+
t.Fatalf("expected 1 cluster, got %d", len(out))
120+
}
121+
if out[0].Fingerprint != "" {
122+
t.Errorf("unparseable cluster should have empty fingerprint, got %q", out[0].Fingerprint)
123+
}
124+
if out[0].Canonical != "SELECT FROM WHERE" {
125+
t.Errorf("unparseable Canonical should be raw, got %q", out[0].Canonical)
126+
}
127+
}
128+
129+
func TestGroupEmpty(t *testing.T) {
130+
out, err := Group(nil)
131+
if err != nil {
132+
t.Fatal(err)
133+
}
134+
if len(out) != 0 {
135+
t.Errorf("expected empty slice, got %d clusters", len(out))
136+
}
137+
}

cmd/qshape/attribute_test.go

Lines changed: 76 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,76 @@
1+
package main
2+
3+
import (
4+
"testing"
5+
6+
"github.com/boringsql/qshape"
7+
)
8+
9+
func TestAttributeCondAliasedEqual(t *testing.T) {
10+
ctx := &attrCtx{byPosition: map[int]*qshape.ParamAttribution{}}
11+
aliases := map[string]tableRef{
12+
"u": {Schema: "auth", Table: "user_account"},
13+
}
14+
attributeCond("(u.user_id = $1)", aliases, "auth", "user_account", ctx)
15+
16+
a, ok := ctx.byPosition[1]
17+
if !ok {
18+
t.Fatal("expected param 1 attributed")
19+
}
20+
if a.Table != "user_account" || a.Column != "user_id" || a.Schema != "auth" {
21+
t.Errorf("wrong attribution: %+v", a)
22+
}
23+
if a.Confidence != "exact" {
24+
t.Errorf("expected exact confidence, got %s", a.Confidence)
25+
}
26+
}
27+
28+
func TestAttributeCondUnaliasedFallback(t *testing.T) {
29+
ctx := &attrCtx{byPosition: map[int]*qshape.ParamAttribution{}}
30+
aliases := map[string]tableRef{}
31+
attributeCond("(id = $1)", aliases, "auth", "session", ctx)
32+
33+
a, ok := ctx.byPosition[1]
34+
if !ok {
35+
t.Fatal("expected param 1 attributed")
36+
}
37+
if a.Table != "session" || a.Column != "id" {
38+
t.Errorf("wrong fallback attribution: %+v", a)
39+
}
40+
if a.Confidence != "heuristic" {
41+
t.Errorf("expected heuristic confidence, got %s", a.Confidence)
42+
}
43+
}
44+
45+
func TestAttributeCondMultipleParams(t *testing.T) {
46+
ctx := &attrCtx{byPosition: map[int]*qshape.ParamAttribution{}}
47+
aliases := map[string]tableRef{
48+
"t": {Schema: "auth", Table: "oauth_token"},
49+
}
50+
attributeCond("((t.access_sha = $2) AND (t.access_hash = hashtext($1)))", aliases, "auth", "oauth_token", ctx)
51+
52+
if a, ok := ctx.byPosition[2]; !ok || a.Column != "access_sha" {
53+
t.Errorf("param 2 wrong: %+v ok=%v", a, ok)
54+
}
55+
// $1 is wrapped in hashtext(...) — no direct column comparison, so we
56+
// don't attribute it. That's fine: unattributed, not incorrect.
57+
if _, ok := ctx.byPosition[1]; ok {
58+
t.Logf("note: $1 got attributed even though wrapped in function — acceptable but brittle")
59+
}
60+
}
61+
62+
func TestAttributeCondPreservesExactOverHeuristic(t *testing.T) {
63+
ctx := &attrCtx{byPosition: map[int]*qshape.ParamAttribution{}}
64+
aliases := map[string]tableRef{
65+
"u": {Schema: "auth", Table: "user_account"},
66+
}
67+
// First hit: exact
68+
attributeCond("(u.user_id = $1)", aliases, "", "", ctx)
69+
// Second hit that would be heuristic on a different relation
70+
attributeCond("(user_id = $1)", map[string]tableRef{}, "public", "other_table", ctx)
71+
72+
a := ctx.byPosition[1]
73+
if a.Table != "user_account" || a.Confidence != "exact" {
74+
t.Errorf("exact attribution overwritten by heuristic: %+v", a)
75+
}
76+
}

cmd/qshape/regresql_stub_test.go

Lines changed: 71 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,71 @@
1+
package main
2+
3+
import (
4+
"encoding/json"
5+
"testing"
6+
7+
"github.com/boringsql/qshape"
8+
)
9+
10+
func TestRewriteParams(t *testing.T) {
11+
got, names := rewriteParams("SELECT id FROM users WHERE id = $1 AND tenant_id = $2 AND id = $1")
12+
want := "SELECT id FROM users WHERE id = :param1 AND tenant_id = :param2 AND id = :param1"
13+
if got != want {
14+
t.Errorf("rewrite mismatch:\n got: %s\nwant: %s", got, want)
15+
}
16+
if len(names) != 2 || names[0] != "param1" || names[1] != "param2" {
17+
t.Errorf("unexpected names: %v", names)
18+
}
19+
}
20+
21+
func TestSampleValuesForParams(t *testing.T) {
22+
fixJSON := `{
23+
"tables": {
24+
"auth.user_account": {
25+
"columns": ["user_id", "email"],
26+
"rows": [[42, "a@b.co"], [99, "x@y.co"], [null, "z@z.co"]]
27+
}
28+
}
29+
}`
30+
var fix fixtureDoc
31+
if err := json.Unmarshal([]byte(fixJSON), &fix); err != nil {
32+
t.Fatal(err)
33+
}
34+
attrs := []qshape.ParamAttribution{
35+
{Position: 1, Schema: "auth", Table: "user_account", Column: "user_id", Confidence: "exact"},
36+
}
37+
out := sampleValuesForParams([]string{"param1"}, attrs, &fix, 3)
38+
if len(out["param1"]) != 2 {
39+
t.Errorf("expected 2 non-null values, got %v", out["param1"])
40+
}
41+
if out["param1"][0].(float64) != 42 {
42+
t.Errorf("first value = %v, want 42", out["param1"][0])
43+
}
44+
}
45+
46+
func TestSampleValuesSkipsUnattributed(t *testing.T) {
47+
fix := &fixtureDoc{}
48+
attrs := []qshape.ParamAttribution{
49+
{Position: 1, Confidence: "none"}, // no table/column
50+
}
51+
out := sampleValuesForParams([]string{"param1"}, attrs, fix, 2)
52+
if len(out) != 0 {
53+
t.Errorf("expected no samples, got %+v", out)
54+
}
55+
}
56+
57+
func TestYAMLScalarStringEscaping(t *testing.T) {
58+
cases := map[any]string{
59+
"hello": `"hello"`,
60+
`a"b`: `"a\"b"`,
61+
int64(5): `5`,
62+
nil: `~`,
63+
true: `true`,
64+
3.14: `3.14`,
65+
}
66+
for in, want := range cases {
67+
if got := yamlScalar(in); got != want {
68+
t.Errorf("yamlScalar(%v) = %q, want %q", in, got, want)
69+
}
70+
}
71+
}

fingerprint_test.go

Lines changed: 53 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,53 @@
1+
package qshape
2+
3+
import (
4+
"strings"
5+
"testing"
6+
)
7+
8+
func TestFingerprintStable(t *testing.T) {
9+
fp1, err := Fingerprint("SELECT id FROM users WHERE id = 1")
10+
if err != nil {
11+
t.Fatal(err)
12+
}
13+
fp2, err := Fingerprint("SELECT id FROM users WHERE id = 1")
14+
if err != nil {
15+
t.Fatal(err)
16+
}
17+
if fp1 != fp2 {
18+
t.Errorf("fingerprint not stable: %s vs %s", fp1, fp2)
19+
}
20+
if !strings.HasPrefix(fp1, "sha1:") {
21+
t.Errorf("expected sha1: prefix, got %q", fp1)
22+
}
23+
}
24+
25+
func TestFingerprintIgnoresWhitespace(t *testing.T) {
26+
a, _ := Fingerprint("SELECT id FROM users WHERE id = 1")
27+
b, _ := Fingerprint("SELECT id FROM users WHERE id=1")
28+
if a != b {
29+
t.Errorf("whitespace variation changed fingerprint: %s vs %s", a, b)
30+
}
31+
}
32+
33+
func TestFingerprintIgnoresLiterals(t *testing.T) {
34+
a, _ := Fingerprint("SELECT id FROM users WHERE id = 1")
35+
b, _ := Fingerprint("SELECT id FROM users WHERE id = 42")
36+
if a != b {
37+
t.Errorf("literal variation changed fingerprint: %s vs %s", a, b)
38+
}
39+
}
40+
41+
func TestFingerprintDistinguishesColumns(t *testing.T) {
42+
a, _ := Fingerprint("SELECT id FROM users WHERE id = 1")
43+
b, _ := Fingerprint("SELECT id FROM users WHERE name = 'x'")
44+
if a == b {
45+
t.Errorf("different predicates should fingerprint differently, both %s", a)
46+
}
47+
}
48+
49+
func TestFingerprintInvalid(t *testing.T) {
50+
if _, err := Fingerprint("SELECT FROM WHERE"); err == nil {
51+
t.Error("expected error on invalid SQL")
52+
}
53+
}

normalize_test.go

Lines changed: 45 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,45 @@
1+
package qshape
2+
3+
import (
4+
"strings"
5+
"testing"
6+
)
7+
8+
func TestNormalize(t *testing.T) {
9+
cases := []struct {
10+
name string
11+
in string
12+
}{
13+
{"plain select", "SELECT id FROM users"},
14+
{"select with AS", "SELECT id AS x FROM users"},
15+
{"aliased table", "SELECT u.id FROM users u WHERE u.id = $1"},
16+
{"where with literal", "SELECT id FROM users WHERE id = 42"},
17+
{"multi-column", "SELECT id, name, email FROM users WHERE id = $1"},
18+
{"insert", "INSERT INTO users (id, name) VALUES ($1, $2)"},
19+
{"update", "UPDATE users SET name = $1 WHERE id = $2"},
20+
}
21+
for _, tc := range cases {
22+
t.Run(tc.name, func(t *testing.T) {
23+
got, err := Normalize(tc.in)
24+
if err != nil {
25+
t.Fatalf("Normalize error: %v", err)
26+
}
27+
if strings.TrimSpace(got) == "" {
28+
t.Fatalf("Normalize returned empty output for %q", tc.in)
29+
}
30+
again, err := Normalize(got)
31+
if err != nil {
32+
t.Fatalf("idempotence Normalize error: %v", err)
33+
}
34+
if again != got {
35+
t.Errorf("not idempotent:\n first: %q\n second: %q", got, again)
36+
}
37+
})
38+
}
39+
}
40+
41+
func TestNormalizeInvalid(t *testing.T) {
42+
if _, err := Normalize("SELECT FROM WHERE"); err == nil {
43+
t.Error("expected error on invalid SQL")
44+
}
45+
}

0 commit comments

Comments
 (0)