Skip to content

Commit b59f24c

Browse files
committed
Merge remote-tracking branch 'origin/master' into mgomes/v0-18-0
# Conflicts: # vibes/execution.go
2 parents ac6ae93 + c869929 commit b59f24c

7 files changed

Lines changed: 428 additions & 21 deletions

File tree

vibes/execution.go

Lines changed: 89 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -111,9 +111,11 @@ const (
111111
)
112112

113113
var (
114-
errLoopBreak = errors.New("loop break")
115-
errLoopNext = errors.New("loop next")
116-
stringTemplatePattern = regexp.MustCompile(`\{\{\s*([A-Za-z_][A-Za-z0-9_.-]*)\s*\}\}`)
114+
errLoopBreak = errors.New("loop break")
115+
errLoopNext = errors.New("loop next")
116+
errStepQuotaExceeded = errors.New("step quota exceeded")
117+
errMemoryQuotaExceeded = errors.New("memory quota exceeded")
118+
stringTemplatePattern = regexp.MustCompile(`\{\{\s*([A-Za-z_][A-Za-z0-9_.-]*)\s*\}\}`)
117119
)
118120

119121
func (re *RuntimeError) Error() string {
@@ -178,7 +180,7 @@ func newAssertionFailureError(message string) error {
178180
func (exec *Execution) step() error {
179181
exec.steps++
180182
if exec.quota > 0 && exec.steps > exec.quota {
181-
return fmt.Errorf("step quota exceeded (%d)", exec.quota)
183+
return fmt.Errorf("%w (%d)", errStepQuotaExceeded, exec.quota)
182184
}
183185
if exec.memoryQuota > 0 && (exec.steps&15) == 0 {
184186
if err := exec.checkMemory(); err != nil {
@@ -237,6 +239,9 @@ func (exec *Execution) wrapError(err error, pos Position) error {
237239
if err == nil {
238240
return nil
239241
}
242+
if isHostControlSignal(err) {
243+
return err
244+
}
240245
if _, ok := err.(*RuntimeError); ok {
241246
return err
242247
}
@@ -462,6 +467,14 @@ func (exec *Execution) assignToMember(obj Value, property string, value Value, p
462467
return exec.errorAt(pos, "private method %s", setterName)
463468
}
464469
_, err := exec.callFunction(fn, obj, []Value{value}, nil, NewNil(), pos)
470+
if err != nil {
471+
if errors.Is(err, errLoopBreak) {
472+
return exec.errorAt(pos, "break cannot cross call boundary")
473+
}
474+
if errors.Is(err, errLoopNext) {
475+
return exec.errorAt(pos, "next cannot cross call boundary")
476+
}
477+
}
465478
return err
466479
}
467480

@@ -1411,7 +1424,7 @@ func (exec *Execution) evalRaiseStatement(stmt *RaiseStmt, env *Env) (Value, boo
14111424
func (exec *Execution) evalTryStatement(stmt *TryStmt, env *Env) (Value, bool, error) {
14121425
val, returned, err := exec.evalStatements(stmt.Body, env)
14131426

1414-
if err != nil && len(stmt.Rescue) > 0 && runtimeErrorMatchesRescueType(err, stmt.RescueTy) {
1427+
if err != nil && !isLoopControlSignal(err) && !isHostControlSignal(err) && len(stmt.Rescue) > 0 && runtimeErrorMatchesRescueType(err, stmt.RescueTy) {
14151428
exec.pushRescuedError(err)
14161429
rescueVal, rescueReturned, rescueErr := exec.evalStatements(stmt.Rescue, env)
14171430
exec.popRescuedError()
@@ -1442,7 +1455,22 @@ func (exec *Execution) evalTryStatement(stmt *TryStmt, env *Env) (Value, bool, e
14421455
return val, returned, nil
14431456
}
14441457

1458+
func isLoopControlSignal(err error) bool {
1459+
return errors.Is(err, errLoopBreak) || errors.Is(err, errLoopNext)
1460+
}
1461+
1462+
func isHostControlSignal(err error) bool {
1463+
return errors.Is(err, context.Canceled) ||
1464+
errors.Is(err, context.DeadlineExceeded) ||
1465+
errors.Is(err, errStepQuotaExceeded) ||
1466+
errors.Is(err, errMemoryQuotaExceeded)
1467+
}
1468+
14451469
func runtimeErrorMatchesRescueType(err error, rescueTy *TypeExpr) bool {
1470+
var runtimeErr *RuntimeError
1471+
if !errors.As(err, &runtimeErr) {
1472+
return false
1473+
}
14461474
if rescueTy == nil {
14471475
return true
14481476
}
@@ -3961,7 +3989,33 @@ func formatReturnTypeMismatch(fnName string, err error) string {
39613989
return fmt.Sprintf("return type check failed for %s: %s", fnName, err.Error())
39623990
}
39633991

3992+
type typeValidationVisit struct {
3993+
valueKind ValueKind
3994+
valueID uintptr
3995+
ty *TypeExpr
3996+
}
3997+
3998+
type typeValidationState struct {
3999+
active map[typeValidationVisit]struct{}
4000+
}
4001+
39644002
func valueMatchesType(val Value, ty *TypeExpr) (bool, error) {
4003+
state := typeValidationState{
4004+
active: make(map[typeValidationVisit]struct{}),
4005+
}
4006+
return state.matches(val, ty)
4007+
}
4008+
4009+
func (s *typeValidationState) matches(val Value, ty *TypeExpr) (bool, error) {
4010+
if visit, ok := typeValidationVisitFor(val, ty); ok {
4011+
if _, seen := s.active[visit]; seen {
4012+
// Recursive value/type pair already being validated higher in the stack.
4013+
return true, nil
4014+
}
4015+
s.active[visit] = struct{}{}
4016+
defer delete(s.active, visit)
4017+
}
4018+
39654019
if ty.Nullable && val.Kind() == KindNil {
39664020
return true, nil
39674021
}
@@ -3998,7 +4052,7 @@ func valueMatchesType(val Value, ty *TypeExpr) (bool, error) {
39984052
}
39994053
elemType := ty.TypeArgs[0]
40004054
for _, elem := range val.Array() {
4001-
matches, err := valueMatchesType(elem, elemType)
4055+
matches, err := s.matches(elem, elemType)
40024056
if err != nil {
40034057
return false, err
40044058
}
@@ -4020,14 +4074,14 @@ func valueMatchesType(val Value, ty *TypeExpr) (bool, error) {
40204074
keyType := ty.TypeArgs[0]
40214075
valueType := ty.TypeArgs[1]
40224076
for key, value := range val.Hash() {
4023-
keyMatches, err := valueMatchesType(NewString(key), keyType)
4077+
keyMatches, err := s.matches(NewString(key), keyType)
40244078
if err != nil {
40254079
return false, err
40264080
}
40274081
if !keyMatches {
40284082
return false, nil
40294083
}
4030-
valueMatches, err := valueMatchesType(value, valueType)
4084+
valueMatches, err := s.matches(value, valueType)
40314085
if err != nil {
40324086
return false, err
40334087
}
@@ -4051,7 +4105,7 @@ func valueMatchesType(val Value, ty *TypeExpr) (bool, error) {
40514105
if !ok {
40524106
return false, nil
40534107
}
4054-
matches, err := valueMatchesType(fieldVal, fieldType)
4108+
matches, err := s.matches(fieldVal, fieldType)
40554109
if err != nil {
40564110
return false, err
40574111
}
@@ -4067,7 +4121,7 @@ func valueMatchesType(val Value, ty *TypeExpr) (bool, error) {
40674121
return true, nil
40684122
case TypeUnion:
40694123
for _, option := range ty.Union {
4070-
matches, err := valueMatchesType(val, option)
4124+
matches, err := s.matches(val, option)
40714125
if err != nil {
40724126
return false, err
40734127
}
@@ -4081,6 +4135,31 @@ func valueMatchesType(val Value, ty *TypeExpr) (bool, error) {
40814135
}
40824136
}
40834137

4138+
func typeValidationVisitFor(val Value, ty *TypeExpr) (typeValidationVisit, bool) {
4139+
if ty == nil {
4140+
return typeValidationVisit{}, false
4141+
}
4142+
4143+
var valueID uintptr
4144+
switch val.Kind() {
4145+
case KindArray:
4146+
valueID = reflect.ValueOf(val.Array()).Pointer()
4147+
case KindHash, KindObject:
4148+
valueID = reflect.ValueOf(val.Hash()).Pointer()
4149+
default:
4150+
return typeValidationVisit{}, false
4151+
}
4152+
if valueID == 0 {
4153+
return typeValidationVisit{}, false
4154+
}
4155+
4156+
return typeValidationVisit{
4157+
valueKind: val.Kind(),
4158+
valueID: valueID,
4159+
ty: ty,
4160+
}, true
4161+
}
4162+
40844163
func formatTypeExpr(ty *TypeExpr) string {
40854164
if ty == nil {
40864165
return "unknown"

vibes/memory.go

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -57,7 +57,7 @@ func (exec *Execution) checkMemoryWith(extras ...Value) error {
5757

5858
used := exec.estimateMemoryUsage(extras...)
5959
if used > exec.memoryQuota {
60-
return fmt.Errorf("memory quota exceeded (%d bytes)", exec.memoryQuota)
60+
return fmt.Errorf("%w (%d bytes)", errMemoryQuotaExceeded, exec.memoryQuota)
6161
}
6262
return nil
6363
}

vibes/modules.go

Lines changed: 20 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -63,6 +63,9 @@ func validateModulePolicyPatterns(patterns []string, label string) error {
6363
}
6464

6565
func modulePolicyMatch(pattern string, module string) bool {
66+
if pattern == "*" {
67+
return module != ""
68+
}
6669
matched, err := path.Match(pattern, module)
6770
if err != nil {
6871
return false
@@ -517,6 +520,17 @@ func isValidModuleAlias(name string) bool {
517520
}
518521

519522
func bindRequireAlias(root *Env, alias string, module Value) error {
523+
if err := validateRequireAliasBinding(root, alias, module); err != nil {
524+
return err
525+
}
526+
if alias == "" {
527+
return nil
528+
}
529+
root.Define(alias, module)
530+
return nil
531+
}
532+
533+
func validateRequireAliasBinding(root *Env, alias string, module Value) error {
520534
if alias == "" {
521535
return nil
522536
}
@@ -526,7 +540,6 @@ func bindRequireAlias(root *Env, alias string, module Value) error {
526540
}
527541
return fmt.Errorf("require: alias %q already defined", alias)
528542
}
529-
root.Define(alias, module)
530543
return nil
531544
}
532545

@@ -610,12 +623,14 @@ func builtinRequire(exec *Execution, receiver Value, args []Value, kwargs map[st
610623
}
611624
}
612625

613-
bindModuleExportsWithoutOverwrite(exec.root, exports)
614-
615626
exportsVal := NewObject(exports)
616-
exec.modules[entry.key] = exportsVal
617-
if err := bindRequireAlias(exec.root, alias, exportsVal); err != nil {
627+
if err := validateRequireAliasBinding(exec.root, alias, exportsVal); err != nil {
618628
return NewNil(), err
619629
}
630+
bindModuleExportsWithoutOverwrite(exec.root, exports)
631+
exec.modules[entry.key] = exportsVal
632+
if alias != "" {
633+
exec.root.Define(alias, exportsVal)
634+
}
620635
return exportsVal, nil
621636
}

vibes/modules_test.go

Lines changed: 67 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -125,6 +125,39 @@ end`)
125125
}
126126
}
127127

128+
func TestRequireAliasConflictDoesNotLeakExportsWhenRescued(t *testing.T) {
129+
engine := MustNewEngine(Config{ModulePaths: []string{filepath.Join("testdata", "modules")}})
130+
131+
script, err := engine.Compile(`def helpers(value)
132+
value
133+
end
134+
135+
def run(value)
136+
begin
137+
require("helper", as: "helpers")
138+
rescue
139+
nil
140+
end
141+
142+
begin
143+
double(value)
144+
rescue
145+
"missing"
146+
end
147+
end`)
148+
if err != nil {
149+
t.Fatalf("compile failed: %v", err)
150+
}
151+
152+
result, err := script.Call(context.Background(), "run", []Value{NewInt(3)}, CallOptions{})
153+
if err != nil {
154+
t.Fatalf("call failed: %v", err)
155+
}
156+
if result.Kind() != KindString || result.String() != "missing" {
157+
t.Fatalf("expected leaked export lookup to fail, got %#v", result)
158+
}
159+
}
160+
128161
func TestRequirePreservesModuleLocalResolution(t *testing.T) {
129162
engine := MustNewEngine(Config{ModulePaths: []string{filepath.Join("testdata", "modules")}})
130163

@@ -781,6 +814,17 @@ end`)
781814
if err == nil || !strings.Contains(err.Error(), "export is only supported for top-level functions") {
782815
t.Fatalf("expected top-level export parse error, got %v", err)
783816
}
817+
818+
_, err = engine.Compile(`def outer()
819+
if true
820+
export def nested()
821+
1
822+
end
823+
end
824+
end`)
825+
if err == nil || !strings.Contains(err.Error(), "export is only supported for top-level functions") {
826+
t.Fatalf("expected nested export parse error, got %v", err)
827+
}
784828
}
785829

786830
func TestRequirePrivateFunctionsAreNotInjectedAsGlobals(t *testing.T) {
@@ -981,6 +1025,29 @@ end`)
9811025
}
9821026
}
9831027

1028+
func TestRequireModuleAllowListStarMatchesNestedModules(t *testing.T) {
1029+
engine := MustNewEngine(Config{
1030+
ModulePaths: []string{filepath.Join("testdata", "modules")},
1031+
ModuleAllowList: []string{"*"},
1032+
})
1033+
1034+
script, err := engine.Compile(`def run(value)
1035+
mod = require("shared/math")
1036+
mod.double(value)
1037+
end`)
1038+
if err != nil {
1039+
t.Fatalf("compile failed: %v", err)
1040+
}
1041+
1042+
result, err := script.Call(context.Background(), "run", []Value{NewInt(4)}, CallOptions{})
1043+
if err != nil {
1044+
t.Fatalf("call failed: %v", err)
1045+
}
1046+
if result.Kind() != KindInt || result.Int() != 8 {
1047+
t.Fatalf("expected nested module call result 8, got %#v", result)
1048+
}
1049+
}
1050+
9841051
func TestRequireModuleDenyListOverridesAllowList(t *testing.T) {
9851052
engine := MustNewEngine(Config{
9861053
ModulePaths: []string{filepath.Join("testdata", "modules")},

0 commit comments

Comments
 (0)