From 25b8aad61fe4dad01811c69ed3350b59597fdd3f Mon Sep 17 00:00:00 2001 From: Jan Winkelmann Date: Thu, 17 Nov 2016 11:27:57 +0100 Subject: [PATCH] add ctx to sub.Next for cancellation --- floodsub_test.go | 6 +++--- subscription.go | 19 +++++++++++++------ 2 files changed, 16 insertions(+), 9 deletions(-) diff --git a/floodsub_test.go b/floodsub_test.go index 6ab4690..f99a88a 100644 --- a/floodsub_test.go +++ b/floodsub_test.go @@ -127,7 +127,7 @@ func TestBasicFloodsub(t *testing.T) { psubs[owner].Publish("foobar", msg) for _, sub := range msgs { - got, err := sub.Next() + got, err := sub.Next(ctx) if err != nil { t.Fatal(sub.err) } @@ -525,7 +525,7 @@ func TestSubscribeMultipleTimes(t *testing.T) { psubs[1].Publish("foo", []byte("bar")) - msg, err := sub1.Next() + msg, err := sub1.Next(ctx) if err != nil { t.Fatalf("unexpected error: %v.", err) } @@ -536,7 +536,7 @@ func TestSubscribeMultipleTimes(t *testing.T) { t.Fatalf("data is %s, expected %s.", data, "bar") } - msg, err = sub2.Next() + msg, err = sub2.Next(ctx) if err != nil { t.Fatalf("unexpected error: %v.", err) } diff --git a/subscription.go b/subscription.go index 6fd01fb..d6e930c 100644 --- a/subscription.go +++ b/subscription.go @@ -1,5 +1,9 @@ package floodsub +import ( + "context" +) + type Subscription struct { topic string ch chan *Message @@ -11,14 +15,17 @@ func (sub *Subscription) Topic() string { return sub.topic } -func (sub *Subscription) Next() (*Message, error) { - msg, ok := <-sub.ch +func (sub *Subscription) Next(ctx context.Context) (*Message, error) { + select { + case msg, ok := <-sub.ch: + if !ok { + return msg, sub.err + } - if !ok { - return msg, sub.err + return msg, nil + case <-ctx.Done(): + return nil, ctx.Err() } - - return msg, nil } func (sub *Subscription) Cancel() {