diff --git a/internal/wire/client.go b/internal/wire/client.go index 2c2642d..5303545 100644 --- a/internal/wire/client.go +++ b/internal/wire/client.go @@ -19,17 +19,17 @@ type Client struct { lastSeq int conn *websocket.Conn - // outbox is the set of requests that we'd like to send. The send loop will - // read off of this channel and write these values to the underlying - // websocket connection. - outbox chan Request + outbox chan *pending + resolved chan Response } // Dial dials the server specified by the client. The returned read-only // channel is a channel of responses from the server that are not replies to a // request sent by the client. func (c *Client) Dial() (<-chan Response, error) { - c.outbox = make(chan Request) + c.outbox = make(chan *pending) + c.resolved = make(chan Response) + dialer := websocket.Dialer{ HandshakeTimeout: 3 * time.Second, ReadBufferSize: 32 * 1024, @@ -58,15 +58,25 @@ func (c *Client) Dial() (<-chan Response, error) { return notifications, nil } -func (c *Client) Send(v Value) { - c.lastSeq++ +func (c *Client) Send(v Value) (Response, error) { d := 3 * time.Second timeout := time.NewTimer(d) + + done := make(chan struct{}) + p := pending{v: v, done: done} + select { - case c.outbox <- NewRequest(c.lastSeq, v): + case c.outbox <- &p: timeout.Stop() case <-timeout.C: - c.Error("send timed out after %v", d) + return Response{}, fmt.Errorf("send timed out after %v", d) + } + + select { + case <-done: + return p.res, p.err + case <-timeout.C: + return Response{}, fmt.Errorf("send timed out (2) after %v", d) } } @@ -92,33 +102,55 @@ func (c *Client) readLoop(notifications chan<- Response) { c.Child("read-frame").Info(string(b)) if res.Re <= 0 { notifications <- res + } else { + c.resolved <- res } } } func (c *Client) writeLoop(done chan bool) { + sent := make(map[int]*pending) + for { select { - case req := <-c.outbox: + case p := <-c.outbox: + c.lastSeq++ + req := NewRequest(c.lastSeq, p.v) + w, err := c.conn.NextWriter(websocket.TextMessage) if err != nil { - c.Error("unable to get a writer frame: %v", err) + p.err = fmt.Errorf("unable to get a writer frame: %w", err) + close(p.done) return } b, err := json.Marshal(req) if err != nil { - c.Error("unable to marshal outgoing response: %v", err) + p.err = fmt.Errorf("unable to marshal outgoing response: %w", err) + close(p.done) break } if _, err := w.Write(b); err != nil { - c.Error("failed to write payload: %v", err) + p.err = fmt.Errorf("failed to write payload: %w", err) + close(p.done) break } if err := w.Close(); err != nil { - c.Error("failed to close write frame: %v", err) + p.err = fmt.Errorf("failed to close write frame: %w", err) + close(p.done) break } c.Child("write-frame").Info(string(b)) + sent[c.lastSeq] = p + + case res := <-c.resolved: + p, ok := sent[res.Re] + if !ok { + c.Error("saw response for unknown seq %d") + break + } + p.res = res + close(p.done) + case shouldClose := <-done: if shouldClose { msg := websocket.FormatCloseMessage(websocket.CloseNormalClosure, "") @@ -136,5 +168,11 @@ func (c *Client) writeLoop(done chan bool) { } } } +} +type pending struct { + v Value + res Response + err error + done chan struct{} }