Skip to content

Commit bda8c17

Browse files
committed
Add typed block signatures and richer type mismatch displays
1 parent 98170af commit bda8c17

7 files changed

Lines changed: 347 additions & 21 deletions

File tree

vibes/ast.go

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -277,7 +277,7 @@ func (e *RangeExpr) exprNode() {}
277277
func (e *RangeExpr) Pos() Position { return e.position }
278278

279279
type BlockLiteral struct {
280-
Params []string
280+
Params []Param
281281
Body []Statement
282282
position Position
283283
}

vibes/execution.go

Lines changed: 129 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -6,6 +6,7 @@ import (
66
"fmt"
77
"maps"
88
"math"
9+
"reflect"
910
"regexp"
1011
"slices"
1112
"sort"
@@ -1004,7 +1005,12 @@ func (exec *Execution) CallBlock(block Value, args []Value) (Value, error) {
10041005
} else {
10051006
val = NewNil()
10061007
}
1007-
blockEnv.Define(param, val)
1008+
if param.Type != nil {
1009+
if err := checkValueType(val, param.Type); err != nil {
1010+
return NewNil(), exec.errorAt(param.Type.position, "%s", formatArgumentTypeMismatch(param.Name, err))
1011+
}
1012+
}
1013+
blockEnv.Define(param.Name, val)
10081014
}
10091015
val, returned, err := exec.evalStatements(blk.Body, blockEnv)
10101016
if err != nil {
@@ -3282,7 +3288,7 @@ func checkValueType(val Value, ty *TypeExpr) error {
32823288
}
32833289
return &typeMismatchError{
32843290
Expected: formatTypeExpr(ty),
3285-
Actual: val.Kind().String(),
3291+
Actual: formatValueTypeExpr(val),
32863292
}
32873293
}
32883294

@@ -3512,6 +3518,127 @@ func formatShapeType(ty *TypeExpr) string {
35123518
return "{ " + strings.Join(parts, ", ") + " }"
35133519
}
35143520

3521+
func formatValueTypeExpr(val Value) string {
3522+
state := valueTypeFormatState{
3523+
seenArrays: make(map[uintptr]struct{}),
3524+
seenHashes: make(map[uintptr]struct{}),
3525+
}
3526+
return state.format(val)
3527+
}
3528+
3529+
type valueTypeFormatState struct {
3530+
seenArrays map[uintptr]struct{}
3531+
seenHashes map[uintptr]struct{}
3532+
}
3533+
3534+
func (s *valueTypeFormatState) format(val Value) string {
3535+
switch val.Kind() {
3536+
case KindNil:
3537+
return "nil"
3538+
case KindBool:
3539+
return "bool"
3540+
case KindInt:
3541+
return "int"
3542+
case KindFloat:
3543+
return "float"
3544+
case KindString:
3545+
return "string"
3546+
case KindMoney:
3547+
return "money"
3548+
case KindDuration:
3549+
return "duration"
3550+
case KindTime:
3551+
return "time"
3552+
case KindSymbol:
3553+
return "symbol"
3554+
case KindRange:
3555+
return "range"
3556+
case KindFunction:
3557+
return "function"
3558+
case KindBuiltin:
3559+
return "builtin"
3560+
case KindBlock:
3561+
return "block"
3562+
case KindClass:
3563+
return "class"
3564+
case KindInstance:
3565+
return "instance"
3566+
case KindArray:
3567+
return s.formatArray(val.Array())
3568+
case KindHash, KindObject:
3569+
return s.formatHash(val.Hash())
3570+
default:
3571+
return val.Kind().String()
3572+
}
3573+
}
3574+
3575+
func (s *valueTypeFormatState) formatArray(values []Value) string {
3576+
if len(values) == 0 {
3577+
return "array<empty>"
3578+
}
3579+
3580+
id := reflect.ValueOf(values).Pointer()
3581+
if id != 0 {
3582+
if _, seen := s.seenArrays[id]; seen {
3583+
return "array<...>"
3584+
}
3585+
s.seenArrays[id] = struct{}{}
3586+
defer delete(s.seenArrays, id)
3587+
}
3588+
3589+
elementTypes := make(map[string]struct{}, len(values))
3590+
for _, value := range values {
3591+
elementTypes[s.format(value)] = struct{}{}
3592+
}
3593+
return "array<" + joinSortedTypes(elementTypes) + ">"
3594+
}
3595+
3596+
func (s *valueTypeFormatState) formatHash(values map[string]Value) string {
3597+
if len(values) == 0 {
3598+
return "{}"
3599+
}
3600+
3601+
id := reflect.ValueOf(values).Pointer()
3602+
if id != 0 {
3603+
if _, seen := s.seenHashes[id]; seen {
3604+
return "{ ... }"
3605+
}
3606+
s.seenHashes[id] = struct{}{}
3607+
defer delete(s.seenHashes, id)
3608+
}
3609+
3610+
if len(values) <= 6 {
3611+
fields := make([]string, 0, len(values))
3612+
for field := range values {
3613+
fields = append(fields, field)
3614+
}
3615+
sort.Strings(fields)
3616+
parts := make([]string, len(fields))
3617+
for i, field := range fields {
3618+
parts[i] = fmt.Sprintf("%s: %s", field, s.format(values[field]))
3619+
}
3620+
return "{ " + strings.Join(parts, ", ") + " }"
3621+
}
3622+
3623+
valueTypes := make(map[string]struct{}, len(values))
3624+
for _, value := range values {
3625+
valueTypes[s.format(value)] = struct{}{}
3626+
}
3627+
return "hash<string, " + joinSortedTypes(valueTypes) + ">"
3628+
}
3629+
3630+
func joinSortedTypes(typeSet map[string]struct{}) string {
3631+
if len(typeSet) == 0 {
3632+
return "empty"
3633+
}
3634+
parts := make([]string, 0, len(typeSet))
3635+
for typeName := range typeSet {
3636+
parts = append(parts, typeName)
3637+
}
3638+
sort.Strings(parts)
3639+
return strings.Join(parts, " | ")
3640+
}
3641+
35153642
// Function looks up a compiled function by name.
35163643
func (s *Script) Function(name string) (*ScriptFunction, bool) {
35173644
fn, ok := s.functions[name]

vibes/memory.go

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -163,7 +163,7 @@ func (est *memoryEstimator) value(val Value) int {
163163
est.seenBlocks[blk] = struct{}{}
164164
size += estimatedBlockBytes + estimatedSliceBaseBytes + len(blk.Params)*estimatedStringHeaderBytes
165165
for _, param := range blk.Params {
166-
size += len(param)
166+
size += len(param.Name)
167167
}
168168
size += estimatedStringHeaderBytes*3 + len(blk.moduleKey) + len(blk.modulePath) + len(blk.moduleRoot)
169169
size += est.env(blk.Env)

vibes/parser.go

Lines changed: 81 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -781,7 +781,7 @@ func (p *parser) parseCallArgument(args *[]Expression, kwargs *[]KeywordArg) {
781781

782782
func (p *parser) parseBlockLiteral() *BlockLiteral {
783783
pos := p.curToken.Pos
784-
params := []string{}
784+
params := []Param{}
785785

786786
p.nextToken()
787787
if p.curToken.Type == tokenPipe {
@@ -801,18 +801,18 @@ func (p *parser) parseBlockLiteral() *BlockLiteral {
801801
return &BlockLiteral{Params: params, Body: body, position: pos}
802802
}
803803

804-
func (p *parser) parseBlockParameters() ([]string, bool) {
805-
params := []string{}
804+
func (p *parser) parseBlockParameters() ([]Param, bool) {
805+
params := []Param{}
806806
p.nextToken()
807807
if p.curToken.Type == tokenPipe {
808808
return params, true
809809
}
810810

811-
if p.curToken.Type != tokenIdent {
812-
p.errorExpected(p.curToken, "block parameter")
811+
param, ok := p.parseBlockParameter()
812+
if !ok {
813813
return nil, false
814814
}
815-
params = append(params, p.curToken.Literal)
815+
params = append(params, param)
816816

817817
for p.peekToken.Type == tokenComma {
818818
p.nextToken()
@@ -821,11 +821,11 @@ func (p *parser) parseBlockParameters() ([]string, bool) {
821821
p.addParseError(p.curToken.Pos, "trailing comma in block parameter list")
822822
return nil, false
823823
}
824-
if p.curToken.Type != tokenIdent {
825-
p.errorExpected(p.curToken, "block parameter")
824+
param, ok := p.parseBlockParameter()
825+
if !ok {
826826
return nil, false
827827
}
828-
params = append(params, p.curToken.Literal)
828+
params = append(params, param)
829829
}
830830

831831
if !p.expectPeek(tokenPipe) {
@@ -835,6 +835,78 @@ func (p *parser) parseBlockParameters() ([]string, bool) {
835835
return params, true
836836
}
837837

838+
func (p *parser) parseBlockParameter() (Param, bool) {
839+
if p.curToken.Type != tokenIdent {
840+
p.errorExpected(p.curToken, "block parameter")
841+
return Param{}, false
842+
}
843+
param := Param{Name: p.curToken.Literal}
844+
if p.peekToken.Type == tokenColon {
845+
p.nextToken()
846+
p.nextToken()
847+
param.Type = p.parseBlockParamType()
848+
if param.Type == nil {
849+
return Param{}, false
850+
}
851+
}
852+
return param, true
853+
}
854+
855+
func (p *parser) parseBlockParamType() *TypeExpr {
856+
first := p.parseTypeAtom()
857+
if first == nil {
858+
return nil
859+
}
860+
861+
union := []*TypeExpr{first}
862+
for p.peekToken.Type == tokenPipe && p.blockParamUnionContinues() {
863+
p.nextToken()
864+
p.nextToken()
865+
next := p.parseTypeAtom()
866+
if next == nil {
867+
return nil
868+
}
869+
union = append(union, next)
870+
}
871+
872+
if len(union) == 1 {
873+
return first
874+
}
875+
876+
names := make([]string, len(union))
877+
for i, option := range union {
878+
names[i] = formatTypeExpr(option)
879+
}
880+
return &TypeExpr{
881+
Name: strings.Join(names, " | "),
882+
Kind: TypeUnion,
883+
Union: union,
884+
position: first.position,
885+
}
886+
}
887+
888+
func (p *parser) blockParamUnionContinues() bool {
889+
if p.peekToken.Type != tokenPipe {
890+
return false
891+
}
892+
893+
savedLexer := *p.l
894+
savedCur := p.curToken
895+
savedPeek := p.peekToken
896+
savedErrors := len(p.errors)
897+
898+
p.nextToken()
899+
p.nextToken()
900+
atom := p.parseTypeAtom()
901+
ok := atom != nil && (p.peekToken.Type == tokenComma || p.peekToken.Type == tokenPipe)
902+
903+
p.l = &savedLexer
904+
p.curToken = savedCur
905+
p.peekToken = savedPeek
906+
p.errors = p.errors[:savedErrors]
907+
return ok
908+
}
909+
838910
func (p *parser) parseMemberExpression(object Expression) Expression {
839911
if object == nil {
840912
return nil

vibes/parser_types_test.go

Lines changed: 57 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -81,3 +81,60 @@ end`
8181
t.Fatalf("unexpected parse error: %s", got)
8282
}
8383
}
84+
85+
func TestParserTypeSyntaxTypedBlockParameters(t *testing.T) {
86+
source := `def run(values)
87+
values.map do |value: int | string, label: string?|
88+
label
89+
end
90+
end`
91+
92+
p := newParser(source)
93+
program, errs := p.ParseProgram()
94+
if len(errs) > 0 {
95+
t.Fatalf("expected no parse errors, got %v", errs)
96+
}
97+
if len(program.Statements) != 1 {
98+
t.Fatalf("expected 1 statement, got %d", len(program.Statements))
99+
}
100+
fn, ok := program.Statements[0].(*FunctionStmt)
101+
if !ok {
102+
t.Fatalf("expected function statement, got %T", program.Statements[0])
103+
}
104+
if len(fn.Body) != 1 {
105+
t.Fatalf("expected 1 body statement, got %d", len(fn.Body))
106+
}
107+
exprStmt, ok := fn.Body[0].(*ExprStmt)
108+
if !ok {
109+
t.Fatalf("expected expression statement, got %T", fn.Body[0])
110+
}
111+
call, ok := exprStmt.Expr.(*CallExpr)
112+
if !ok {
113+
t.Fatalf("expected call expression, got %T", exprStmt.Expr)
114+
}
115+
if call.Block == nil {
116+
t.Fatalf("expected call block")
117+
}
118+
if len(call.Block.Params) != 2 {
119+
t.Fatalf("expected 2 block params, got %d", len(call.Block.Params))
120+
}
121+
122+
first := call.Block.Params[0]
123+
if first.Name != "value" {
124+
t.Fatalf("expected first param name value, got %q", first.Name)
125+
}
126+
if first.Type == nil || first.Type.Kind != TypeUnion || len(first.Type.Union) != 2 {
127+
t.Fatalf("expected first param union type, got %#v", first.Type)
128+
}
129+
if first.Type.Union[0].Kind != TypeInt || first.Type.Union[1].Kind != TypeString {
130+
t.Fatalf("expected union int|string, got %#v", first.Type.Union)
131+
}
132+
133+
second := call.Block.Params[1]
134+
if second.Name != "label" {
135+
t.Fatalf("expected second param name label, got %q", second.Name)
136+
}
137+
if second.Type == nil || second.Type.Kind != TypeString || !second.Type.Nullable {
138+
t.Fatalf("expected nullable string type, got %#v", second.Type)
139+
}
140+
}

0 commit comments

Comments
 (0)