Skip to content

Commit 8daf6a6

Browse files
committed
Add typed rescue error matching
1 parent 0125440 commit 8daf6a6

7 files changed

Lines changed: 229 additions & 9 deletions

File tree

ROADMAP.md

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -218,7 +218,7 @@ Goal: improve language ergonomics for complex script logic and recovery behavior
218218
### Error Handling Constructs
219219

220220
- [x] Add structured error handling syntax (`begin/rescue/ensure` or equivalent).
221-
- [ ] Add typed error matching where feasible.
221+
- [x] Add typed error matching where feasible.
222222
- [ ] Define re-raise semantics and stack preservation.
223223
- [x] Ensure runtime errors preserve original position and call frames.
224224

docs/errors.md

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -70,7 +70,7 @@ Use `begin` with `rescue` and/or `ensure` for script-level recovery:
7070
def safe_div(a, b)
7171
begin
7272
a / b
73-
rescue
73+
rescue(RuntimeError)
7474
"fallback"
7575
ensure
7676
audit("safe_div attempted")
@@ -81,9 +81,11 @@ end
8181
Semantics:
8282

8383
- `rescue` runs only when the `begin` body raises an error.
84+
- `rescue` supports optional typed matching via `rescue(<Type>)`.
85+
- `rescue` supports `AssertionError`, `RuntimeError`, and unions such as `rescue(AssertionError | RuntimeError)`.
8486
- `ensure` always runs (success, rescue path, or failure path).
8587
- Without `rescue`, original runtime errors still propagate after `ensure` executes.
86-
- `rescue` currently catches runtime failures broadly (typed matching and re-raise semantics are separate work).
88+
- Unmatched typed rescues do not swallow the original error.
8789

8890
## REPL Debugging
8991

vibes/ast.go

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -156,6 +156,7 @@ func (s *NextStmt) Pos() Position { return s.position }
156156

157157
type TryStmt struct {
158158
Body []Statement
159+
RescueTy *TypeExpr
159160
Rescue []Statement
160161
Ensure []Statement
161162
position Position

vibes/builtins.go

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,6 @@
11
package vibes
22

33
import (
4-
"errors"
54
"fmt"
65
"time"
76
)
@@ -20,7 +19,7 @@ func builtinAssert(exec *Execution, receiver Value, args []Value, kwargs map[str
2019
} else if msg, ok := kwargs["message"]; ok {
2120
message = msg.String()
2221
}
23-
return NewNil(), errors.New(message)
22+
return NewNil(), newAssertionFailureError(message)
2423
}
2524

2625
func builtinMoney(exec *Execution, receiver Value, args []Value, kwargs map[string]Value, block Value) (Value, error) {

vibes/execution.go

Lines changed: 88 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -89,11 +89,25 @@ type StackFrame struct {
8989
}
9090

9191
type RuntimeError struct {
92+
Type string
9293
Message string
9394
CodeFrame string
9495
Frames []StackFrame
9596
}
9697

98+
type assertionFailureError struct {
99+
message string
100+
}
101+
102+
func (e *assertionFailureError) Error() string {
103+
return e.message
104+
}
105+
106+
const (
107+
runtimeErrorTypeBase = "RuntimeError"
108+
runtimeErrorTypeAssertion = "AssertionError"
109+
)
110+
97111
var (
98112
errLoopBreak = errors.New("loop break")
99113
errLoopNext = errors.New("loop next")
@@ -127,6 +141,37 @@ func (re *RuntimeError) Unwrap() error {
127141
return nil
128142
}
129143

144+
func canonicalRuntimeErrorType(name string) (string, bool) {
145+
switch {
146+
case strings.EqualFold(name, runtimeErrorTypeBase), strings.EqualFold(name, "Error"):
147+
return runtimeErrorTypeBase, true
148+
case strings.EqualFold(name, runtimeErrorTypeAssertion):
149+
return runtimeErrorTypeAssertion, true
150+
default:
151+
return "", false
152+
}
153+
}
154+
155+
func classifyRuntimeErrorType(err error) string {
156+
if err == nil {
157+
return runtimeErrorTypeBase
158+
}
159+
var assertionErr *assertionFailureError
160+
if errors.As(err, &assertionErr) {
161+
return runtimeErrorTypeAssertion
162+
}
163+
if runtimeErr, ok := err.(*RuntimeError); ok {
164+
if kind, known := canonicalRuntimeErrorType(runtimeErr.Type); known {
165+
return kind
166+
}
167+
}
168+
return runtimeErrorTypeBase
169+
}
170+
171+
func newAssertionFailureError(message string) error {
172+
return &assertionFailureError{message: message}
173+
}
174+
130175
func (exec *Execution) step() error {
131176
exec.steps++
132177
if exec.quota > 0 && exec.steps > exec.quota {
@@ -152,6 +197,16 @@ func (exec *Execution) errorAt(pos Position, format string, args ...any) error {
152197
}
153198

154199
func (exec *Execution) newRuntimeError(message string, pos Position) error {
200+
return exec.newRuntimeErrorWithType(runtimeErrorTypeBase, message, pos)
201+
}
202+
203+
func (exec *Execution) newRuntimeErrorWithType(kind string, message string, pos Position) error {
204+
if canonical, ok := canonicalRuntimeErrorType(kind); ok {
205+
kind = canonical
206+
} else {
207+
kind = runtimeErrorTypeBase
208+
}
209+
155210
frames := make([]StackFrame, 0, len(exec.callStack)+1)
156211

157212
if len(exec.callStack) > 0 {
@@ -172,7 +227,7 @@ func (exec *Execution) newRuntimeError(message string, pos Position) error {
172227
if exec.script != nil {
173228
codeFrame = formatCodeFrame(exec.script.source, pos)
174229
}
175-
return &RuntimeError{Message: message, CodeFrame: codeFrame, Frames: frames}
230+
return &RuntimeError{Type: kind, Message: message, CodeFrame: codeFrame, Frames: frames}
176231
}
177232

178233
func (exec *Execution) wrapError(err error, pos Position) error {
@@ -182,7 +237,7 @@ func (exec *Execution) wrapError(err error, pos Position) error {
182237
if _, ok := err.(*RuntimeError); ok {
183238
return err
184239
}
185-
return exec.newRuntimeError(err.Error(), pos)
240+
return exec.newRuntimeErrorWithType(classifyRuntimeErrorType(err), err.Error(), pos)
186241
}
187242

188243
func (exec *Execution) pushReceiver(v Value) {
@@ -1317,7 +1372,7 @@ func (exec *Execution) evalUntilStatement(stmt *UntilStmt, env *Env) (Value, boo
13171372
func (exec *Execution) evalTryStatement(stmt *TryStmt, env *Env) (Value, bool, error) {
13181373
val, returned, err := exec.evalStatements(stmt.Body, env)
13191374

1320-
if err != nil && len(stmt.Rescue) > 0 {
1375+
if err != nil && len(stmt.Rescue) > 0 && runtimeErrorMatchesRescueType(err, stmt.RescueTy) {
13211376
rescueVal, rescueReturned, rescueErr := exec.evalStatements(stmt.Rescue, env)
13221377
if rescueErr != nil {
13231378
val = NewNil()
@@ -1346,6 +1401,36 @@ func (exec *Execution) evalTryStatement(stmt *TryStmt, env *Env) (Value, bool, e
13461401
return val, returned, nil
13471402
}
13481403

1404+
func runtimeErrorMatchesRescueType(err error, rescueTy *TypeExpr) bool {
1405+
if rescueTy == nil {
1406+
return true
1407+
}
1408+
errKind := classifyRuntimeErrorType(err)
1409+
return rescueTypeMatchesErrorKind(rescueTy, errKind)
1410+
}
1411+
1412+
func rescueTypeMatchesErrorKind(ty *TypeExpr, errKind string) bool {
1413+
if ty == nil {
1414+
return false
1415+
}
1416+
if ty.Kind == TypeUnion {
1417+
for _, option := range ty.Union {
1418+
if rescueTypeMatchesErrorKind(option, errKind) {
1419+
return true
1420+
}
1421+
}
1422+
return false
1423+
}
1424+
canonical, ok := canonicalRuntimeErrorType(ty.Name)
1425+
if !ok {
1426+
return false
1427+
}
1428+
if canonical == runtimeErrorTypeBase {
1429+
return true
1430+
}
1431+
return canonical == errKind
1432+
}
1433+
13491434
func (exec *Execution) getMember(obj Value, property string, pos Position) (Value, error) {
13501435
switch obj.Kind() {
13511436
case KindHash, KindObject:

vibes/parser.go

Lines changed: 44 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -440,8 +440,24 @@ func (p *parser) parseBeginStatement() Statement {
440440
p.nextToken()
441441
body := p.parseBlock(tokenRescue, tokenEnsure, tokenEnd)
442442

443+
var rescueTy *TypeExpr
443444
var rescueBody []Statement
444445
if p.curToken.Type == tokenRescue {
446+
rescuePos := p.curToken.Pos
447+
if p.peekToken.Type == tokenLParen && p.peekToken.Pos.Line == rescuePos.Line {
448+
p.nextToken()
449+
p.nextToken()
450+
rescueTy = p.parseTypeExpr()
451+
if rescueTy == nil {
452+
return nil
453+
}
454+
if !p.validateRescueTypeExpr(rescueTy, rescuePos) {
455+
return nil
456+
}
457+
if !p.expectPeek(tokenRParen) {
458+
return nil
459+
}
460+
}
445461
p.nextToken()
446462
rescueBody = p.parseBlock(tokenEnsure, tokenEnd)
447463
}
@@ -462,7 +478,34 @@ func (p *parser) parseBeginStatement() Statement {
462478
return nil
463479
}
464480

465-
return &TryStmt{Body: body, Rescue: rescueBody, Ensure: ensureBody, position: pos}
481+
return &TryStmt{Body: body, RescueTy: rescueTy, Rescue: rescueBody, Ensure: ensureBody, position: pos}
482+
}
483+
484+
func (p *parser) validateRescueTypeExpr(ty *TypeExpr, pos Position) bool {
485+
if ty == nil {
486+
p.addParseError(pos, "rescue type cannot be empty")
487+
return false
488+
}
489+
490+
if ty.Kind == TypeUnion {
491+
ok := true
492+
for _, option := range ty.Union {
493+
if !p.validateRescueTypeExpr(option, option.position) {
494+
ok = false
495+
}
496+
}
497+
return ok
498+
}
499+
500+
if len(ty.TypeArgs) > 0 || len(ty.Shape) > 0 {
501+
p.addParseError(pos, fmt.Sprintf("rescue type must be an error class, got %s", formatTypeExpr(ty)))
502+
return false
503+
}
504+
if _, ok := canonicalRuntimeErrorType(ty.Name); !ok {
505+
p.addParseError(pos, fmt.Sprintf("unknown rescue error type %s", ty.Name))
506+
return false
507+
}
508+
return true
466509
}
467510

468511
func (p *parser) parseBlock(stop ...TokenType) []Statement {

vibes/runtime_test.go

Lines changed: 90 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1014,6 +1014,96 @@ func TestBeginRescueEnsure(t *testing.T) {
10141014
}
10151015
}
10161016

1017+
func TestBeginRescueTypedMatching(t *testing.T) {
1018+
script := compileScript(t, `
1019+
def typed_assertion()
1020+
begin
1021+
assert false, "boom"
1022+
rescue(AssertionError)
1023+
"assertion"
1024+
end
1025+
end
1026+
1027+
def typed_runtime()
1028+
begin
1029+
assert false, "boom"
1030+
rescue(RuntimeError)
1031+
"runtime"
1032+
end
1033+
end
1034+
1035+
def typed_union()
1036+
begin
1037+
assert false, "boom"
1038+
rescue(AssertionError | RuntimeError)
1039+
"union"
1040+
end
1041+
end
1042+
1043+
def rescue_mismatch()
1044+
begin
1045+
1 / 0
1046+
rescue(AssertionError)
1047+
"nope"
1048+
end
1049+
end
1050+
1051+
def assertion_passthrough()
1052+
assert false, "raw"
1053+
end
1054+
`)
1055+
1056+
if got := callFunc(t, script, "typed_assertion", nil); !got.Equal(NewString("assertion")) {
1057+
t.Fatalf("typed_assertion mismatch: %v", got)
1058+
}
1059+
if got := callFunc(t, script, "typed_runtime", nil); !got.Equal(NewString("runtime")) {
1060+
t.Fatalf("typed_runtime mismatch: %v", got)
1061+
}
1062+
if got := callFunc(t, script, "typed_union", nil); !got.Equal(NewString("union")) {
1063+
t.Fatalf("typed_union mismatch: %v", got)
1064+
}
1065+
1066+
_, err := script.Call(context.Background(), "rescue_mismatch", nil, CallOptions{})
1067+
if err == nil || !strings.Contains(err.Error(), "division by zero") {
1068+
t.Fatalf("expected typed rescue mismatch to preserve original error, got %v", err)
1069+
}
1070+
var divideErr *RuntimeError
1071+
if !errors.As(err, &divideErr) {
1072+
t.Fatalf("expected RuntimeError, got %T", err)
1073+
}
1074+
if divideErr.Type != runtimeErrorTypeBase {
1075+
t.Fatalf("expected runtime error type %s, got %s", runtimeErrorTypeBase, divideErr.Type)
1076+
}
1077+
1078+
_, err = script.Call(context.Background(), "assertion_passthrough", nil, CallOptions{})
1079+
if err == nil || !strings.Contains(err.Error(), "raw") {
1080+
t.Fatalf("expected assertion passthrough error, got %v", err)
1081+
}
1082+
var assertionErr *RuntimeError
1083+
if !errors.As(err, &assertionErr) {
1084+
t.Fatalf("expected RuntimeError, got %T", err)
1085+
}
1086+
if assertionErr.Type != runtimeErrorTypeAssertion {
1087+
t.Fatalf("expected runtime error type %s, got %s", runtimeErrorTypeAssertion, assertionErr.Type)
1088+
}
1089+
}
1090+
1091+
func TestBeginRescueTypedUnknownTypeFailsCompile(t *testing.T) {
1092+
engine := MustNewEngine(Config{})
1093+
_, err := engine.Compile(`
1094+
def bad()
1095+
begin
1096+
1 / 0
1097+
rescue(NotARealError)
1098+
"fallback"
1099+
end
1100+
end
1101+
`)
1102+
if err == nil || !strings.Contains(err.Error(), "unknown rescue error type NotARealError") {
1103+
t.Fatalf("expected unknown rescue type compile error, got %v", err)
1104+
}
1105+
}
1106+
10171107
func TestLoopControlBreakAndNext(t *testing.T) {
10181108
script := compileScript(t, `
10191109
def for_break()

0 commit comments

Comments
 (0)