redo the bit reader shit

master
Jordan Orelli 9 years ago
parent 39bfe68f79
commit 8010c47cab

117
bit.go

@ -1,117 +0,0 @@
package main
import (
"fmt"
"io"
)
// a bitBuffer is a buffer of raw data that is not necessarily byte-aligned,
// for performing bitwise reads and manipulation. The bulk of this source is
// adapted from the bitReader defined in the standard library bzip2 package.
type bitBuffer struct {
source []byte // the source data to be read. This slice is never modified
index int // position of the last byte read out of source
scratch uint64 // scratch register of bits worthy of manipulation
bits uint // bit width of the scratch register
err error // stored error
}
func newBitBuffer(buf []byte) *bitBuffer {
return &bitBuffer{source: buf}
}
func (b *bitBuffer) readBits(bits uint) uint64 {
for bits > b.bits {
if b.index >= len(b.source) {
b.err = io.ErrUnexpectedEOF
return 0
}
b.scratch <<= 8
b.scratch |= uint64(b.source[b.index])
b.index += 1
b.bits += 8
}
// b.scratch looks like this (assuming that b.bits = 14 and bits = 6):
// Bit: 111111
// 5432109876543210
//
// (6 bits, the desired output)
// |-----|
// V V
// 0101101101001110
// ^ ^
// |------------|
// b.bits (num valid bits)
//
// This the next line right shifts the desired bits into the
// least-significant places and masks off anything above.
n := (b.scratch >> (b.bits - bits)) & ((1 << bits) - 1)
b.bits -= bits
return n
}
func (b *bitBuffer) readByte() (out byte) {
if b.bits == 0 {
if b.index >= len(b.source) {
b.err = io.ErrUnexpectedEOF
return
}
out = b.source[b.index]
b.index += 1
return
}
return byte(b.readBits(8))
}
func (b *bitBuffer) readBytes(n int) []byte {
if b.bits == 0 {
b.index += n
if b.index > len(b.source) {
b.err = io.ErrUnexpectedEOF
return []byte{}
}
return b.source[b.index-n : b.index]
}
buf := make([]byte, n)
for i := 0; i < n; i++ {
buf[i] = byte(b.readBits(8))
}
return buf
}
// readVarUint reads a variable-length uint32, encoded with some scheme that I
// can't find a standard for. first two bits are a length prefix, followed by a
// 4, 8, 12, or 32-bit wide uint.
func (b *bitBuffer) readVarUint() uint32 {
switch b.readBits(2) {
case 0:
return uint32(b.readBits(4))
case 1:
return uint32(b.readBits(4) | b.readBits(4)<<4)
case 2:
return uint32(b.readBits(4) | b.readBits(8)<<4)
case 3:
return uint32(b.readBits(4) | b.readBits(28)<<4)
default:
// this switch is already exhaustive, the compiler just can't tell.
panic(fmt.Sprintf("invalid varuint prefix"))
}
}
// readVarInt reads a varint-encoded value off of the front of the buffer. This
// is the varint encoding used in protobuf. That is: each byte utilizes a 7-bit
// group. the msb of each byte indicates whether there are more bytes to
// follow.
func (b *bitBuffer) readVarInt() uint64 {
var x, n uint64
for shift := uint(0); shift < 64; shift += 7 {
n = b.readBits(8)
if n < 0x80 {
return x | n<<shift
}
x |= n &^ 0x80 << shift
}
b.err = fmt.Errorf("readVarInt never saw the end of varint")
return 0
}

@ -0,0 +1,52 @@
package bit
import (
"bufio"
"bytes"
"io"
)
// bit.Reader allows for bit-level reading of arbitrary source data. This is
// based on the bit reader found in the standard library's bzip2 package.
// https://golang.org/src/compress/bzip2/bit_reader.go
type Reader struct {
src io.ByteReader // source of data
n uint64 // bit buffer
bits uint // number of valid bits in n
err error // stored error
}
// NewReader creates a new bit.Reader for any arbitrary reader.
func NewReader(r io.Reader) *Reader {
br, ok := r.(io.ByteReader)
if !ok {
br = bufio.NewReader(r)
}
return &Reader{src: br}
}
// NewByteReader creates a bit.Reader for a static slice of bytes. It's just
// using a bytes.Reader internally.
func NewBytesReader(b []byte) *Reader {
return NewReader(bytes.NewReader(b))
}
// ReadBits reads the given number of bits and returns them in the
// least-significant part of a uint64.
func (r *Reader) ReadBits(bits uint) (n uint64) {
for bits > r.bits {
b, err := r.src.ReadByte()
if err != nil {
r.err = err
return 0
}
r.n <<= 8
r.n |= uint64(b)
r.bits += 8
}
n = (r.n >> (r.bits - bits)) & ((1 << bits) - 1)
r.bits -= bits
return
}
func (r *Reader) Err() error { return r.err }

@ -0,0 +1,52 @@
package bit
import (
"testing"
"github.com/stretchr/testify/assert"
)
var (
// 1000 1011 1010 1101 1111 0000 0000 1101
badFood = []byte{0x8b, 0xad, 0xf0, 0x0d}
)
func TestRead(t *testing.T) {
assert := assert.New(t)
var r *Reader
// aligned reading
r = NewBytesReader(badFood)
assert.Equal(uint64(0x8b), r.ReadBits(8))
assert.Equal(uint64(0xad), r.ReadBits(8))
assert.Equal(uint64(0xf0), r.ReadBits(8))
assert.Equal(uint64(0x0d), r.ReadBits(8))
// misaligned reading
r = NewBytesReader(badFood)
// 1000 1011 1010 1101 1111 0000 0000 1101
// ^
assert.Equal(uint64(0x01), r.ReadBits(1))
// 1000 1011 1010 1101 1111 0000 0000 1101
// ^-^
assert.Equal(uint64(0), r.ReadBits(3))
// 1000 1011 1010 1101 1111 0000 0000 1101
// ^--^
assert.Equal(uint64(0xb), r.ReadBits(4))
// 1000 1011 1010 1101 1111 0000 0000 1101
// ^----^
assert.Equal(uint64(0x15), r.ReadBits(5))
// 1000 1011 1010 1101 1111 0000 0000 1101
// ^---------^
assert.Equal(uint64(0x17c), r.ReadBits(9))
// 1000 1011 1010 1101 1111 0000 0000 1101
// ^----------^
assert.Equal(uint64(0xd), r.ReadBits(10))
}

@ -1,83 +0,0 @@
package main
import (
"io"
"testing"
)
func TestBits(t *testing.T) {
buf := []byte{0x00}
bb := newBitBuffer(buf)
for i := 0; i < 8; i++ {
if bb.readBits(1) != 0x00 {
t.Error("hahha what")
}
if bb.err != nil {
t.Errorf("oh weird error: %v", bb.err)
}
}
if bb.readBits(1) != 0x00 {
t.Error("hahha what")
}
if bb.err != io.ErrUnexpectedEOF {
t.Errorf("oh weird error: %v", bb.err)
}
buf = []byte{0x10}
bb = newBitBuffer(buf)
if n := bb.readBits(4); n != 0x01 {
t.Errorf("shit. wanted %v, got %v", 0x01, n)
}
if n := bb.readBits(4); n != 0x00 {
t.Errorf("poop. wanted %v, got %v", 0x00, n)
}
if bb.err != nil {
t.Errorf("fuck")
}
buf = []byte{0x4}
bb = newBitBuffer(buf)
u := bb.readVarUint()
if u != 1 {
t.Errorf("feck. wanted %v, got %v", 1, u)
}
if bb.readBits(2); bb.err != nil {
t.Errorf("shouldn't have an error yet")
}
if bb.readBits(1); bb.err == nil {
t.Errorf("we should be at EOF now")
}
buf = []byte{0x3c}
bb = newBitBuffer(buf)
u = bb.readVarUint()
if u != 15 {
t.Errorf("feck. wanted %v, got %v", 15, u)
}
if bb.readBits(2); bb.err != nil {
t.Errorf("shouldn't have an error yet")
}
if bb.readBits(1); bb.err == nil {
t.Errorf("we should be at EOF now")
}
buf = []byte{0x48, 0x10}
// 0100 1000 0001 0000
// 01 - prefix bits. indicates length 12
// 00 10 - least significant four
// 00 0001 00 - most significant eight
// 00 - not read.
//
// 0000 0100 0010 - actual value (0x42, or 66)
bb = newBitBuffer(buf)
u = bb.readVarUint()
if u != 66 {
t.Errorf("feck. wanted %v, got %v", 66, u)
}
if bb.readBits(2); bb.err != nil {
t.Errorf("shouldn't have an error yet")
}
if bb.readBits(1); bb.err == nil {
t.Errorf("we should be at EOF now")
}
}

19
glide.lock generated

@ -1,10 +1,23 @@
hash: 9208b00dc0b7be6b23958e0642efb66c8584c0be1d97fd81a9de0df643ac2877
updated: 2016-07-31T18:52:05.106198608-04:00
hash: 3f7fbcf64c0749e5f78dc8188c594871ab368257d8a05f238cb2ff901d76f8f8
updated: 2016-08-01T20:29:43.617478897-04:00
imports:
- name: github.com/golang/protobuf
version: c3cefd437628a0b7d31b34fe44b3a7a540e98527
subpackages:
- proto
- protoc-gen-go/descriptor
- name: github.com/golang/snappy
version: d9eb7a3d35ec988b8585d4a0068e462c27d28380
testImports: []
- name: github.com/stretchr/testify
version: f390dcf405f7b83c997eac1b06768bb9f44dec18
subpackages:
- assert
testImports:
- name: github.com/davecgh/go-spew
version: 5215b55f46b2b919f50a1df0eaa5886afe4e3b3d
subpackages:
- spew
- name: github.com/pmezard/go-difflib
version: d8ed2627bdf02c080bf22230dbb337003b7aba2d
subpackages:
- difflib

@ -2,3 +2,5 @@ package: github.com/jordanorelli/hyperstone
import:
- package: github.com/golang/protobuf
- package: github.com/golang/snappy
- package: github.com/stretchr/testify
version: ^1.1.3

@ -42,18 +42,7 @@ func (m *message) check(dump bool) error {
}
if dump {
shit := packet.GetData()[:4]
bb := newBitBuffer(packet.GetData())
type T struct {
t int32
}
var v T
v.t = int32(bb.readVarUint())
if bb.err != nil {
fmt.Printf("packet error: %v\n", bb.err)
} else {
fmt.Printf("{in: %d out: %d data: %v shit: %x}\n", packet.GetSequenceIn(), packet.GetSequenceOutAck(), v, shit)
}
fmt.Println("I broke packet dumping.")
}
return nil
}

Loading…
Cancel
Save