chore: bump go-waku to fix mem leak
This commit is contained in:
parent
329f5c8316
commit
a6d33b9912
5
go.mod
5
go.mod
|
@ -32,7 +32,7 @@ require (
|
|||
github.com/kilic/bls12-381 v0.0.0-20200607163746-32e1441c8a9f
|
||||
github.com/lib/pq v1.10.4
|
||||
github.com/libp2p/go-libp2p v0.25.1
|
||||
github.com/libp2p/go-libp2p-pubsub v0.9.1
|
||||
github.com/libp2p/go-libp2p-pubsub v0.9.3
|
||||
github.com/lucasb-eyer/go-colorful v1.0.3
|
||||
github.com/mat/besticon v0.0.0-20210314201728-1579f269edb7
|
||||
github.com/multiformats/go-multiaddr v0.8.0
|
||||
|
@ -80,7 +80,7 @@ require (
|
|||
github.com/ipfs/go-log/v2 v2.5.1
|
||||
github.com/ladydascalie/currency v1.6.0
|
||||
github.com/meirf/gopart v0.0.0-20180520194036-37e9492a85a8
|
||||
github.com/waku-org/go-waku v0.5.2-0.20230308135126-4b52983fc483
|
||||
github.com/waku-org/go-waku v0.5.3-0.20230327132601-b540953f74e9
|
||||
github.com/yeqown/go-qrcode/v2 v2.2.1
|
||||
github.com/yeqown/go-qrcode/writer/standard v1.2.1
|
||||
go.uber.org/multierr v1.8.0
|
||||
|
@ -110,6 +110,7 @@ require (
|
|||
github.com/benbjohnson/clock v1.3.0 // indirect
|
||||
github.com/benbjohnson/immutable v0.3.0 // indirect
|
||||
github.com/beorn7/perks v1.0.1 // indirect
|
||||
github.com/berty/go-libp2p-rendezvous v0.4.1 // indirect
|
||||
github.com/bits-and-blooms/bitset v1.2.0 // indirect
|
||||
github.com/bradfitz/iter v0.0.0-20191230175014-e8f45d346db8 // indirect
|
||||
github.com/btcsuite/btcd v0.22.1 // indirect
|
||||
|
|
10
go.sum
10
go.sum
|
@ -397,6 +397,8 @@ github.com/beorn7/perks v0.0.0-20180321164747-3a771d992973/go.mod h1:Dwedo/Wpr24
|
|||
github.com/beorn7/perks v1.0.0/go.mod h1:KWe93zE9D1o94FZ5RNwFwVgaQK1VOXiVxmqh+CedLV8=
|
||||
github.com/beorn7/perks v1.0.1 h1:VlbKKnNfV8bJzeqoa4cOKqO6bYr3WgKZxO8Z16+hsOM=
|
||||
github.com/beorn7/perks v1.0.1/go.mod h1:G2ZrVWU2WbWT9wwq4/hrbKbnv/1ERSJQ0ibhJ6rlkpw=
|
||||
github.com/berty/go-libp2p-rendezvous v0.4.1 h1:+yXsKocTxfKt+Sl3JkcPbv1J31QKjYoYSyIpMWGb/Wc=
|
||||
github.com/berty/go-libp2p-rendezvous v0.4.1/go.mod h1:Kc2dtCckvFN44/eCiWXT5YbwVbR7WA5iPhDZIKNkQG0=
|
||||
github.com/bgentry/speakeasy v0.1.0/go.mod h1:+zsyZBPWlz7T6j88CTgSN5bM796AkVf0kBD4zp0CCIs=
|
||||
github.com/bitly/go-hostpool v0.0.0-20171023180738-a3a6125de932/go.mod h1:NOuUCSz6Q9T7+igc/hlvDOUdtWKryOrtFyIVABv/p7k=
|
||||
github.com/bitly/go-simplejson v0.5.0/go.mod h1:cXHtHw4XUPsvGaxgjIAn8PhEWG9NfngEKAMDJEczWVA=
|
||||
|
@ -1346,8 +1348,8 @@ github.com/libp2p/go-flow-metrics v0.1.0 h1:0iPhMI8PskQwzh57jB9WxIuIOQ0r+15PChFG
|
|||
github.com/libp2p/go-flow-metrics v0.1.0/go.mod h1:4Xi8MX8wj5aWNDAZttg6UPmc0ZrnFNsMtpsYUClFtro=
|
||||
github.com/libp2p/go-libp2p-asn-util v0.2.0 h1:rg3+Os8jbnO5DxkC7K/Utdi+DkY3q/d1/1q+8WeNAsw=
|
||||
github.com/libp2p/go-libp2p-asn-util v0.2.0/go.mod h1:WoaWxbHKBymSN41hWSq/lGKJEca7TNm58+gGJi2WsLI=
|
||||
github.com/libp2p/go-libp2p-pubsub v0.9.1 h1:A6LBg9BaoLf3NwRz+E974sAxTVcbUZYg95IhK2BZz9g=
|
||||
github.com/libp2p/go-libp2p-pubsub v0.9.1/go.mod h1:RYA7aM9jIic5VV47WXu4GkcRxRhrdElWf8xtyli+Dzc=
|
||||
github.com/libp2p/go-libp2p-pubsub v0.9.3 h1:ihcz9oIBMaCK9kcx+yHWm3mLAFBMAUsM4ux42aikDxo=
|
||||
github.com/libp2p/go-libp2p-pubsub v0.9.3/go.mod h1:RYA7aM9jIic5VV47WXu4GkcRxRhrdElWf8xtyli+Dzc=
|
||||
github.com/libp2p/go-libp2p-testing v0.12.0 h1:EPvBb4kKMWO29qP4mZGyhVzUyR25dvfUIK5WDu6iPUA=
|
||||
github.com/libp2p/go-maddr-filter v0.1.0/go.mod h1:VzZhTXkMucEGGEOSKddrwGiOv0tUhgnKqNEmIAz/bPU=
|
||||
github.com/libp2p/go-mplex v0.7.0 h1:BDhFZdlk5tbr0oyFq/xv/NPGfjbnrsDam1EvutpBDbY=
|
||||
|
@ -2100,8 +2102,8 @@ github.com/vishvananda/netns v0.0.0-20200728191858-db3c7e526aae/go.mod h1:DD4vA1
|
|||
github.com/vishvananda/netns v0.0.0-20210104183010-2eb08e3e575f/go.mod h1:DD4vA1DwXk04H54A1oHXtwZmA0grkVMdPxx/VGLCah0=
|
||||
github.com/waku-org/go-discover v0.0.0-20221209174356-61c833f34d98 h1:xwY0kW5XZFimdqfZb9cZwT1S3VJP9j3AE6bdNd9boXM=
|
||||
github.com/waku-org/go-discover v0.0.0-20221209174356-61c833f34d98/go.mod h1:eBHgM6T4EG0RZzxpxKy+rGz/6Dw2Nd8DWxS0lm9ESDw=
|
||||
github.com/waku-org/go-waku v0.5.2-0.20230308135126-4b52983fc483 h1:WB7CnxOpd99PxPE+mpNC4y2sdwDE263O6qgiDyRIYjY=
|
||||
github.com/waku-org/go-waku v0.5.2-0.20230308135126-4b52983fc483/go.mod h1:Uz6WhNbCtbM8fSr0wb8apqhAPQYKvOPoyaGOHdw9DkU=
|
||||
github.com/waku-org/go-waku v0.5.3-0.20230327132601-b540953f74e9 h1:Br5OIct6oaOUfzj01bNaRpDbBNJahiQKLzX94IkxLtw=
|
||||
github.com/waku-org/go-waku v0.5.3-0.20230327132601-b540953f74e9/go.mod h1:13mD936XMysk4WMkyYBmMRcS8HkLQzTY+WSeOuTo3cw=
|
||||
github.com/waku-org/go-zerokit-rln v0.1.7-wakuorg h1:2vVIBCtBih2w1K9ll8YnToTDZvbxcgbsClsPlJS/kkg=
|
||||
github.com/waku-org/go-zerokit-rln v0.1.7-wakuorg/go.mod h1:GlyaVeEWNEBxVJrWC6jFTvb4LNb9d9qnjdS6EiWVUvk=
|
||||
github.com/wealdtech/go-ens/v3 v3.5.0 h1:Huc9GxBgiGweCOGTYomvsg07K2QggAqZpZ5SuiZdC8o=
|
||||
|
|
|
@ -0,0 +1,2 @@
|
|||
*.pb.go linguist-generated merge=ours -diff
|
||||
go.sum linguist-generated text
|
|
@ -0,0 +1,14 @@
|
|||
# Binaries for programs and plugins
|
||||
*.exe
|
||||
*.dll
|
||||
*.so
|
||||
*.dylib
|
||||
|
||||
# Test binary, build with `go test -c`
|
||||
*.test
|
||||
|
||||
# Output of the go coverage tool, specifically when used with LiteIDE
|
||||
*.out
|
||||
|
||||
# Project-local glide cache, RE: https://github.com/Masterminds/glide/issues/736
|
||||
.glide/
|
|
@ -0,0 +1,32 @@
|
|||
run:
|
||||
deadline: 1m
|
||||
tests: false
|
||||
skip-files:
|
||||
- "test/.*"
|
||||
- "test/.*/.*"
|
||||
|
||||
linters-settings:
|
||||
golint:
|
||||
min-confidence: 0
|
||||
maligned:
|
||||
suggest-new: true
|
||||
goconst:
|
||||
min-len: 5
|
||||
min-occurrences: 4
|
||||
misspell:
|
||||
locale: US
|
||||
|
||||
linters:
|
||||
disable-all: true
|
||||
enable:
|
||||
- goconst
|
||||
- misspell
|
||||
- unused
|
||||
- staticcheck
|
||||
- unconvert
|
||||
- gofmt
|
||||
- goimports
|
||||
# @TODO(gfanton): disable revive for now has it generate to many errors,
|
||||
# it should be enable in a dedicated PR
|
||||
# - revive
|
||||
- ineffassign
|
|
@ -0,0 +1,10 @@
|
|||
{
|
||||
"release": {
|
||||
"branches": ["master"]
|
||||
},
|
||||
"plugins": [
|
||||
"@semantic-release/commit-analyzer",
|
||||
"@semantic-release/release-notes-generator",
|
||||
"@semantic-release/github"
|
||||
]
|
||||
}
|
|
@ -0,0 +1,21 @@
|
|||
MIT License
|
||||
|
||||
Copyright (c) 2018 libp2p
|
||||
|
||||
Permission is hereby granted, free of charge, to any person obtaining a copy
|
||||
of this software and associated documentation files (the "Software"), to deal
|
||||
in the Software without restriction, including without limitation the rights
|
||||
to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
|
||||
copies of the Software, and to permit persons to whom the Software is
|
||||
furnished to do so, subject to the following conditions:
|
||||
|
||||
The above copyright notice and this permission notice shall be included in all
|
||||
copies or substantial portions of the Software.
|
||||
|
||||
THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
|
||||
IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
|
||||
FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
|
||||
AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
|
||||
LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
|
||||
OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
|
||||
SOFTWARE.
|
|
@ -0,0 +1,418 @@
|
|||
package rendezvous
|
||||
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
"math/rand"
|
||||
"time"
|
||||
|
||||
ggio "github.com/gogo/protobuf/io"
|
||||
"github.com/libp2p/go-libp2p/core/host"
|
||||
inet "github.com/libp2p/go-libp2p/core/network"
|
||||
"github.com/libp2p/go-libp2p/core/peer"
|
||||
|
||||
pb "github.com/berty/go-libp2p-rendezvous/pb"
|
||||
)
|
||||
|
||||
var (
|
||||
DiscoverAsyncInterval = 2 * time.Minute
|
||||
)
|
||||
|
||||
type RendezvousPoint interface {
|
||||
Register(ctx context.Context, ns string, ttl int) (time.Duration, error)
|
||||
Unregister(ctx context.Context, ns string) error
|
||||
Discover(ctx context.Context, ns string, limit int, cookie []byte) ([]Registration, []byte, error)
|
||||
DiscoverAsync(ctx context.Context, ns string) (<-chan Registration, error)
|
||||
DiscoverSubscribe(ctx context.Context, ns string, serviceTypes []RendezvousSyncClient) (<-chan peer.AddrInfo, error)
|
||||
}
|
||||
|
||||
type Registration struct {
|
||||
Peer peer.AddrInfo
|
||||
Ns string
|
||||
Ttl int
|
||||
}
|
||||
|
||||
type RendezvousClient interface {
|
||||
Register(ctx context.Context, ns string, ttl int) (time.Duration, error)
|
||||
Unregister(ctx context.Context, ns string) error
|
||||
Discover(ctx context.Context, ns string, limit int, cookie []byte) ([]peer.AddrInfo, []byte, error)
|
||||
DiscoverAsync(ctx context.Context, ns string) (<-chan peer.AddrInfo, error)
|
||||
DiscoverSubscribe(ctx context.Context, ns string) (<-chan peer.AddrInfo, error)
|
||||
}
|
||||
|
||||
func NewRendezvousPoint(host host.Host, p peer.ID, opts ...RendezvousPointOption) RendezvousPoint {
|
||||
cfg := defaultRendezvousPointConfig
|
||||
cfg.apply(opts...)
|
||||
return &rendezvousPoint{
|
||||
addrFactory: cfg.AddrsFactory,
|
||||
host: host,
|
||||
p: p,
|
||||
}
|
||||
}
|
||||
|
||||
type rendezvousPoint struct {
|
||||
addrFactory AddrsFactory
|
||||
host host.Host
|
||||
p peer.ID
|
||||
}
|
||||
|
||||
func NewRendezvousClient(host host.Host, rp peer.ID, sync ...RendezvousSyncClient) RendezvousClient {
|
||||
return NewRendezvousClientWithPoint(NewRendezvousPoint(host, rp), sync...)
|
||||
}
|
||||
|
||||
func NewRendezvousClientWithPoint(rp RendezvousPoint, syncClientList ...RendezvousSyncClient) RendezvousClient {
|
||||
return &rendezvousClient{rp: rp, syncClients: syncClientList}
|
||||
}
|
||||
|
||||
type rendezvousClient struct {
|
||||
rp RendezvousPoint
|
||||
syncClients []RendezvousSyncClient
|
||||
}
|
||||
|
||||
func (rp *rendezvousPoint) Register(ctx context.Context, ns string, ttl int) (time.Duration, error) {
|
||||
s, err := rp.host.NewStream(ctx, rp.p, RendezvousProto)
|
||||
if err != nil {
|
||||
return 0, err
|
||||
}
|
||||
defer s.Reset()
|
||||
|
||||
r := ggio.NewDelimitedReader(s, inet.MessageSizeMax)
|
||||
w := ggio.NewDelimitedWriter(s)
|
||||
|
||||
addrs := rp.addrFactory(rp.host.Addrs())
|
||||
if len(addrs) == 0 {
|
||||
return 0, fmt.Errorf("no addrs available to advertise: %s", ns)
|
||||
}
|
||||
|
||||
log.Debugf("advertising on `%s` with: %v", ns, addrs)
|
||||
req := newRegisterMessage(ns, peer.AddrInfo{ID: rp.host.ID(), Addrs: addrs}, ttl)
|
||||
err = w.WriteMsg(req)
|
||||
if err != nil {
|
||||
return 0, err
|
||||
}
|
||||
|
||||
var res pb.Message
|
||||
err = r.ReadMsg(&res)
|
||||
if err != nil {
|
||||
return 0, err
|
||||
}
|
||||
|
||||
if res.GetType() != pb.Message_REGISTER_RESPONSE {
|
||||
return 0, fmt.Errorf("unexpected response: %s", res.GetType().String())
|
||||
}
|
||||
|
||||
response := res.GetRegisterResponse()
|
||||
status := response.GetStatus()
|
||||
if status != pb.Message_OK {
|
||||
return 0, RendezvousError{Status: status, Text: res.GetRegisterResponse().GetStatusText()}
|
||||
}
|
||||
|
||||
return time.Duration(response.Ttl) * time.Second, nil
|
||||
}
|
||||
|
||||
func (rc *rendezvousClient) Register(ctx context.Context, ns string, ttl int) (time.Duration, error) {
|
||||
if ttl < 120 {
|
||||
return 0, fmt.Errorf("registration TTL is too short")
|
||||
}
|
||||
|
||||
returnedTTL, err := rc.rp.Register(ctx, ns, ttl)
|
||||
if err != nil {
|
||||
return 0, err
|
||||
}
|
||||
|
||||
go registerRefresh(ctx, rc.rp, ns, ttl)
|
||||
return returnedTTL, nil
|
||||
}
|
||||
|
||||
func registerRefresh(ctx context.Context, rz RendezvousPoint, ns string, ttl int) {
|
||||
var refresh time.Duration
|
||||
errcount := 0
|
||||
|
||||
for {
|
||||
if errcount > 0 {
|
||||
// do randomized exponential backoff, up to ~4 hours
|
||||
if errcount > 7 {
|
||||
errcount = 7
|
||||
}
|
||||
backoff := 2 << uint(errcount)
|
||||
refresh = 5*time.Minute + time.Duration(rand.Intn(backoff*60000))*time.Millisecond
|
||||
} else {
|
||||
refresh = time.Duration(ttl-30) * time.Second
|
||||
}
|
||||
|
||||
select {
|
||||
case <-time.After(refresh):
|
||||
case <-ctx.Done():
|
||||
return
|
||||
}
|
||||
|
||||
_, err := rz.Register(ctx, ns, ttl)
|
||||
if err != nil {
|
||||
log.Errorf("Error registering [%s]: %s", ns, err.Error())
|
||||
errcount++
|
||||
} else {
|
||||
errcount = 0
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func (rp *rendezvousPoint) Unregister(ctx context.Context, ns string) error {
|
||||
s, err := rp.host.NewStream(ctx, rp.p, RendezvousProto)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
defer s.Close()
|
||||
|
||||
w := ggio.NewDelimitedWriter(s)
|
||||
req := newUnregisterMessage(ns, rp.host.ID())
|
||||
return w.WriteMsg(req)
|
||||
}
|
||||
|
||||
func (rc *rendezvousClient) Unregister(ctx context.Context, ns string) error {
|
||||
return rc.rp.Unregister(ctx, ns)
|
||||
}
|
||||
|
||||
func (rp *rendezvousPoint) Discover(ctx context.Context, ns string, limit int, cookie []byte) ([]Registration, []byte, error) {
|
||||
s, err := rp.host.NewStream(ctx, rp.p, RendezvousProto)
|
||||
if err != nil {
|
||||
return nil, nil, err
|
||||
}
|
||||
defer s.Reset()
|
||||
|
||||
r := ggio.NewDelimitedReader(s, inet.MessageSizeMax)
|
||||
w := ggio.NewDelimitedWriter(s)
|
||||
|
||||
return discoverQuery(ns, limit, cookie, r, w)
|
||||
}
|
||||
|
||||
func discoverQuery(ns string, limit int, cookie []byte, r ggio.Reader, w ggio.Writer) ([]Registration, []byte, error) {
|
||||
req := newDiscoverMessage(ns, limit, cookie)
|
||||
err := w.WriteMsg(req)
|
||||
if err != nil {
|
||||
return nil, nil, err
|
||||
}
|
||||
|
||||
var res pb.Message
|
||||
err = r.ReadMsg(&res)
|
||||
if err != nil {
|
||||
return nil, nil, err
|
||||
}
|
||||
|
||||
if res.GetType() != pb.Message_DISCOVER_RESPONSE {
|
||||
return nil, nil, fmt.Errorf("Unexpected response: %s", res.GetType().String())
|
||||
}
|
||||
|
||||
status := res.GetDiscoverResponse().GetStatus()
|
||||
if status != pb.Message_OK {
|
||||
return nil, nil, RendezvousError{Status: status, Text: res.GetDiscoverResponse().GetStatusText()}
|
||||
}
|
||||
|
||||
regs := res.GetDiscoverResponse().GetRegistrations()
|
||||
result := make([]Registration, 0, len(regs))
|
||||
for _, reg := range regs {
|
||||
pi, err := pbToPeerInfo(reg.GetPeer())
|
||||
if err != nil {
|
||||
log.Errorf("Invalid peer info: %s", err.Error())
|
||||
continue
|
||||
}
|
||||
result = append(result, Registration{Peer: pi, Ns: reg.GetNs(), Ttl: int(reg.GetTtl())})
|
||||
}
|
||||
|
||||
return result, res.GetDiscoverResponse().GetCookie(), nil
|
||||
}
|
||||
|
||||
func (rp *rendezvousPoint) DiscoverAsync(ctx context.Context, ns string) (<-chan Registration, error) {
|
||||
s, err := rp.host.NewStream(ctx, rp.p, RendezvousProto)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
ch := make(chan Registration)
|
||||
go discoverAsync(ctx, ns, s, ch)
|
||||
return ch, nil
|
||||
}
|
||||
|
||||
func discoverAsync(ctx context.Context, ns string, s inet.Stream, ch chan Registration) {
|
||||
defer s.Reset()
|
||||
defer close(ch)
|
||||
|
||||
r := ggio.NewDelimitedReader(s, inet.MessageSizeMax)
|
||||
w := ggio.NewDelimitedWriter(s)
|
||||
|
||||
const batch = 200
|
||||
|
||||
var (
|
||||
cookie []byte
|
||||
regs []Registration
|
||||
err error
|
||||
)
|
||||
|
||||
for {
|
||||
regs, cookie, err = discoverQuery(ns, batch, cookie, r, w)
|
||||
if err != nil {
|
||||
// TODO robust error recovery
|
||||
// - handle closed streams with backoff + new stream, preserving the cookie
|
||||
// - handle E_INVALID_COOKIE errors in that case to restart the discovery
|
||||
log.Errorf("Error in discovery [%s]: %s", ns, err.Error())
|
||||
return
|
||||
}
|
||||
|
||||
for _, reg := range regs {
|
||||
select {
|
||||
case ch <- reg:
|
||||
case <-ctx.Done():
|
||||
return
|
||||
}
|
||||
}
|
||||
|
||||
if len(regs) < batch {
|
||||
// TODO adaptive backoff for heavily loaded rendezvous points
|
||||
select {
|
||||
case <-time.After(DiscoverAsyncInterval):
|
||||
case <-ctx.Done():
|
||||
return
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func (rc *rendezvousClient) Discover(ctx context.Context, ns string, limit int, cookie []byte) ([]peer.AddrInfo, []byte, error) {
|
||||
regs, cookie, err := rc.rp.Discover(ctx, ns, limit, cookie)
|
||||
if err != nil {
|
||||
return nil, nil, err
|
||||
}
|
||||
|
||||
pinfos := make([]peer.AddrInfo, len(regs))
|
||||
for i, reg := range regs {
|
||||
pinfos[i] = reg.Peer
|
||||
}
|
||||
|
||||
return pinfos, cookie, nil
|
||||
}
|
||||
|
||||
func (rc *rendezvousClient) DiscoverAsync(ctx context.Context, ns string) (<-chan peer.AddrInfo, error) {
|
||||
rch, err := rc.rp.DiscoverAsync(ctx, ns)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
ch := make(chan peer.AddrInfo)
|
||||
go discoverPeersAsync(ctx, rch, ch)
|
||||
return ch, nil
|
||||
}
|
||||
|
||||
func discoverPeersAsync(ctx context.Context, rch <-chan Registration, ch chan peer.AddrInfo) {
|
||||
defer close(ch)
|
||||
for {
|
||||
select {
|
||||
case reg, ok := <-rch:
|
||||
if !ok {
|
||||
return
|
||||
}
|
||||
|
||||
select {
|
||||
case ch <- reg.Peer:
|
||||
case <-ctx.Done():
|
||||
return
|
||||
}
|
||||
case <-ctx.Done():
|
||||
return
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func (rc *rendezvousClient) DiscoverSubscribe(ctx context.Context, ns string) (<-chan peer.AddrInfo, error) {
|
||||
return rc.rp.DiscoverSubscribe(ctx, ns, rc.syncClients)
|
||||
}
|
||||
|
||||
func subscribeServiceTypes(serviceTypeClients []RendezvousSyncClient) []string {
|
||||
serviceTypes := []string(nil)
|
||||
for _, serviceType := range serviceTypeClients {
|
||||
serviceTypes = append(serviceTypes, serviceType.GetServiceType())
|
||||
}
|
||||
|
||||
return serviceTypes
|
||||
}
|
||||
|
||||
func (rp *rendezvousPoint) DiscoverSubscribe(ctx context.Context, ns string, serviceTypeClients []RendezvousSyncClient) (<-chan peer.AddrInfo, error) {
|
||||
serviceTypes := subscribeServiceTypes(serviceTypeClients)
|
||||
|
||||
s, err := rp.host.NewStream(ctx, rp.p, RendezvousProto)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
defer s.Close()
|
||||
|
||||
r := ggio.NewDelimitedReader(s, inet.MessageSizeMax)
|
||||
w := ggio.NewDelimitedWriter(s)
|
||||
|
||||
subType, subDetails, err := discoverSubscribeQuery(ns, serviceTypes, r, w)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("discover subscribe error: %w", err)
|
||||
}
|
||||
|
||||
subClient := RendezvousSyncClient(nil)
|
||||
for _, subClient = range serviceTypeClients {
|
||||
if subClient.GetServiceType() == subType {
|
||||
break
|
||||
}
|
||||
}
|
||||
if subClient == nil {
|
||||
return nil, fmt.Errorf("unrecognized client type")
|
||||
}
|
||||
|
||||
regCh, err := subClient.Subscribe(ctx, subDetails)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("unable to subscribe to updates: %w", err)
|
||||
}
|
||||
|
||||
ch := make(chan peer.AddrInfo)
|
||||
go func() {
|
||||
defer close(ch)
|
||||
|
||||
for {
|
||||
select {
|
||||
case result, ok := <-regCh:
|
||||
if !ok {
|
||||
return
|
||||
}
|
||||
ch <- result.Peer
|
||||
case <-ctx.Done():
|
||||
return
|
||||
}
|
||||
}
|
||||
}()
|
||||
|
||||
return ch, nil
|
||||
}
|
||||
|
||||
func discoverSubscribeQuery(ns string, serviceTypes []string, r ggio.Reader, w ggio.Writer) (subType string, subDetails string, err error) {
|
||||
req := &pb.Message{
|
||||
Type: pb.Message_DISCOVER_SUBSCRIBE,
|
||||
DiscoverSubscribe: newDiscoverSubscribeMessage(ns, serviceTypes),
|
||||
}
|
||||
err = w.WriteMsg(req)
|
||||
if err != nil {
|
||||
return "", "", fmt.Errorf("write err: %w", err)
|
||||
}
|
||||
|
||||
var res pb.Message
|
||||
err = r.ReadMsg(&res)
|
||||
if err != nil {
|
||||
return "", "", fmt.Errorf("read err: %w", err)
|
||||
}
|
||||
|
||||
if res.GetType() != pb.Message_DISCOVER_SUBSCRIBE_RESPONSE {
|
||||
return "", "", fmt.Errorf("unexpected response: %s", res.GetType().String())
|
||||
}
|
||||
|
||||
status := res.GetDiscoverSubscribeResponse().GetStatus()
|
||||
if status != pb.Message_OK {
|
||||
return "", "", RendezvousError{Status: status, Text: res.GetDiscoverSubscribeResponse().GetStatusText()}
|
||||
}
|
||||
|
||||
subType = res.GetDiscoverSubscribeResponse().GetSubscriptionType()
|
||||
subDetails = res.GetDiscoverSubscribeResponse().GetSubscriptionDetails()
|
||||
|
||||
return subType, subDetails, nil
|
||||
}
|
|
@ -0,0 +1,21 @@
|
|||
package dbi
|
||||
|
||||
import (
|
||||
"github.com/libp2p/go-libp2p/core/peer"
|
||||
)
|
||||
|
||||
type RegistrationRecord struct {
|
||||
Id peer.ID
|
||||
Addrs [][]byte
|
||||
Ns string
|
||||
Ttl int
|
||||
}
|
||||
|
||||
type DB interface {
|
||||
Close() error
|
||||
Register(p peer.ID, ns string, addrs [][]byte, ttl int) (uint64, error)
|
||||
Unregister(p peer.ID, ns string) error
|
||||
CountRegistrations(p peer.ID) (int, error)
|
||||
Discover(ns string, cookie []byte, limit int) ([]RegistrationRecord, []byte, error)
|
||||
ValidCookie(ns string, cookie []byte) bool
|
||||
}
|
|
@ -0,0 +1,156 @@
|
|||
package rendezvous
|
||||
|
||||
import (
|
||||
"context"
|
||||
"math"
|
||||
"math/rand"
|
||||
"sync"
|
||||
"time"
|
||||
|
||||
"github.com/libp2p/go-libp2p/core/discovery"
|
||||
"github.com/libp2p/go-libp2p/core/host"
|
||||
"github.com/libp2p/go-libp2p/core/peer"
|
||||
)
|
||||
|
||||
type rendezvousDiscovery struct {
|
||||
rp RendezvousPoint
|
||||
peerCache map[string]*discoveryCache
|
||||
peerCacheMux sync.RWMutex
|
||||
rng *rand.Rand
|
||||
rngMux sync.Mutex
|
||||
}
|
||||
|
||||
type discoveryCache struct {
|
||||
recs map[peer.ID]*record
|
||||
cookie []byte
|
||||
mux sync.Mutex
|
||||
}
|
||||
|
||||
type record struct {
|
||||
peer peer.AddrInfo
|
||||
expire int64
|
||||
}
|
||||
|
||||
func NewRendezvousDiscovery(host host.Host, rendezvousPeer peer.ID) discovery.Discovery {
|
||||
rp := NewRendezvousPoint(host, rendezvousPeer)
|
||||
return &rendezvousDiscovery{rp: rp, peerCache: make(map[string]*discoveryCache), rng: rand.New(rand.NewSource(rand.Int63()))}
|
||||
}
|
||||
|
||||
func (c *rendezvousDiscovery) Advertise(ctx context.Context, ns string, opts ...discovery.Option) (time.Duration, error) {
|
||||
// Get options
|
||||
var options discovery.Options
|
||||
err := options.Apply(opts...)
|
||||
if err != nil {
|
||||
return 0, err
|
||||
}
|
||||
|
||||
ttl := options.Ttl
|
||||
var ttlSeconds int
|
||||
|
||||
if ttl == 0 {
|
||||
ttlSeconds = 7200
|
||||
} else {
|
||||
ttlSeconds = int(math.Round(ttl.Seconds()))
|
||||
}
|
||||
|
||||
if rttl, err := c.rp.Register(ctx, ns, ttlSeconds); err != nil {
|
||||
return 0, err
|
||||
} else {
|
||||
return rttl, nil
|
||||
}
|
||||
}
|
||||
|
||||
func (c *rendezvousDiscovery) FindPeers(ctx context.Context, ns string, opts ...discovery.Option) (<-chan peer.AddrInfo, error) {
|
||||
// Get options
|
||||
var options discovery.Options
|
||||
err := options.Apply(opts...)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
const maxLimit = 1000
|
||||
limit := options.Limit
|
||||
if limit == 0 || limit > maxLimit {
|
||||
limit = maxLimit
|
||||
}
|
||||
|
||||
// Get cached peers
|
||||
var cache *discoveryCache
|
||||
|
||||
c.peerCacheMux.RLock()
|
||||
cache, ok := c.peerCache[ns]
|
||||
c.peerCacheMux.RUnlock()
|
||||
if !ok {
|
||||
c.peerCacheMux.Lock()
|
||||
cache, ok = c.peerCache[ns]
|
||||
if !ok {
|
||||
cache = &discoveryCache{recs: make(map[peer.ID]*record)}
|
||||
c.peerCache[ns] = cache
|
||||
}
|
||||
c.peerCacheMux.Unlock()
|
||||
}
|
||||
|
||||
cache.mux.Lock()
|
||||
defer cache.mux.Unlock()
|
||||
|
||||
// Remove all expired entries from cache
|
||||
currentTime := time.Now().Unix()
|
||||
newCacheSize := len(cache.recs)
|
||||
|
||||
for p := range cache.recs {
|
||||
rec := cache.recs[p]
|
||||
if rec.expire < currentTime {
|
||||
newCacheSize--
|
||||
delete(cache.recs, p)
|
||||
}
|
||||
}
|
||||
|
||||
cookie := cache.cookie
|
||||
|
||||
// Discover new records if we don't have enough
|
||||
if newCacheSize < limit {
|
||||
// TODO: Should we return error even if we have valid cached results?
|
||||
var regs []Registration
|
||||
var newCookie []byte
|
||||
if regs, newCookie, err = c.rp.Discover(ctx, ns, limit, cookie); err == nil {
|
||||
for _, reg := range regs {
|
||||
rec := &record{peer: reg.Peer, expire: int64(reg.Ttl) + currentTime}
|
||||
cache.recs[rec.peer.ID] = rec
|
||||
}
|
||||
cache.cookie = newCookie
|
||||
}
|
||||
}
|
||||
|
||||
// Randomize and fill channel with available records
|
||||
count := len(cache.recs)
|
||||
if limit < count {
|
||||
count = limit
|
||||
}
|
||||
|
||||
chPeer := make(chan peer.AddrInfo, count)
|
||||
|
||||
c.rngMux.Lock()
|
||||
perm := c.rng.Perm(len(cache.recs))[0:count]
|
||||
c.rngMux.Unlock()
|
||||
|
||||
permSet := make(map[int]int)
|
||||
for i, v := range perm {
|
||||
permSet[v] = i
|
||||
}
|
||||
|
||||
sendLst := make([]*peer.AddrInfo, count)
|
||||
iter := 0
|
||||
for k := range cache.recs {
|
||||
if sendIndex, ok := permSet[iter]; ok {
|
||||
sendLst[sendIndex] = &cache.recs[k].peer
|
||||
}
|
||||
iter++
|
||||
}
|
||||
|
||||
for _, send := range sendLst {
|
||||
chPeer <- *send
|
||||
}
|
||||
|
||||
close(chPeer)
|
||||
return chPeer, err
|
||||
}
|
|
@ -0,0 +1,2 @@
|
|||
//go:generate protoc --proto_path=pb/ --gofast_opt="Mrendezvous.proto=.;rendezvous_pb" --gofast_out=./pb ./pb/rendezvous.proto
|
||||
package rendezvous
|
|
@ -0,0 +1,32 @@
|
|||
package rendezvous
|
||||
|
||||
import (
|
||||
ma "github.com/multiformats/go-multiaddr"
|
||||
)
|
||||
|
||||
type RendezvousPointOption func(cfg *rendezvousPointConfig)
|
||||
|
||||
type AddrsFactory func(addrs []ma.Multiaddr) []ma.Multiaddr
|
||||
|
||||
var DefaultAddrFactory = func(addrs []ma.Multiaddr) []ma.Multiaddr { return addrs }
|
||||
|
||||
var defaultRendezvousPointConfig = rendezvousPointConfig{
|
||||
AddrsFactory: DefaultAddrFactory,
|
||||
}
|
||||
|
||||
type rendezvousPointConfig struct {
|
||||
AddrsFactory AddrsFactory
|
||||
}
|
||||
|
||||
func (cfg *rendezvousPointConfig) apply(opts ...RendezvousPointOption) {
|
||||
for _, opt := range opts {
|
||||
opt(cfg)
|
||||
}
|
||||
}
|
||||
|
||||
// AddrsFactory configures libp2p to use the given address factory.
|
||||
func ClientWithAddrsFactory(factory AddrsFactory) RendezvousPointOption {
|
||||
return func(cfg *rendezvousPointConfig) {
|
||||
cfg.AddrsFactory = factory
|
||||
}
|
||||
}
|
3334
vendor/github.com/berty/go-libp2p-rendezvous/pb/rendezvous.pb.go
generated
vendored
Normal file
3334
vendor/github.com/berty/go-libp2p-rendezvous/pb/rendezvous.pb.go
generated
vendored
Normal file
File diff suppressed because it is too large
Load Diff
|
@ -0,0 +1,91 @@
|
|||
syntax = "proto3";
|
||||
|
||||
package rendezvous.pb;
|
||||
|
||||
message Message {
|
||||
enum MessageType {
|
||||
REGISTER = 0;
|
||||
REGISTER_RESPONSE = 1;
|
||||
UNREGISTER = 2;
|
||||
DISCOVER = 3;
|
||||
DISCOVER_RESPONSE = 4;
|
||||
|
||||
DISCOVER_SUBSCRIBE = 100;
|
||||
DISCOVER_SUBSCRIBE_RESPONSE = 101;
|
||||
}
|
||||
|
||||
enum ResponseStatus {
|
||||
OK = 0;
|
||||
E_INVALID_NAMESPACE = 100;
|
||||
E_INVALID_PEER_INFO = 101;
|
||||
E_INVALID_TTL = 102;
|
||||
E_INVALID_COOKIE = 103;
|
||||
E_NOT_AUTHORIZED = 200;
|
||||
E_INTERNAL_ERROR = 300;
|
||||
E_UNAVAILABLE = 400;
|
||||
}
|
||||
|
||||
message PeerInfo {
|
||||
bytes id = 1;
|
||||
repeated bytes addrs = 2;
|
||||
}
|
||||
|
||||
message Register {
|
||||
string ns = 1;
|
||||
PeerInfo peer = 2;
|
||||
int64 ttl = 3; // in seconds
|
||||
}
|
||||
|
||||
message RegisterResponse {
|
||||
ResponseStatus status = 1;
|
||||
string statusText = 2;
|
||||
int64 ttl = 3;
|
||||
}
|
||||
|
||||
message Unregister {
|
||||
string ns = 1;
|
||||
bytes id = 2;
|
||||
}
|
||||
|
||||
message Discover {
|
||||
string ns = 1;
|
||||
int64 limit = 2;
|
||||
bytes cookie = 3;
|
||||
}
|
||||
|
||||
message DiscoverResponse {
|
||||
repeated Register registrations = 1;
|
||||
bytes cookie = 2;
|
||||
ResponseStatus status = 3;
|
||||
string statusText = 4;
|
||||
}
|
||||
|
||||
message DiscoverSubscribe {
|
||||
repeated string supported_subscription_types = 1;
|
||||
string ns = 2;
|
||||
}
|
||||
|
||||
message DiscoverSubscribeResponse {
|
||||
string subscription_type = 1;
|
||||
string subscription_details = 2;
|
||||
ResponseStatus status = 3;
|
||||
string statusText = 4;
|
||||
}
|
||||
|
||||
MessageType type = 1;
|
||||
Register register = 2;
|
||||
RegisterResponse registerResponse = 3;
|
||||
Unregister unregister = 4;
|
||||
Discover discover = 5;
|
||||
DiscoverResponse discoverResponse = 6;
|
||||
|
||||
DiscoverSubscribe discoverSubscribe = 100;
|
||||
DiscoverSubscribeResponse discoverSubscribeResponse = 101;
|
||||
}
|
||||
|
||||
message RegistrationRecord{
|
||||
string id = 1;
|
||||
repeated bytes addrs = 2;
|
||||
string ns = 3;
|
||||
int64 ttl = 4;
|
||||
}
|
|
@ -0,0 +1,178 @@
|
|||
package rendezvous
|
||||
|
||||
import (
|
||||
"errors"
|
||||
"fmt"
|
||||
|
||||
db "github.com/berty/go-libp2p-rendezvous/db"
|
||||
pb "github.com/berty/go-libp2p-rendezvous/pb"
|
||||
|
||||
logging "github.com/ipfs/go-log/v2"
|
||||
"github.com/libp2p/go-libp2p/core/peer"
|
||||
"github.com/libp2p/go-libp2p/core/protocol"
|
||||
ma "github.com/multiformats/go-multiaddr"
|
||||
)
|
||||
|
||||
var log = logging.Logger("rendezvous")
|
||||
|
||||
const (
|
||||
RendezvousProto = protocol.ID("/rendezvous/1.0.0")
|
||||
|
||||
DefaultTTL = 2 * 3600 // 2hr
|
||||
)
|
||||
|
||||
type RendezvousError struct {
|
||||
Status pb.Message_ResponseStatus
|
||||
Text string
|
||||
}
|
||||
|
||||
func (e RendezvousError) Error() string {
|
||||
return fmt.Sprintf("Rendezvous error: %s (%s)", e.Text, e.Status.String())
|
||||
}
|
||||
|
||||
func NewRegisterMessage(ns string, pi peer.AddrInfo, ttl int) *pb.Message {
|
||||
return newRegisterMessage(ns, pi, ttl)
|
||||
}
|
||||
|
||||
func newRegisterMessage(ns string, pi peer.AddrInfo, ttl int) *pb.Message {
|
||||
msg := new(pb.Message)
|
||||
msg.Type = pb.Message_REGISTER
|
||||
msg.Register = new(pb.Message_Register)
|
||||
if ns != "" {
|
||||
msg.Register.Ns = ns
|
||||
}
|
||||
if ttl > 0 {
|
||||
ttl64 := int64(ttl)
|
||||
msg.Register.Ttl = ttl64
|
||||
}
|
||||
msg.Register.Peer = new(pb.Message_PeerInfo)
|
||||
msg.Register.Peer.Id = []byte(pi.ID)
|
||||
msg.Register.Peer.Addrs = make([][]byte, len(pi.Addrs))
|
||||
for i, addr := range pi.Addrs {
|
||||
msg.Register.Peer.Addrs[i] = addr.Bytes()
|
||||
}
|
||||
return msg
|
||||
}
|
||||
|
||||
func newUnregisterMessage(ns string, pid peer.ID) *pb.Message {
|
||||
msg := new(pb.Message)
|
||||
msg.Type = pb.Message_UNREGISTER
|
||||
msg.Unregister = new(pb.Message_Unregister)
|
||||
if ns != "" {
|
||||
msg.Unregister.Ns = ns
|
||||
}
|
||||
msg.Unregister.Id = []byte(pid)
|
||||
return msg
|
||||
}
|
||||
|
||||
func NewDiscoverMessage(ns string, limit int, cookie []byte) *pb.Message {
|
||||
return newDiscoverMessage(ns, limit, cookie)
|
||||
}
|
||||
|
||||
func newDiscoverMessage(ns string, limit int, cookie []byte) *pb.Message {
|
||||
msg := new(pb.Message)
|
||||
msg.Type = pb.Message_DISCOVER
|
||||
msg.Discover = new(pb.Message_Discover)
|
||||
if ns != "" {
|
||||
msg.Discover.Ns = ns
|
||||
}
|
||||
if limit > 0 {
|
||||
limit64 := int64(limit)
|
||||
msg.Discover.Limit = limit64
|
||||
}
|
||||
if cookie != nil {
|
||||
msg.Discover.Cookie = cookie
|
||||
}
|
||||
return msg
|
||||
}
|
||||
|
||||
func pbToPeerInfo(p *pb.Message_PeerInfo) (peer.AddrInfo, error) {
|
||||
if p == nil {
|
||||
return peer.AddrInfo{}, errors.New("missing peer info")
|
||||
}
|
||||
|
||||
id, err := peer.IDFromBytes(p.Id)
|
||||
if err != nil {
|
||||
return peer.AddrInfo{}, err
|
||||
}
|
||||
addrs := make([]ma.Multiaddr, 0, len(p.Addrs))
|
||||
for _, bs := range p.Addrs {
|
||||
addr, err := ma.NewMultiaddrBytes(bs)
|
||||
if err != nil {
|
||||
log.Errorf("Error parsing multiaddr: %s", err.Error())
|
||||
continue
|
||||
}
|
||||
addrs = append(addrs, addr)
|
||||
}
|
||||
|
||||
return peer.AddrInfo{ID: id, Addrs: addrs}, nil
|
||||
}
|
||||
|
||||
func newRegisterResponse(ttl int) *pb.Message_RegisterResponse {
|
||||
ttl64 := int64(ttl)
|
||||
r := new(pb.Message_RegisterResponse)
|
||||
r.Status = pb.Message_OK
|
||||
r.Ttl = ttl64
|
||||
return r
|
||||
}
|
||||
|
||||
func newRegisterResponseError(status pb.Message_ResponseStatus, text string) *pb.Message_RegisterResponse {
|
||||
r := new(pb.Message_RegisterResponse)
|
||||
r.Status = status
|
||||
r.StatusText = text
|
||||
return r
|
||||
}
|
||||
|
||||
func newDiscoverResponse(regs []db.RegistrationRecord, cookie []byte) *pb.Message_DiscoverResponse {
|
||||
r := new(pb.Message_DiscoverResponse)
|
||||
r.Status = pb.Message_OK
|
||||
|
||||
rregs := make([]*pb.Message_Register, len(regs))
|
||||
for i, reg := range regs {
|
||||
rreg := new(pb.Message_Register)
|
||||
rns := reg.Ns
|
||||
rreg.Ns = rns
|
||||
rreg.Peer = new(pb.Message_PeerInfo)
|
||||
rreg.Peer.Id = []byte(reg.Id)
|
||||
rreg.Peer.Addrs = reg.Addrs
|
||||
rttl := int64(reg.Ttl)
|
||||
rreg.Ttl = rttl
|
||||
rregs[i] = rreg
|
||||
}
|
||||
|
||||
r.Registrations = rregs
|
||||
r.Cookie = cookie
|
||||
|
||||
return r
|
||||
}
|
||||
|
||||
func newDiscoverResponseError(status pb.Message_ResponseStatus, text string) *pb.Message_DiscoverResponse {
|
||||
r := new(pb.Message_DiscoverResponse)
|
||||
r.Status = status
|
||||
r.StatusText = text
|
||||
return r
|
||||
}
|
||||
|
||||
func newDiscoverSubscribeResponse(subscriptionType string, subscriptionDetails string) *pb.Message_DiscoverSubscribeResponse {
|
||||
r := new(pb.Message_DiscoverSubscribeResponse)
|
||||
r.Status = pb.Message_OK
|
||||
|
||||
r.SubscriptionDetails = subscriptionDetails
|
||||
r.SubscriptionType = subscriptionType
|
||||
|
||||
return r
|
||||
}
|
||||
|
||||
func newDiscoverSubscribeResponseError(status pb.Message_ResponseStatus, text string) *pb.Message_DiscoverSubscribeResponse {
|
||||
r := new(pb.Message_DiscoverSubscribeResponse)
|
||||
r.Status = status
|
||||
r.StatusText = text
|
||||
return r
|
||||
}
|
||||
|
||||
func newDiscoverSubscribeMessage(ns string, supportedSubscriptionTypes []string) *pb.Message_DiscoverSubscribe {
|
||||
r := new(pb.Message_DiscoverSubscribe)
|
||||
r.Ns = ns
|
||||
r.SupportedSubscriptionTypes = supportedSubscriptionTypes
|
||||
return r
|
||||
}
|
|
@ -0,0 +1,261 @@
|
|||
package rendezvous
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
|
||||
ggio "github.com/gogo/protobuf/io"
|
||||
"github.com/libp2p/go-libp2p/core/host"
|
||||
inet "github.com/libp2p/go-libp2p/core/network"
|
||||
"github.com/libp2p/go-libp2p/core/peer"
|
||||
|
||||
db "github.com/berty/go-libp2p-rendezvous/db"
|
||||
pb "github.com/berty/go-libp2p-rendezvous/pb"
|
||||
)
|
||||
|
||||
const (
|
||||
MaxTTL = 72 * 3600 // 72hr
|
||||
MaxNamespaceLength = 256
|
||||
MaxPeerAddressLength = 2048
|
||||
MaxRegistrations = 1000
|
||||
MaxDiscoverLimit = 1000
|
||||
)
|
||||
|
||||
type RendezvousService struct {
|
||||
DB db.DB
|
||||
rzs []RendezvousSync
|
||||
}
|
||||
|
||||
func NewRendezvousService(host host.Host, db db.DB, rzs ...RendezvousSync) *RendezvousService {
|
||||
rz := &RendezvousService{DB: db, rzs: rzs}
|
||||
host.SetStreamHandler(RendezvousProto, rz.handleStream)
|
||||
return rz
|
||||
}
|
||||
|
||||
func (rz *RendezvousService) handleStream(s inet.Stream) {
|
||||
defer s.Reset()
|
||||
|
||||
pid := s.Conn().RemotePeer()
|
||||
log.Debugf("New stream from %s", pid.Pretty())
|
||||
|
||||
r := ggio.NewDelimitedReader(s, inet.MessageSizeMax)
|
||||
w := ggio.NewDelimitedWriter(s)
|
||||
|
||||
for {
|
||||
var req pb.Message
|
||||
var res pb.Message
|
||||
|
||||
err := r.ReadMsg(&req)
|
||||
if err != nil {
|
||||
return
|
||||
}
|
||||
|
||||
t := req.GetType()
|
||||
switch t {
|
||||
case pb.Message_REGISTER:
|
||||
r := rz.handleRegister(pid, req.GetRegister())
|
||||
res.Type = pb.Message_REGISTER_RESPONSE
|
||||
res.RegisterResponse = r
|
||||
err = w.WriteMsg(&res)
|
||||
if err != nil {
|
||||
log.Debugf("Error writing response: %s", err.Error())
|
||||
return
|
||||
}
|
||||
|
||||
case pb.Message_UNREGISTER:
|
||||
err := rz.handleUnregister(pid, req.GetUnregister())
|
||||
if err != nil {
|
||||
log.Debugf("Error unregistering peer: %s", err.Error())
|
||||
}
|
||||
|
||||
case pb.Message_DISCOVER:
|
||||
r := rz.handleDiscover(pid, req.GetDiscover())
|
||||
res.Type = pb.Message_DISCOVER_RESPONSE
|
||||
res.DiscoverResponse = r
|
||||
err = w.WriteMsg(&res)
|
||||
if err != nil {
|
||||
log.Debugf("Error writing response: %s", err.Error())
|
||||
return
|
||||
}
|
||||
|
||||
case pb.Message_DISCOVER_SUBSCRIBE:
|
||||
r := rz.handleDiscoverSubscribe(pid, req.GetDiscoverSubscribe())
|
||||
res.Type = pb.Message_DISCOVER_SUBSCRIBE_RESPONSE
|
||||
res.DiscoverSubscribeResponse = r
|
||||
err = w.WriteMsg(&res)
|
||||
if err != nil {
|
||||
log.Debugf("Error writing response: %s", err.Error())
|
||||
return
|
||||
}
|
||||
|
||||
default:
|
||||
log.Debugf("Unexpected message: %s", t.String())
|
||||
return
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func (rz *RendezvousService) handleRegister(p peer.ID, m *pb.Message_Register) *pb.Message_RegisterResponse {
|
||||
ns := m.GetNs()
|
||||
if ns == "" {
|
||||
return newRegisterResponseError(pb.Message_E_INVALID_NAMESPACE, "unspecified namespace")
|
||||
}
|
||||
|
||||
if len(ns) > MaxNamespaceLength {
|
||||
return newRegisterResponseError(pb.Message_E_INVALID_NAMESPACE, "namespace too long")
|
||||
}
|
||||
|
||||
mpi := m.GetPeer()
|
||||
if mpi == nil {
|
||||
return newRegisterResponseError(pb.Message_E_INVALID_PEER_INFO, "missing peer info")
|
||||
}
|
||||
|
||||
mpid := mpi.GetId()
|
||||
if mpid != nil {
|
||||
mp, err := peer.IDFromBytes(mpid)
|
||||
if err != nil {
|
||||
return newRegisterResponseError(pb.Message_E_INVALID_PEER_INFO, "bad peer id")
|
||||
}
|
||||
|
||||
if mp != p {
|
||||
return newRegisterResponseError(pb.Message_E_INVALID_PEER_INFO, "peer id mismatch")
|
||||
}
|
||||
}
|
||||
|
||||
maddrs := mpi.GetAddrs()
|
||||
if len(maddrs) == 0 {
|
||||
return newRegisterResponseError(pb.Message_E_INVALID_PEER_INFO, "missing peer addresses")
|
||||
}
|
||||
|
||||
mlen := 0
|
||||
for _, maddr := range maddrs {
|
||||
mlen += len(maddr)
|
||||
}
|
||||
if mlen > MaxPeerAddressLength {
|
||||
return newRegisterResponseError(pb.Message_E_INVALID_PEER_INFO, "peer info too long")
|
||||
}
|
||||
|
||||
// Note:
|
||||
// We don't validate the addresses, because they could include protocols we don't understand
|
||||
// Perhaps we should though.
|
||||
|
||||
mttl := m.GetTtl()
|
||||
if mttl < 0 || mttl > MaxTTL {
|
||||
return newRegisterResponseError(pb.Message_E_INVALID_TTL, "bad ttl")
|
||||
}
|
||||
|
||||
ttl := DefaultTTL
|
||||
if mttl > 0 {
|
||||
ttl = int(mttl)
|
||||
}
|
||||
|
||||
// now check how many registrations we have for this peer -- simple limit to defend
|
||||
// against trivial DoS attacks (eg a peer connects and keeps registering until it
|
||||
// fills our db)
|
||||
rcount, err := rz.DB.CountRegistrations(p)
|
||||
if err != nil {
|
||||
log.Errorf("Error counting registrations: %s", err.Error())
|
||||
return newRegisterResponseError(pb.Message_E_INTERNAL_ERROR, "database error")
|
||||
}
|
||||
|
||||
if rcount > MaxRegistrations {
|
||||
log.Warningf("Too many registrations for %s", p)
|
||||
return newRegisterResponseError(pb.Message_E_NOT_AUTHORIZED, "too many registrations")
|
||||
}
|
||||
|
||||
// ok, seems like we can register
|
||||
counter, err := rz.DB.Register(p, ns, maddrs, ttl)
|
||||
if err != nil {
|
||||
log.Errorf("Error registering: %s", err.Error())
|
||||
return newRegisterResponseError(pb.Message_E_INTERNAL_ERROR, "database error")
|
||||
}
|
||||
|
||||
log.Infof("registered peer %s %s (%d)", p, ns, ttl)
|
||||
|
||||
for _, rzs := range rz.rzs {
|
||||
rzs.Register(p, ns, maddrs, ttl, counter)
|
||||
}
|
||||
|
||||
return newRegisterResponse(ttl)
|
||||
}
|
||||
|
||||
func (rz *RendezvousService) handleUnregister(p peer.ID, m *pb.Message_Unregister) error {
|
||||
ns := m.GetNs()
|
||||
|
||||
mpid := m.GetId()
|
||||
if mpid != nil {
|
||||
mp, err := peer.IDFromBytes(mpid)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
if mp != p {
|
||||
return fmt.Errorf("peer id mismatch: %s asked to unregister %s", p.Pretty(), mp.Pretty())
|
||||
}
|
||||
}
|
||||
|
||||
err := rz.DB.Unregister(p, ns)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
log.Infof("unregistered peer %s %s", p, ns)
|
||||
|
||||
for _, rzs := range rz.rzs {
|
||||
rzs.Unregister(p, ns)
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func (rz *RendezvousService) handleDiscover(p peer.ID, m *pb.Message_Discover) *pb.Message_DiscoverResponse {
|
||||
ns := m.GetNs()
|
||||
|
||||
if len(ns) > MaxNamespaceLength {
|
||||
return newDiscoverResponseError(pb.Message_E_INVALID_NAMESPACE, "namespace too long")
|
||||
}
|
||||
|
||||
limit := MaxDiscoverLimit
|
||||
mlimit := m.GetLimit()
|
||||
if mlimit > 0 && mlimit < int64(limit) {
|
||||
limit = int(mlimit)
|
||||
}
|
||||
|
||||
cookie := m.GetCookie()
|
||||
if cookie != nil && !rz.DB.ValidCookie(ns, cookie) {
|
||||
return newDiscoverResponseError(pb.Message_E_INVALID_COOKIE, "bad cookie")
|
||||
}
|
||||
|
||||
regs, cookie, err := rz.DB.Discover(ns, cookie, limit)
|
||||
if err != nil {
|
||||
log.Errorf("Error in query: %s", err.Error())
|
||||
return newDiscoverResponseError(pb.Message_E_INTERNAL_ERROR, "database error")
|
||||
}
|
||||
|
||||
log.Infof("discover query: %s %s -> %d", p, ns, len(regs))
|
||||
|
||||
return newDiscoverResponse(regs, cookie)
|
||||
}
|
||||
|
||||
func (rz *RendezvousService) handleDiscoverSubscribe(_ peer.ID, m *pb.Message_DiscoverSubscribe) *pb.Message_DiscoverSubscribeResponse {
|
||||
ns := m.GetNs()
|
||||
|
||||
for _, s := range rz.rzs {
|
||||
rzSub, ok := s.(RendezvousSyncSubscribable)
|
||||
if !ok {
|
||||
continue
|
||||
}
|
||||
|
||||
for _, supportedSubType := range m.GetSupportedSubscriptionTypes() {
|
||||
if rzSub.GetServiceType() == supportedSubType {
|
||||
sub, err := rzSub.Subscribe(ns)
|
||||
if err != nil {
|
||||
return newDiscoverSubscribeResponseError(pb.Message_E_INTERNAL_ERROR, "error while subscribing")
|
||||
}
|
||||
|
||||
return newDiscoverSubscribeResponse(supportedSubType, sub)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
return newDiscoverSubscribeResponseError(pb.Message_E_INTERNAL_ERROR, "subscription type not found")
|
||||
}
|
|
@ -0,0 +1,22 @@
|
|||
package rendezvous
|
||||
|
||||
import (
|
||||
"context"
|
||||
|
||||
"github.com/libp2p/go-libp2p/core/peer"
|
||||
)
|
||||
|
||||
type RendezvousSync interface {
|
||||
Register(p peer.ID, ns string, addrs [][]byte, ttl int, counter uint64)
|
||||
Unregister(p peer.ID, ns string)
|
||||
}
|
||||
|
||||
type RendezvousSyncSubscribable interface {
|
||||
Subscribe(ns string) (syncDetails string, err error)
|
||||
GetServiceType() string
|
||||
}
|
||||
|
||||
type RendezvousSyncClient interface {
|
||||
Subscribe(ctx context.Context, syncDetails string) (<-chan *Registration, error)
|
||||
GetServiceType() string
|
||||
}
|
156
vendor/github.com/berty/go-libp2p-rendezvous/sync_inmem_client.go
generated
vendored
Normal file
156
vendor/github.com/berty/go-libp2p-rendezvous/sync_inmem_client.go
generated
vendored
Normal file
|
@ -0,0 +1,156 @@
|
|||
package rendezvous
|
||||
|
||||
import (
|
||||
"context"
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"sync"
|
||||
|
||||
ggio "github.com/gogo/protobuf/io"
|
||||
"github.com/google/uuid"
|
||||
"github.com/libp2p/go-libp2p/core/host"
|
||||
inet "github.com/libp2p/go-libp2p/core/network"
|
||||
"github.com/libp2p/go-libp2p/core/peer"
|
||||
"github.com/multiformats/go-multiaddr"
|
||||
|
||||
pb "github.com/berty/go-libp2p-rendezvous/pb"
|
||||
)
|
||||
|
||||
type client struct {
|
||||
ctx context.Context
|
||||
host host.Host
|
||||
mu sync.Mutex
|
||||
streams map[string]inet.Stream
|
||||
subscriptions map[string]map[string]chan *Registration
|
||||
}
|
||||
|
||||
func NewSyncInMemClient(ctx context.Context, h host.Host) *client {
|
||||
return &client{
|
||||
ctx: ctx,
|
||||
host: h,
|
||||
streams: map[string]inet.Stream{},
|
||||
subscriptions: map[string]map[string]chan *Registration{},
|
||||
}
|
||||
}
|
||||
|
||||
func (c *client) getStreamToPeer(pidStr string) (inet.Stream, error) {
|
||||
c.mu.Lock()
|
||||
defer c.mu.Unlock()
|
||||
|
||||
if stream, ok := c.streams[pidStr]; ok {
|
||||
return stream, nil
|
||||
}
|
||||
|
||||
pid, err := peer.Decode(pidStr)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("unable to decode peer id: %w", err)
|
||||
}
|
||||
|
||||
stream, err := c.host.NewStream(c.ctx, pid, ServiceProto)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("unable to connect to peer: %w", err)
|
||||
}
|
||||
|
||||
go c.streamListener(stream)
|
||||
|
||||
return stream, nil
|
||||
}
|
||||
|
||||
func (c *client) streamListener(s inet.Stream) {
|
||||
r := ggio.NewDelimitedReader(s, inet.MessageSizeMax)
|
||||
record := &pb.RegistrationRecord{}
|
||||
|
||||
for {
|
||||
err := r.ReadMsg(record)
|
||||
if err != nil {
|
||||
log.Errorf("unable to decode message: %s", err.Error())
|
||||
return
|
||||
}
|
||||
|
||||
pid, err := peer.Decode(record.Id)
|
||||
if err != nil {
|
||||
log.Warnf("invalid peer id: %s", err.Error())
|
||||
continue
|
||||
}
|
||||
|
||||
maddrs := make([]multiaddr.Multiaddr, len(record.Addrs))
|
||||
for i, addrBytes := range record.Addrs {
|
||||
maddrs[i], err = multiaddr.NewMultiaddrBytes(addrBytes)
|
||||
if err != nil {
|
||||
log.Warnf("invalid multiaddr: %s", err.Error())
|
||||
continue
|
||||
}
|
||||
}
|
||||
|
||||
c.mu.Lock()
|
||||
subscriptions, ok := c.subscriptions[record.Ns]
|
||||
if ok {
|
||||
for _, subscription := range subscriptions {
|
||||
subscription <- &Registration{
|
||||
Peer: peer.AddrInfo{
|
||||
ID: pid,
|
||||
Addrs: maddrs,
|
||||
},
|
||||
Ns: record.Ns,
|
||||
Ttl: int(record.Ttl),
|
||||
}
|
||||
}
|
||||
}
|
||||
c.mu.Unlock()
|
||||
}
|
||||
}
|
||||
|
||||
func (c *client) Subscribe(ctx context.Context, syncDetails string) (<-chan *Registration, error) {
|
||||
ctxUUID, err := uuid.NewRandom()
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("unable to generate uuid: %w", err)
|
||||
}
|
||||
|
||||
psDetails := &PubSubSubscriptionDetails{}
|
||||
|
||||
err = json.Unmarshal([]byte(syncDetails), psDetails)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("unable to decode json: %w", err)
|
||||
}
|
||||
|
||||
s, err := c.getStreamToPeer(psDetails.PeerID)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("unable to get stream to peer: %w", err)
|
||||
}
|
||||
|
||||
w := ggio.NewDelimitedWriter(s)
|
||||
|
||||
err = w.WriteMsg(&pb.Message{
|
||||
Type: pb.Message_DISCOVER_SUBSCRIBE,
|
||||
DiscoverSubscribe: &pb.Message_DiscoverSubscribe{
|
||||
Ns: psDetails.ChannelName,
|
||||
}})
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("unable to query server")
|
||||
}
|
||||
|
||||
ch := make(chan *Registration)
|
||||
c.mu.Lock()
|
||||
if _, ok := c.subscriptions[psDetails.ChannelName]; !ok {
|
||||
c.subscriptions[psDetails.ChannelName] = map[string]chan *Registration{}
|
||||
}
|
||||
|
||||
c.subscriptions[psDetails.ChannelName][ctxUUID.String()] = ch
|
||||
c.mu.Unlock()
|
||||
|
||||
go func() {
|
||||
<-ctx.Done()
|
||||
c.mu.Lock()
|
||||
delete(c.subscriptions[psDetails.ChannelName], ctxUUID.String())
|
||||
c.mu.Unlock()
|
||||
close(ch)
|
||||
}()
|
||||
|
||||
return ch, nil
|
||||
}
|
||||
|
||||
func (c *client) GetServiceType() string {
|
||||
return ServiceType
|
||||
}
|
||||
|
||||
var _ RendezvousSyncClient = (*client)(nil)
|
156
vendor/github.com/berty/go-libp2p-rendezvous/sync_inmem_provider.go
generated
vendored
Normal file
156
vendor/github.com/berty/go-libp2p-rendezvous/sync_inmem_provider.go
generated
vendored
Normal file
|
@ -0,0 +1,156 @@
|
|||
package rendezvous
|
||||
|
||||
import (
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"sync"
|
||||
"time"
|
||||
|
||||
pb "github.com/berty/go-libp2p-rendezvous/pb"
|
||||
ggio "github.com/gogo/protobuf/io"
|
||||
"github.com/libp2p/go-libp2p/core/host"
|
||||
inet "github.com/libp2p/go-libp2p/core/network"
|
||||
"github.com/libp2p/go-libp2p/core/peer"
|
||||
"github.com/libp2p/go-libp2p/core/protocol"
|
||||
)
|
||||
|
||||
const (
|
||||
ServiceType = "inmem"
|
||||
ServiceProto = protocol.ID("/rendezvous/sync/inmem/1.0.0")
|
||||
)
|
||||
|
||||
type PubSub struct {
|
||||
mu sync.RWMutex
|
||||
host host.Host
|
||||
topics map[string]*PubSubSubscribers
|
||||
}
|
||||
|
||||
type PubSubSubscribers struct {
|
||||
mu sync.RWMutex
|
||||
subscribers map[peer.ID]ggio.Writer
|
||||
lastAnnouncement *pb.RegistrationRecord
|
||||
}
|
||||
|
||||
type PubSubSubscriptionDetails struct {
|
||||
PeerID string
|
||||
ChannelName string
|
||||
}
|
||||
|
||||
func NewSyncInMemProvider(host host.Host) (*PubSub, error) {
|
||||
ps := &PubSub{
|
||||
host: host,
|
||||
topics: map[string]*PubSubSubscribers{},
|
||||
}
|
||||
|
||||
ps.Listen()
|
||||
|
||||
return ps, nil
|
||||
}
|
||||
|
||||
func (ps *PubSub) Subscribe(ns string) (syncDetails string, err error) {
|
||||
details, err := json.Marshal(&PubSubSubscriptionDetails{
|
||||
PeerID: ps.host.ID().String(),
|
||||
ChannelName: ns,
|
||||
})
|
||||
|
||||
if err != nil {
|
||||
return "", fmt.Errorf("unable to marshal subscription details: %w", err)
|
||||
}
|
||||
|
||||
return string(details), nil
|
||||
}
|
||||
|
||||
func (ps *PubSub) GetServiceType() string {
|
||||
return ServiceType
|
||||
}
|
||||
|
||||
func (ps *PubSub) getOrCreateTopic(ns string) *PubSubSubscribers {
|
||||
ps.mu.Lock()
|
||||
defer ps.mu.Unlock()
|
||||
|
||||
if subscribers, ok := ps.topics[ns]; ok {
|
||||
return subscribers
|
||||
}
|
||||
|
||||
ps.topics[ns] = &PubSubSubscribers{
|
||||
subscribers: map[peer.ID]ggio.Writer{},
|
||||
lastAnnouncement: nil,
|
||||
}
|
||||
return ps.topics[ns]
|
||||
}
|
||||
|
||||
func (ps *PubSub) Register(pid peer.ID, ns string, addrs [][]byte, ttlAsSeconds int, counter uint64) {
|
||||
topic := ps.getOrCreateTopic(ns)
|
||||
dataToSend := &pb.RegistrationRecord{
|
||||
Id: pid.String(),
|
||||
Addrs: addrs,
|
||||
Ns: ns,
|
||||
Ttl: time.Now().Add(time.Duration(ttlAsSeconds) * time.Second).UnixMilli(),
|
||||
}
|
||||
|
||||
topic.mu.Lock()
|
||||
topic.lastAnnouncement = dataToSend
|
||||
toNotify := topic.subscribers
|
||||
for _, stream := range toNotify {
|
||||
if err := stream.WriteMsg(dataToSend); err != nil {
|
||||
log.Errorf("unable to notify rendezvous data update: %s", err.Error())
|
||||
}
|
||||
}
|
||||
topic.mu.Unlock()
|
||||
}
|
||||
|
||||
func (ps *PubSub) Unregister(p peer.ID, ns string) {
|
||||
// TODO: unsupported
|
||||
}
|
||||
|
||||
func (ps *PubSub) Listen() {
|
||||
ps.host.SetStreamHandler(ServiceProto, ps.handleStream)
|
||||
}
|
||||
|
||||
func (ps *PubSub) handleStream(s inet.Stream) {
|
||||
defer s.Reset()
|
||||
|
||||
r := ggio.NewDelimitedReader(s, inet.MessageSizeMax)
|
||||
w := ggio.NewDelimitedWriter(s)
|
||||
|
||||
subscribedTopics := map[string]struct{}{}
|
||||
|
||||
for {
|
||||
var req pb.Message
|
||||
|
||||
err := r.ReadMsg(&req)
|
||||
if err != nil {
|
||||
for ns := range subscribedTopics {
|
||||
topic := ps.getOrCreateTopic(ns)
|
||||
topic.mu.Lock()
|
||||
delete(topic.subscribers, s.Conn().RemotePeer())
|
||||
topic.mu.Unlock()
|
||||
}
|
||||
return
|
||||
}
|
||||
|
||||
if req.Type != pb.Message_DISCOVER_SUBSCRIBE {
|
||||
continue
|
||||
}
|
||||
|
||||
topic := ps.getOrCreateTopic(req.DiscoverSubscribe.Ns)
|
||||
topic.mu.Lock()
|
||||
if _, ok := topic.subscribers[s.Conn().RemotePeer()]; ok {
|
||||
topic.mu.Unlock()
|
||||
continue
|
||||
}
|
||||
|
||||
topic.subscribers[s.Conn().RemotePeer()] = w
|
||||
subscribedTopics[req.DiscoverSubscribe.Ns] = struct{}{}
|
||||
lastAnnouncement := topic.lastAnnouncement
|
||||
if lastAnnouncement != nil {
|
||||
if err := w.WriteMsg(lastAnnouncement); err != nil {
|
||||
log.Errorf("unable to write announcement: %s", err.Error())
|
||||
}
|
||||
}
|
||||
topic.mu.Unlock()
|
||||
}
|
||||
}
|
||||
|
||||
var _ RendezvousSync = (*PubSub)(nil)
|
||||
var _ RendezvousSyncSubscribable = (*PubSub)(nil)
|
|
@ -0,0 +1,102 @@
|
|||
// Protocol Buffers for Go with Gadgets
|
||||
//
|
||||
// Copyright (c) 2013, The GoGo Authors. All rights reserved.
|
||||
// http://github.com/gogo/protobuf
|
||||
//
|
||||
// Redistribution and use in source and binary forms, with or without
|
||||
// modification, are permitted provided that the following conditions are
|
||||
// met:
|
||||
//
|
||||
// * Redistributions of source code must retain the above copyright
|
||||
// notice, this list of conditions and the following disclaimer.
|
||||
// * Redistributions in binary form must reproduce the above
|
||||
// copyright notice, this list of conditions and the following disclaimer
|
||||
// in the documentation and/or other materials provided with the
|
||||
// distribution.
|
||||
//
|
||||
// THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS
|
||||
// "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT
|
||||
// LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR
|
||||
// A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT
|
||||
// OWNER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL,
|
||||
// SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT
|
||||
// LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE,
|
||||
// DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY
|
||||
// THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT
|
||||
// (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
|
||||
// OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
|
||||
|
||||
package io
|
||||
|
||||
import (
|
||||
"github.com/gogo/protobuf/proto"
|
||||
"io"
|
||||
)
|
||||
|
||||
func NewFullWriter(w io.Writer) WriteCloser {
|
||||
return &fullWriter{w, nil}
|
||||
}
|
||||
|
||||
type fullWriter struct {
|
||||
w io.Writer
|
||||
buffer []byte
|
||||
}
|
||||
|
||||
func (this *fullWriter) WriteMsg(msg proto.Message) (err error) {
|
||||
var data []byte
|
||||
if m, ok := msg.(marshaler); ok {
|
||||
n, ok := getSize(m)
|
||||
if !ok {
|
||||
data, err = proto.Marshal(msg)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
}
|
||||
if n >= len(this.buffer) {
|
||||
this.buffer = make([]byte, n)
|
||||
}
|
||||
_, err = m.MarshalTo(this.buffer)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
data = this.buffer[:n]
|
||||
} else {
|
||||
data, err = proto.Marshal(msg)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
}
|
||||
_, err = this.w.Write(data)
|
||||
return err
|
||||
}
|
||||
|
||||
func (this *fullWriter) Close() error {
|
||||
if closer, ok := this.w.(io.Closer); ok {
|
||||
return closer.Close()
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
type fullReader struct {
|
||||
r io.Reader
|
||||
buf []byte
|
||||
}
|
||||
|
||||
func NewFullReader(r io.Reader, maxSize int) ReadCloser {
|
||||
return &fullReader{r, make([]byte, maxSize)}
|
||||
}
|
||||
|
||||
func (this *fullReader) ReadMsg(msg proto.Message) error {
|
||||
length, err := this.r.Read(this.buf)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
return proto.Unmarshal(this.buf[:length], msg)
|
||||
}
|
||||
|
||||
func (this *fullReader) Close() error {
|
||||
if closer, ok := this.r.(io.Closer); ok {
|
||||
return closer.Close()
|
||||
}
|
||||
return nil
|
||||
}
|
|
@ -0,0 +1,70 @@
|
|||
// Protocol Buffers for Go with Gadgets
|
||||
//
|
||||
// Copyright (c) 2013, The GoGo Authors. All rights reserved.
|
||||
// http://github.com/gogo/protobuf
|
||||
//
|
||||
// Redistribution and use in source and binary forms, with or without
|
||||
// modification, are permitted provided that the following conditions are
|
||||
// met:
|
||||
//
|
||||
// * Redistributions of source code must retain the above copyright
|
||||
// notice, this list of conditions and the following disclaimer.
|
||||
// * Redistributions in binary form must reproduce the above
|
||||
// copyright notice, this list of conditions and the following disclaimer
|
||||
// in the documentation and/or other materials provided with the
|
||||
// distribution.
|
||||
//
|
||||
// THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS
|
||||
// "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT
|
||||
// LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR
|
||||
// A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT
|
||||
// OWNER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL,
|
||||
// SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT
|
||||
// LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE,
|
||||
// DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY
|
||||
// THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT
|
||||
// (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
|
||||
// OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
|
||||
|
||||
package io
|
||||
|
||||
import (
|
||||
"github.com/gogo/protobuf/proto"
|
||||
"io"
|
||||
)
|
||||
|
||||
type Writer interface {
|
||||
WriteMsg(proto.Message) error
|
||||
}
|
||||
|
||||
type WriteCloser interface {
|
||||
Writer
|
||||
io.Closer
|
||||
}
|
||||
|
||||
type Reader interface {
|
||||
ReadMsg(msg proto.Message) error
|
||||
}
|
||||
|
||||
type ReadCloser interface {
|
||||
Reader
|
||||
io.Closer
|
||||
}
|
||||
|
||||
type marshaler interface {
|
||||
MarshalTo(data []byte) (n int, err error)
|
||||
}
|
||||
|
||||
func getSize(v interface{}) (int, bool) {
|
||||
if sz, ok := v.(interface {
|
||||
Size() (n int)
|
||||
}); ok {
|
||||
return sz.Size(), true
|
||||
} else if sz, ok := v.(interface {
|
||||
ProtoSize() (n int)
|
||||
}); ok {
|
||||
return sz.ProtoSize(), true
|
||||
} else {
|
||||
return 0, false
|
||||
}
|
||||
}
|
|
@ -0,0 +1,138 @@
|
|||
// Protocol Buffers for Go with Gadgets
|
||||
//
|
||||
// Copyright (c) 2013, The GoGo Authors. All rights reserved.
|
||||
// http://github.com/gogo/protobuf
|
||||
//
|
||||
// Redistribution and use in source and binary forms, with or without
|
||||
// modification, are permitted provided that the following conditions are
|
||||
// met:
|
||||
//
|
||||
// * Redistributions of source code must retain the above copyright
|
||||
// notice, this list of conditions and the following disclaimer.
|
||||
// * Redistributions in binary form must reproduce the above
|
||||
// copyright notice, this list of conditions and the following disclaimer
|
||||
// in the documentation and/or other materials provided with the
|
||||
// distribution.
|
||||
//
|
||||
// THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS
|
||||
// "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT
|
||||
// LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR
|
||||
// A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT
|
||||
// OWNER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL,
|
||||
// SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT
|
||||
// LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE,
|
||||
// DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY
|
||||
// THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT
|
||||
// (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
|
||||
// OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
|
||||
|
||||
package io
|
||||
|
||||
import (
|
||||
"encoding/binary"
|
||||
"io"
|
||||
|
||||
"github.com/gogo/protobuf/proto"
|
||||
)
|
||||
|
||||
const uint32BinaryLen = 4
|
||||
|
||||
func NewUint32DelimitedWriter(w io.Writer, byteOrder binary.ByteOrder) WriteCloser {
|
||||
return &uint32Writer{w, byteOrder, nil, make([]byte, uint32BinaryLen)}
|
||||
}
|
||||
|
||||
func NewSizeUint32DelimitedWriter(w io.Writer, byteOrder binary.ByteOrder, size int) WriteCloser {
|
||||
return &uint32Writer{w, byteOrder, make([]byte, size), make([]byte, uint32BinaryLen)}
|
||||
}
|
||||
|
||||
type uint32Writer struct {
|
||||
w io.Writer
|
||||
byteOrder binary.ByteOrder
|
||||
buffer []byte
|
||||
lenBuf []byte
|
||||
}
|
||||
|
||||
func (this *uint32Writer) writeFallback(msg proto.Message) error {
|
||||
data, err := proto.Marshal(msg)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
length := uint32(len(data))
|
||||
this.byteOrder.PutUint32(this.lenBuf, length)
|
||||
if _, err = this.w.Write(this.lenBuf); err != nil {
|
||||
return err
|
||||
}
|
||||
_, err = this.w.Write(data)
|
||||
return err
|
||||
}
|
||||
|
||||
func (this *uint32Writer) WriteMsg(msg proto.Message) error {
|
||||
m, ok := msg.(marshaler)
|
||||
if !ok {
|
||||
return this.writeFallback(msg)
|
||||
}
|
||||
|
||||
n, ok := getSize(m)
|
||||
if !ok {
|
||||
return this.writeFallback(msg)
|
||||
}
|
||||
|
||||
size := n + uint32BinaryLen
|
||||
if size > len(this.buffer) {
|
||||
this.buffer = make([]byte, size)
|
||||
}
|
||||
|
||||
this.byteOrder.PutUint32(this.buffer, uint32(n))
|
||||
if _, err := m.MarshalTo(this.buffer[uint32BinaryLen:]); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
_, err := this.w.Write(this.buffer[:size])
|
||||
return err
|
||||
}
|
||||
|
||||
func (this *uint32Writer) Close() error {
|
||||
if closer, ok := this.w.(io.Closer); ok {
|
||||
return closer.Close()
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
type uint32Reader struct {
|
||||
r io.Reader
|
||||
byteOrder binary.ByteOrder
|
||||
lenBuf []byte
|
||||
buf []byte
|
||||
maxSize int
|
||||
}
|
||||
|
||||
func NewUint32DelimitedReader(r io.Reader, byteOrder binary.ByteOrder, maxSize int) ReadCloser {
|
||||
return &uint32Reader{r, byteOrder, make([]byte, 4), nil, maxSize}
|
||||
}
|
||||
|
||||
func (this *uint32Reader) ReadMsg(msg proto.Message) error {
|
||||
if _, err := io.ReadFull(this.r, this.lenBuf); err != nil {
|
||||
return err
|
||||
}
|
||||
length32 := this.byteOrder.Uint32(this.lenBuf)
|
||||
length := int(length32)
|
||||
if length < 0 || length > this.maxSize {
|
||||
return io.ErrShortBuffer
|
||||
}
|
||||
if length > len(this.buf) {
|
||||
this.buf = make([]byte, length)
|
||||
}
|
||||
_, err := io.ReadFull(this.r, this.buf[:length])
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
return proto.Unmarshal(this.buf[:length], msg)
|
||||
}
|
||||
|
||||
func (this *uint32Reader) Close() error {
|
||||
if closer, ok := this.r.(io.Closer); ok {
|
||||
return closer.Close()
|
||||
}
|
||||
return nil
|
||||
}
|
|
@ -0,0 +1,133 @@
|
|||
// Protocol Buffers for Go with Gadgets
|
||||
//
|
||||
// Copyright (c) 2013, The GoGo Authors. All rights reserved.
|
||||
// http://github.com/gogo/protobuf
|
||||
//
|
||||
// Redistribution and use in source and binary forms, with or without
|
||||
// modification, are permitted provided that the following conditions are
|
||||
// met:
|
||||
//
|
||||
// * Redistributions of source code must retain the above copyright
|
||||
// notice, this list of conditions and the following disclaimer.
|
||||
// * Redistributions in binary form must reproduce the above
|
||||
// copyright notice, this list of conditions and the following disclaimer
|
||||
// in the documentation and/or other materials provided with the
|
||||
// distribution.
|
||||
//
|
||||
// THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS
|
||||
// "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT
|
||||
// LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR
|
||||
// A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT
|
||||
// OWNER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL,
|
||||
// SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT
|
||||
// LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE,
|
||||
// DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY
|
||||
// THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT
|
||||
// (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
|
||||
// OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
|
||||
|
||||
package io
|
||||
|
||||
import (
|
||||
"bufio"
|
||||
"encoding/binary"
|
||||
"errors"
|
||||
"github.com/gogo/protobuf/proto"
|
||||
"io"
|
||||
)
|
||||
|
||||
var (
|
||||
errSmallBuffer = errors.New("Buffer Too Small")
|
||||
errLargeValue = errors.New("Value is Larger than 64 bits")
|
||||
)
|
||||
|
||||
func NewDelimitedWriter(w io.Writer) WriteCloser {
|
||||
return &varintWriter{w, make([]byte, binary.MaxVarintLen64), nil}
|
||||
}
|
||||
|
||||
type varintWriter struct {
|
||||
w io.Writer
|
||||
lenBuf []byte
|
||||
buffer []byte
|
||||
}
|
||||
|
||||
func (this *varintWriter) WriteMsg(msg proto.Message) (err error) {
|
||||
var data []byte
|
||||
if m, ok := msg.(marshaler); ok {
|
||||
n, ok := getSize(m)
|
||||
if ok {
|
||||
if n+binary.MaxVarintLen64 >= len(this.buffer) {
|
||||
this.buffer = make([]byte, n+binary.MaxVarintLen64)
|
||||
}
|
||||
lenOff := binary.PutUvarint(this.buffer, uint64(n))
|
||||
_, err = m.MarshalTo(this.buffer[lenOff:])
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
_, err = this.w.Write(this.buffer[:lenOff+n])
|
||||
return err
|
||||
}
|
||||
}
|
||||
|
||||
// fallback
|
||||
data, err = proto.Marshal(msg)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
length := uint64(len(data))
|
||||
n := binary.PutUvarint(this.lenBuf, length)
|
||||
_, err = this.w.Write(this.lenBuf[:n])
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
_, err = this.w.Write(data)
|
||||
return err
|
||||
}
|
||||
|
||||
func (this *varintWriter) Close() error {
|
||||
if closer, ok := this.w.(io.Closer); ok {
|
||||
return closer.Close()
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func NewDelimitedReader(r io.Reader, maxSize int) ReadCloser {
|
||||
var closer io.Closer
|
||||
if c, ok := r.(io.Closer); ok {
|
||||
closer = c
|
||||
}
|
||||
return &varintReader{bufio.NewReader(r), nil, maxSize, closer}
|
||||
}
|
||||
|
||||
type varintReader struct {
|
||||
r *bufio.Reader
|
||||
buf []byte
|
||||
maxSize int
|
||||
closer io.Closer
|
||||
}
|
||||
|
||||
func (this *varintReader) ReadMsg(msg proto.Message) error {
|
||||
length64, err := binary.ReadUvarint(this.r)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
length := int(length64)
|
||||
if length < 0 || length > this.maxSize {
|
||||
return io.ErrShortBuffer
|
||||
}
|
||||
if len(this.buf) < length {
|
||||
this.buf = make([]byte, length)
|
||||
}
|
||||
buf := this.buf[:length]
|
||||
if _, err := io.ReadFull(this.r, buf); err != nil {
|
||||
return err
|
||||
}
|
||||
return proto.Unmarshal(buf, msg)
|
||||
}
|
||||
|
||||
func (this *varintReader) Close() error {
|
||||
if this.closer != nil {
|
||||
return this.closer.Close()
|
||||
}
|
||||
return nil
|
||||
}
|
|
@ -9,7 +9,7 @@ import (
|
|||
var backgroundSweepInterval = time.Minute
|
||||
|
||||
func background(ctx context.Context, lk sync.Locker, m map[string]time.Time) {
|
||||
ticker := time.NewTimer(backgroundSweepInterval)
|
||||
ticker := time.NewTicker(backgroundSweepInterval)
|
||||
defer ticker.Stop()
|
||||
|
||||
for {
|
||||
|
|
|
@ -70,7 +70,10 @@ type validation struct {
|
|||
// mx protects the validator map
|
||||
mx sync.Mutex
|
||||
// topicVals tracks per topic validators
|
||||
topicVals map[string]*topicVal
|
||||
topicVals map[string]*validatorImpl
|
||||
|
||||
// defaultVals tracks default validators applicable to all topics
|
||||
defaultVals []*validatorImpl
|
||||
|
||||
// validateQ is the front-end to the validation pipeline
|
||||
validateQ chan *validateReq
|
||||
|
@ -84,13 +87,13 @@ type validation struct {
|
|||
|
||||
// validation requests
|
||||
type validateReq struct {
|
||||
vals []*topicVal
|
||||
vals []*validatorImpl
|
||||
src peer.ID
|
||||
msg *Message
|
||||
}
|
||||
|
||||
// representation of topic validators
|
||||
type topicVal struct {
|
||||
type validatorImpl struct {
|
||||
topic string
|
||||
validate ValidatorEx
|
||||
validateTimeout time.Duration
|
||||
|
@ -117,7 +120,7 @@ type rmValReq struct {
|
|||
// newValidation creates a new validation pipeline
|
||||
func newValidation() *validation {
|
||||
return &validation{
|
||||
topicVals: make(map[string]*topicVal),
|
||||
topicVals: make(map[string]*validatorImpl),
|
||||
validateQ: make(chan *validateReq, defaultValidateQueueSize),
|
||||
validateThrottle: make(chan struct{}, defaultValidateThrottle),
|
||||
validateWorkers: runtime.NumCPU(),
|
||||
|
@ -136,10 +139,16 @@ func (v *validation) Start(p *PubSub) {
|
|||
|
||||
// AddValidator adds a new validator
|
||||
func (v *validation) AddValidator(req *addValReq) {
|
||||
val, err := v.makeValidator(req)
|
||||
if err != nil {
|
||||
req.resp <- err
|
||||
return
|
||||
}
|
||||
|
||||
v.mx.Lock()
|
||||
defer v.mx.Unlock()
|
||||
|
||||
topic := req.topic
|
||||
topic := val.topic
|
||||
|
||||
_, ok := v.topicVals[topic]
|
||||
if ok {
|
||||
|
@ -147,6 +156,11 @@ func (v *validation) AddValidator(req *addValReq) {
|
|||
return
|
||||
}
|
||||
|
||||
v.topicVals[topic] = val
|
||||
req.resp <- nil
|
||||
}
|
||||
|
||||
func (v *validation) makeValidator(req *addValReq) (*validatorImpl, error) {
|
||||
makeValidatorEx := func(v Validator) ValidatorEx {
|
||||
return func(ctx context.Context, p peer.ID, msg *Message) ValidationResult {
|
||||
if v(ctx, p, msg) {
|
||||
|
@ -170,12 +184,15 @@ func (v *validation) AddValidator(req *addValReq) {
|
|||
validator = v
|
||||
|
||||
default:
|
||||
req.resp <- fmt.Errorf("unknown validator type for topic %s; must be an instance of Validator or ValidatorEx", topic)
|
||||
return
|
||||
topic := req.topic
|
||||
if req.topic == "" {
|
||||
topic = "(default)"
|
||||
}
|
||||
return nil, fmt.Errorf("unknown validator type for topic %s; must be an instance of Validator or ValidatorEx", topic)
|
||||
}
|
||||
|
||||
val := &topicVal{
|
||||
topic: topic,
|
||||
val := &validatorImpl{
|
||||
topic: req.topic,
|
||||
validate: validator,
|
||||
validateTimeout: 0,
|
||||
validateThrottle: make(chan struct{}, defaultValidateConcurrency),
|
||||
|
@ -190,8 +207,7 @@ func (v *validation) AddValidator(req *addValReq) {
|
|||
val.validateThrottle = make(chan struct{}, req.throttle)
|
||||
}
|
||||
|
||||
v.topicVals[topic] = val
|
||||
req.resp <- nil
|
||||
return val, nil
|
||||
}
|
||||
|
||||
// RemoveValidator removes an existing validator
|
||||
|
@ -244,18 +260,21 @@ func (v *validation) Push(src peer.ID, msg *Message) bool {
|
|||
}
|
||||
|
||||
// getValidators returns all validators that apply to a given message
|
||||
func (v *validation) getValidators(msg *Message) []*topicVal {
|
||||
func (v *validation) getValidators(msg *Message) []*validatorImpl {
|
||||
v.mx.Lock()
|
||||
defer v.mx.Unlock()
|
||||
|
||||
var vals []*validatorImpl
|
||||
vals = append(vals, v.defaultVals...)
|
||||
|
||||
topic := msg.GetTopic()
|
||||
|
||||
val, ok := v.topicVals[topic]
|
||||
if !ok {
|
||||
return nil
|
||||
return vals
|
||||
}
|
||||
|
||||
return []*topicVal{val}
|
||||
return append(vals, val)
|
||||
}
|
||||
|
||||
// validateWorker is an active goroutine performing inline validation
|
||||
|
@ -271,7 +290,7 @@ func (v *validation) validateWorker() {
|
|||
}
|
||||
|
||||
// validate performs validation and only sends the message if all validators succeed
|
||||
func (v *validation) validate(vals []*topicVal, src peer.ID, msg *Message, synchronous bool) error {
|
||||
func (v *validation) validate(vals []*validatorImpl, src peer.ID, msg *Message, synchronous bool) error {
|
||||
// If signature verification is enabled, but signing is disabled,
|
||||
// the Signature is required to be nil upon receiving the message in PubSub.pushMsg.
|
||||
if msg.Signature != nil {
|
||||
|
@ -292,7 +311,7 @@ func (v *validation) validate(vals []*topicVal, src peer.ID, msg *Message, synch
|
|||
v.tracer.ValidateMessage(msg)
|
||||
}
|
||||
|
||||
var inline, async []*topicVal
|
||||
var inline, async []*validatorImpl
|
||||
for _, val := range vals {
|
||||
if val.validateInline || synchronous {
|
||||
inline = append(inline, val)
|
||||
|
@ -360,7 +379,7 @@ func (v *validation) validateSignature(msg *Message) bool {
|
|||
return true
|
||||
}
|
||||
|
||||
func (v *validation) doValidateTopic(vals []*topicVal, src peer.ID, msg *Message, r ValidationResult) {
|
||||
func (v *validation) doValidateTopic(vals []*validatorImpl, src peer.ID, msg *Message, r ValidationResult) {
|
||||
result := v.validateTopic(vals, src, msg)
|
||||
|
||||
if result == ValidationAccept && r != ValidationAccept {
|
||||
|
@ -388,7 +407,7 @@ func (v *validation) doValidateTopic(vals []*topicVal, src peer.ID, msg *Message
|
|||
}
|
||||
}
|
||||
|
||||
func (v *validation) validateTopic(vals []*topicVal, src peer.ID, msg *Message) ValidationResult {
|
||||
func (v *validation) validateTopic(vals []*validatorImpl, src peer.ID, msg *Message) ValidationResult {
|
||||
if len(vals) == 1 {
|
||||
return v.validateSingleTopic(vals[0], src, msg)
|
||||
}
|
||||
|
@ -404,7 +423,7 @@ func (v *validation) validateTopic(vals []*topicVal, src peer.ID, msg *Message)
|
|||
|
||||
select {
|
||||
case val.validateThrottle <- struct{}{}:
|
||||
go func(val *topicVal) {
|
||||
go func(val *validatorImpl) {
|
||||
rch <- val.validateMsg(ctx, src, msg)
|
||||
<-val.validateThrottle
|
||||
}(val)
|
||||
|
@ -438,7 +457,7 @@ loop:
|
|||
}
|
||||
|
||||
// fast path for single topic validation that avoids the extra goroutine
|
||||
func (v *validation) validateSingleTopic(val *topicVal, src peer.ID, msg *Message) ValidationResult {
|
||||
func (v *validation) validateSingleTopic(val *validatorImpl, src peer.ID, msg *Message) ValidationResult {
|
||||
select {
|
||||
case val.validateThrottle <- struct{}{}:
|
||||
res := val.validateMsg(v.p.ctx, src, msg)
|
||||
|
@ -451,7 +470,7 @@ func (v *validation) validateSingleTopic(val *topicVal, src peer.ID, msg *Messag
|
|||
}
|
||||
}
|
||||
|
||||
func (val *topicVal) validateMsg(ctx context.Context, src peer.ID, msg *Message) ValidationResult {
|
||||
func (val *validatorImpl) validateMsg(ctx context.Context, src peer.ID, msg *Message) ValidationResult {
|
||||
start := time.Now()
|
||||
defer func() {
|
||||
log.Debugf("validation done; took %s", time.Since(start))
|
||||
|
@ -479,6 +498,31 @@ func (val *topicVal) validateMsg(ctx context.Context, src peer.ID, msg *Message)
|
|||
}
|
||||
|
||||
// / Options
|
||||
// WithDefaultValidator adds a validator that applies to all topics by default; it can be used
|
||||
// more than once and add multiple validators. Having a defult validator does not inhibit registering
|
||||
// a per topic validator.
|
||||
func WithDefaultValidator(val interface{}, opts ...ValidatorOpt) Option {
|
||||
return func(ps *PubSub) error {
|
||||
addVal := &addValReq{
|
||||
validate: val,
|
||||
}
|
||||
|
||||
for _, opt := range opts {
|
||||
err := opt(addVal)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
}
|
||||
|
||||
val, err := ps.val.makeValidator(addVal)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
ps.val.defaultVals = append(ps.val.defaultVals, val)
|
||||
return nil
|
||||
}
|
||||
}
|
||||
|
||||
// WithValidateQueueSize sets the buffer of validate queue. Defaults to 32.
|
||||
// When queue is full, validation is throttled and new messages are dropped.
|
||||
|
|
|
@ -0,0 +1,101 @@
|
|||
package pubsub
|
||||
|
||||
import (
|
||||
"context"
|
||||
"encoding/binary"
|
||||
"sync"
|
||||
|
||||
"github.com/libp2p/go-libp2p/core/peer"
|
||||
)
|
||||
|
||||
// PeerMetadataStore is an interface for storing and retrieving per peer metadata
|
||||
type PeerMetadataStore interface {
|
||||
// Get retrieves the metadata associated with a peer;
|
||||
// It should return nil if there is no metadata associated with the peer and not an error.
|
||||
Get(context.Context, peer.ID) ([]byte, error)
|
||||
// Put sets the metadata associated with a peer.
|
||||
Put(context.Context, peer.ID, []byte) error
|
||||
}
|
||||
|
||||
// BasicSeqnoValidator is a basic validator, usable as a default validator, that ignores replayed
|
||||
// messages outside the seen cache window. The validator uses the message seqno as a peer-specific
|
||||
// nonce to decide whether the message should be propagated, comparing to the maximal nonce store
|
||||
// in the peer metadata store. This is useful to ensure that there can be no infinitely propagating
|
||||
// messages in the network regardless of the seen cache span and network diameter.
|
||||
// It requires that pubsub is instantiated with a strict message signing policy and that seqnos
|
||||
// are not disabled, ie it doesn't support anonymous mode.
|
||||
//
|
||||
// Warning: See https://github.com/libp2p/rust-libp2p/issues/3453
|
||||
// TL;DR: rust is currently violating the spec by issuing a random seqno, which creates an
|
||||
// interoperability hazard. We expect this issue to be addressed in the not so distant future,
|
||||
// but keep this in mind if you are in a mixed environment with (older) rust nodes.
|
||||
type BasicSeqnoValidator struct {
|
||||
mx sync.RWMutex
|
||||
meta PeerMetadataStore
|
||||
}
|
||||
|
||||
// NewBasicSeqnoValidator constructs a BasicSeqnoValidator using the givven PeerMetadataStore.
|
||||
func NewBasicSeqnoValidator(meta PeerMetadataStore) ValidatorEx {
|
||||
val := &BasicSeqnoValidator{
|
||||
meta: meta,
|
||||
}
|
||||
return val.validate
|
||||
}
|
||||
|
||||
func (v *BasicSeqnoValidator) validate(ctx context.Context, _ peer.ID, m *Message) ValidationResult {
|
||||
p := m.GetFrom()
|
||||
|
||||
v.mx.RLock()
|
||||
nonceBytes, err := v.meta.Get(ctx, p)
|
||||
v.mx.RUnlock()
|
||||
|
||||
if err != nil {
|
||||
log.Warn("error retrieving peer nonce: %s", err)
|
||||
return ValidationIgnore
|
||||
}
|
||||
|
||||
var nonce uint64
|
||||
if len(nonceBytes) > 0 {
|
||||
nonce = binary.BigEndian.Uint64(nonceBytes)
|
||||
}
|
||||
|
||||
var seqno uint64
|
||||
seqnoBytes := m.GetSeqno()
|
||||
if len(seqnoBytes) > 0 {
|
||||
seqno = binary.BigEndian.Uint64(seqnoBytes)
|
||||
}
|
||||
|
||||
// compare against the largest seen nonce
|
||||
if seqno <= nonce {
|
||||
return ValidationIgnore
|
||||
}
|
||||
|
||||
// get the nonce and compare again with an exclusive lock before commiting (cf concurrent validation)
|
||||
v.mx.Lock()
|
||||
defer v.mx.Unlock()
|
||||
|
||||
nonceBytes, err = v.meta.Get(ctx, p)
|
||||
if err != nil {
|
||||
log.Warn("error retrieving peer nonce: %s", err)
|
||||
return ValidationIgnore
|
||||
}
|
||||
|
||||
if len(nonceBytes) > 0 {
|
||||
nonce = binary.BigEndian.Uint64(nonceBytes)
|
||||
}
|
||||
|
||||
if seqno <= nonce {
|
||||
return ValidationIgnore
|
||||
}
|
||||
|
||||
// update the nonce
|
||||
nonceBytes = make([]byte, 8)
|
||||
binary.BigEndian.PutUint64(nonceBytes, seqno)
|
||||
|
||||
err = v.meta.Put(ctx, p, nonceBytes)
|
||||
if err != nil {
|
||||
log.Warn("error storing peer nonce: %s", err)
|
||||
}
|
||||
|
||||
return ValidationAccept
|
||||
}
|
|
@ -1,6 +1,7 @@
|
|||
package persistence
|
||||
|
||||
import (
|
||||
"context"
|
||||
"database/sql"
|
||||
"errors"
|
||||
"fmt"
|
||||
|
@ -21,7 +22,7 @@ type MessageProvider interface {
|
|||
Put(env *protocol.Envelope) error
|
||||
Query(query *pb.HistoryQuery) ([]StoredMessage, error)
|
||||
MostRecentTimestamp() (int64, error)
|
||||
Start(timesource timesource.Timesource) error
|
||||
Start(ctx context.Context, timesource timesource.Timesource) error
|
||||
Stop()
|
||||
}
|
||||
|
||||
|
@ -45,8 +46,8 @@ type DBStore struct {
|
|||
|
||||
enableMigrations bool
|
||||
|
||||
wg sync.WaitGroup
|
||||
quit chan struct{}
|
||||
wg sync.WaitGroup
|
||||
cancel context.CancelFunc
|
||||
}
|
||||
|
||||
type StoredMessage struct {
|
||||
|
@ -124,7 +125,6 @@ func DefaultOptions() []DBOption {
|
|||
func NewDBStore(log *zap.Logger, options ...DBOption) (*DBStore, error) {
|
||||
result := new(DBStore)
|
||||
result.log = log.Named("dbstore")
|
||||
result.quit = make(chan struct{})
|
||||
|
||||
optList := DefaultOptions()
|
||||
optList = append(optList, options...)
|
||||
|
@ -146,7 +146,10 @@ func NewDBStore(log *zap.Logger, options ...DBOption) (*DBStore, error) {
|
|||
return result, nil
|
||||
}
|
||||
|
||||
func (d *DBStore) Start(timesource timesource.Timesource) error {
|
||||
func (d *DBStore) Start(ctx context.Context, timesource timesource.Timesource) error {
|
||||
ctx, cancel := context.WithCancel(ctx)
|
||||
|
||||
d.cancel = cancel
|
||||
d.timesource = timesource
|
||||
|
||||
err := d.cleanOlderRecords()
|
||||
|
@ -155,7 +158,7 @@ func (d *DBStore) Start(timesource timesource.Timesource) error {
|
|||
}
|
||||
|
||||
d.wg.Add(1)
|
||||
go d.checkForOlderRecords(60 * time.Second)
|
||||
go d.checkForOlderRecords(ctx, 60*time.Second)
|
||||
|
||||
return nil
|
||||
}
|
||||
|
@ -192,7 +195,7 @@ func (d *DBStore) cleanOlderRecords() error {
|
|||
return nil
|
||||
}
|
||||
|
||||
func (d *DBStore) checkForOlderRecords(t time.Duration) {
|
||||
func (d *DBStore) checkForOlderRecords(ctx context.Context, t time.Duration) {
|
||||
defer d.wg.Done()
|
||||
|
||||
ticker := time.NewTicker(t)
|
||||
|
@ -200,7 +203,7 @@ func (d *DBStore) checkForOlderRecords(t time.Duration) {
|
|||
|
||||
for {
|
||||
select {
|
||||
case <-d.quit:
|
||||
case <-ctx.Done():
|
||||
return
|
||||
case <-ticker.C:
|
||||
err := d.cleanOlderRecords()
|
||||
|
@ -213,7 +216,11 @@ func (d *DBStore) checkForOlderRecords(t time.Duration) {
|
|||
|
||||
// Stop closes a DB connection
|
||||
func (d *DBStore) Stop() {
|
||||
d.quit <- struct{}{}
|
||||
if d.cancel == nil {
|
||||
return
|
||||
}
|
||||
|
||||
d.cancel()
|
||||
d.wg.Wait()
|
||||
d.db.Close()
|
||||
}
|
||||
|
|
|
@ -41,6 +41,7 @@ import (
|
|||
"github.com/waku-org/go-waku/waku/v2/protocol/relay"
|
||||
"github.com/waku-org/go-waku/waku/v2/protocol/store"
|
||||
"github.com/waku-org/go-waku/waku/v2/protocol/swap"
|
||||
"github.com/waku-org/go-waku/waku/v2/rendezvous"
|
||||
"github.com/waku-org/go-waku/waku/v2/timesource"
|
||||
|
||||
"github.com/waku-org/go-waku/waku/v2/utils"
|
||||
|
@ -80,6 +81,7 @@ type WakuNode struct {
|
|||
peerConnector PeerConnectorService
|
||||
discoveryV5 Service
|
||||
peerExchange Service
|
||||
rendezvous Service
|
||||
filter ReceptorService
|
||||
filterV2Full ReceptorService
|
||||
filterV2Light Service
|
||||
|
@ -212,6 +214,7 @@ func New(opts ...WakuNodeOption) (*WakuNode, error) {
|
|||
return nil, err
|
||||
}
|
||||
|
||||
w.rendezvous = rendezvous.NewRendezvous(w.host, w.opts.rendezvousDB, w.peerConnector, w.log)
|
||||
w.relay = relay.NewWakuRelay(w.host, w.bcaster, w.opts.minRelayPeersToPublish, w.timesource, w.log, w.opts.wOpts...)
|
||||
w.filter = filter.NewWakuFilter(w.host, w.bcaster, w.opts.isFilterFullNode, w.timesource, w.log, w.opts.filterOpts...)
|
||||
w.filterV2Full = filterv2.NewWakuFilterFullnode(w.host, w.bcaster, w.timesource, w.log, w.opts.filterV2Opts...)
|
||||
|
@ -390,6 +393,13 @@ func (w *WakuNode) Start(ctx context.Context) error {
|
|||
}
|
||||
}
|
||||
|
||||
if w.opts.enableRendezvous {
|
||||
err := w.rendezvous.Start(ctx)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
}
|
||||
|
||||
if w.opts.enableRLN {
|
||||
err = w.mountRlnRelay(ctx)
|
||||
if err != nil {
|
||||
|
@ -415,6 +425,10 @@ func (w *WakuNode) Stop() {
|
|||
defer w.identificationEventSub.Close()
|
||||
defer w.addressChangesSub.Close()
|
||||
|
||||
if w.opts.enableRendezvous {
|
||||
w.rendezvous.Stop()
|
||||
}
|
||||
|
||||
w.relay.Stop()
|
||||
w.lightPush.Stop()
|
||||
w.store.Stop()
|
||||
|
|
|
@ -29,6 +29,7 @@ import (
|
|||
"github.com/waku-org/go-waku/waku/v2/protocol/filterv2"
|
||||
"github.com/waku-org/go-waku/waku/v2/protocol/pb"
|
||||
"github.com/waku-org/go-waku/waku/v2/protocol/store"
|
||||
"github.com/waku-org/go-waku/waku/v2/rendezvous"
|
||||
"github.com/waku-org/go-waku/waku/v2/timesource"
|
||||
"github.com/waku-org/go-waku/waku/v2/utils"
|
||||
"go.uber.org/zap"
|
||||
|
@ -79,6 +80,9 @@ type WakuNodeParameters struct {
|
|||
resumeNodes []multiaddr.Multiaddr
|
||||
messageProvider store.MessageProvider
|
||||
|
||||
enableRendezvous bool
|
||||
rendezvousDB *rendezvous.DB
|
||||
|
||||
swapMode int
|
||||
swapDisconnectThreshold int
|
||||
swapPaymentThreshold int
|
||||
|
@ -433,6 +437,16 @@ func WithWebsockets(address string, port int) WakuNodeOption {
|
|||
}
|
||||
}
|
||||
|
||||
// WithRendezvousServer is a WakuOption used to set the node as a rendezvous
|
||||
// point, using an specific storage for the peer information
|
||||
func WithRendezvousServer(db *rendezvous.DB) WakuNodeOption {
|
||||
return func(params *WakuNodeParameters) error {
|
||||
params.enableRendezvous = true
|
||||
params.rendezvousDB = db
|
||||
return nil
|
||||
}
|
||||
}
|
||||
|
||||
// WithSecureWebsockets is a WakuNodeOption used to enable secure websockets support
|
||||
func WithSecureWebsockets(address string, port int, certPath string, keyPath string) WakuNodeOption {
|
||||
return func(params *WakuNodeParameters) error {
|
||||
|
|
|
@ -4,7 +4,6 @@ import (
|
|||
"context"
|
||||
"encoding/hex"
|
||||
"errors"
|
||||
"fmt"
|
||||
"math"
|
||||
"net/http"
|
||||
"sync"
|
||||
|
@ -82,9 +81,6 @@ func (wf *WakuFilterLightnode) Start(ctx context.Context) error {
|
|||
|
||||
wf.h.SetStreamHandlerMatch(FilterPushID_v20beta1, protocol.PrefixTextMatch(string(FilterPushID_v20beta1)), wf.onRequest(ctx))
|
||||
|
||||
// wf.wg.Add(1)
|
||||
// TODO: go wf.keepAliveSubscriptions(ctx)
|
||||
|
||||
wf.log.Info("filter protocol (light) started")
|
||||
|
||||
return nil
|
||||
|
@ -175,7 +171,8 @@ func (wf *WakuFilterLightnode) request(ctx context.Context, params *FilterSubscr
|
|||
}
|
||||
|
||||
if filterSubscribeResponse.StatusCode != http.StatusOK {
|
||||
return fmt.Errorf("filter err: %d, %s", filterSubscribeResponse.StatusCode, filterSubscribeResponse.StatusDesc)
|
||||
err := NewFilterError(int(filterSubscribeResponse.StatusCode), filterSubscribeResponse.StatusDesc)
|
||||
return &err
|
||||
}
|
||||
|
||||
return nil
|
||||
|
@ -210,12 +207,16 @@ func (wf *WakuFilterLightnode) Subscribe(ctx context.Context, contentFilter Cont
|
|||
return nil, err
|
||||
}
|
||||
|
||||
return wf.FilterSubscription(params.selectedPeer, contentFilter), nil
|
||||
return wf.subscriptions.NewSubscription(params.selectedPeer, contentFilter.Topic, contentFilter.ContentTopics), nil
|
||||
}
|
||||
|
||||
// FilterSubscription is used to obtain an object from which you could receive messages received via filter protocol
|
||||
func (wf *WakuFilterLightnode) FilterSubscription(peerID peer.ID, contentFilter ContentFilter) *SubscriptionDetails {
|
||||
return wf.subscriptions.NewSubscription(peerID, contentFilter.Topic, contentFilter.ContentTopics)
|
||||
func (wf *WakuFilterLightnode) FilterSubscription(peerID peer.ID, contentFilter ContentFilter) (*SubscriptionDetails, error) {
|
||||
if !wf.subscriptions.Has(peerID, contentFilter.Topic, contentFilter.ContentTopics) {
|
||||
return nil, errors.New("subscription does not exist")
|
||||
}
|
||||
|
||||
return wf.subscriptions.NewSubscription(peerID, contentFilter.Topic, contentFilter.ContentTopics), nil
|
||||
}
|
||||
|
||||
func (wf *WakuFilterLightnode) getUnsubscribeParameters(opts ...FilterUnsubscribeOption) (*FilterUnsubscribeParameters, error) {
|
||||
|
@ -232,6 +233,18 @@ func (wf *WakuFilterLightnode) getUnsubscribeParameters(opts ...FilterUnsubscrib
|
|||
return params, nil
|
||||
}
|
||||
|
||||
func (wf *WakuFilterLightnode) Ping(ctx context.Context, peerID peer.ID) error {
|
||||
return wf.request(
|
||||
ctx,
|
||||
&FilterSubscribeParameters{selectedPeer: peerID},
|
||||
pb.FilterSubscribeRequest_SUBSCRIBER_PING,
|
||||
ContentFilter{})
|
||||
}
|
||||
|
||||
func (wf *WakuFilterLightnode) IsSubscriptionAlive(ctx context.Context, subscription *SubscriptionDetails) error {
|
||||
return wf.Ping(ctx, subscription.peerID)
|
||||
}
|
||||
|
||||
// Unsubscribe is used to stop receiving messages from a peer that match a content filter
|
||||
func (wf *WakuFilterLightnode) Unsubscribe(ctx context.Context, contentFilter ContentFilter, opts ...FilterUnsubscribeOption) (<-chan WakuFilterPushResult, error) {
|
||||
if contentFilter.Topic == "" {
|
||||
|
|
|
@ -1,4 +1,22 @@
|
|||
package filterv2
|
||||
|
||||
import "fmt"
|
||||
|
||||
const DefaultMaxSubscriptions = 1000
|
||||
const MaxCriteriaPerSubscription = 1000
|
||||
|
||||
type FilterError struct {
|
||||
Code int
|
||||
Message string
|
||||
}
|
||||
|
||||
func NewFilterError(code int, message string) FilterError {
|
||||
return FilterError{
|
||||
Code: code,
|
||||
Message: message,
|
||||
}
|
||||
}
|
||||
|
||||
func (e *FilterError) Error() string {
|
||||
return fmt.Sprintf("error %d: %s", e.Code, e.Message)
|
||||
}
|
||||
|
|
33
vendor/github.com/waku-org/go-waku/waku/v2/protocol/filterv2/subscriptions_map.go
generated
vendored
33
vendor/github.com/waku-org/go-waku/waku/v2/protocol/filterv2/subscriptions_map.go
generated
vendored
|
@ -76,6 +76,37 @@ func (sub *SubscriptionsMap) NewSubscription(peerID peer.ID, topic string, conte
|
|||
return details
|
||||
}
|
||||
|
||||
func (sub *SubscriptionsMap) Has(peerID peer.ID, topic string, contentTopics []string) bool {
|
||||
// Check if peer exits
|
||||
peerSubscription, ok := sub.items[peerID]
|
||||
if !ok {
|
||||
return false
|
||||
}
|
||||
|
||||
// Check if pubsub topic exists
|
||||
subscriptions, ok := peerSubscription.subscriptionsPerTopic[topic]
|
||||
if !ok {
|
||||
return false
|
||||
}
|
||||
|
||||
// Check if the content topic exists within the list of subscriptions for this peer
|
||||
for _, ct := range contentTopics {
|
||||
found := false
|
||||
for _, subscription := range subscriptions {
|
||||
_, exists := subscription.contentTopics[ct]
|
||||
if exists {
|
||||
found = true
|
||||
break
|
||||
}
|
||||
}
|
||||
if !found {
|
||||
return false
|
||||
}
|
||||
}
|
||||
|
||||
return true
|
||||
}
|
||||
|
||||
func (sub *SubscriptionsMap) Delete(subscription *SubscriptionDetails) error {
|
||||
sub.Lock()
|
||||
defer sub.Unlock()
|
||||
|
@ -116,7 +147,6 @@ func (s *SubscriptionDetails) closeC() {
|
|||
s.closed = true
|
||||
close(s.C)
|
||||
})
|
||||
|
||||
}
|
||||
|
||||
func (s *SubscriptionDetails) Close() error {
|
||||
|
@ -185,6 +215,7 @@ func iterateSubscriptionSet(subscriptions SubscriptionSet, envelope *protocol.En
|
|||
}
|
||||
|
||||
if !subscription.closed {
|
||||
// TODO: consider pushing or dropping if subscription is not available
|
||||
subscription.C <- envelope
|
||||
}
|
||||
}(subscription)
|
||||
|
|
|
@ -49,6 +49,7 @@ type WakuSwap interface {
|
|||
|
||||
type WakuStore struct {
|
||||
ctx context.Context
|
||||
cancel context.CancelFunc
|
||||
timesource timesource.Timesource
|
||||
MsgC chan *protocol.Envelope
|
||||
wg *sync.WaitGroup
|
||||
|
@ -56,7 +57,6 @@ type WakuStore struct {
|
|||
log *zap.Logger
|
||||
|
||||
started bool
|
||||
quit chan struct{}
|
||||
|
||||
msgProvider MessageProvider
|
||||
h host.Host
|
||||
|
@ -71,7 +71,6 @@ func NewWakuStore(host host.Host, swap WakuSwap, p MessageProvider, timesource t
|
|||
wakuStore.swap = swap
|
||||
wakuStore.wg = &sync.WaitGroup{}
|
||||
wakuStore.log = log.Named("store")
|
||||
wakuStore.quit = make(chan struct{})
|
||||
wakuStore.timesource = timesource
|
||||
|
||||
return wakuStore
|
||||
|
|
21
vendor/github.com/waku-org/go-waku/waku/v2/protocol/store/waku_store_protocol.go
generated
vendored
21
vendor/github.com/waku-org/go-waku/waku/v2/protocol/store/waku_store_protocol.go
generated
vendored
|
@ -79,7 +79,7 @@ type MessageProvider interface {
|
|||
Query(query *pb.HistoryQuery) (*pb.Index, []persistence.StoredMessage, error)
|
||||
Put(env *protocol.Envelope) error
|
||||
MostRecentTimestamp() (int64, error)
|
||||
Start(timesource timesource.Timesource) error
|
||||
Start(ctx context.Context, timesource timesource.Timesource) error
|
||||
Stop()
|
||||
Count() (int, error)
|
||||
}
|
||||
|
@ -110,21 +110,21 @@ func (store *WakuStore) Start(ctx context.Context) error {
|
|||
return nil
|
||||
}
|
||||
|
||||
err := store.msgProvider.Start(store.timesource)
|
||||
err := store.msgProvider.Start(ctx, store.timesource) // TODO: store protocol should not start a message provider
|
||||
if err != nil {
|
||||
store.log.Error("Error starting message provider", zap.Error(err))
|
||||
return nil
|
||||
}
|
||||
|
||||
store.started = true
|
||||
store.ctx = ctx
|
||||
store.ctx, store.cancel = context.WithCancel(ctx)
|
||||
store.MsgC = make(chan *protocol.Envelope, 1024)
|
||||
|
||||
store.h.SetStreamHandlerMatch(StoreID_v20beta4, protocol.PrefixTextMatch(string(StoreID_v20beta4)), store.onRequest)
|
||||
|
||||
store.wg.Add(2)
|
||||
go store.storeIncomingMessages(ctx)
|
||||
go store.updateMetrics(ctx)
|
||||
go store.storeIncomingMessages(store.ctx)
|
||||
go store.updateMetrics(store.ctx)
|
||||
|
||||
store.log.Info("Store protocol started")
|
||||
|
||||
|
@ -174,7 +174,7 @@ func (store *WakuStore) updateMetrics(ctx context.Context) {
|
|||
} else {
|
||||
metrics.RecordMessage(store.ctx, "stored", msgCount)
|
||||
}
|
||||
case <-store.quit:
|
||||
case <-ctx.Done():
|
||||
return
|
||||
}
|
||||
}
|
||||
|
@ -229,6 +229,12 @@ func (store *WakuStore) MessageChannel() chan *protocol.Envelope {
|
|||
|
||||
// Stop closes the store message channel and removes the protocol stream handler
|
||||
func (store *WakuStore) Stop() {
|
||||
if store.cancel == nil {
|
||||
return
|
||||
}
|
||||
|
||||
store.cancel()
|
||||
|
||||
store.started = false
|
||||
|
||||
if store.MsgC != nil {
|
||||
|
@ -236,8 +242,7 @@ func (store *WakuStore) Stop() {
|
|||
}
|
||||
|
||||
if store.msgProvider != nil {
|
||||
store.msgProvider.Stop()
|
||||
store.quit <- struct{}{}
|
||||
store.msgProvider.Stop() // TODO: StoreProtocol should not stop a message provider
|
||||
}
|
||||
|
||||
if store.h != nil {
|
||||
|
|
|
@ -0,0 +1,439 @@
|
|||
package rendezvous
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"context"
|
||||
"crypto/rand"
|
||||
"crypto/sha256"
|
||||
"database/sql"
|
||||
"encoding/binary"
|
||||
"errors"
|
||||
"fmt"
|
||||
"time"
|
||||
|
||||
dbi "github.com/berty/go-libp2p-rendezvous/db"
|
||||
"github.com/libp2p/go-libp2p/core/peer"
|
||||
"go.uber.org/zap"
|
||||
)
|
||||
|
||||
type DB struct {
|
||||
db *sql.DB
|
||||
logger *zap.Logger
|
||||
|
||||
insertPeerRegistration *sql.Stmt
|
||||
deletePeerRegistrations *sql.Stmt
|
||||
deletePeerRegistrationsNs *sql.Stmt
|
||||
countPeerRegistrations *sql.Stmt
|
||||
selectPeerRegistrations *sql.Stmt
|
||||
selectPeerRegistrationsNS *sql.Stmt
|
||||
selectPeerRegistrationsC *sql.Stmt
|
||||
selectPeerRegistrationsNSC *sql.Stmt
|
||||
deleteExpiredRegistrations *sql.Stmt
|
||||
getCounter *sql.Stmt
|
||||
|
||||
nonce []byte
|
||||
|
||||
cancel func()
|
||||
}
|
||||
|
||||
func NewDB(ctx context.Context, db *sql.DB, logger *zap.Logger) *DB {
|
||||
rdb := &DB{
|
||||
db: db,
|
||||
logger: logger.Named("rendezvous/db"),
|
||||
}
|
||||
|
||||
return rdb
|
||||
}
|
||||
|
||||
func (db *DB) Start(ctx context.Context) error {
|
||||
err := db.loadNonce()
|
||||
if err != nil {
|
||||
db.Close()
|
||||
return err
|
||||
}
|
||||
|
||||
err = db.prepareStmts()
|
||||
if err != nil {
|
||||
db.Close()
|
||||
return err
|
||||
}
|
||||
|
||||
bgctx, cancel := context.WithCancel(ctx)
|
||||
db.cancel = cancel
|
||||
go db.background(bgctx)
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func (db *DB) Close() error {
|
||||
db.cancel()
|
||||
return db.db.Close()
|
||||
}
|
||||
|
||||
func (db *DB) insertNonce() error {
|
||||
nonce := make([]byte, 32)
|
||||
_, err := rand.Read(nonce)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
_, err = db.db.Exec("INSERT INTO nonce VALUES (?)", nonce)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
db.nonce = nonce
|
||||
return nil
|
||||
}
|
||||
|
||||
func (db *DB) loadNonce() error {
|
||||
var nonce []byte
|
||||
row := db.db.QueryRow("SELECT nonce FROM nonce")
|
||||
err := row.Scan(&nonce)
|
||||
if err != nil {
|
||||
if errors.Is(err, sql.ErrNoRows) {
|
||||
return db.insertNonce()
|
||||
}
|
||||
return err
|
||||
}
|
||||
db.nonce = nonce
|
||||
return nil
|
||||
}
|
||||
|
||||
func (db *DB) prepareStmts() error {
|
||||
stmt, err := db.db.Prepare("INSERT INTO registrations VALUES (NULL, ?, ?, ?, ?)")
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
db.insertPeerRegistration = stmt
|
||||
|
||||
stmt, err = db.db.Prepare("DELETE FROM registrations WHERE peer = ?")
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
db.deletePeerRegistrations = stmt
|
||||
|
||||
stmt, err = db.db.Prepare("DELETE FROM registrations WHERE peer = ? AND ns = ?")
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
db.deletePeerRegistrationsNs = stmt
|
||||
|
||||
stmt, err = db.db.Prepare("SELECT COUNT(*) FROM registrations WHERE peer = ?")
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
db.countPeerRegistrations = stmt
|
||||
|
||||
stmt, err = db.db.Prepare("SELECT * FROM registrations WHERE expire > ? LIMIT ?")
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
db.selectPeerRegistrations = stmt
|
||||
|
||||
stmt, err = db.db.Prepare("SELECT * FROM registrations WHERE ns = ? AND expire > ? LIMIT ?")
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
db.selectPeerRegistrationsNS = stmt
|
||||
|
||||
stmt, err = db.db.Prepare("SELECT * FROM registrations WHERE counter > ? AND expire > ? LIMIT ?")
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
db.selectPeerRegistrationsC = stmt
|
||||
|
||||
stmt, err = db.db.Prepare("SELECT * FROM registrations WHERE counter > ? AND ns = ? AND expire > ? LIMIT ?")
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
db.selectPeerRegistrationsNSC = stmt
|
||||
|
||||
stmt, err = db.db.Prepare("DELETE FROM registrations WHERE expire < ?")
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
db.deleteExpiredRegistrations = stmt
|
||||
|
||||
stmt, err = db.db.Prepare("SELECT MAX(counter) FROM registrations")
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
db.getCounter = stmt
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func (db *DB) Register(p peer.ID, ns string, addrs [][]byte, ttl int) (uint64, error) {
|
||||
pid := p.Pretty()
|
||||
maddrs := packAddrs(addrs)
|
||||
expire := time.Now().Unix() + int64(ttl)
|
||||
|
||||
tx, err := db.db.Begin()
|
||||
if err != nil {
|
||||
return 0, err
|
||||
}
|
||||
|
||||
delOld := tx.Stmt(db.deletePeerRegistrationsNs)
|
||||
insertNew := tx.Stmt(db.insertPeerRegistration)
|
||||
getCounter := tx.Stmt(db.getCounter)
|
||||
|
||||
_, err = delOld.Exec(pid, ns)
|
||||
if err != nil {
|
||||
_ = tx.Rollback()
|
||||
return 0, err
|
||||
}
|
||||
|
||||
_, err = insertNew.Exec(pid, ns, expire, maddrs)
|
||||
if err != nil {
|
||||
_ = tx.Rollback()
|
||||
return 0, err
|
||||
}
|
||||
|
||||
var counter uint64
|
||||
row := getCounter.QueryRow()
|
||||
err = row.Scan(&counter)
|
||||
if err != nil {
|
||||
_ = tx.Rollback()
|
||||
return 0, err
|
||||
}
|
||||
|
||||
err = tx.Commit()
|
||||
return counter, err
|
||||
}
|
||||
|
||||
func (db *DB) CountRegistrations(p peer.ID) (int, error) {
|
||||
pid := p.Pretty()
|
||||
|
||||
row := db.countPeerRegistrations.QueryRow(pid)
|
||||
|
||||
var count int
|
||||
err := row.Scan(&count)
|
||||
|
||||
return count, err
|
||||
}
|
||||
|
||||
func (db *DB) Unregister(p peer.ID, ns string) error {
|
||||
pid := p.Pretty()
|
||||
|
||||
var err error
|
||||
|
||||
if ns == "" {
|
||||
_, err = db.deletePeerRegistrations.Exec(pid)
|
||||
} else {
|
||||
_, err = db.deletePeerRegistrationsNs.Exec(pid, ns)
|
||||
}
|
||||
|
||||
return err
|
||||
}
|
||||
|
||||
func (db *DB) Discover(ns string, cookie []byte, limit int) ([]dbi.RegistrationRecord, []byte, error) {
|
||||
now := time.Now().Unix()
|
||||
|
||||
var (
|
||||
counter int64
|
||||
rows *sql.Rows
|
||||
err error
|
||||
)
|
||||
|
||||
if cookie != nil {
|
||||
counter, err = unpackCookie(cookie)
|
||||
if err != nil {
|
||||
db.logger.Error("unpacking cookie", zap.Error(err))
|
||||
return nil, nil, err
|
||||
}
|
||||
}
|
||||
|
||||
if counter > 0 {
|
||||
if ns == "" {
|
||||
rows, err = db.selectPeerRegistrationsC.Query(counter, now, limit)
|
||||
} else {
|
||||
rows, err = db.selectPeerRegistrationsNSC.Query(counter, ns, now, limit)
|
||||
}
|
||||
} else {
|
||||
if ns == "" {
|
||||
rows, err = db.selectPeerRegistrations.Query(now, limit)
|
||||
} else {
|
||||
rows, err = db.selectPeerRegistrationsNS.Query(ns, now, limit)
|
||||
}
|
||||
}
|
||||
|
||||
if err != nil {
|
||||
db.logger.Error("query", zap.Error(err))
|
||||
return nil, nil, err
|
||||
}
|
||||
|
||||
defer rows.Close()
|
||||
|
||||
regs := make([]dbi.RegistrationRecord, 0, limit)
|
||||
for rows.Next() {
|
||||
var (
|
||||
reg dbi.RegistrationRecord
|
||||
rid string
|
||||
rns string
|
||||
expire int64
|
||||
raddrs []byte
|
||||
addrs [][]byte
|
||||
p peer.ID
|
||||
)
|
||||
|
||||
err = rows.Scan(&counter, &rid, &rns, &expire, &raddrs)
|
||||
if err != nil {
|
||||
db.logger.Error("row scan error", zap.Error(err))
|
||||
return nil, nil, err
|
||||
}
|
||||
|
||||
p, err = peer.Decode(rid)
|
||||
if err != nil {
|
||||
db.logger.Error("error decoding peer id", zap.Error(err))
|
||||
continue
|
||||
}
|
||||
|
||||
addrs, err := unpackAddrs(raddrs)
|
||||
if err != nil {
|
||||
db.logger.Error("error unpacking address", zap.Error(err))
|
||||
continue
|
||||
}
|
||||
|
||||
reg.Id = p
|
||||
reg.Addrs = addrs
|
||||
reg.Ttl = int(expire - now)
|
||||
|
||||
if ns == "" {
|
||||
reg.Ns = rns
|
||||
}
|
||||
|
||||
regs = append(regs, reg)
|
||||
}
|
||||
|
||||
err = rows.Err()
|
||||
if err != nil {
|
||||
return nil, nil, err
|
||||
}
|
||||
|
||||
if counter > 0 {
|
||||
cookie = packCookie(counter, ns, db.nonce)
|
||||
}
|
||||
|
||||
return regs, cookie, nil
|
||||
}
|
||||
|
||||
func (db *DB) ValidCookie(ns string, cookie []byte) bool {
|
||||
return validCookie(cookie, ns, db.nonce)
|
||||
}
|
||||
|
||||
func (db *DB) background(ctx context.Context) {
|
||||
for {
|
||||
db.cleanupExpired()
|
||||
|
||||
select {
|
||||
case <-time.After(15 * time.Minute):
|
||||
case <-ctx.Done():
|
||||
return
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func (db *DB) cleanupExpired() {
|
||||
now := time.Now().Unix()
|
||||
_, err := db.deleteExpiredRegistrations.Exec(now)
|
||||
if err != nil {
|
||||
db.logger.Error("deleting expired registrations", zap.Error(err))
|
||||
}
|
||||
}
|
||||
|
||||
func packAddrs(addrs [][]byte) []byte {
|
||||
packlen := 0
|
||||
for _, addr := range addrs {
|
||||
packlen = packlen + 2 + len(addr)
|
||||
}
|
||||
|
||||
packed := make([]byte, packlen)
|
||||
buf := packed
|
||||
for _, addr := range addrs {
|
||||
binary.BigEndian.PutUint16(buf, uint16(len(addr)))
|
||||
buf = buf[2:]
|
||||
copy(buf, addr)
|
||||
buf = buf[len(addr):]
|
||||
}
|
||||
|
||||
return packed
|
||||
}
|
||||
|
||||
func unpackAddrs(packed []byte) ([][]byte, error) {
|
||||
var addrs [][]byte
|
||||
|
||||
buf := packed
|
||||
for len(buf) > 1 {
|
||||
l := binary.BigEndian.Uint16(buf)
|
||||
buf = buf[2:]
|
||||
if len(buf) < int(l) {
|
||||
return nil, fmt.Errorf("bad packed address: not enough bytes %v %v", packed, buf)
|
||||
}
|
||||
addr := make([]byte, l)
|
||||
copy(addr, buf[:l])
|
||||
buf = buf[l:]
|
||||
addrs = append(addrs, addr)
|
||||
}
|
||||
|
||||
if len(buf) > 0 {
|
||||
return nil, fmt.Errorf("bad packed address: unprocessed bytes: %v %v", packed, buf)
|
||||
}
|
||||
|
||||
return addrs, nil
|
||||
}
|
||||
|
||||
// cookie: counter:SHA256(nonce + ns + counter)
|
||||
func packCookie(counter int64, ns string, nonce []byte) []byte {
|
||||
cbits := make([]byte, 8)
|
||||
binary.BigEndian.PutUint64(cbits, uint64(counter))
|
||||
|
||||
hash := sha256.New()
|
||||
_, err := hash.Write(nonce)
|
||||
if err != nil {
|
||||
panic(err)
|
||||
}
|
||||
_, err = hash.Write([]byte(ns))
|
||||
if err != nil {
|
||||
panic(err)
|
||||
}
|
||||
_, err = hash.Write(cbits)
|
||||
if err != nil {
|
||||
panic(err)
|
||||
}
|
||||
|
||||
return hash.Sum(cbits)
|
||||
}
|
||||
|
||||
func unpackCookie(cookie []byte) (int64, error) {
|
||||
if len(cookie) < 8 {
|
||||
return 0, fmt.Errorf("bad packed cookie: not enough bytes: %v", cookie)
|
||||
}
|
||||
|
||||
counter := binary.BigEndian.Uint64(cookie[:8])
|
||||
return int64(counter), nil
|
||||
}
|
||||
|
||||
func validCookie(cookie []byte, ns string, nonce []byte) bool {
|
||||
if len(cookie) != 40 {
|
||||
return false
|
||||
}
|
||||
|
||||
cbits := cookie[:8]
|
||||
hash := sha256.New()
|
||||
_, err := hash.Write(nonce)
|
||||
if err != nil {
|
||||
panic(err)
|
||||
}
|
||||
_, err = hash.Write([]byte(ns))
|
||||
if err != nil {
|
||||
panic(err)
|
||||
}
|
||||
_, err = hash.Write(cbits)
|
||||
if err != nil {
|
||||
panic(err)
|
||||
}
|
||||
hbits := hash.Sum(nil)
|
||||
|
||||
return bytes.Equal(cookie[8:], hbits)
|
||||
}
|
52
vendor/github.com/waku-org/go-waku/waku/v2/rendezvous/rendezvous.go
generated
vendored
Normal file
52
vendor/github.com/waku-org/go-waku/waku/v2/rendezvous/rendezvous.go
generated
vendored
Normal file
|
@ -0,0 +1,52 @@
|
|||
package rendezvous
|
||||
|
||||
import (
|
||||
"context"
|
||||
|
||||
rvs "github.com/berty/go-libp2p-rendezvous"
|
||||
"github.com/libp2p/go-libp2p/core/host"
|
||||
"github.com/libp2p/go-libp2p/core/peer"
|
||||
"go.uber.org/zap"
|
||||
)
|
||||
|
||||
const RendezvousID = rvs.RendezvousProto
|
||||
|
||||
type Rendezvous struct {
|
||||
host host.Host
|
||||
peerConnector PeerConnector
|
||||
db *DB
|
||||
rendezvousSvc *rvs.RendezvousService
|
||||
|
||||
log *zap.Logger
|
||||
}
|
||||
|
||||
type PeerConnector interface {
|
||||
PeerChannel() chan<- peer.AddrInfo
|
||||
}
|
||||
|
||||
func NewRendezvous(host host.Host, db *DB, peerConnector PeerConnector, log *zap.Logger) *Rendezvous {
|
||||
logger := log.Named("rendezvous")
|
||||
|
||||
return &Rendezvous{
|
||||
host: host,
|
||||
db: db,
|
||||
peerConnector: peerConnector,
|
||||
log: logger,
|
||||
}
|
||||
}
|
||||
|
||||
func (r *Rendezvous) Start(ctx context.Context) error {
|
||||
err := r.db.Start(ctx)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
r.rendezvousSvc = rvs.NewRendezvousService(r.host, r.db)
|
||||
r.log.Info("rendezvous protocol started")
|
||||
return nil
|
||||
}
|
||||
|
||||
func (r *Rendezvous) Stop() {
|
||||
r.host.RemoveStreamHandler(rvs.RendezvousProto)
|
||||
r.rendezvousSvc = nil
|
||||
}
|
|
@ -124,6 +124,11 @@ github.com/benbjohnson/immutable
|
|||
# github.com/beorn7/perks v1.0.1
|
||||
## explicit; go 1.11
|
||||
github.com/beorn7/perks/quantile
|
||||
# github.com/berty/go-libp2p-rendezvous v0.4.1
|
||||
## explicit; go 1.18
|
||||
github.com/berty/go-libp2p-rendezvous
|
||||
github.com/berty/go-libp2p-rendezvous/db
|
||||
github.com/berty/go-libp2p-rendezvous/pb
|
||||
# github.com/bits-and-blooms/bitset v1.2.0
|
||||
## explicit; go 1.14
|
||||
github.com/bits-and-blooms/bitset
|
||||
|
@ -305,6 +310,7 @@ github.com/godbus/dbus/v5
|
|||
# github.com/gogo/protobuf v1.3.2
|
||||
## explicit; go 1.15
|
||||
github.com/gogo/protobuf/gogoproto
|
||||
github.com/gogo/protobuf/io
|
||||
github.com/gogo/protobuf/proto
|
||||
github.com/gogo/protobuf/protoc-gen-gogo/descriptor
|
||||
# github.com/golang-jwt/jwt/v4 v4.3.0
|
||||
|
@ -535,7 +541,7 @@ github.com/libp2p/go-libp2p/p2p/transport/webtransport
|
|||
# github.com/libp2p/go-libp2p-asn-util v0.2.0
|
||||
## explicit; go 1.17
|
||||
github.com/libp2p/go-libp2p-asn-util
|
||||
# github.com/libp2p/go-libp2p-pubsub v0.9.1
|
||||
# github.com/libp2p/go-libp2p-pubsub v0.9.3
|
||||
## explicit; go 1.19
|
||||
github.com/libp2p/go-libp2p-pubsub
|
||||
github.com/libp2p/go-libp2p-pubsub/pb
|
||||
|
@ -971,7 +977,7 @@ github.com/vacp2p/mvds/transport
|
|||
github.com/waku-org/go-discover/discover
|
||||
github.com/waku-org/go-discover/discover/v4wire
|
||||
github.com/waku-org/go-discover/discover/v5wire
|
||||
# github.com/waku-org/go-waku v0.5.2-0.20230308135126-4b52983fc483
|
||||
# github.com/waku-org/go-waku v0.5.3-0.20230327132601-b540953f74e9
|
||||
## explicit; go 1.19
|
||||
github.com/waku-org/go-waku/logging
|
||||
github.com/waku-org/go-waku/waku/persistence
|
||||
|
@ -999,6 +1005,7 @@ github.com/waku-org/go-waku/waku/v2/protocol/rln/contracts
|
|||
github.com/waku-org/go-waku/waku/v2/protocol/store
|
||||
github.com/waku-org/go-waku/waku/v2/protocol/store/pb
|
||||
github.com/waku-org/go-waku/waku/v2/protocol/swap
|
||||
github.com/waku-org/go-waku/waku/v2/rendezvous
|
||||
github.com/waku-org/go-waku/waku/v2/timesource
|
||||
github.com/waku-org/go-waku/waku/v2/utils
|
||||
# github.com/waku-org/go-zerokit-rln v0.1.7-wakuorg
|
||||
|
|
|
@ -1,6 +1,7 @@
|
|||
package persistence
|
||||
|
||||
import (
|
||||
"context"
|
||||
"database/sql"
|
||||
"errors"
|
||||
"fmt"
|
||||
|
@ -28,8 +29,8 @@ type DBStore struct {
|
|||
maxMessages int
|
||||
maxDuration time.Duration
|
||||
|
||||
wg sync.WaitGroup
|
||||
quit chan struct{}
|
||||
wg sync.WaitGroup
|
||||
cancel context.CancelFunc
|
||||
}
|
||||
|
||||
// DBOption is an optional setting that can be used to configure the DBStore
|
||||
|
@ -59,7 +60,6 @@ func WithRetentionPolicy(maxMessages int, maxDuration time.Duration) DBOption {
|
|||
func NewDBStore(log *zap.Logger, options ...DBOption) (*DBStore, error) {
|
||||
result := new(DBStore)
|
||||
result.log = log.Named("dbstore")
|
||||
result.quit = make(chan struct{})
|
||||
|
||||
for _, opt := range options {
|
||||
err := opt(result)
|
||||
|
@ -71,14 +71,18 @@ func NewDBStore(log *zap.Logger, options ...DBOption) (*DBStore, error) {
|
|||
return result, nil
|
||||
}
|
||||
|
||||
func (d *DBStore) Start(timesource timesource.Timesource) error {
|
||||
func (d *DBStore) Start(ctx context.Context, timesource timesource.Timesource) error {
|
||||
ctx, cancel := context.WithCancel(ctx)
|
||||
|
||||
d.cancel = cancel
|
||||
|
||||
err := d.cleanOlderRecords()
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
d.wg.Add(1)
|
||||
go d.checkForOlderRecords(60 * time.Second)
|
||||
go d.checkForOlderRecords(ctx, 60*time.Second)
|
||||
|
||||
return nil
|
||||
}
|
||||
|
@ -113,7 +117,7 @@ func (d *DBStore) cleanOlderRecords() error {
|
|||
return nil
|
||||
}
|
||||
|
||||
func (d *DBStore) checkForOlderRecords(t time.Duration) {
|
||||
func (d *DBStore) checkForOlderRecords(ctx context.Context, t time.Duration) {
|
||||
defer d.wg.Done()
|
||||
|
||||
ticker := time.NewTicker(t)
|
||||
|
@ -121,7 +125,7 @@ func (d *DBStore) checkForOlderRecords(t time.Duration) {
|
|||
|
||||
for {
|
||||
select {
|
||||
case <-d.quit:
|
||||
case <-ctx.Done():
|
||||
return
|
||||
case <-ticker.C:
|
||||
err := d.cleanOlderRecords()
|
||||
|
@ -134,7 +138,11 @@ func (d *DBStore) checkForOlderRecords(t time.Duration) {
|
|||
|
||||
// Stop closes a DB connection
|
||||
func (d *DBStore) Stop() {
|
||||
d.quit <- struct{}{}
|
||||
if d.cancel == nil {
|
||||
return
|
||||
}
|
||||
|
||||
d.cancel()
|
||||
d.wg.Wait()
|
||||
d.db.Close()
|
||||
}
|
||||
|
|
Loading…
Reference in New Issue