diff --git a/protocol/communities/manager.go b/protocol/communities/manager.go index b5f641d57..1cb37eb59 100644 --- a/protocol/communities/manager.go +++ b/protocol/communities/manager.go @@ -102,6 +102,54 @@ type Manager struct { RekeyInterval time.Duration PermissionChecker PermissionChecker keyDistributor KeyDistributor + communityLock *CommunityLock +} + +type CommunityLock struct { + logger *zap.Logger + locks map[string]*sync.Mutex + mutex sync.Mutex +} + +func NewCommunityLock(logger *zap.Logger) *CommunityLock { + return &CommunityLock{ + logger: logger.Named("CommunityLock"), + locks: make(map[string]*sync.Mutex), + } +} + +func (c *CommunityLock) Lock(communityID types.HexBytes) { + c.mutex.Lock() + communityIDStr := types.EncodeHex(communityID) + lock, ok := c.locks[communityIDStr] + if !ok { + lock = &sync.Mutex{} + c.locks[communityIDStr] = lock + } + c.mutex.Unlock() + + lock.Lock() +} + +func (c *CommunityLock) Unlock(communityID types.HexBytes) { + c.mutex.Lock() + communityIDStr := types.EncodeHex(communityID) + lock, ok := c.locks[communityIDStr] + c.mutex.Unlock() + + if ok { + lock.Unlock() + } else { + c.logger.Warn("trying to unlock a non-existent lock", zap.String("communityID", communityIDStr)) + } +} + +func (c *CommunityLock) Init() { + c.locks = make(map[string]*sync.Mutex) +} + +func (c *CommunityLock) Release() { + c.locks = nil } type HistoryArchiveDownloadTask struct { @@ -270,6 +318,7 @@ func NewManager(identity *ecdsa.PrivateKey, installationID string, db *sql.DB, e torrentTasks: make(map[string]metainfo.Hash), historyArchiveDownloadTasks: make(map[string]*HistoryArchiveDownloadTask), keyDistributor: keyDistributor, + communityLock: NewCommunityLock(logger), } manager.persistence = &Persistence{ @@ -375,6 +424,7 @@ func (m *Manager) Subscribe() chan *Subscription { func (m *Manager) Start() error { m.stopped = false + m.communityLock.Init() if m.ensVerifier != nil { m.runENSVerificationLoop() } @@ -521,6 +571,7 @@ func (m *Manager) Stop() error { close(c) } m.StopTorrentClient() + m.communityLock.Release() return nil } @@ -1308,6 +1359,9 @@ func (m *Manager) ImportCommunity(key *ecdsa.PrivateKey, clock uint64) (*Communi } func (m *Manager) CreateChat(communityID types.HexBytes, chat *protobuf.CommunityChat, publish bool, thirdPartyID string) (*CommunityChanges, error) { + m.communityLock.Lock(communityID) + defer m.communityLock.Unlock(communityID) + community, err := m.GetByID(communityID) if err != nil { return nil, err @@ -1378,6 +1432,9 @@ func (m *Manager) DeleteChat(communityID types.HexBytes, chatID string) (*Commun } func (m *Manager) CreateCategory(request *requests.CreateCommunityCategory, publish bool) (*Community, *CommunityChanges, error) { + m.communityLock.Lock(request.CommunityID) + defer m.communityLock.Unlock(request.CommunityID) + community, err := m.GetByID(request.CommunityID) if err != nil { return nil, nil, err @@ -1588,6 +1645,8 @@ func (m *Manager) HandleCommunityDescriptionMessage(signer *ecdsa.PublicKey, des return nil, err } + m.communityLock.Lock(id) + defer m.communityLock.Unlock(id) community, err := m.GetByID(id) if err != nil && err != ErrOrgNotFound { return nil, err @@ -2971,6 +3030,8 @@ func (m *Manager) JoinCommunity(id types.HexBytes, forceJoin bool) (*Community, } func (m *Manager) SpectateCommunity(id types.HexBytes) (*Community, error) { + m.communityLock.Lock(id) + defer m.communityLock.Unlock(id) community, err := m.GetByID(id) if err != nil { return nil, err diff --git a/protocol/messenger_community_for_mobile_testing_test.go b/protocol/messenger_community_for_mobile_testing_test.go new file mode 100644 index 000000000..24ec0120e --- /dev/null +++ b/protocol/messenger_community_for_mobile_testing_test.go @@ -0,0 +1,79 @@ +package protocol + +import ( + "errors" + "sync" + "testing" + "time" + + "github.com/stretchr/testify/suite" + + "github.com/status-im/status-go/eth-node/types" + "github.com/status-im/status-go/protocol/tt" +) + +type MessengerCommunityForMobileTestingTestSuite struct { + MessengerBaseTestSuite +} + +func TestMessengerCommunityForMobileTesting(t *testing.T) { + suite.Run(t, new(MessengerCommunityForMobileTestingTestSuite)) +} + +func (s *MessengerCommunityForMobileTestingTestSuite) SetupTest() { + s.MessengerBaseTestSuite.SetupTest() +} + +func (s *MessengerCommunityForMobileTestingTestSuite) TearDownTest() { + s.MessengerBaseTestSuite.TearDownTest() +} + +func (s *MessengerCommunityForMobileTestingTestSuite) TestCreateClosedCommunity() { + var wg sync.WaitGroup + wg.Add(1) + // simulate invoking `HandleCommunityDescription` + go func() { + err := tt.RetryWithBackOff(func() error { + r, err := s.m.RetrieveAll() + s.Require().NoError(err) + if len(r.Communities()) > 0 { + return nil + } + return errors.New("not receive enough messages relate to community") + }) + wg.Done() + s.Require().NoError(err) + }() + + wg.Add(1) + var communityID types.HexBytes + // simulate frontend(mobile) invoking `CreateClosedCommunity` + go func() { + response, err := s.m.CreateClosedCommunity() + s.Require().NoError(err) + s.Require().Len(response.Communities(), 1) + s.Require().Len(response.Communities()[0].Categories(), 2) + s.Require().Len(response.Chats(), 4) + s.Require().Len(response.Communities()[0].Description().Chats, 4) + communityID = response.Communities()[0].ID() + wg.Done() + }() + + timeout := time.After(20 * time.Second) + done := make(chan bool) + go func() { + wg.Wait() + close(done) + }() + + select { + case <-timeout: + s.Fail("TestCreateClosedCommunity timed out") + case <-done: + // validate concurrent updating community result + lastCommunity, err := s.m.GetCommunityByID(communityID) + s.Require().NoError(err) + s.Require().Len(lastCommunity.Categories(), 2) + s.Require().Len(lastCommunity.Description().Chats, 4) + } +}