diff --git a/values.go b/values.go index 8664e1c..414e754 100644 --- a/values.go +++ b/values.go @@ -19,7 +19,14 @@ type value interface { } func readValue(r io.Reader) (value, error) { - br := bufio.NewReader(r) + var br *bufio.Reader + switch t := r.(type) { + case *bufio.Reader: + br = t + default: + br = bufio.NewReader(r) + } + line, err := br.ReadBytes('\n') switch err { case io.EOF: @@ -34,7 +41,7 @@ func readValue(r io.Reader) (value, error) { } if len(line) < 3 { - return nil, fmt.Errorf("unable to read redis protocol value: input is too small") + return nil, fmt.Errorf("unable to read redis protocol value: input %q is too small", line) } if line[len(line)-2] != '\r' { return nil, fmt.Errorf("unable to read redis protocol value: bad line terminator") @@ -49,6 +56,8 @@ func readValue(r io.Reader) (value, error) { return readInteger(line[1:]) case start_bulkstring: return readBulkString(line[1:], br) + case start_array: + return readArray(line[1:], br) default: return nil, fmt.Errorf("unable to read redis protocol value: illegal start character: %c", line[0]) } @@ -92,6 +101,7 @@ func readBulkString(prefix []byte, r io.Reader) (value, error) { return nil, fmt.Errorf("unable to read bulkstring in redis protocol: bad prefix: %v", err) } + n += 2 b := make([]byte, n) n_read, err := r.Read(b) switch err { @@ -105,5 +115,30 @@ func readBulkString(prefix []byte, r io.Reader) (value, error) { return nil, fmt.Errorf("unable to read bulkstring in redis protocol: read %d bytes, expected to read %d bytes", int64(n_read), n) } - return BulkString(b), nil + if len(b) < 2 { + return nil, fmt.Errorf("unable to read bulkstring in redis protocol: input %q is too short", b) + } + + return BulkString(b[:len(b)-2]), nil +} + +// ----------------------------------------------------------------------------------------- + +type Array []value + +func readArray(prefix []byte, r *bufio.Reader) (value, error) { + n, err := strconv.ParseInt(string(prefix), 10, 64) + if err != nil { + return nil, fmt.Errorf("unable to read array in redis protocol: bad prefix: %v", err) + } + + a := make(Array, n) + for i := int64(0); i < n; i++ { + v, err := readValue(r) + if err != nil { + return nil, fmt.Errorf("unable to read array value in redis protocol: %v", err) + } + a[i] = v + } + return a, nil } diff --git a/values_test.go b/values_test.go index 777f550..ccc5490 100644 --- a/values_test.go +++ b/values_test.go @@ -15,8 +15,24 @@ func (test valueTest) run(t *testing.T) { if err != nil { t.Errorf("valueTest error: %v", err) } - if v != test.out { - t.Errorf("expected %v, got %v", test.out, v) + switch expected := test.out.(type) { + case Array: + got, ok := v.(Array) + if !ok { + t.Errorf("expected Array value, got %v", v) + } + if len(got) != len(expected) { + t.Errorf("expected Array of length %d, saw Array of length %d", len(expected), len(got)) + } + for i := 0; i < len(got); i++ { + if got[i] != expected[i] { + t.Errorf("Array values do not match: got %v, expected %v", got, expected) + } + } + default: + if v != test.out { + t.Errorf("expected %v, got %v", test.out, v) + } } } @@ -61,6 +77,11 @@ var valueTests = []valueTest{ {":-12345\r\n+extra\r\n", Integer(-12345)}, {":9223372036854775807\r\n+extra\r\n", Integer(9223372036854775807)}, // int64 max {":-9223372036854775808\r\n+extra\r\n", Integer(-9223372036854775808)}, // int64 min + + {"*0\r\n", Array{}}, // is this a thing? I have no idea. + {"*1\r\n+hello\r\n", Array{String("hello")}}, + {"*2\r\n+one\r\n+two", Array{String("one"), String("two")}}, + {"*2\r\n$4\r\necho\r\n$5\r\nhello", Array{BulkString("echo"), BulkString("hello")}}, } func TestValues(t *testing.T) {