Added fix for pairing server race on port

This commit is contained in:
Samuel Hawksby-Robinson 2022-10-12 11:00:14 +01:00
parent f5666bfcb8
commit 6cdd24a048
7 changed files with 90 additions and 33 deletions

View File

@ -30,7 +30,8 @@ func (s *ConnectionParamsSuite) SetupSuite() {
s.Require().NoError(err)
bs := NewServer(&cert, defaultIP.String(), func(int) {})
bs.port = 1337
err = bs.SetPort(1337)
s.Require().NoError(err)
s.server = &PairingServer{
Server: bs,

View File

@ -407,6 +407,7 @@ func handlePairingReceive(ps *PairingServer) http.HandlerFunc {
if err != nil {
signal.SendLocalPairingEvent(Event{Type: EventTransferError, Error: err})
ps.logger.Error("ioutil.ReadAll(r.Body)", zap.Error(err))
return
}
signal.SendLocalPairingEvent(Event{Type: EventTransferSuccess})
@ -414,6 +415,7 @@ func handlePairingReceive(ps *PairingServer) http.HandlerFunc {
if err != nil {
signal.SendLocalPairingEvent(Event{Type: EventProcessError, Error: err})
ps.logger.Error("ps.PayloadManager.Receive(payload)", zap.Error(err))
return
}
signal.SendLocalPairingEvent(Event{Type: EventProcessSuccess})
}
@ -429,6 +431,7 @@ func handlePairingSend(ps *PairingServer) http.HandlerFunc {
if err != nil {
signal.SendLocalPairingEvent(Event{Type: EventTransferError, Error: err})
ps.logger.Error("w.Write(ps.PayloadManager.ToSend())", zap.Error(err))
return
}
signal.SendLocalPairingEvent(Event{Type: EventTransferSuccess})
}
@ -440,6 +443,7 @@ func challengeMiddleware(ps *PairingServer, next http.Handler) http.HandlerFunc
if err != nil {
ps.logger.Error("ps.cookieStore.Get(r, pairingStoreChallenge)", zap.Error(err))
http.Error(w, "error", http.StatusInternalServerError)
return
}
blocked, ok := s.Values[sessionBlocked].(bool)
@ -491,6 +495,7 @@ func handlePairingChallenge(ps *PairingServer) http.HandlerFunc {
s, err := ps.cookieStore.Get(r, sessionChallenge)
if err != nil {
ps.logger.Error("ps.cookieStore.Get(r, pairingStoreChallenge)", zap.Error(err))
return
}
var challenge []byte
@ -500,12 +505,14 @@ func handlePairingChallenge(ps *PairingServer) http.HandlerFunc {
_, err = rand.Read(challenge)
if err != nil {
ps.logger.Error("_, err = rand.Read(auth)", zap.Error(err))
return
}
s.Values[sessionChallenge] = challenge
err = s.Save(r, w)
if err != nil {
ps.logger.Error("err = s.Save(r, w)", zap.Error(err))
return
}
}
@ -513,6 +520,7 @@ func handlePairingChallenge(ps *PairingServer) http.HandlerFunc {
_, err = w.Write(challenge)
if err != nil {
ps.logger.Error("_, err = w.Write(challenge)", zap.Error(err))
return
}
}
}

44
server/ports.go Normal file
View File

@ -0,0 +1,44 @@
package server
import (
"fmt"
"sync"
)
type portManger struct {
port int
afterPortChanged func(port int)
portWait *sync.Mutex
}
func newPortManager(afterPortChanged func(int)) portManger {
pm := portManger{
afterPortChanged: afterPortChanged,
portWait: new(sync.Mutex),
}
pm.portWait.Lock()
return pm
}
func (p *portManger) SetPort(port int) error {
if port == 0 {
return fmt.Errorf("port can not be `0`, use ResetPort() instead")
}
p.port = port
p.portWait.Unlock()
return nil
}
func (p *portManger) ResetPort() {
if p.portWait.TryLock() {
p.port = 0
}
}
func (p *portManger) MustGetPort() int {
p.portWait.Lock()
defer p.portWait.Unlock()
return p.port
}

View File

@ -14,18 +14,22 @@ import (
)
type Server struct {
run bool
server *http.Server
logger *zap.Logger
cert *tls.Certificate
hostname string
port int
handlers HandlerPatternMap
afterPortChanged func(port int)
isRunning bool
server *http.Server
logger *zap.Logger
cert *tls.Certificate
hostname string
handlers HandlerPatternMap
portManger
}
func NewServer(cert *tls.Certificate, hostname string, afterPortChanged func(int)) Server {
return Server{logger: logutils.ZapLogger(), cert: cert, hostname: hostname, afterPortChanged: afterPortChanged}
return Server{
logger: logutils.ZapLogger(),
cert: cert,
hostname: hostname,
portManger: newPortManager(afterPortChanged),
}
}
func (s *Server) getHost() string {
@ -40,7 +44,7 @@ func (s *Server) listenAndServe() {
listener, err := tls.Listen("tcp", s.getHost(), cfg)
if err != nil {
s.logger.Error("failed to start server, retrying", zap.Error(err))
s.port = 0
s.ResetPort()
err = s.Start()
if err != nil {
s.logger.Error("server start failed, giving up", zap.Error(err))
@ -48,11 +52,17 @@ func (s *Server) listenAndServe() {
return
}
s.port = listener.Addr().(*net.TCPAddr).Port
if s.afterPortChanged != nil {
s.afterPortChanged(s.port)
err = s.SetPort(listener.Addr().(*net.TCPAddr).Port)
if err != nil {
s.logger.Error("failed to set Server.port", zap.Error(err))
return
}
s.run = true
if s.afterPortChanged != nil {
s.afterPortChanged(s.MustGetPort())
}
s.isRunning = true
err = s.server.Serve(listener)
if err != http.ErrServerClosed {
@ -64,11 +74,12 @@ func (s *Server) listenAndServe() {
return
}
s.run = false
s.isRunning = false
}
func (s *Server) resetServer() {
s.server = new(http.Server)
s.ResetPort()
}
func (s *Server) applyHandlers() {
@ -100,7 +111,7 @@ func (s *Server) Stop() error {
}
func (s *Server) ToForeground() {
if !s.run && (s.server != nil) {
if !s.isRunning && (s.server != nil) {
err := s.Start()
if err != nil {
s.logger.Error("server start failed during foreground transition", zap.Error(err))
@ -109,7 +120,7 @@ func (s *Server) ToForeground() {
}
func (s *Server) ToBackground() {
if s.run {
if s.isRunning {
err := s.Stop()
if err != nil {
s.logger.Error("server stop failed during background transition", zap.Error(err))

View File

@ -91,11 +91,7 @@ func (s *PairingServer) MakeConnectionParams() (*ConnectionParams, error) {
netIP = netIP4
}
if s.port == 0 {
return nil, fmt.Errorf("port is 0, listener is not yet set")
}
return NewConnectionParams(netIP, s.port, s.pk, s.ek, s.mode), nil
return NewConnectionParams(netIP, s.MustGetPort(), s.pk, s.ek, s.mode), nil
}
func (s *PairingServer) StartPairing() error {

View File

@ -4,7 +4,6 @@ import (
"crypto/ecdsa"
"crypto/rand"
"testing"
"time"
"github.com/stretchr/testify/suite"
)
@ -18,7 +17,7 @@ type PairingServerSuite struct {
TestPairingServerComponents
}
func (s *PairingServerSuite) SetupSuite() {
func (s *PairingServerSuite) SetupTest() {
s.SetupPairingServerComponents(s.T())
}
@ -39,9 +38,6 @@ func (s *PairingServerSuite) TestPairingServer_StartPairing() {
err = s.PS.StartPairing()
s.Require().NoError(err)
// Give time for the sever to be ready, hacky I know, I'll iron this out
time.Sleep(10 * time.Millisecond)
cp, err := s.PS.MakeConnectionParams()
s.Require().NoError(err)
@ -89,6 +85,7 @@ func (s *PairingServerSuite) TestPairingServer_StartPairing() {
// Reset the server's PayloadEncryptionManager
s.PS.PayloadManager.(*MockEncryptOnlyPayloadManager).ResetPayload()
s.PS.ResetPort()
}
}
@ -102,9 +99,6 @@ func (s *PairingServerSuite) sendingSetup() *PairingClient {
err = s.PS.StartPairing()
s.Require().NoError(err)
// Give time for the sever to be ready, hacky I know, I'll iron this out
time.Sleep(10 * time.Millisecond)
cp, err := s.PS.MakeConnectionParams()
s.Require().NoError(err)

View File

@ -22,9 +22,12 @@ func (s *ServerURLSuite) SetupSuite() {
s.SetupKeyComponents(s.T())
s.server = &MediaServer{Server: Server{
hostname: defaultIP.String(),
port: 1337,
hostname: defaultIP.String(),
portManger: newPortManager(nil),
}}
err := s.server.SetPort(1337)
s.Require().NoError(err)
s.serverNoPort = &MediaServer{Server: Server{
hostname: defaultIP.String(),
}}