diff --git a/internal/server/server.go b/internal/server/server.go index 0ac88af..b4fd32e 100644 --- a/internal/server/server.go +++ b/internal/server/server.go @@ -17,13 +17,15 @@ import ( ) type Server struct { - sync.Mutex *blammo.Log - Host string - Port int - http *http.Server - lastSessionID int - sessions map[int]*session + Host string + Port int + http *http.Server + + sync.Mutex + lastSessionID int + sessions map[int]*session + waitOnSessions sync.WaitGroup } func (s *Server) Start() error { @@ -77,23 +79,31 @@ func (s *Server) createSession(conn *websocket.Conn) *session { sn := &session{ Log: s.Log.Child("sessions").Child(strconv.Itoa(s.lastSessionID)), id: s.lastSessionID, + start: time.Now(), conn: conn, outbox: make(chan wire.Response), - done: make(chan chan struct{}, 1), + done: make(chan bool, 1), } if s.sessions == nil { s.sessions = make(map[int]*session) } + s.waitOnSessions.Add(1) s.sessions[sn.id] = sn + s.Info("created session %d, %d sessions active", sn.id, len(s.sessions)) return sn } +// dropSession removes a session from the server. This should only be called as +// a result of the connection's read loop terminating func (s *Server) dropSession(sn *session) { s.Lock() defer s.Unlock() close(sn.done) delete(s.sessions, sn.id) + s.waitOnSessions.Add(-1) + + s.Info("dropped session %d after %v time connected, %d sessions active", sn.id, time.Since(sn.start), len(s.sessions)) } func (s *Server) ServeHTTP(w http.ResponseWriter, r *http.Request) { @@ -105,35 +115,48 @@ func (s *Server) ServeHTTP(w http.ResponseWriter, r *http.Request) { return } - defer func() { - s.Info("closing connection") - if err := conn.Close(); err != nil { - s.Error("error closing connection: %v", err) - } - }() - sn := s.createSession(conn) - defer s.dropSession(sn) - go sn.run() sn.read() + s.dropSession(sn) + + sn.Info("closing connection") + if err := conn.Close(); err != nil { + s.Error("error closing connection: %v", err) + } } func (s *Server) Shutdown() { - s.Info("shutting down") - s.http.Shutdown(context.Background()) + s.Info("starting shutdown procedure") + + var wg sync.WaitGroup + wg.Add(2) + go func() { + defer wg.Done() + s.Info("shutting down http server") + if err := s.http.Shutdown(context.Background()); err != nil { + s.Error("error shutting down http server: %v", err) + } else { + s.Info("http server has shut down") + } + }() - s.Lock() - zzz := make([]chan struct{}, 0, len(s.sessions)) - for id, sn := range s.sessions { - s.Info("sending done signal to session: %d", id) - c := make(chan struct{}) - zzz = append(zzz, c) - sn.done <- c - } - s.Unlock() - for _, c := range zzz { - <-c - } - time.Sleep(time.Second) + go func() { + defer wg.Done() + + s.Info("broadcasting shutdown to all active sessions") + + s.Lock() + for id, sn := range s.sessions { + s.Info("sending done signal to session: %d", id) + sn.done <- true + } + s.Unlock() + + s.Info("waiting on connected sessions to shut down") + s.waitOnSessions.Wait() + s.Info("all sessions have shut down") + }() + wg.Wait() + s.Info("shutdown procedure complete") } diff --git a/internal/server/session.go b/internal/server/session.go index ad302cd..29c1d89 100644 --- a/internal/server/session.go +++ b/internal/server/session.go @@ -13,9 +13,10 @@ import ( type session struct { *blammo.Log id int + start time.Time conn *websocket.Conn outbox chan wire.Response - done chan chan struct{} + done chan bool } // run is the session run loop. @@ -26,9 +27,9 @@ func (sn *session) run() { if err := sn.sendResponse(res); err != nil { sn.Error(err.Error()) } - case c, ok := <-sn.done: - sn.Info("saw done signal: %t", ok) - if ok { + case sendCloseFrame := <-sn.done: + sn.Info("saw done signal") + if sendCloseFrame { sn.Info("sending close frame") msg := websocket.FormatCloseMessage(websocket.CloseNormalClosure, "") if err := sn.conn.WriteMessage(websocket.CloseMessage, msg); err != nil { @@ -36,7 +37,6 @@ func (sn *session) run() { } else { sn.Info("sent close frame") } - close(c) } return }