diff --git a/p2pclient/pubsub.go b/p2pclient/pubsub.go new file mode 100644 index 0000000..f56e047 --- /dev/null +++ b/p2pclient/pubsub.go @@ -0,0 +1,150 @@ +package p2pclient + +import ( + "context" + "fmt" + + ggio "github.com/gogo/protobuf/io" + pb "github.com/libp2p/go-libp2p-daemon/pb" + peer "github.com/libp2p/go-libp2p-peer" +) + +func newPubsubReq(req *pb.PSRequest) *pb.Request { + return &pb.Request{ + Type: pb.Request_PUBSUB.Enum(), + Pubsub: req, + } +} + +func (c *Client) doPubsub(psReq *pb.PSRequest) (*pb.PSResponse, error) { + control, err := c.newControlConn() + if err != nil { + return nil, err + } + defer control.Close() + + w := ggio.NewDelimitedWriter(control) + req := newPubsubReq(psReq) + if err = w.WriteMsg(req); err != nil { + return nil, err + } + + r := ggio.NewDelimitedReader(control, MessageSizeMax) + msg := &pb.Response{} + if err = r.ReadMsg(msg); err != nil { + return nil, err + } + + if msg.GetType() == pb.Response_ERROR { + err := fmt.Errorf("error from daemon in %s response: %s", req.GetType().String(), msg.GetError()) + log.Errorf(err.Error()) + return nil, err + } + + return msg.GetPubsub(), nil + +} + +func (c *Client) streamPubsubRequest(ctx context.Context, psReq *pb.PSRequest) (<-chan *pb.PSMessage, error) { + control, err := c.newControlConn() + if err != nil { + return nil, err + } + + w := ggio.NewDelimitedWriter(control) + req := newPubsubReq(psReq) + if err = w.WriteMsg(req); err != nil { + control.Close() + return nil, err + } + + r := ggio.NewDelimitedReader(control, MessageSizeMax) + msg := &pb.Response{} + if err = r.ReadMsg(msg); err != nil { + control.Close() + return nil, err + } + + if msg.GetType() == pb.Response_ERROR { + err := fmt.Errorf("error from daemon in %s response: %s", req.GetType().String(), msg.GetError()) + log.Errorf(err.Error()) + return nil, err + } + + go func() { + <-ctx.Done() + control.Close() + }() + + out := make(chan *pb.PSMessage) + go func() { + defer close(out) + defer control.Close() + + for { + msg := &pb.PSMessage{} + if err := r.ReadMsg(msg); err != nil { + log.Errorf("reading pubsub message: %s", err) + return + } + out <- msg + } + }() + + return out, nil +} + +func (c *Client) GetTopics() ([]string, error) { + req := &pb.PSRequest{ + Type: pb.PSRequest_GET_TOPICS.Enum(), + } + + res, err := c.doPubsub(req) + if err != nil { + return nil, err + } + + return res.GetTopics(), nil +} + +func (c *Client) ListPeers() ([]peer.ID, error) { + req := &pb.PSRequest{ + Type: pb.PSRequest_LIST_PEERS.Enum(), + } + + res, err := c.doPubsub(req) + if err != nil { + return nil, err + } + + ids := make([]peer.ID, len(res.GetPeerIDs())) + for i, idbytes := range res.GetPeerIDs() { + id, err := peer.IDFromBytes(idbytes) + if err != nil { + return nil, err + } + ids[i] = id + } + + return ids, nil +} + +func (c *Client) Publish(topic string, data []byte) error { + req := &pb.PSRequest{ + Type: pb.PSRequest_PUBLISH.Enum(), + Topic: &topic, + Data: data, + } + + _, err := c.doPubsub(req) + return err +} + +func (c *Client) Subscribe(ctx context.Context, topic string) (<-chan *pb.PSMessage, error) { + req := &pb.PSRequest{ + Type: pb.PSRequest_SUBSCRIBE.Enum(), + Topic: &topic, + } + + return c.streamPubsubRequest(ctx, req) +} diff --git a/test/dht_test.go b/test/dht_test.go index 5172b7b..9582585 100644 --- a/test/dht_test.go +++ b/test/dht_test.go @@ -3,7 +3,6 @@ package test import ( "bytes" "context" - "fmt" "reflect" "testing" "time" @@ -248,7 +247,6 @@ func TestDHTGetClosestPeers(t *testing.T) { if !bytes.Equal(req.GetKey(), key) { t.Fatal("request key didn't match expected key") } - fmt.Println("we good") resps := make([]*pb.DHTResponse, 2) for i, id := range ids { diff --git a/test/pubsub_test.go b/test/pubsub_test.go new file mode 100644 index 0000000..7d964ca --- /dev/null +++ b/test/pubsub_test.go @@ -0,0 +1,81 @@ +package test + +import ( + "context" + "testing" + "time" +) + +func TestPubsubGetTopicsAndSubscribe(t *testing.T) { + _, client, closer := createDaemonClientPair(t) + defer closer() + ctx, cancel := context.WithTimeout(context.Background(), 5*time.Second) + done := make(chan struct{}) + go func() { + _, err := client.Subscribe(ctx, "test") + if err != nil { + t.Fatal(err) + } + done <- struct{}{} + }() + <-done + topics, err := client.GetTopics() + if err != nil { + t.Fatal(err) + } + if len(topics) != 1 { + t.Fatalf("expected 1 topic, found %d", len(topics)) + } + if topics[0] != "test" { + t.Fatalf("expected topic \"test\", found \"%s\"", topics[0]) + } + cancel() +} + +func TestPubsubMessages(t *testing.T) { + _, sender, senderCloser := createDaemonClientPair(t) + defer senderCloser() + _, receiver, receiverCloser := createDaemonClientPair(t) + defer receiverCloser() + + id, addrs, err := receiver.Identify() + if err != nil { + t.Fatal(err) + } + + if err = sender.Connect(id, addrs); err != nil { + t.Fatal(err) + } + + progress := make(chan struct{}) + done := make(chan struct{}) + go func() { + ctx, cancel := context.WithTimeout(context.Background(), 5*time.Second) + defer cancel() + msgs, err := receiver.Subscribe(ctx, "test") + if err != nil { + t.Fatal(err) + } + progress <- struct{}{} + + select { + case msg := <-msgs: + msgstr := string(msg.Data) + if msgstr != "foobar" { + t.Fatalf("expected \"foobar\", got %s", msgstr) + } + done <- struct{}{} + case <-time.After(5 * time.Second): + t.Fatal("timed out waiting for message") + } + }() + + go func() { + <-progress + if err := sender.Publish("test", []byte("foobar")); err != nil { + t.Fatal(err) + } + }() + + <-done +} diff --git a/test/utils.go b/test/utils.go index c6cc1cb..5aa3727 100644 --- a/test/utils.go +++ b/test/utils.go @@ -41,6 +41,7 @@ func createTempDir(t *testing.T) (string, string, func()) { func createDaemon(t *testing.T, daemonAddr ma.Multiaddr) (*p2pd.Daemon, func()) { ctx, cancelCtx := context.WithCancel(context.Background()) daemon, err := p2pd.NewDaemon(ctx, daemonAddr, false, false) + daemon.EnablePubsub("gossipsub", false, false) if err != nil { t.Fatal(err) }