Added fix for pairing server race on port
This commit is contained in:
parent
f5666bfcb8
commit
6cdd24a048
|
@ -30,7 +30,8 @@ func (s *ConnectionParamsSuite) SetupSuite() {
|
|||
s.Require().NoError(err)
|
||||
|
||||
bs := NewServer(&cert, defaultIP.String(), func(int) {})
|
||||
bs.port = 1337
|
||||
err = bs.SetPort(1337)
|
||||
s.Require().NoError(err)
|
||||
|
||||
s.server = &PairingServer{
|
||||
Server: bs,
|
||||
|
|
|
@ -407,6 +407,7 @@ func handlePairingReceive(ps *PairingServer) http.HandlerFunc {
|
|||
if err != nil {
|
||||
signal.SendLocalPairingEvent(Event{Type: EventTransferError, Error: err})
|
||||
ps.logger.Error("ioutil.ReadAll(r.Body)", zap.Error(err))
|
||||
return
|
||||
}
|
||||
signal.SendLocalPairingEvent(Event{Type: EventTransferSuccess})
|
||||
|
||||
|
@ -414,6 +415,7 @@ func handlePairingReceive(ps *PairingServer) http.HandlerFunc {
|
|||
if err != nil {
|
||||
signal.SendLocalPairingEvent(Event{Type: EventProcessError, Error: err})
|
||||
ps.logger.Error("ps.PayloadManager.Receive(payload)", zap.Error(err))
|
||||
return
|
||||
}
|
||||
signal.SendLocalPairingEvent(Event{Type: EventProcessSuccess})
|
||||
}
|
||||
|
@ -429,6 +431,7 @@ func handlePairingSend(ps *PairingServer) http.HandlerFunc {
|
|||
if err != nil {
|
||||
signal.SendLocalPairingEvent(Event{Type: EventTransferError, Error: err})
|
||||
ps.logger.Error("w.Write(ps.PayloadManager.ToSend())", zap.Error(err))
|
||||
return
|
||||
}
|
||||
signal.SendLocalPairingEvent(Event{Type: EventTransferSuccess})
|
||||
}
|
||||
|
@ -440,6 +443,7 @@ func challengeMiddleware(ps *PairingServer, next http.Handler) http.HandlerFunc
|
|||
if err != nil {
|
||||
ps.logger.Error("ps.cookieStore.Get(r, pairingStoreChallenge)", zap.Error(err))
|
||||
http.Error(w, "error", http.StatusInternalServerError)
|
||||
return
|
||||
}
|
||||
|
||||
blocked, ok := s.Values[sessionBlocked].(bool)
|
||||
|
@ -491,6 +495,7 @@ func handlePairingChallenge(ps *PairingServer) http.HandlerFunc {
|
|||
s, err := ps.cookieStore.Get(r, sessionChallenge)
|
||||
if err != nil {
|
||||
ps.logger.Error("ps.cookieStore.Get(r, pairingStoreChallenge)", zap.Error(err))
|
||||
return
|
||||
}
|
||||
|
||||
var challenge []byte
|
||||
|
@ -500,12 +505,14 @@ func handlePairingChallenge(ps *PairingServer) http.HandlerFunc {
|
|||
_, err = rand.Read(challenge)
|
||||
if err != nil {
|
||||
ps.logger.Error("_, err = rand.Read(auth)", zap.Error(err))
|
||||
return
|
||||
}
|
||||
|
||||
s.Values[sessionChallenge] = challenge
|
||||
err = s.Save(r, w)
|
||||
if err != nil {
|
||||
ps.logger.Error("err = s.Save(r, w)", zap.Error(err))
|
||||
return
|
||||
}
|
||||
}
|
||||
|
||||
|
@ -513,6 +520,7 @@ func handlePairingChallenge(ps *PairingServer) http.HandlerFunc {
|
|||
_, err = w.Write(challenge)
|
||||
if err != nil {
|
||||
ps.logger.Error("_, err = w.Write(challenge)", zap.Error(err))
|
||||
return
|
||||
}
|
||||
}
|
||||
}
|
||||
|
|
|
@ -0,0 +1,44 @@
|
|||
package server
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"sync"
|
||||
)
|
||||
|
||||
type portManger struct {
|
||||
port int
|
||||
afterPortChanged func(port int)
|
||||
portWait *sync.Mutex
|
||||
}
|
||||
|
||||
func newPortManager(afterPortChanged func(int)) portManger {
|
||||
pm := portManger{
|
||||
afterPortChanged: afterPortChanged,
|
||||
portWait: new(sync.Mutex),
|
||||
}
|
||||
pm.portWait.Lock()
|
||||
return pm
|
||||
}
|
||||
|
||||
func (p *portManger) SetPort(port int) error {
|
||||
if port == 0 {
|
||||
return fmt.Errorf("port can not be `0`, use ResetPort() instead")
|
||||
}
|
||||
|
||||
p.port = port
|
||||
p.portWait.Unlock()
|
||||
return nil
|
||||
}
|
||||
|
||||
func (p *portManger) ResetPort() {
|
||||
if p.portWait.TryLock() {
|
||||
p.port = 0
|
||||
}
|
||||
}
|
||||
|
||||
func (p *portManger) MustGetPort() int {
|
||||
p.portWait.Lock()
|
||||
defer p.portWait.Unlock()
|
||||
|
||||
return p.port
|
||||
}
|
|
@ -14,18 +14,22 @@ import (
|
|||
)
|
||||
|
||||
type Server struct {
|
||||
run bool
|
||||
server *http.Server
|
||||
logger *zap.Logger
|
||||
cert *tls.Certificate
|
||||
hostname string
|
||||
port int
|
||||
handlers HandlerPatternMap
|
||||
afterPortChanged func(port int)
|
||||
isRunning bool
|
||||
server *http.Server
|
||||
logger *zap.Logger
|
||||
cert *tls.Certificate
|
||||
hostname string
|
||||
handlers HandlerPatternMap
|
||||
portManger
|
||||
}
|
||||
|
||||
func NewServer(cert *tls.Certificate, hostname string, afterPortChanged func(int)) Server {
|
||||
return Server{logger: logutils.ZapLogger(), cert: cert, hostname: hostname, afterPortChanged: afterPortChanged}
|
||||
return Server{
|
||||
logger: logutils.ZapLogger(),
|
||||
cert: cert,
|
||||
hostname: hostname,
|
||||
portManger: newPortManager(afterPortChanged),
|
||||
}
|
||||
}
|
||||
|
||||
func (s *Server) getHost() string {
|
||||
|
@ -40,7 +44,7 @@ func (s *Server) listenAndServe() {
|
|||
listener, err := tls.Listen("tcp", s.getHost(), cfg)
|
||||
if err != nil {
|
||||
s.logger.Error("failed to start server, retrying", zap.Error(err))
|
||||
s.port = 0
|
||||
s.ResetPort()
|
||||
err = s.Start()
|
||||
if err != nil {
|
||||
s.logger.Error("server start failed, giving up", zap.Error(err))
|
||||
|
@ -48,11 +52,17 @@ func (s *Server) listenAndServe() {
|
|||
return
|
||||
}
|
||||
|
||||
s.port = listener.Addr().(*net.TCPAddr).Port
|
||||
if s.afterPortChanged != nil {
|
||||
s.afterPortChanged(s.port)
|
||||
err = s.SetPort(listener.Addr().(*net.TCPAddr).Port)
|
||||
if err != nil {
|
||||
s.logger.Error("failed to set Server.port", zap.Error(err))
|
||||
return
|
||||
}
|
||||
s.run = true
|
||||
|
||||
if s.afterPortChanged != nil {
|
||||
s.afterPortChanged(s.MustGetPort())
|
||||
}
|
||||
|
||||
s.isRunning = true
|
||||
|
||||
err = s.server.Serve(listener)
|
||||
if err != http.ErrServerClosed {
|
||||
|
@ -64,11 +74,12 @@ func (s *Server) listenAndServe() {
|
|||
return
|
||||
}
|
||||
|
||||
s.run = false
|
||||
s.isRunning = false
|
||||
}
|
||||
|
||||
func (s *Server) resetServer() {
|
||||
s.server = new(http.Server)
|
||||
s.ResetPort()
|
||||
}
|
||||
|
||||
func (s *Server) applyHandlers() {
|
||||
|
@ -100,7 +111,7 @@ func (s *Server) Stop() error {
|
|||
}
|
||||
|
||||
func (s *Server) ToForeground() {
|
||||
if !s.run && (s.server != nil) {
|
||||
if !s.isRunning && (s.server != nil) {
|
||||
err := s.Start()
|
||||
if err != nil {
|
||||
s.logger.Error("server start failed during foreground transition", zap.Error(err))
|
||||
|
@ -109,7 +120,7 @@ func (s *Server) ToForeground() {
|
|||
}
|
||||
|
||||
func (s *Server) ToBackground() {
|
||||
if s.run {
|
||||
if s.isRunning {
|
||||
err := s.Stop()
|
||||
if err != nil {
|
||||
s.logger.Error("server stop failed during background transition", zap.Error(err))
|
||||
|
|
|
@ -91,11 +91,7 @@ func (s *PairingServer) MakeConnectionParams() (*ConnectionParams, error) {
|
|||
netIP = netIP4
|
||||
}
|
||||
|
||||
if s.port == 0 {
|
||||
return nil, fmt.Errorf("port is 0, listener is not yet set")
|
||||
}
|
||||
|
||||
return NewConnectionParams(netIP, s.port, s.pk, s.ek, s.mode), nil
|
||||
return NewConnectionParams(netIP, s.MustGetPort(), s.pk, s.ek, s.mode), nil
|
||||
}
|
||||
|
||||
func (s *PairingServer) StartPairing() error {
|
||||
|
|
|
@ -4,7 +4,6 @@ import (
|
|||
"crypto/ecdsa"
|
||||
"crypto/rand"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"github.com/stretchr/testify/suite"
|
||||
)
|
||||
|
@ -18,7 +17,7 @@ type PairingServerSuite struct {
|
|||
TestPairingServerComponents
|
||||
}
|
||||
|
||||
func (s *PairingServerSuite) SetupSuite() {
|
||||
func (s *PairingServerSuite) SetupTest() {
|
||||
s.SetupPairingServerComponents(s.T())
|
||||
}
|
||||
|
||||
|
@ -39,9 +38,6 @@ func (s *PairingServerSuite) TestPairingServer_StartPairing() {
|
|||
err = s.PS.StartPairing()
|
||||
s.Require().NoError(err)
|
||||
|
||||
// Give time for the sever to be ready, hacky I know, I'll iron this out
|
||||
time.Sleep(10 * time.Millisecond)
|
||||
|
||||
cp, err := s.PS.MakeConnectionParams()
|
||||
s.Require().NoError(err)
|
||||
|
||||
|
@ -89,6 +85,7 @@ func (s *PairingServerSuite) TestPairingServer_StartPairing() {
|
|||
|
||||
// Reset the server's PayloadEncryptionManager
|
||||
s.PS.PayloadManager.(*MockEncryptOnlyPayloadManager).ResetPayload()
|
||||
s.PS.ResetPort()
|
||||
}
|
||||
}
|
||||
|
||||
|
@ -102,9 +99,6 @@ func (s *PairingServerSuite) sendingSetup() *PairingClient {
|
|||
err = s.PS.StartPairing()
|
||||
s.Require().NoError(err)
|
||||
|
||||
// Give time for the sever to be ready, hacky I know, I'll iron this out
|
||||
time.Sleep(10 * time.Millisecond)
|
||||
|
||||
cp, err := s.PS.MakeConnectionParams()
|
||||
s.Require().NoError(err)
|
||||
|
||||
|
|
|
@ -22,9 +22,12 @@ func (s *ServerURLSuite) SetupSuite() {
|
|||
s.SetupKeyComponents(s.T())
|
||||
|
||||
s.server = &MediaServer{Server: Server{
|
||||
hostname: defaultIP.String(),
|
||||
port: 1337,
|
||||
hostname: defaultIP.String(),
|
||||
portManger: newPortManager(nil),
|
||||
}}
|
||||
err := s.server.SetPort(1337)
|
||||
s.Require().NoError(err)
|
||||
|
||||
s.serverNoPort = &MediaServer{Server: Server{
|
||||
hostname: defaultIP.String(),
|
||||
}}
|
||||
|
|
Loading…
Reference in New Issue