i'm going nuts

master
Jordan Orelli 4 years ago
parent cbee897739
commit 9bf04dd6c8

@ -9,6 +9,8 @@ import (
"net/http" "net/http"
"os" "os"
"strconv" "strconv"
"sync"
"time"
"github.com/gorilla/websocket" "github.com/gorilla/websocket"
"github.com/jordanorelli/astro-domu/internal/errors" "github.com/jordanorelli/astro-domu/internal/errors"
@ -17,10 +19,13 @@ import (
) )
type Server struct { type Server struct {
sync.Mutex
*blammo.Log *blammo.Log
Host string Host string
Port int Port int
http *http.Server
lastSessionID int lastSessionID int
sessions map[int]*session
} }
func (s *Server) Start() error { func (s *Server) Start() error {
@ -56,12 +61,43 @@ func (s *Server) Start() error {
} }
func (s *Server) runHTTPServer(lis net.Listener) { func (s *Server) runHTTPServer(lis net.Listener) {
err := http.Serve(lis, s) zzz := http.Server{
Handler: s,
}
s.http = &zzz
err := zzz.Serve(lis)
if err != nil && !errors.Is(err, http.ErrServerClosed) { if err != nil && !errors.Is(err, http.ErrServerClosed) {
s.Error("error in http.Serve: %v", err) s.Error("error in http.Serve: %v", err)
} }
} }
func (s *Server) createSession(conn *websocket.Conn) *session {
s.Lock()
defer s.Unlock()
s.lastSessionID++
sn := &session{
Log: s.Log.Child("sessions").Child(strconv.Itoa(s.lastSessionID)),
id: s.lastSessionID,
conn: conn,
outbox: make(chan wire.Response),
done: make(chan chan struct{}, 1),
}
if s.sessions == nil {
s.sessions = make(map[int]*session)
}
s.sessions[sn.id] = sn
return sn
}
func (s *Server) dropSession(sn *session) {
s.Lock()
defer s.Unlock()
close(sn.done)
delete(s.sessions, sn.id)
}
func (s *Server) ServeHTTP(w http.ResponseWriter, r *http.Request) { func (s *Server) ServeHTTP(w http.ResponseWriter, r *http.Request) {
upgrader := websocket.Upgrader{} upgrader := websocket.Upgrader{}
@ -78,17 +114,10 @@ func (s *Server) ServeHTTP(w http.ResponseWriter, r *http.Request) {
} }
}() }()
ctx, cancel := context.WithCancel(context.Background()) sn := s.createSession(conn)
defer cancel() defer s.dropSession(sn)
s.lastSessionID++ go sn.run()
sn := session{
Log: s.Log.Child("sessions").Child(strconv.Itoa(s.lastSessionID)),
id: s.lastSessionID,
conn: conn,
outbox: make(chan wire.Response),
}
go sn.pump(ctx)
for { for {
t, r, err := conn.NextReader() t, r, err := conn.NextReader()
@ -121,3 +150,22 @@ func (s *Server) ServeHTTP(w http.ResponseWriter, r *http.Request) {
} }
} }
} }
func (s *Server) Shutdown() {
s.Info("shutting down")
s.http.Shutdown(context.Background())
s.Lock()
zzz := make([]chan struct{}, 0, len(s.sessions))
for id, sn := range s.sessions {
s.Info("sending done signal to session: %d", id)
c := make(chan struct{})
zzz = append(zzz, c)
sn.done <- c
}
s.Unlock()
for _, c := range zzz {
<-c
}
time.Sleep(time.Second)
}

@ -1,7 +1,6 @@
package server package server
import ( import (
"context"
"encoding/json" "encoding/json"
"fmt" "fmt"
"time" "time"
@ -16,19 +15,30 @@ type session struct {
id int id int
conn *websocket.Conn conn *websocket.Conn
outbox chan wire.Response outbox chan wire.Response
done chan chan struct{}
} }
// pump is the session send loop. Pump should pump the session's outbox // pump is the session send loop. Pump should pump the session's outbox
// messages to the underlying connection until the context is closed. // messages to the underlying connection until the context is closed.
func (sn *session) pump(ctx context.Context) { func (sn *session) run() {
for { for {
select { select {
case res := <-sn.outbox: case res := <-sn.outbox:
if err := sn.sendResponse(res); err != nil { if err := sn.sendResponse(res); err != nil {
sn.Error(err.Error()) sn.Error(err.Error())
} }
case <-ctx.Done(): case c, ok := <-sn.done:
sn.Info("parent context done, shutting down write pump") sn.Info("saw done signal: %t", ok)
if ok {
sn.Info("sending close frame")
msg := websocket.FormatCloseMessage(websocket.CloseNormalClosure, "")
if err := sn.conn.WriteMessage(websocket.CloseMessage, msg); err != nil {
sn.Error("failed to write close message: %v", err)
} else {
sn.Info("sent close frame")
}
close(c)
}
return return
} }
} }

@ -43,6 +43,8 @@ func main() {
sig := make(chan os.Signal, 1) sig := make(chan os.Signal, 1)
signal.Notify(sig, os.Interrupt) signal.Notify(sig, os.Interrupt)
<-sig <-sig
s.Shutdown()
default: default:
exit.WithMessage(1, "supported options are [client|server]") exit.WithMessage(1, "supported options are [client|server]")
} }

Loading…
Cancel
Save