diff --git a/waku/v2/dnsdisc/enr.go b/waku/v2/dnsdisc/enr.go index 6c6c8442..4228636b 100644 --- a/waku/v2/dnsdisc/enr.go +++ b/waku/v2/dnsdisc/enr.go @@ -2,6 +2,7 @@ package dnsdisc import ( "context" + "errors" "github.com/ethereum/go-ethereum/p2p/dnsdisc" "github.com/ethereum/go-ethereum/p2p/enode" @@ -18,18 +19,28 @@ type dnsDiscoveryParameters struct { resolver dnsdisc.Resolver } -type DNSDiscoveryOption func(*dnsDiscoveryParameters) +type DNSDiscoveryOption func(*dnsDiscoveryParameters) error + +var ErrExclusiveOpts = errors.New("cannot set both nameserver and resolver") // WithNameserver is a DnsDiscoveryOption that configures the nameserver to use func WithNameserver(nameserver string) DNSDiscoveryOption { - return func(params *dnsDiscoveryParameters) { + return func(params *dnsDiscoveryParameters) error { + if params.resolver != nil { + return ErrExclusiveOpts + } params.nameserver = nameserver + return nil } } func WithResolver(resolver dnsdisc.Resolver) DNSDiscoveryOption { - return func(params *dnsDiscoveryParameters) { + return func(params *dnsDiscoveryParameters) error { + if params.nameserver != "" { + return ErrExclusiveOpts + } params.resolver = resolver + return nil } } @@ -56,7 +67,10 @@ func RetrieveNodes(ctx context.Context, url string, opts ...DNSDiscoveryOption) params := new(dnsDiscoveryParameters) for _, opt := range opts { - opt(params) + err := opt(params) + if err != nil { + return nil, err + } } if params.resolver == nil { diff --git a/waku/v2/dnsdisc/enr_test.go b/waku/v2/dnsdisc/enr_test.go index c1c4f0f5..afdb8bc2 100644 --- a/waku/v2/dnsdisc/enr_test.go +++ b/waku/v2/dnsdisc/enr_test.go @@ -63,3 +63,18 @@ func TestRetrieveNodes(t *testing.T) { require.NoError(t, err) require.Equal(t, len(discoveredNodes), 2) } + +func TestExclusiveOpts(t *testing.T) { + var opts []DNSDiscoveryOption + + tree, url := makeTestTree("n", nil, nil) + resolver := mapResolver(tree.ToTXT("n")) + + opts = append(opts, WithNameserver("1.1.1.1"), WithResolver(resolver)) + _, err := RetrieveNodes(context.Background(), url, opts...) + require.Equal(t, err, ErrExclusiveOpts) + + opts = append(opts, WithResolver(resolver), WithNameserver("1.1.1.1")) + _, err = RetrieveNodes(context.Background(), url, opts...) + require.Equal(t, err, ErrExclusiveOpts) +}