diff --git a/auth.go b/auth.go index cec0c3d..d2fd2ab 100644 --- a/auth.go +++ b/auth.go @@ -12,3 +12,5 @@ type AuthRequest struct { func (a *AuthRequest) Kind() string { return "auth" } + +func init() { registerRequestType(func() request { return new(AuthRequest) }) } diff --git a/client.go b/client.go index f397bb0..c4bdaff 100644 --- a/client.go +++ b/client.go @@ -344,7 +344,7 @@ func (c *Client) getNote(args []string) { c.err("that doesn't look like an int: %v", err) return } - res, err := c.sendRequest(GetNoteRequest(id)) + res, err := c.sendRequest(GetNoteRequest{Id: id}) if err != nil { c.err("couldn't request note: %v", err) return diff --git a/error.go b/error.go index 7f4be1b..b7ada1b 100644 --- a/error.go +++ b/error.go @@ -11,3 +11,5 @@ func (e ErrorDoc) Kind() string { func (e ErrorDoc) Error() string { return string(e) } + +func init() { registerRequestType(func() request { return new(ErrorDoc) }) } diff --git a/key.go b/key.go index 8b8bb88..d15f8c1 100644 --- a/key.go +++ b/key.go @@ -126,6 +126,8 @@ func (k KeyRequest) Nick() string { return string(k) } +func init() { registerRequestType(func() request { return new(KeyRequest) }) } + type KeyResponse struct { Nick string Key rsa.PublicKey @@ -134,3 +136,5 @@ type KeyResponse struct { func (k KeyResponse) Kind() string { return "key-response" } + +func init() { registerRequestType(func() request { return new(KeyResponse) }) } diff --git a/message.go b/message.go index 632c204..14c4a3a 100644 --- a/message.go +++ b/message.go @@ -13,6 +13,8 @@ func (m Message) Kind() string { return "send-message" } +func init() { registerRequestType(func() request { return new(Message) }) } + type ListMessages struct { N int } @@ -21,6 +23,8 @@ func (l ListMessages) Kind() string { return "list-messages" } +func init() { registerRequestType(func() request { return new(ListMessages) }) } + type ListMessagesResponseItem struct { Id int Key []byte @@ -33,6 +37,8 @@ func (l ListMessagesResponse) Kind() string { return "list-messages-response" } +func init() { registerRequestType(func() request { return new(ListMessagesResponse) }) } + type GetMessage struct { Id int } @@ -40,3 +46,5 @@ type GetMessage struct { func (g GetMessage) Kind() string { return "get-message" } + +func init() { registerRequestType(func() request { return new(GetMessage) }) } diff --git a/meta.go b/meta.go deleted file mode 100644 index 016c4a4..0000000 --- a/meta.go +++ /dev/null @@ -1,9 +0,0 @@ -package main - -import () - -type Meta string - -func (m Meta) Kind() string { - return "meta" -} diff --git a/note.go b/note.go index 1cdc003..098f43e 100644 --- a/note.go +++ b/note.go @@ -15,12 +15,16 @@ func decodeInt(s string) (int, error) { return numEncoder.DecodeInt(s) } -type GetNoteRequest int +type GetNoteRequest struct { + Id int +} func (g GetNoteRequest) Kind() string { return "get-note" } +func init() { registerRequestType(func() request { return new(GetNoteRequest) }) } + type Note struct { Title string Body []byte @@ -32,6 +36,8 @@ type EncryptedNote struct { Body []byte } +func init() { registerRequestType(func() request { return new(EncryptedNote) }) } + func (n EncryptedNote) Kind() string { return "note" } @@ -53,6 +59,8 @@ func (l ListNotes) Kind() string { return "list-notes-request" } +func init() { registerRequestType(func() request { return new(ListNotes) }) } + type ListNotesResponseItem struct { Id int Key []byte @@ -64,3 +72,5 @@ type ListNotesResponse []ListNotesResponseItem func (l ListNotesResponse) Kind() string { return "list-notes-response" } + +func init() { registerRequestType(func() request { return new(ListNotesResponse) }) } diff --git a/request.go b/request.go index 5790438..b364362 100644 --- a/request.go +++ b/request.go @@ -4,21 +4,46 @@ import ( "encoding/json" "fmt" "io" - "net" ) +var requestTypes = make(map[string]func() request, 32) + +func registerRequestType(fn func() request) { + r := fn() + if _, ok := requestTypes[r.Kind()]; ok { + panic("request type already registered") + } + requestTypes[r.Kind()] = fn +} + type Envelope struct { Id int `json:"id"` Kind string `json:"kind"` Body json.RawMessage `json:"body"` } +func (e Envelope) Open() (request, error) { + fn, ok := requestTypes[e.Kind] + if !ok { + return nil, fmt.Errorf("unknown request type: %s", e.Kind) + } + r := fn() + if err := json.Unmarshal(e.Body, r); err != nil { + return nil, fmt.Errorf("failed to unmarshal json in note open: %v", err) + } + return r, nil +} + +// Bool is used to acknowledge that a request has been received and that there +// is no useful information for the user. type Bool bool func (b Bool) Kind() string { return "bool" } +func init() { registerRequestType(func() request { return new(Bool) }) } + type request interface { Kind() string } @@ -55,21 +80,3 @@ func writeRequest(w io.Writer, id int, r request) error { } return nil } - -func decodeRequest(conn net.Conn) (request, error) { - d := json.NewDecoder(conn) - var env Envelope - if err := d.Decode(&env); err != nil { - return nil, fmt.Errorf("unable to decode client request: %v", err) - } - switch env.Kind { - case "auth": - var auth AuthRequest - if err := json.Unmarshal(env.Body, &auth); err != nil { - return nil, fmt.Errorf("unable to unmarshal auth request: %v", err) - } - return &auth, nil - default: - return nil, fmt.Errorf("unknown request type: %s", env.Kind) - } -} diff --git a/request_test.go b/request_test.go new file mode 100644 index 0000000..44ad4f3 --- /dev/null +++ b/request_test.go @@ -0,0 +1,87 @@ +package main + +import ( + "crypto/rand" + "crypto/rsa" + "reflect" + "testing" +) + +var requests = []request{ + &Message{ + Key: []byte("hmm maybe this should be checked."), + From: []byte("bob"), + To: "alice", + Text: []byte("this is my great message"), + }, + &ListMessages{N: 10}, + &ListMessagesResponse{ + {0, []byte("key"), []byte("from")}, + {1, []byte("key"), []byte("from")}, + {2, []byte("key"), []byte("from")}, + {3, []byte("key"), []byte("from")}, + }, + &GetMessage{Id: 8}, + &GetNoteRequest{Id: 12}, + &EncryptedNote{ + Key: []byte("this is not a key"), + Title: []byte("likewise, this is not a title's ciphertext"), + Body: []byte("nor is this the ciphertext of an encrypted note"), + }, + &ListNotes{N: 10}, + &ListNotesResponse{ + {0, []byte("key"), []byte("title")}, + {1, []byte("key"), []byte("title")}, + {2, []byte("key"), []byte("title")}, + {3, []byte("key"), []byte("title")}, + }, +} + +func TestEnvelope(t *testing.T) { + t.Logf("envelope test") + + tru, falz := Bool(true), Bool(false) + requests = append(requests, &tru, &falz) + + key, err := rsa.GenerateKey(rand.Reader, 128) + if err != nil { + t.Errorf("unable to create key for testing: %v", err) + } else { + requests = append(requests, &AuthRequest{ + Nick: "nick", + Key: &key.PublicKey, + }) + } + + e := ErrorDoc("this is my error document.") + requests = append(requests, &e) + + r := KeyRequest("bob") + requests = append(requests, &r) + + key2, err := rsa.GenerateKey(rand.Reader, 128) + if err != nil { + t.Errorf("unable to create key for testing: %v", err) + } else { + requests = append(requests, &KeyResponse{ + Nick: "nick", + Key: key2.PublicKey, + }) + } + + for i, r := range requests { + t.Logf("wrapping request %d of type %v", i, reflect.TypeOf(r)) + e, err := wrapRequest(i, r) + if err != nil { + t.Errorf("unable to wrap %v request: %v", reflect.TypeOf(r), err) + continue + } + r2, err := e.Open() + if err != nil { + t.Errorf("unable to open envelope %d of kind %v: %v", e.Id, e.Kind, err) + } + if !reflect.DeepEqual(r, r2) { + t.Errorf("request didn't envelope and unenvelope correctly: %v != %v", r, r2) + } + } +} diff --git a/server.go b/server.go index 9e274a0..4029f49 100644 --- a/server.go +++ b/server.go @@ -138,7 +138,7 @@ func (s *serverConnection) handleGetNoteRequest(requestId int, body json.RawMess if err := json.Unmarshal(body, &req); err != nil { return fmt.Errorf("bad getnote request: %v", err) } - key := fmt.Sprintf("notes/%s", encodeInt(int(req))) + key := fmt.Sprintf("notes/%s", encodeInt(int(req.Id))) b, err := s.db.Get([]byte(key), nil) if err != nil { return fmt.Errorf("couldn't retrieve note: %v", err)