fix table merge problem by defining the merge identity

master
Jordan Orelli 3 years ago
parent 1958ff49f8
commit e8710f676d

@ -11,6 +11,7 @@ import (
// heterogeneous members, but should likely be avoided otherwise. // heterogeneous members, but should likely be avoided otherwise.
type Boxed struct { type Boxed struct {
val interface{} val interface{}
ident func() interface{}
merge func(interface{}) error merge func(interface{}) error
} }
@ -21,8 +22,18 @@ func (b Boxed) Merge(from Boxed) error {
return nil return nil
} }
func (b Boxed) MergeIdentity() Boxed {
return Boxed{
val: b.ident(),
ident: b.ident,
merge: b.merge,
}
}
var typeMismatch = errors.New("mismatched types") var typeMismatch = errors.New("mismatched types")
// strip takes a typed function and erases the type information from its
// parameter
func strip[X any](f func(X) error) func(interface{}) error { func strip[X any](f func(X) error) func(interface{}) error {
return func(v interface{}) error { return func(v interface{}) error {
vv, ok := v.(X) vv, ok := v.(X)
@ -38,6 +49,9 @@ func strip[X any](f func(X) error) func(interface{}) error {
func Box[X Merges[X]](x X) Boxed { func Box[X Merges[X]](x X) Boxed {
return Boxed{ return Boxed{
val: x, val: x,
ident: func() interface{} {
return x.MergeIdentity()
},
merge: strip(x.Merge), merge: strip(x.Merge),
} }
} }

@ -18,6 +18,7 @@ package merge
// Although the type B satisfies the interface Merges[*C], it does not satisfy // Although the type B satisfies the interface Merges[*C], it does not satisfy
// the constraint [X Merges[X]], which is what is used throughout this package. // the constraint [X Merges[X]], which is what is used throughout this package.
type Merges[X any] interface { type Merges[X any] interface {
MergeIdentity() X
Merge(X) error Merge(X) error
} }

@ -10,6 +10,10 @@ type additive struct {
total int total int
} }
func (*additive) MergeIdentity() *additive {
return new(additive)
}
func (a *additive) Merge(b *additive) error { func (a *additive) Merge(b *additive) error {
a.total += b.total a.total += b.total
return nil return nil
@ -23,6 +27,10 @@ type multiplicative struct {
scale int scale int
} }
func (m *multiplicative) MergeIdentity() *multiplicative {
return &multiplicative{scale: 1}
}
func (m *multiplicative) Merge(v *multiplicative) error { func (m *multiplicative) Merge(v *multiplicative) error {
m.scale *= v.scale m.scale *= v.scale
return nil return nil
@ -38,6 +46,8 @@ type exclusive struct {
stock int stock int
} }
func (*exclusive) MergeIdentity() *exclusive { return new(exclusive) }
func (e *exclusive) Merge(source *exclusive) error { func (e *exclusive) Merge(source *exclusive) error {
e.stock += source.stock e.stock += source.stock
source.stock = 0 source.stock = 0
@ -75,5 +85,3 @@ func TestMerge(t *testing.T) {
} }
}) })
} }

@ -10,7 +10,12 @@ func (t Table[K, V]) Merge(from Table[K, V]) error {
for k, v := range from { for k, v := range from {
e, ok := t[k] e, ok := t[k]
if !ok { if !ok {
t[k] = v var z V
z = z.MergeIdentity()
if err := z.Merge(v); err != nil {
return fmt.Errorf("tables failed to merge: %w", err)
}
t[k] = z
continue continue
} }

@ -35,4 +35,7 @@ func TestMergeTables(t *testing.T) {
check("strawberry", 2) check("strawberry", 2)
check("pistacchio", 5) check("pistacchio", 5)
bob["pistacchio"].Merge(add(100))
bob["pistacchio"] = add(1)
check("pistacchio", 5)
} }

Loading…
Cancel
Save