fix table merge problem by defining the merge identity

master
Jordan Orelli 3 years ago
parent 1958ff49f8
commit e8710f676d

@ -11,6 +11,7 @@ import (
// heterogeneous members, but should likely be avoided otherwise.
type Boxed struct {
val interface{}
ident func() interface{}
merge func(interface{}) error
}
@ -21,8 +22,18 @@ func (b Boxed) Merge(from Boxed) error {
return nil
}
func (b Boxed) MergeIdentity() Boxed {
return Boxed{
val: b.ident(),
ident: b.ident,
merge: b.merge,
}
}
var typeMismatch = errors.New("mismatched types")
// strip takes a typed function and erases the type information from its
// parameter
func strip[X any](f func(X) error) func(interface{}) error {
return func(v interface{}) error {
vv, ok := v.(X)
@ -38,6 +49,9 @@ func strip[X any](f func(X) error) func(interface{}) error {
func Box[X Merges[X]](x X) Boxed {
return Boxed{
val: x,
ident: func() interface{} {
return x.MergeIdentity()
},
merge: strip(x.Merge),
}
}

@ -18,6 +18,7 @@ package merge
// Although the type B satisfies the interface Merges[*C], it does not satisfy
// the constraint [X Merges[X]], which is what is used throughout this package.
type Merges[X any] interface {
MergeIdentity() X
Merge(X) error
}

@ -10,6 +10,10 @@ type additive struct {
total int
}
func (*additive) MergeIdentity() *additive {
return new(additive)
}
func (a *additive) Merge(b *additive) error {
a.total += b.total
return nil
@ -23,6 +27,10 @@ type multiplicative struct {
scale int
}
func (m *multiplicative) MergeIdentity() *multiplicative {
return &multiplicative{scale: 1}
}
func (m *multiplicative) Merge(v *multiplicative) error {
m.scale *= v.scale
return nil
@ -38,6 +46,8 @@ type exclusive struct {
stock int
}
func (*exclusive) MergeIdentity() *exclusive { return new(exclusive) }
func (e *exclusive) Merge(source *exclusive) error {
e.stock += source.stock
source.stock = 0
@ -75,5 +85,3 @@ func TestMerge(t *testing.T) {
}
})
}

@ -10,7 +10,12 @@ func (t Table[K, V]) Merge(from Table[K, V]) error {
for k, v := range from {
e, ok := t[k]
if !ok {
t[k] = v
var z V
z = z.MergeIdentity()
if err := z.Merge(v); err != nil {
return fmt.Errorf("tables failed to merge: %w", err)
}
t[k] = z
continue
}

@ -35,4 +35,7 @@ func TestMergeTables(t *testing.T) {
check("strawberry", 2)
check("pistacchio", 5)
bob["pistacchio"].Merge(add(100))
bob["pistacchio"] = add(1)
check("pistacchio", 5)
}

Loading…
Cancel
Save