diff --git a/acc.go b/acc.go index be4fb44..2206bc8 100644 --- a/acc.go +++ b/acc.go @@ -14,8 +14,8 @@ type accumulator struct { floating bool } -func (a *accumulator) total(vals ...interface{}) (interface{}, error) { - if len(vals) == 0 { +func (a *accumulator) total(vals []interface{}) (interface{}, error) { + if vals == nil || len(vals) == 0 { return int64(0), nil } diff --git a/proc.go b/proc.go index 3936f60..3da45a5 100644 --- a/proc.go +++ b/proc.go @@ -4,9 +4,29 @@ import ( "errors" ) -type builtin func(...interface{}) (interface{}, error) +type builtin func([]interface{}) (interface{}, error) -func addition(vals ...interface{}) (interface{}, error) { +// evaluates all of the arguments, and then calls the function with the results +// of the evaluations +func (b *builtin) call(env *environment, rawArgs []interface{}) (interface{}, error) { + if rawArgs == nil { + return (*b)(nil) + } + + // eval all arguments first + args := make([]interface{}, 0, len(rawArgs)) + for _, raw := range rawArgs { + v, err := eval(raw, env) + if err != nil { + return nil, err + } + args = append(args, v) + } + + return (*b)(args) +} + +func addition(vals []interface{}) (interface{}, error) { a := accumulator{ name: "addition", floatFn: func(left, right float64) (float64, error) { @@ -16,10 +36,10 @@ func addition(vals ...interface{}) (interface{}, error) { return left + right, nil }, } - return a.total(vals...) + return a.total(vals) } -func subtraction(vals ...interface{}) (interface{}, error) { +func subtraction(vals []interface{}) (interface{}, error) { a := accumulator{ name: "subtraction", floatFn: func(left, right float64) (float64, error) { @@ -29,10 +49,10 @@ func subtraction(vals ...interface{}) (interface{}, error) { return left - right, nil }, } - return a.total(vals...) + return a.total(vals) } -func multiplication(vals ...interface{}) (interface{}, error) { +func multiplication(vals []interface{}) (interface{}, error) { a := accumulator{ name: "multiplication", floatFn: func(left, right float64) (float64, error) { @@ -42,10 +62,10 @@ func multiplication(vals ...interface{}) (interface{}, error) { return left * right, nil }, } - return a.total(vals...) + return a.total(vals) } -func division(vals ...interface{}) (interface{}, error) { +func division(vals []interface{}) (interface{}, error) { a := accumulator{ name: "division", floatFn: func(left, right float64) (float64, error) { @@ -61,5 +81,5 @@ func division(vals ...interface{}) (interface{}, error) { return left / right, nil }, } - return a.total(vals...) + return a.total(vals) } diff --git a/skeam.go b/skeam.go index 0b55184..518dc3b 100644 --- a/skeam.go +++ b/skeam.go @@ -131,6 +131,7 @@ func eval(v interface{}, env *environment) (interface{}, error) { return nil, err } + // check to see if this is a special form if spec, ok := v.(special); ok { if len(t) > 1 { return spec(env, t[1:]...) @@ -139,33 +140,16 @@ func eval(v interface{}, env *environment) (interface{}, error) { } } - fn, ok := v.(builtin) - if !ok { - return nil, fmt.Errorf("expected builtin, found %v", reflect.TypeOf(v)) - } - - if len(t) > 1 { - args := make([]interface{}, 0, len(t)-1) - for _, raw := range t[1:] { - v, err := eval(raw, env) - if err != nil { - return nil, err - } - args = append(args, v) - } - inner, err := fn(args...) - if err != nil { - return nil, err + // exec builtin func if one exists + if b, ok := v.(builtin); ok { + if len(t) > 1 { + return b.call(env, t[1:]) + } else { + return b.call(env, nil) } - return eval(inner, env) - } - - inner, err := fn() - if err != nil { - return nil, err } - return eval(inner, env) + return nil, fmt.Errorf(`expected special form or builtin procedure, received %v`, reflect.TypeOf(v)) default: return v, nil