|
6 | 6 | "fmt" |
7 | 7 | "maps" |
8 | 8 | "math" |
| 9 | + "reflect" |
9 | 10 | "regexp" |
10 | 11 | "slices" |
11 | 12 | "sort" |
@@ -1004,7 +1005,12 @@ func (exec *Execution) CallBlock(block Value, args []Value) (Value, error) { |
1004 | 1005 | } else { |
1005 | 1006 | val = NewNil() |
1006 | 1007 | } |
1007 | | - blockEnv.Define(param, val) |
| 1008 | + if param.Type != nil { |
| 1009 | + if err := checkValueType(val, param.Type); err != nil { |
| 1010 | + return NewNil(), exec.errorAt(param.Type.position, "%s", formatArgumentTypeMismatch(param.Name, err)) |
| 1011 | + } |
| 1012 | + } |
| 1013 | + blockEnv.Define(param.Name, val) |
1008 | 1014 | } |
1009 | 1015 | val, returned, err := exec.evalStatements(blk.Body, blockEnv) |
1010 | 1016 | if err != nil { |
@@ -3282,7 +3288,7 @@ func checkValueType(val Value, ty *TypeExpr) error { |
3282 | 3288 | } |
3283 | 3289 | return &typeMismatchError{ |
3284 | 3290 | Expected: formatTypeExpr(ty), |
3285 | | - Actual: val.Kind().String(), |
| 3291 | + Actual: formatValueTypeExpr(val), |
3286 | 3292 | } |
3287 | 3293 | } |
3288 | 3294 |
|
@@ -3512,6 +3518,127 @@ func formatShapeType(ty *TypeExpr) string { |
3512 | 3518 | return "{ " + strings.Join(parts, ", ") + " }" |
3513 | 3519 | } |
3514 | 3520 |
|
| 3521 | +func formatValueTypeExpr(val Value) string { |
| 3522 | + state := valueTypeFormatState{ |
| 3523 | + seenArrays: make(map[uintptr]struct{}), |
| 3524 | + seenHashes: make(map[uintptr]struct{}), |
| 3525 | + } |
| 3526 | + return state.format(val) |
| 3527 | +} |
| 3528 | + |
| 3529 | +type valueTypeFormatState struct { |
| 3530 | + seenArrays map[uintptr]struct{} |
| 3531 | + seenHashes map[uintptr]struct{} |
| 3532 | +} |
| 3533 | + |
| 3534 | +func (s *valueTypeFormatState) format(val Value) string { |
| 3535 | + switch val.Kind() { |
| 3536 | + case KindNil: |
| 3537 | + return "nil" |
| 3538 | + case KindBool: |
| 3539 | + return "bool" |
| 3540 | + case KindInt: |
| 3541 | + return "int" |
| 3542 | + case KindFloat: |
| 3543 | + return "float" |
| 3544 | + case KindString: |
| 3545 | + return "string" |
| 3546 | + case KindMoney: |
| 3547 | + return "money" |
| 3548 | + case KindDuration: |
| 3549 | + return "duration" |
| 3550 | + case KindTime: |
| 3551 | + return "time" |
| 3552 | + case KindSymbol: |
| 3553 | + return "symbol" |
| 3554 | + case KindRange: |
| 3555 | + return "range" |
| 3556 | + case KindFunction: |
| 3557 | + return "function" |
| 3558 | + case KindBuiltin: |
| 3559 | + return "builtin" |
| 3560 | + case KindBlock: |
| 3561 | + return "block" |
| 3562 | + case KindClass: |
| 3563 | + return "class" |
| 3564 | + case KindInstance: |
| 3565 | + return "instance" |
| 3566 | + case KindArray: |
| 3567 | + return s.formatArray(val.Array()) |
| 3568 | + case KindHash, KindObject: |
| 3569 | + return s.formatHash(val.Hash()) |
| 3570 | + default: |
| 3571 | + return val.Kind().String() |
| 3572 | + } |
| 3573 | +} |
| 3574 | + |
| 3575 | +func (s *valueTypeFormatState) formatArray(values []Value) string { |
| 3576 | + if len(values) == 0 { |
| 3577 | + return "array<empty>" |
| 3578 | + } |
| 3579 | + |
| 3580 | + id := reflect.ValueOf(values).Pointer() |
| 3581 | + if id != 0 { |
| 3582 | + if _, seen := s.seenArrays[id]; seen { |
| 3583 | + return "array<...>" |
| 3584 | + } |
| 3585 | + s.seenArrays[id] = struct{}{} |
| 3586 | + defer delete(s.seenArrays, id) |
| 3587 | + } |
| 3588 | + |
| 3589 | + elementTypes := make(map[string]struct{}, len(values)) |
| 3590 | + for _, value := range values { |
| 3591 | + elementTypes[s.format(value)] = struct{}{} |
| 3592 | + } |
| 3593 | + return "array<" + joinSortedTypes(elementTypes) + ">" |
| 3594 | +} |
| 3595 | + |
| 3596 | +func (s *valueTypeFormatState) formatHash(values map[string]Value) string { |
| 3597 | + if len(values) == 0 { |
| 3598 | + return "{}" |
| 3599 | + } |
| 3600 | + |
| 3601 | + id := reflect.ValueOf(values).Pointer() |
| 3602 | + if id != 0 { |
| 3603 | + if _, seen := s.seenHashes[id]; seen { |
| 3604 | + return "{ ... }" |
| 3605 | + } |
| 3606 | + s.seenHashes[id] = struct{}{} |
| 3607 | + defer delete(s.seenHashes, id) |
| 3608 | + } |
| 3609 | + |
| 3610 | + if len(values) <= 6 { |
| 3611 | + fields := make([]string, 0, len(values)) |
| 3612 | + for field := range values { |
| 3613 | + fields = append(fields, field) |
| 3614 | + } |
| 3615 | + sort.Strings(fields) |
| 3616 | + parts := make([]string, len(fields)) |
| 3617 | + for i, field := range fields { |
| 3618 | + parts[i] = fmt.Sprintf("%s: %s", field, s.format(values[field])) |
| 3619 | + } |
| 3620 | + return "{ " + strings.Join(parts, ", ") + " }" |
| 3621 | + } |
| 3622 | + |
| 3623 | + valueTypes := make(map[string]struct{}, len(values)) |
| 3624 | + for _, value := range values { |
| 3625 | + valueTypes[s.format(value)] = struct{}{} |
| 3626 | + } |
| 3627 | + return "hash<string, " + joinSortedTypes(valueTypes) + ">" |
| 3628 | +} |
| 3629 | + |
| 3630 | +func joinSortedTypes(typeSet map[string]struct{}) string { |
| 3631 | + if len(typeSet) == 0 { |
| 3632 | + return "empty" |
| 3633 | + } |
| 3634 | + parts := make([]string, 0, len(typeSet)) |
| 3635 | + for typeName := range typeSet { |
| 3636 | + parts = append(parts, typeName) |
| 3637 | + } |
| 3638 | + sort.Strings(parts) |
| 3639 | + return strings.Join(parts, " | ") |
| 3640 | +} |
| 3641 | + |
3515 | 3642 | // Function looks up a compiled function by name. |
3516 | 3643 | func (s *Script) Function(name string) (*ScriptFunction, bool) { |
3517 | 3644 | fn, ok := s.functions[name] |
|
0 commit comments