From b79fb00f2e856d9a5f675b05791b74d828644c56 Mon Sep 17 00:00:00 2001 From: Jordan Orelli Date: Sun, 28 Nov 2021 23:03:38 +0000 Subject: [PATCH] bags --- bag/bag.go | 77 +++++++++++++++++++++++++++++++++++++++++ bag/bag_test.go | 92 +++++++++++++++++++++++++++++++++++++++++++++++++ 2 files changed, 169 insertions(+) create mode 100644 bag/bag.go create mode 100644 bag/bag_test.go diff --git a/bag/bag.go b/bag/bag.go new file mode 100644 index 0000000..e6183ec --- /dev/null +++ b/bag/bag.go @@ -0,0 +1,77 @@ +package bag + +import ( + "errors" +) + +var errNotFound = errors.New("not found") +var errTypeError = errors.New("type error") + +// Bag is a read-only collection of values. Consumer can add elements to the +// bag if and only no element has been added for that key in the past. Elements +// can be added to the bag either as values or as pointers. +type Bag map[string]bagged + +// Add adds a value to a bag. The provided value can be retrieved from the bag +// directly. There's really no reason to call this with a pointer but I don't +// know how to prevent that. +func Add(b Bag, k string, v interface{}) bool { + if b.Has(k) { + return false + } + b[k] = bagged{val: v} + return true +} + +// Ref adds a reference to a bag. The provided value must be a pointer. Once +// added, the pointer is never retrievable from the bag; reading this key from +// the bag dereferences the pointer at the time of reading. +func Ref[V any](b Bag, k string, v *V) bool { + if b.Has(k) { + return false + } + + if v == nil { + return false + } + + b[k] = bagged{val: v, ref: true} + return true +} + +// Get retrieves a value from a bag. Whether a value was added or a ref was +// added, you always get a value out. +func Get[V any](b Bag, k string) (V, error) { + bv, ok := b[k] + if !ok { + var zero V + return zero, errNotFound + } + + if bv.ref { + ptr, ok := bv.val.(*V) + if !ok { + var zero V + return zero, errTypeError + } + return *ptr, nil + } + + v, ok := bv.val.(V) + if !ok { + var zero V + return zero, errTypeError + } + return v, nil +} + +// Has describes whether or not the bag contains the given key +func (b Bag) Has(k string) bool { + _, ok := b[k] + return ok +} + +type bagged struct { + val interface{} + ref bool +} diff --git a/bag/bag_test.go b/bag/bag_test.go new file mode 100644 index 0000000..c0c569b --- /dev/null +++ b/bag/bag_test.go @@ -0,0 +1,92 @@ +package bag + +import ( + "errors" + "testing" +) + +func TestEmpty(t *testing.T) { + b := make(Bag) + + _, err := Get[string](b, "foo") + if !errors.Is(err, errNotFound) { + t.Fatalf("expected not found error, saw %v", err) + } +} + +func TestAdd(t *testing.T) { + b := make(Bag) + + if !Add(b, "foo", "bar") { + t.Fatalf("weird add failure") + } + + foo, err := Get[string](b, "foo") + if err != nil { + t.Fatalf("unexpected error: %v", err) + } + if foo != "bar" { + t.Fatalf("unexpected value: %v", foo) + } + + if Add(b, "foo", "again") { + t.Fatalf("weird add success") + } + + _, err = Get[int](b, "foo") + if !errors.Is(err, errTypeError) { + t.Fatalf("unexpected error: %v", err) + } +} + +func TestRef(t *testing.T) { + b := make(Bag) + + name := "Jordan" + if !Ref(b, "name", &name) { + t.Fatal("ref failed") + } + + _, err := Get[*string](b, "name") + if !errors.Is(err, errTypeError) { + t.Fatal("retrieving pointer for ref did not fail") + } + + _, err = Get[int](b, "name") + if !errors.Is(err, errTypeError) { + t.Fatal("retrieving value of differing type for ref did not fail") + } + + readName, err := Get[string](b, "name") + if err != nil { + t.Fatalf("unexpected error: %v", err) + } + if readName != "Jordan" { + t.Fatalf("unexpected value: %v", readName) + } + + name = "Jordan Orelli" + if readName != "Jordan" { + t.Fatalf("unexpected value: %v", readName) + } + readName, err = Get[string](b, "name") + if err != nil { + t.Fatalf("unexpected error: %v", err) + } + if readName != "Jordan Orelli" { + t.Fatalf("unexpected value: %v", readName) + } + + fn := func(s *string) { + *s = "mute" + } + fn(&name) + + readName, err = Get[string](b, "name") + if err != nil { + t.Fatalf("unexpected error: %v", err) + } + if readName != "mute" { + t.Fatalf("unexpected value: %v", readName) + } +}