refactored builtin type

this is done to avoid repeating all the arity checks all over the place.
That was getting a little tedious.
master
Jordan Orelli 12 years ago
parent b9d2afce74
commit c4fed38dd9

@ -5,6 +5,9 @@ import (
"reflect"
)
// type accumulator describes an accumulator. That is, it is a numerical
// structure that applies a pair of functions across a list of values that are
// expected to be numerical; i.e. of type int64 or float64.
type accumulator struct {
name string
floatFn func(float64, float64) (float64, error)
@ -14,7 +17,7 @@ type accumulator struct {
floating bool
}
func (a *accumulator) total(vals []interface{}) (interface{}, error) {
func (a accumulator) total(vals []interface{}) (interface{}, error) {
if vals == nil || len(vals) == 0 {
return a.acc, nil
}

@ -0,0 +1,235 @@
package main
import (
"errors"
"fmt"
"reflect"
)
type builtin struct {
// name of the function
name string
// minimum number of arguments
arity int
// whether we will accept arbitrarily large numbers of arguments or not
variadic bool
// function to be called
fn func([]interface{}) (interface{}, error)
}
// begins by evaluating all of its inputs. An error on input evaluation will
// stop evaluation of a builtin. After evaluating its inputs, an arity check
// is performed to see if the proper number of arguments have been supplied.
// Perhaps this is the wrong order, I'm unsure. Finally, the procudure is
// passed the post-evaluation arguments to be executed.
func (b builtin) call(env *environment, rawArgs []interface{}) (interface{}, error) {
// eval all arguments first
args := make([]interface{}, 0, len(rawArgs))
for _, raw := range rawArgs {
v, err := eval(raw, env)
if err != nil {
return nil, err
}
args = append(args, v)
}
if err := b.checkArity(len(rawArgs)); err != nil {
return nil, err
}
return b.fn(args)
}
func (b builtin) checkArity(n int) error {
if n == b.arity {
return nil
}
if b.variadic && n > b.arity {
return nil
}
return arityError{
expected: b.arity,
received: n,
name: b.name,
variadic: b.variadic,
}
}
var add = builtin{
name: "+",
variadic: true,
fn: func(vals []interface{}) (interface{}, error) {
return accumulator{
name: "+",
floatFn: func(left, right float64) (float64, error) {
return left + right, nil
},
intFn: func(left, right int64) (int64, error) {
return left + right, nil
},
}.total(vals)
},
}
var sub = builtin{
name: "-",
variadic: true,
fn: func(vals []interface{}) (interface{}, error) {
return accumulator{
name: "-",
floatFn: func(left, right float64) (float64, error) {
return left - right, nil
},
intFn: func(left, right int64) (int64, error) {
return left - right, nil
},
}.total(vals)
},
}
var mul = builtin{
name: "*",
variadic: true,
fn: func(vals []interface{}) (interface{}, error) {
return accumulator{
name: "*",
floatFn: func(left, right float64) (float64, error) {
return left * right, nil
},
intFn: func(left, right int64) (int64, error) {
return left * right, nil
},
acc: 1,
accf: 1.0,
}.total(vals)
},
}
var div = builtin{
name: "/",
variadic: true,
fn: func(vals []interface{}) (interface{}, error) {
return accumulator{
name: "/",
floatFn: func(left, right float64) (float64, error) {
if right == 0.0 {
return 0.0, errors.New("float division by zero")
}
return left / right, nil
},
intFn: func(left, right int64) (int64, error) {
if right == 0 {
return 0, errors.New("int division by zero")
}
return left / right, nil
},
}.total(vals)
},
}
var not = builtin{
name: "not",
arity: 1,
fn: func(vals []interface{}) (interface{}, error) {
return !booleanize(vals[0]), nil
},
}
var length = builtin{
name: "length",
arity: 1,
fn: func(vals []interface{}) (interface{}, error) {
switch t := vals[0].(type) {
case sexp:
return len(t), nil
case list:
return len(t.sexp), nil
}
return nil, fmt.Errorf("first argument must be sexp, received %v", reflect.TypeOf(vals[0]))
},
}
var lst = builtin{
name: "list",
variadic: true,
fn: func(vals []interface{}) (interface{}, error) {
return list{sexp(vals), 1}, nil
},
}
var islist = builtin{
name: "list?",
arity: 1,
fn: func(vals []interface{}) (interface{}, error) {
switch vals[0].(type) {
case list, sexp:
return true, nil
}
return false, nil
},
}
var isnull = builtin{
name: "null?",
arity: 1,
fn: func(vals []interface{}) (interface{}, error) {
s, ok := vals[0].(sexp)
if !ok {
return false, nil
}
return len(s) == 0, nil
},
}
var issymbol = builtin{
name: "symbol?",
arity: 1,
fn: func(vals []interface{}) (interface{}, error) {
_, ok := vals[0].(symbol)
return ok, nil
},
}
var cons = builtin{
name: "cons",
arity: 2,
fn: func(vals []interface{}) (interface{}, error) {
s := sexp{vals[0]}
switch t := vals[1].(type) {
case sexp:
return append(s, t...), nil
default:
return append(s, t), nil
}
panic("not reached")
},
}
var car = builtin{
name: "car",
arity: 1,
fn: func(vals []interface{}) (interface{}, error) {
s, ok := vals[0].(sexp)
if !ok {
return nil, errors.New("expected list")
}
return s[0], nil
},
}
var cdr = builtin{
name: "cdr",
arity: 1,
fn: func(vals []interface{}) (interface{}, error) {
s, ok := vals[0].(sexp)
if !ok {
return nil, errors.New("expected list")
}
return s[1:], nil
},
}

101
cmp.go

@ -0,0 +1,101 @@
package main
import (
"errors"
"fmt"
"reflect"
)
type cmp_bin_i func(int64, int64) bool
type cmp_bin_f func(float64, float64) bool
func cmp_left(vals []interface{}, fni cmp_bin_i, fnf cmp_bin_f) (bool, error) {
if len(vals) < 2 {
return false, errors.New("expected at least 2 arguments")
}
var lasti int64
var lastf float64
var floating bool
switch v := vals[0].(type) {
case float64:
floating = true
lastf = v
case int64:
lasti = v
default:
return false, fmt.Errorf("gt is not defined for %v", reflect.TypeOf(v))
}
for _, raw := range vals[1:] {
switch v := raw.(type) {
case float64:
if !floating {
floating = true
lastf = float64(lasti)
}
if !fnf(lastf, v) {
return false, nil
}
lastf = v
case int64:
if floating {
f := float64(v)
if !fnf(lastf, f) {
return false, nil
}
lastf = f
} else {
if !fni(lasti, v) {
return false, nil
}
lasti = v
}
default:
return false, errors.New("ooga booga")
}
}
return true, nil
}
var gt = builtin{
name: ">",
variadic: true,
fn: func(vals []interface{}) (interface{}, error) {
fni := func(x, y int64) bool { return x > y }
fnf := func(x, y float64) bool { return x > y }
return cmp_left(vals, fni, fnf)
},
}
var gte = builtin{
name: ">=",
variadic: true,
fn: func(vals []interface{}) (interface{}, error) {
fni := func(x, y int64) bool { return x >= y }
fnf := func(x, y float64) bool { return x >= y }
return cmp_left(vals, fni, fnf)
},
}
var lt = builtin{
name: "<",
variadic: true,
fn: func(vals []interface{}) (interface{}, error) {
fni := func(x, y int64) bool { return x < y }
fnf := func(x, y float64) bool { return x < y }
return cmp_left(vals, fni, fnf)
},
}
var lte = builtin{
name: "<=",
variadic: true,
fn: func(vals []interface{}) (interface{}, error) {
fni := func(x, y int64) bool { return x <= y }
fnf := func(x, y float64) bool { return x <= y }
return cmp_left(vals, fni, fnf)
},
}

@ -1,257 +0,0 @@
package main
import (
"errors"
"fmt"
"reflect"
)
type builtin func([]interface{}) (interface{}, error)
// evaluates all of the arguments, and then calls the function with the results
// of the evaluations
func (b *builtin) call(env *environment, rawArgs []interface{}) (interface{}, error) {
if rawArgs == nil {
return (*b)(nil)
}
// eval all arguments first
args := make([]interface{}, 0, len(rawArgs))
for _, raw := range rawArgs {
v, err := eval(raw, env)
if err != nil {
return nil, err
}
args = append(args, v)
}
return (*b)(args)
}
func addition(vals []interface{}) (interface{}, error) {
a := accumulator{
name: "addition",
floatFn: func(left, right float64) (float64, error) {
return left + right, nil
},
intFn: func(left, right int64) (int64, error) {
return left + right, nil
},
}
return a.total(vals)
}
func subtraction(vals []interface{}) (interface{}, error) {
a := accumulator{
name: "subtraction",
floatFn: func(left, right float64) (float64, error) {
return left - right, nil
},
intFn: func(left, right int64) (int64, error) {
return left - right, nil
},
}
return a.total(vals)
}
func multiplication(vals []interface{}) (interface{}, error) {
a := accumulator{
name: "multiplication",
floatFn: func(left, right float64) (float64, error) {
return left * right, nil
},
intFn: func(left, right int64) (int64, error) {
return left * right, nil
},
acc: 1,
accf: 1.0,
}
return a.total(vals)
}
func division(vals []interface{}) (interface{}, error) {
a := accumulator{
name: "division",
floatFn: func(left, right float64) (float64, error) {
if right == 0.0 {
return 0.0, errors.New("float division by zero")
}
return left / right, nil
},
intFn: func(left, right int64) (int64, error) {
if right == 0 {
return 0, errors.New("int division by zero")
}
return left / right, nil
},
}
return a.total(vals)
}
func not(vals []interface{}) (interface{}, error) {
if err := checkArity(1, vals, "not"); err != nil {
return nil, err
}
return !booleanize(vals[0]), nil
}
func length(vals []interface{}) (interface{}, error) {
if err := checkArity(1, vals, "length"); err != nil {
return nil, err
}
x, ok := vals[0].(sexp)
if !ok {
return nil, fmt.Errorf("first argument must be sexp, received %v", reflect.TypeOf(vals[0]))
}
return len(x), nil
}
func lst(vals []interface{}) (interface{}, error) {
return sexp(vals), nil
}
func islist(vals []interface{}) (interface{}, error) {
if err := checkArity(1, vals, "list?"); err != nil {
return nil, err
}
_, ok := vals[0].(sexp)
return ok, nil
}
func isnull(vals []interface{}) (interface{}, error) {
if err := checkArity(1, vals, "null?"); err != nil {
return nil, err
}
s, ok := vals[0].(sexp)
if !ok {
return false, nil
}
return len(s) == 0, nil
}
func issymbol(vals []interface{}) (interface{}, error) {
if err := checkArity(1, vals, "symbol?"); err != nil {
return nil, err
}
_, ok := vals[0].(symbol)
return ok, nil
}
func cons(vals []interface{}) (interface{}, error) {
if err := checkArity(2, vals, "cons"); err != nil {
return nil, err
}
s := sexp{vals[0]}
switch t := vals[1].(type) {
case sexp:
return append(s, t...), nil
default:
return append(s, t), nil
}
panic("not reached")
}
func car(vals []interface{}) (interface{}, error) {
if err := checkArity(1, vals, "car"); err != nil {
return nil, err
}
s, ok := vals[0].(sexp)
if !ok {
return nil, errors.New("expected list")
}
return s[0], nil
}
func cdr(vals []interface{}) (interface{}, error) {
if err := checkArity(1, vals, "cdr"); err != nil {
return nil, err
}
s, ok := vals[0].(sexp)
if !ok {
return nil, errors.New("expected list")
}
return s[1:], nil
}
type cmp_bin_i func(int64, int64) bool
type cmp_bin_f func(float64, float64) bool
func cmp_left(vals []interface{}, fni cmp_bin_i, fnf cmp_bin_f) (bool, error) {
if len(vals) < 2 {
return false, errors.New("expected at least 2 arguments")
}
var lasti int64
var lastf float64
var floating bool
switch v := vals[0].(type) {
case float64:
floating = true
lastf = v
case int64:
lasti = v
default:
return false, fmt.Errorf("gt is not defined for %v", reflect.TypeOf(v))
}
for _, raw := range vals[1:] {
switch v := raw.(type) {
case float64:
if !floating {
floating = true
lastf = float64(lasti)
}
if !fnf(lastf, v) {
return false, nil
}
lastf = v
case int64:
if floating {
f := float64(v)
if !fnf(lastf, f) {
return false, nil
}
lastf = f
} else {
if !fni(lasti, v) {
return false, nil
}
lasti = v
}
default:
return false, errors.New("ooga booga")
}
}
return true, nil
}
func gt(vals []interface{}) (interface{}, error) {
fni := func(x, y int64) bool { return x > y }
fnf := func(x, y float64) bool { return x > y }
return cmp_left(vals, fni, fnf)
}
func gte(vals []interface{}) (interface{}, error) {
fni := func(x, y int64) bool { return x >= y }
fnf := func(x, y float64) bool { return x >= y }
return cmp_left(vals, fni, fnf)
}
func lt(vals []interface{}) (interface{}, error) {
fni := func(x, y int64) bool { return x < y }
fnf := func(x, y float64) bool { return x < y }
return cmp_left(vals, fni, fnf)
}
func lte(vals []interface{}) (interface{}, error) {
fni := func(x, y int64) bool { return x <= y }
fnf := func(x, y float64) bool { return x <= y }
return cmp_left(vals, fni, fnf)
}

@ -41,25 +41,25 @@ var universe = &environment{map[symbol]interface{}{
"null": nil,
// builtin functions
"+": builtin(addition),
"-": builtin(subtraction),
"*": builtin(multiplication),
"/": builtin(division),
">": builtin(gt),
">=": builtin(gte),
"<": builtin(lt),
"<=": builtin(lte),
"cons": builtin(cons),
"car": builtin(car),
"cdr": builtin(cdr),
"length": builtin(length),
"list": builtin(lst),
"list?": builtin(islist),
"not": builtin(not),
"null?": builtin(isnull),
"symbol?": builtin(issymbol),
// "="
// "equal?"
symbol(add.name): add,
symbol(sub.name): sub,
symbol(mul.name): mul,
symbol(div.name): div,
symbol(gt.name): gt,
symbol(gte.name): gte,
symbol(lt.name): lt,
symbol(lte.name): lte,
symbol(cons.name): cons,
symbol(car.name): car,
symbol(cdr.name): cdr,
symbol(length.name): length,
symbol(lst.name): lst,
symbol(islist.name): islist,
symbol(not.name): not,
symbol(isnull.name): isnull,
symbol(issymbol.name): issymbol,
// "=": builtin(equal),
// "equal?": builtin(equal),
// "eq?"
// "append"

@ -17,9 +17,15 @@ type arityError struct {
expected int
received int
name string
variadic bool
}
func (n arityError) Error() string {
if n.variadic {
return fmt.Sprintf(`received %d arguments in *%v*, expected %d (or more)`,
n.received, n.name, n.expected)
}
return fmt.Sprintf(`received %d arguments in *%v*, expected %d`,
n.received, n.name, n.expected)
}
@ -31,10 +37,10 @@ func checkArity(arity int, args []interface{}, name string) error {
if arity == 0 {
return nil
}
return arityError{arity, 0, name}
return arityError{arity, 0, name, false}
}
if len(args) != arity {
return arityError{arity, len(args), name}
return arityError{arity, len(args), name, false}
}
return nil
}

Loading…
Cancel
Save