diff --git a/internal/server/server.go b/internal/server/server.go index d6e13ba..101d9ee 100644 --- a/internal/server/server.go +++ b/internal/server/server.go @@ -9,6 +9,8 @@ import ( "net/http" "os" "strconv" + "sync" + "time" "github.com/gorilla/websocket" "github.com/jordanorelli/astro-domu/internal/errors" @@ -17,10 +19,13 @@ import ( ) type Server struct { + sync.Mutex *blammo.Log Host string Port int + http *http.Server lastSessionID int + sessions map[int]*session } func (s *Server) Start() error { @@ -56,12 +61,43 @@ func (s *Server) Start() error { } func (s *Server) runHTTPServer(lis net.Listener) { - err := http.Serve(lis, s) + zzz := http.Server{ + Handler: s, + } + s.http = &zzz + err := zzz.Serve(lis) if err != nil && !errors.Is(err, http.ErrServerClosed) { s.Error("error in http.Serve: %v", err) } } +func (s *Server) createSession(conn *websocket.Conn) *session { + s.Lock() + defer s.Unlock() + + s.lastSessionID++ + sn := &session{ + Log: s.Log.Child("sessions").Child(strconv.Itoa(s.lastSessionID)), + id: s.lastSessionID, + conn: conn, + outbox: make(chan wire.Response), + done: make(chan chan struct{}, 1), + } + if s.sessions == nil { + s.sessions = make(map[int]*session) + } + s.sessions[sn.id] = sn + return sn +} + +func (s *Server) dropSession(sn *session) { + s.Lock() + defer s.Unlock() + + close(sn.done) + delete(s.sessions, sn.id) +} + func (s *Server) ServeHTTP(w http.ResponseWriter, r *http.Request) { upgrader := websocket.Upgrader{} @@ -78,17 +114,10 @@ func (s *Server) ServeHTTP(w http.ResponseWriter, r *http.Request) { } }() - ctx, cancel := context.WithCancel(context.Background()) - defer cancel() + sn := s.createSession(conn) + defer s.dropSession(sn) - s.lastSessionID++ - sn := session{ - Log: s.Log.Child("sessions").Child(strconv.Itoa(s.lastSessionID)), - id: s.lastSessionID, - conn: conn, - outbox: make(chan wire.Response), - } - go sn.pump(ctx) + go sn.run() for { t, r, err := conn.NextReader() @@ -121,3 +150,22 @@ func (s *Server) ServeHTTP(w http.ResponseWriter, r *http.Request) { } } } + +func (s *Server) Shutdown() { + s.Info("shutting down") + s.http.Shutdown(context.Background()) + + s.Lock() + zzz := make([]chan struct{}, 0, len(s.sessions)) + for id, sn := range s.sessions { + s.Info("sending done signal to session: %d", id) + c := make(chan struct{}) + zzz = append(zzz, c) + sn.done <- c + } + s.Unlock() + for _, c := range zzz { + <-c + } + time.Sleep(time.Second) +} diff --git a/internal/server/session.go b/internal/server/session.go index 0500f37..aadd184 100644 --- a/internal/server/session.go +++ b/internal/server/session.go @@ -1,7 +1,6 @@ package server import ( - "context" "encoding/json" "fmt" "time" @@ -16,19 +15,30 @@ type session struct { id int conn *websocket.Conn outbox chan wire.Response + done chan chan struct{} } // pump is the session send loop. Pump should pump the session's outbox // messages to the underlying connection until the context is closed. -func (sn *session) pump(ctx context.Context) { +func (sn *session) run() { for { select { case res := <-sn.outbox: if err := sn.sendResponse(res); err != nil { sn.Error(err.Error()) } - case <-ctx.Done(): - sn.Info("parent context done, shutting down write pump") + case c, ok := <-sn.done: + sn.Info("saw done signal: %t", ok) + if ok { + sn.Info("sending close frame") + msg := websocket.FormatCloseMessage(websocket.CloseNormalClosure, "") + if err := sn.conn.WriteMessage(websocket.CloseMessage, msg); err != nil { + sn.Error("failed to write close message: %v", err) + } else { + sn.Info("sent close frame") + } + close(c) + } return } } diff --git a/main.go b/main.go index 186b464..1335419 100644 --- a/main.go +++ b/main.go @@ -43,6 +43,8 @@ func main() { sig := make(chan os.Signal, 1) signal.Notify(sig, os.Interrupt) <-sig + s.Shutdown() + default: exit.WithMessage(1, "supported options are [client|server]") }