chore: error handling in params construction (#934)

This commit is contained in:
kaichao 2023-12-01 13:04:32 +08:00 committed by GitHub
parent 0b4df80b98
commit 16d59f37d0
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
2 changed files with 33 additions and 4 deletions

View File

@ -2,6 +2,7 @@ package dnsdisc
import (
"context"
"errors"
"github.com/ethereum/go-ethereum/p2p/dnsdisc"
"github.com/ethereum/go-ethereum/p2p/enode"
@ -18,18 +19,28 @@ type dnsDiscoveryParameters struct {
resolver dnsdisc.Resolver
}
type DNSDiscoveryOption func(*dnsDiscoveryParameters)
type DNSDiscoveryOption func(*dnsDiscoveryParameters) error
var ErrExclusiveOpts = errors.New("cannot set both nameserver and resolver")
// WithNameserver is a DnsDiscoveryOption that configures the nameserver to use
func WithNameserver(nameserver string) DNSDiscoveryOption {
return func(params *dnsDiscoveryParameters) {
return func(params *dnsDiscoveryParameters) error {
if params.resolver != nil {
return ErrExclusiveOpts
}
params.nameserver = nameserver
return nil
}
}
func WithResolver(resolver dnsdisc.Resolver) DNSDiscoveryOption {
return func(params *dnsDiscoveryParameters) {
return func(params *dnsDiscoveryParameters) error {
if params.nameserver != "" {
return ErrExclusiveOpts
}
params.resolver = resolver
return nil
}
}
@ -56,7 +67,10 @@ func RetrieveNodes(ctx context.Context, url string, opts ...DNSDiscoveryOption)
params := new(dnsDiscoveryParameters)
for _, opt := range opts {
opt(params)
err := opt(params)
if err != nil {
return nil, err
}
}
if params.resolver == nil {

View File

@ -63,3 +63,18 @@ func TestRetrieveNodes(t *testing.T) {
require.NoError(t, err)
require.Equal(t, len(discoveredNodes), 2)
}
func TestExclusiveOpts(t *testing.T) {
var opts []DNSDiscoveryOption
tree, url := makeTestTree("n", nil, nil)
resolver := mapResolver(tree.ToTXT("n"))
opts = append(opts, WithNameserver("1.1.1.1"), WithResolver(resolver))
_, err := RetrieveNodes(context.Background(), url, opts...)
require.Equal(t, err, ErrExclusiveOpts)
opts = append(opts, WithResolver(resolver), WithNameserver("1.1.1.1"))
_, err = RetrieveNodes(context.Background(), url, opts...)
require.Equal(t, err, ErrExclusiveOpts)
}