Fix race conditions in tests (#857)

This commit is contained in:
Adam Babik 2018-04-23 15:35:48 +02:00 committed by GitHub
parent b85e50cbc9
commit 0473f29a8a
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
2 changed files with 21 additions and 12 deletions

View File

@ -357,6 +357,9 @@ func (n *StatusNode) gethService(serviceInstance interface{}) error {
// LightEthereumService exposes reference to LES service running on top of the node // LightEthereumService exposes reference to LES service running on top of the node
func (n *StatusNode) LightEthereumService() (l *les.LightEthereum, err error) { func (n *StatusNode) LightEthereumService() (l *les.LightEthereum, err error) {
n.mu.RLock()
defer n.mu.RUnlock()
err = n.gethService(&l) err = n.gethService(&l)
if err == node.ErrServiceUnknown { if err == node.ErrServiceUnknown {
err = ErrServiceUnknown err = ErrServiceUnknown
@ -367,6 +370,9 @@ func (n *StatusNode) LightEthereumService() (l *les.LightEthereum, err error) {
// WhisperService exposes reference to Whisper service running on top of the node // WhisperService exposes reference to Whisper service running on top of the node
func (n *StatusNode) WhisperService() (w *whisper.Whisper, err error) { func (n *StatusNode) WhisperService() (w *whisper.Whisper, err error) {
n.mu.RLock()
defer n.mu.RUnlock()
err = n.gethService(&w) err = n.gethService(&w)
if err == node.ErrServiceUnknown { if err == node.ErrServiceUnknown {
err = ErrServiceUnknown err = ErrServiceUnknown

View File

@ -21,6 +21,7 @@ import (
"reflect" "reflect"
"strconv" "strconv"
"strings" "strings"
"sync"
"testing" "testing"
"time" "time"
@ -1026,7 +1027,7 @@ func testDiscardTransaction(t *testing.T) bool { //nolint: gocyclo
// replace transaction notification handler // replace transaction notification handler
var txID string var txID string
txFailedEventCalled := false txFailedEventCalled := make(chan struct{})
signal.SetDefaultNodeNotificationHandler(func(jsonEvent string) { signal.SetDefaultNodeNotificationHandler(func(jsonEvent string) {
var envelope signal.Envelope var envelope signal.Envelope
if err := json.Unmarshal([]byte(jsonEvent), &envelope); err != nil { if err := json.Unmarshal([]byte(jsonEvent), &envelope); err != nil {
@ -1088,7 +1089,7 @@ func testDiscardTransaction(t *testing.T) bool { //nolint: gocyclo
return return
} }
txFailedEventCalled = true close(txFailedEventCalled)
} }
}) })
@ -1114,12 +1115,13 @@ func testDiscardTransaction(t *testing.T) bool { //nolint: gocyclo
return false return false
} }
if !txFailedEventCalled { select {
case <-txFailedEventCalled:
return true
default:
t.Error("expected tx failure signal is not received") t.Error("expected tx failure signal is not received")
return false return false
} }
return true
} }
func testDiscardMultipleQueuedTransactions(t *testing.T) bool { //nolint: gocyclo func testDiscardMultipleQueuedTransactions(t *testing.T) bool { //nolint: gocyclo
@ -1134,10 +1136,11 @@ func testDiscardMultipleQueuedTransactions(t *testing.T) bool { //nolint: gocycl
// make sure you panic if transaction complete doesn't return // make sure you panic if transaction complete doesn't return
testTxCount := 3 testTxCount := 3
txIDs := make(chan string, testTxCount) txIDs := make(chan string, testTxCount)
allTestTxDiscarded := make(chan struct{}, 1)
var testTxDiscarded sync.WaitGroup
testTxDiscarded.Add(testTxCount)
// replace transaction notification handler // replace transaction notification handler
txFailedEventCallCount := 0
signal.SetDefaultNodeNotificationHandler(func(jsonEvent string) { signal.SetDefaultNodeNotificationHandler(func(jsonEvent string) {
var txID string var txID string
var envelope signal.Envelope var envelope signal.Envelope
@ -1175,10 +1178,7 @@ func testDiscardMultipleQueuedTransactions(t *testing.T) bool { //nolint: gocycl
return return
} }
txFailedEventCallCount++ testTxDiscarded.Done()
if txFailedEventCallCount == testTxCount {
allTestTxDiscarded <- struct{}{}
}
} }
}) })
@ -1275,8 +1275,11 @@ func testDiscardMultipleQueuedTransactions(t *testing.T) bool { //nolint: gocycl
go sendTx() go sendTx()
} }
done := make(chan struct{})
go func() { testTxDiscarded.Wait(); close(done) }()
select { select {
case <-allTestTxDiscarded: case <-done:
// pass // pass
case <-time.After(20 * time.Second): case <-time.After(20 * time.Second):
t.Error("test timed out") t.Error("test timed out")