diff --git a/server/ports.go b/server/ports.go index a76eae3fb..a4b40454c 100644 --- a/server/ports.go +++ b/server/ports.go @@ -2,46 +2,33 @@ package server import ( "fmt" - "sync" + "time" "go.uber.org/zap" ) // portManager is responsible for maintaining segregated access to the port field -// via the use of rwLock sync.RWMutex and mustRead sync.Mutex -// rwLock establishes a standard read write mutex that allows consecutive reads and exclusive writes -// mustRead forces MustGetPort to wait until port has a none 0 value type portManger struct { logger *zap.Logger port int afterPortChanged func(port int) - rwLock *sync.RWMutex - mustRead *sync.Mutex } -// newPortManager returns a newly initialised portManager with a pre-Locked portManger.mustRead sync.Mutex +// newPortManager returns a newly initialised portManager func newPortManager(logger *zap.Logger, afterPortChanged func(int)) portManger { pm := portManger{ logger: logger.Named("portManger"), afterPortChanged: afterPortChanged, - rwLock: new(sync.RWMutex), - mustRead: new(sync.Mutex), } - pm.mustRead.Lock() return pm } // SetPort sets portManger.port field to the given port value // next triggers any given portManger.afterPortChanged function -// additionally portManger.mustRead.Unlock() is called, releasing any calls to MustGetPort func (p *portManger) SetPort(port int) error { l := p.logger.Named("SetPort") l.Debug("fired", zap.Int("port", port)) - p.rwLock.Lock() - defer p.rwLock.Unlock() - l.Debug("acquired rwLock.Lock") - if port == 0 { errMsg := "port can not be `0`, use ResetPort() instead" l.Error(errMsg) @@ -53,58 +40,41 @@ func (p *portManger) SetPort(port int) error { l.Debug("p.afterPortChanged != nil") p.afterPortChanged(port) } - p.mustRead.Unlock() - l.Debug("p.mustRead.Unlock()") return nil } -// ResetPort attempts to reset portManger.port to 0 -// if portManger.mustRead is already locked the function returns after doing nothing -// portManger.mustRead.TryLock() is called because ResetPort may be called multiple times in a row -// and calling multiple times must not cause a deadlock or an infinite hang, but the lock needs to be -// reapplied if it is not present. +// ResetPort resets portManger.port to 0 func (p *portManger) ResetPort() { l := p.logger.Named("ResetPort") l.Debug("fired") - p.rwLock.Lock() - defer p.rwLock.Unlock() - l.Debug("acquired rwLock.Lock") - - if p.mustRead.TryLock() { - l.Debug("able to lock mustRead") - p.port = 0 - return - } - l.Debug("unable to lock mustRead") + p.port = 0 } // GetPort gets the current value of portManager.port without any concern for the state of its value -// and therefore does not block until portManager.mustRead.Unlock() is called +// and therefore does not wait if portManager.port is 0 func (p *portManger) GetPort() int { l := p.logger.Named("GetPort") l.Debug("fired") - p.rwLock.RLock() - defer p.rwLock.RUnlock() - l.Debug("acquired rwLock.RLock") - return p.port } -// MustGetPort only returns portManager.port if portManager.mustRead is unlocked. -// This presupposes that portManger.mustRead has a default state of locked and SetPort unlock portManager.mustRead +// MustGetPort only returns portManager.port if portManager.port is not 0. func (p *portManger) MustGetPort() int { l := p.logger.Named("MustGetPort") l.Debug("fired") - p.mustRead.Lock() - defer p.mustRead.Unlock() - l.Debug("acquired mustRead.Lock") + for { + if p.port != 0 { + port := p.port + if port == 0 { + panic("port is zero, port has reset") + } + return port + } - p.rwLock.RLock() - defer p.rwLock.RUnlock() - l.Debug("acquired rwLock.RLock") - - return p.port + l.Debug("port is zero") + time.Sleep(20 * time.Millisecond) + } } diff --git a/server/server_pairing_test.go b/server/server_pairing_test.go index 35c5a3012..75da4b2ef 100644 --- a/server/server_pairing_test.go +++ b/server/server_pairing_test.go @@ -3,6 +3,7 @@ package server import ( "crypto/ecdsa" "crypto/rand" + "regexp" "testing" "github.com/stretchr/testify/suite" @@ -21,6 +22,18 @@ func (s *PairingServerSuite) SetupTest() { s.SetupPairingServerComponents(s.T()) } +func (s *PairingServerSuite) TestMultiBackgroundForeground() { + err := s.PS.Start() + s.Require().NoError(err) + s.PS.ToBackground() + s.PS.ToForeground() + s.PS.ToBackground() + s.PS.ToBackground() + s.PS.ToForeground() + s.PS.ToForeground() + s.Require().Regexp(regexp.MustCompile("(https://192\\.168\\.0\\.\\d+:\\d+)"), s.PS.MakeBaseURL().String()) +} + func (s *PairingServerSuite) TestPairingServer_StartPairing() { // Replace PairingServer.PayloadManager with a MockEncryptOnlyPayloadManager pm, err := NewMockEncryptOnlyPayloadManager(s.EphemeralAES) diff --git a/server/server_test.go b/server/server_test.go index 9a346511b..8714c4626 100644 --- a/server/server_test.go +++ b/server/server_test.go @@ -43,7 +43,7 @@ func (s *ServerURLSuite) SetupTest() { }} go func() { time.Sleep(waitTime) - s.serverNoPort.mustRead.Unlock() + s.serverNoPort.port = 80 }() s.testStart = time.Now() @@ -60,12 +60,12 @@ func (s *ServerURLSuite) testNoPort(expected string, actual string) { func (s *ServerURLSuite) TestServer_MakeBaseURL() { s.Require().Equal("https://127.0.0.1:1337", s.server.MakeBaseURL().String()) - s.testNoPort("https://127.0.0.1:0", s.serverNoPort.MakeBaseURL().String()) + s.testNoPort("https://127.0.0.1:80", s.serverNoPort.MakeBaseURL().String()) } func (s *ServerURLSuite) TestServer_MakeImageServerURL() { s.Require().Equal("https://127.0.0.1:1337/messages/", s.server.MakeImageServerURL()) - s.testNoPort("https://127.0.0.1:0/messages/", s.serverNoPort.MakeImageServerURL()) + s.testNoPort("https://127.0.0.1:80/messages/", s.serverNoPort.MakeImageServerURL()) } func (s *ServerURLSuite) TestServer_MakeIdenticonURL() { @@ -73,7 +73,7 @@ func (s *ServerURLSuite) TestServer_MakeIdenticonURL() { "https://127.0.0.1:1337/messages/identicons?publicKey=0xdaff0d11decade", s.server.MakeIdenticonURL("0xdaff0d11decade")) s.testNoPort( - "https://127.0.0.1:0/messages/identicons?publicKey=0xdaff0d11decade", + "https://127.0.0.1:80/messages/identicons?publicKey=0xdaff0d11decade", s.serverNoPort.MakeIdenticonURL("0xdaff0d11decade")) } @@ -82,7 +82,7 @@ func (s *ServerURLSuite) TestServer_MakeImageURL() { "https://127.0.0.1:1337/messages/images?messageId=0x10aded70ffee", s.server.MakeImageURL("0x10aded70ffee")) s.testNoPort( - "https://127.0.0.1:0/messages/images?messageId=0x10aded70ffee", + "https://127.0.0.1:80/messages/images?messageId=0x10aded70ffee", s.serverNoPort.MakeImageURL("0x10aded70ffee")) } @@ -91,7 +91,7 @@ func (s *ServerURLSuite) TestServer_MakeAudioURL() { "https://127.0.0.1:1337/messages/audio?messageId=0xde1e7ebee71e", s.server.MakeAudioURL("0xde1e7ebee71e")) s.testNoPort( - "https://127.0.0.1:0/messages/audio?messageId=0xde1e7ebee71e", + "https://127.0.0.1:80/messages/audio?messageId=0xde1e7ebee71e", s.serverNoPort.MakeAudioURL("0xde1e7ebee71e")) } @@ -100,6 +100,6 @@ func (s *ServerURLSuite) TestServer_MakeStickerURL() { "https://127.0.0.1:1337/ipfs?hash=0xdeadbeef4ac0", s.server.MakeStickerURL("0xdeadbeef4ac0")) s.testNoPort( - "https://127.0.0.1:0/ipfs?hash=0xdeadbeef4ac0", + "https://127.0.0.1:80/ipfs?hash=0xdeadbeef4ac0", s.serverNoPort.MakeStickerURL("0xdeadbeef4ac0")) }