diff --git a/services/shhext/api.go b/services/shhext/api.go index 7bac86228..64c292880 100644 --- a/services/shhext/api.go +++ b/services/shhext/api.go @@ -217,22 +217,24 @@ func (api *PublicAPI) RequestMessagesSync(conf RetryConfig, r MessagesRequest) ( shh := api.service.w events := make(chan whisper.EnvelopeEvent, 10) - sub := shh.SubscribeEnvelopeEvents(events) - defer sub.Unsubscribe() var ( requestID hexutil.Bytes err error retries int ) for retries <= conf.MaxRetries { + sub := shh.SubscribeEnvelopeEvents(events) r.Timeout = conf.BaseTimeout + conf.StepTimeout*time.Duration(retries) + timeout := r.Timeout // FIXME this weird conversion is required because MessagesRequest expects seconds but defines time.Duration r.Timeout = time.Duration(int(r.Timeout.Seconds())) requestID, err = api.RequestMessages(context.Background(), r) if err != nil { + sub.Unsubscribe() return resp, err } - mailServerResp, err := waitForExpiredOrCompleted(common.BytesToHash(requestID), events) + mailServerResp, err := waitForExpiredOrCompleted(common.BytesToHash(requestID), events, timeout) + sub.Unsubscribe() if err == nil { resp.Cursor = hex.EncodeToString(mailServerResp.Cursor) resp.Error = mailServerResp.Error @@ -244,9 +246,17 @@ func (api *PublicAPI) RequestMessagesSync(conf RetryConfig, r MessagesRequest) ( return resp, fmt.Errorf("failed to request messages after %d retries", retries) } -func waitForExpiredOrCompleted(requestID common.Hash, events chan whisper.EnvelopeEvent) (*whisper.MailServerResponse, error) { +func waitForExpiredOrCompleted(requestID common.Hash, events chan whisper.EnvelopeEvent, timeout time.Duration) (*whisper.MailServerResponse, error) { + expired := fmt.Errorf("request %x expired", requestID) + after := time.NewTimer(timeout) + defer after.Stop() for { - ev := <-events + var ev whisper.EnvelopeEvent + select { + case ev = <-events: + case <-after.C: + return nil, expired + } if ev.Hash != requestID { continue } @@ -258,7 +268,7 @@ func waitForExpiredOrCompleted(requestID common.Hash, events chan whisper.Envelo } return nil, errors.New("invalid event data type") case whisper.EventMailServerRequestExpired: - return nil, errors.New("request expired") + return nil, expired } } } diff --git a/services/shhext/api_test.go b/services/shhext/api_test.go index 9dd211374..bffa9deec 100644 --- a/services/shhext/api_test.go +++ b/services/shhext/api_test.go @@ -220,3 +220,20 @@ func TestSyncMessagesErrors(t *testing.T) { }) } } + +func TestExpiredOrCompleted(t *testing.T) { + timeout := time.Millisecond + events := make(chan whisper.EnvelopeEvent) + errors := make(chan error, 1) + hash := common.Hash{1} + go func() { + _, err := waitForExpiredOrCompleted(hash, events, timeout) + errors <- err + }() + select { + case <-time.After(time.Second): + require.FailNow(t, "timed out waiting for waitForExpiredOrCompleted to complete") + case err := <-errors: + require.EqualError(t, err, fmt.Sprintf("request %x expired", hash)) + } +}