mirror of https://github.com/status-im/op-geth.git
Merge branch 'fjl-poc8-net-integration' into develop
This commit is contained in:
commit
5c251b6928
|
@ -0,0 +1,93 @@
|
|||
/*
|
||||
This file is part of go-ethereum
|
||||
|
||||
go-ethereum is free software: you can redistribute it and/or modify
|
||||
it under the terms of the GNU General Public License as published by
|
||||
the Free Software Foundation, either version 3 of the License, or
|
||||
(at your option) any later version.
|
||||
|
||||
go-ethereum is distributed in the hope that it will be useful,
|
||||
but WITHOUT ANY WARRANTY; without even the implied warranty of
|
||||
MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
|
||||
GNU General Public License for more details.
|
||||
|
||||
You should have received a copy of the GNU General Public License
|
||||
along with go-ethereum. If not, see <http://www.gnu.org/licenses/>.
|
||||
*/
|
||||
|
||||
// Command bootnode runs a bootstrap node for the Discovery Protocol.
|
||||
package main
|
||||
|
||||
import (
|
||||
"crypto/ecdsa"
|
||||
"encoding/hex"
|
||||
"flag"
|
||||
"fmt"
|
||||
"io/ioutil"
|
||||
"log"
|
||||
"os"
|
||||
|
||||
"github.com/ethereum/go-ethereum/crypto"
|
||||
"github.com/ethereum/go-ethereum/logger"
|
||||
"github.com/ethereum/go-ethereum/p2p/discover"
|
||||
"github.com/ethereum/go-ethereum/p2p/nat"
|
||||
)
|
||||
|
||||
func main() {
|
||||
var (
|
||||
listenAddr = flag.String("addr", ":30301", "listen address")
|
||||
genKey = flag.String("genkey", "", "generate a node key and quit")
|
||||
nodeKeyFile = flag.String("nodekey", "", "private key filename")
|
||||
nodeKeyHex = flag.String("nodekeyhex", "", "private key as hex (for testing)")
|
||||
natdesc = flag.String("nat", "none", "port mapping mechanism (any|none|upnp|pmp|extip:<IP>)")
|
||||
|
||||
nodeKey *ecdsa.PrivateKey
|
||||
err error
|
||||
)
|
||||
flag.Parse()
|
||||
logger.AddLogSystem(logger.NewStdLogSystem(os.Stdout, log.LstdFlags, logger.DebugLevel))
|
||||
|
||||
if *genKey != "" {
|
||||
writeKey(*genKey)
|
||||
os.Exit(0)
|
||||
}
|
||||
|
||||
natm, err := nat.Parse(*natdesc)
|
||||
if err != nil {
|
||||
log.Fatalf("-nat: %v", err)
|
||||
}
|
||||
switch {
|
||||
case *nodeKeyFile == "" && *nodeKeyHex == "":
|
||||
log.Fatal("Use -nodekey or -nodekeyhex to specify a private key")
|
||||
case *nodeKeyFile != "" && *nodeKeyHex != "":
|
||||
log.Fatal("Options -nodekey and -nodekeyhex are mutually exclusive")
|
||||
case *nodeKeyFile != "":
|
||||
if nodeKey, err = crypto.LoadECDSA(*nodeKeyFile); err != nil {
|
||||
log.Fatalf("-nodekey: %v", err)
|
||||
}
|
||||
case *nodeKeyHex != "":
|
||||
if nodeKey, err = crypto.HexToECDSA(*nodeKeyHex); err != nil {
|
||||
log.Fatalf("-nodekeyhex: %v", err)
|
||||
}
|
||||
}
|
||||
|
||||
if _, err := discover.ListenUDP(nodeKey, *listenAddr, natm); err != nil {
|
||||
log.Fatal(err)
|
||||
}
|
||||
select {}
|
||||
}
|
||||
|
||||
func writeKey(target string) {
|
||||
key, err := crypto.GenerateKey()
|
||||
if err != nil {
|
||||
log.Fatal("could not generate key: %v", err)
|
||||
}
|
||||
b := crypto.FromECDSA(key)
|
||||
if target == "-" {
|
||||
fmt.Println(hex.EncodeToString(b))
|
||||
} else {
|
||||
if err := ioutil.WriteFile(target, b, 0600); err != nil {
|
||||
log.Fatal("write error: ", err)
|
||||
}
|
||||
}
|
||||
}
|
|
@ -21,6 +21,7 @@
|
|||
package main
|
||||
|
||||
import (
|
||||
"crypto/ecdsa"
|
||||
"flag"
|
||||
"fmt"
|
||||
"log"
|
||||
|
@ -28,7 +29,9 @@ import (
|
|||
"os/user"
|
||||
"path"
|
||||
|
||||
"github.com/ethereum/go-ethereum/crypto"
|
||||
"github.com/ethereum/go-ethereum/logger"
|
||||
"github.com/ethereum/go-ethereum/p2p/nat"
|
||||
"github.com/ethereum/go-ethereum/vm"
|
||||
)
|
||||
|
||||
|
@ -42,14 +45,14 @@ var (
|
|||
StartWebSockets bool
|
||||
RpcPort int
|
||||
WsPort int
|
||||
NatType string
|
||||
PMPGateway string
|
||||
OutboundPort string
|
||||
ShowGenesis bool
|
||||
AddPeer string
|
||||
MaxPeer int
|
||||
GenAddr bool
|
||||
SeedNode string
|
||||
BootNodes string
|
||||
NodeKey *ecdsa.PrivateKey
|
||||
NAT nat.Interface
|
||||
SecretFile string
|
||||
ExportDir string
|
||||
NonInteractive bool
|
||||
|
@ -84,6 +87,7 @@ func defaultDataDir() string {
|
|||
var defaultConfigFile = path.Join(defaultDataDir(), "conf.ini")
|
||||
|
||||
func Init() {
|
||||
// TODO: move common flag processing to cmd/util
|
||||
flag.Usage = func() {
|
||||
fmt.Fprintf(os.Stderr, "%s [options] [filename]:\noptions precedence: default < config file < environment variables < command line\n", os.Args[0])
|
||||
flag.PrintDefaults()
|
||||
|
@ -93,18 +97,12 @@ func Init() {
|
|||
flag.StringVar(&Identifier, "id", "", "Custom client identifier")
|
||||
flag.StringVar(&KeyRing, "keyring", "", "identifier for keyring to use")
|
||||
flag.StringVar(&KeyStore, "keystore", "db", "system to store keyrings: db|file (db)")
|
||||
flag.StringVar(&OutboundPort, "port", "30303", "listening port")
|
||||
flag.StringVar(&NatType, "nat", "", "NAT support (UPNP|PMP) (none)")
|
||||
flag.StringVar(&PMPGateway, "pmp", "", "Gateway IP for PMP")
|
||||
flag.IntVar(&MaxPeer, "maxpeer", 30, "maximum desired peers")
|
||||
|
||||
flag.IntVar(&RpcPort, "rpcport", 8545, "port to start json-rpc server on")
|
||||
flag.IntVar(&WsPort, "wsport", 40404, "port to start websocket rpc server on")
|
||||
flag.BoolVar(&StartRpc, "rpc", false, "start rpc server")
|
||||
flag.BoolVar(&StartWebSockets, "ws", false, "start websocket server")
|
||||
flag.BoolVar(&NonInteractive, "y", false, "non-interactive mode (say yes to confirmations)")
|
||||
flag.StringVar(&SeedNode, "seednode", "poc-8.ethdev.com:30303", "ip:port of seed node to connect to. Set to blank for skip")
|
||||
flag.BoolVar(&SHH, "shh", true, "whisper protocol (on)")
|
||||
flag.BoolVar(&Dial, "dial", true, "dial out connections (on)")
|
||||
flag.BoolVar(&GenAddr, "genaddr", false, "create a new priv/pub key")
|
||||
flag.StringVar(&SecretFile, "import", "", "imports the file given (hex or mnemonic formats)")
|
||||
flag.StringVar(&ExportDir, "export", "", "exports the session keyring to files in the directory given")
|
||||
|
@ -127,8 +125,38 @@ func Init() {
|
|||
flag.BoolVar(&StartJsConsole, "js", false, "launches javascript console")
|
||||
flag.BoolVar(&PrintVersion, "version", false, "prints version number")
|
||||
|
||||
// Network stuff
|
||||
var (
|
||||
nodeKeyFile = flag.String("nodekey", "", "network private key file")
|
||||
nodeKeyHex = flag.String("nodekeyhex", "", "network private key (for testing)")
|
||||
natstr = flag.String("nat", "any", "port mapping mechanism (any|none|upnp|pmp|extip:<IP>)")
|
||||
)
|
||||
flag.BoolVar(&Dial, "dial", true, "dial out connections (default on)")
|
||||
flag.BoolVar(&SHH, "shh", true, "run whisper protocol (default on)")
|
||||
flag.StringVar(&OutboundPort, "port", "30303", "listening port")
|
||||
|
||||
flag.StringVar(&BootNodes, "bootnodes", "", "space-separated node URLs for discovery bootstrap")
|
||||
flag.IntVar(&MaxPeer, "maxpeer", 30, "maximum desired peers")
|
||||
|
||||
flag.Parse()
|
||||
|
||||
var err error
|
||||
if NAT, err = nat.Parse(*natstr); err != nil {
|
||||
log.Fatalf("-nat: %v", err)
|
||||
}
|
||||
switch {
|
||||
case *nodeKeyFile != "" && *nodeKeyHex != "":
|
||||
log.Fatal("Options -nodekey and -nodekeyhex are mutually exclusive")
|
||||
case *nodeKeyFile != "":
|
||||
if NodeKey, err = crypto.LoadECDSA(*nodeKeyFile); err != nil {
|
||||
log.Fatalf("-nodekey: %v", err)
|
||||
}
|
||||
case *nodeKeyHex != "":
|
||||
if NodeKey, err = crypto.HexToECDSA(*nodeKeyHex); err != nil {
|
||||
log.Fatalf("-nodekeyhex: %v", err)
|
||||
}
|
||||
}
|
||||
|
||||
if VmType >= int(vm.MaxVmTy) {
|
||||
log.Fatal("Invalid VM type ", VmType)
|
||||
}
|
||||
|
|
|
@ -31,6 +31,7 @@ import (
|
|||
"github.com/ethereum/go-ethereum/eth"
|
||||
"github.com/ethereum/go-ethereum/ethutil"
|
||||
"github.com/ethereum/go-ethereum/logger"
|
||||
"github.com/ethereum/go-ethereum/p2p"
|
||||
"github.com/ethereum/go-ethereum/state"
|
||||
)
|
||||
|
||||
|
@ -61,21 +62,19 @@ func main() {
|
|||
utils.InitConfig(VmType, ConfigFile, Datadir, "ETH")
|
||||
|
||||
ethereum, err := eth.New(ð.Config{
|
||||
Name: ClientIdentifier,
|
||||
Version: Version,
|
||||
KeyStore: KeyStore,
|
||||
DataDir: Datadir,
|
||||
LogFile: LogFile,
|
||||
LogLevel: LogLevel,
|
||||
LogFormat: LogFormat,
|
||||
Identifier: Identifier,
|
||||
MaxPeers: MaxPeer,
|
||||
Port: OutboundPort,
|
||||
NATType: PMPGateway,
|
||||
PMPGateway: PMPGateway,
|
||||
KeyRing: KeyRing,
|
||||
Shh: SHH,
|
||||
Dial: Dial,
|
||||
Name: p2p.MakeName(ClientIdentifier, Version),
|
||||
KeyStore: KeyStore,
|
||||
DataDir: Datadir,
|
||||
LogFile: LogFile,
|
||||
LogLevel: LogLevel,
|
||||
MaxPeers: MaxPeer,
|
||||
Port: OutboundPort,
|
||||
NAT: NAT,
|
||||
KeyRing: KeyRing,
|
||||
Shh: SHH,
|
||||
Dial: Dial,
|
||||
BootNodes: BootNodes,
|
||||
NodeKey: NodeKey,
|
||||
})
|
||||
|
||||
if err != nil {
|
||||
|
@ -135,7 +134,7 @@ func main() {
|
|||
utils.StartWebSockets(ethereum, WsPort)
|
||||
}
|
||||
|
||||
utils.StartEthereum(ethereum, SeedNode)
|
||||
utils.StartEthereum(ethereum)
|
||||
|
||||
if StartJsConsole {
|
||||
InitJsConsole(ethereum)
|
||||
|
|
|
@ -79,6 +79,12 @@
|
|||
contract.received({from: eth.coinbase}).changed(function() {
|
||||
refresh();
|
||||
});
|
||||
|
||||
var ev = contract.SingleTransact({})
|
||||
ev.watch(function(log) {
|
||||
someElement.innerHTML += "tnaheousnthaoeu";
|
||||
});
|
||||
|
||||
eth.watch('chain').changed(function() {
|
||||
refresh();
|
||||
});
|
||||
|
|
|
@ -32,18 +32,6 @@ Rectangle {
|
|||
width: 500
|
||||
}
|
||||
|
||||
Label {
|
||||
text: "Client ID"
|
||||
}
|
||||
TextField {
|
||||
text: gui.getCustomIdentifier()
|
||||
width: 500
|
||||
placeholderText: "Anonymous"
|
||||
onTextChanged: {
|
||||
gui.setCustomIdentifier(text)
|
||||
}
|
||||
}
|
||||
|
||||
TextArea {
|
||||
objectName: "statsPane"
|
||||
width: parent.width
|
||||
|
|
|
@ -64,15 +64,6 @@ func (gui *Gui) Transact(recipient, value, gas, gasPrice, d string) (string, err
|
|||
return gui.xeth.Transact(recipient, value, gas, gasPrice, data)
|
||||
}
|
||||
|
||||
func (gui *Gui) SetCustomIdentifier(customIdentifier string) {
|
||||
gui.clientIdentity.SetCustomIdentifier(customIdentifier)
|
||||
gui.config.Save("id", customIdentifier)
|
||||
}
|
||||
|
||||
func (gui *Gui) GetCustomIdentifier() string {
|
||||
return gui.clientIdentity.GetCustomIdentifier()
|
||||
}
|
||||
|
||||
// functions that allow Gui to implement interface guilogger.LogSystem
|
||||
func (gui *Gui) SetLogLevel(level logger.LogLevel) {
|
||||
gui.logLevel = level
|
||||
|
|
|
@ -21,6 +21,7 @@
|
|||
package main
|
||||
|
||||
import (
|
||||
"crypto/ecdsa"
|
||||
"flag"
|
||||
"fmt"
|
||||
"log"
|
||||
|
@ -31,7 +32,9 @@ import (
|
|||
"runtime"
|
||||
|
||||
"bitbucket.org/kardianos/osext"
|
||||
"github.com/ethereum/go-ethereum/crypto"
|
||||
"github.com/ethereum/go-ethereum/logger"
|
||||
"github.com/ethereum/go-ethereum/p2p/nat"
|
||||
"github.com/ethereum/go-ethereum/vm"
|
||||
)
|
||||
|
||||
|
@ -39,19 +42,18 @@ var (
|
|||
Identifier string
|
||||
KeyRing string
|
||||
KeyStore string
|
||||
PMPGateway string
|
||||
StartRpc bool
|
||||
StartWebSockets bool
|
||||
RpcPort int
|
||||
WsPort int
|
||||
UseUPnP bool
|
||||
NatType string
|
||||
OutboundPort string
|
||||
ShowGenesis bool
|
||||
AddPeer string
|
||||
MaxPeer int
|
||||
GenAddr bool
|
||||
SeedNode string
|
||||
BootNodes string
|
||||
NodeKey *ecdsa.PrivateKey
|
||||
NAT nat.Interface
|
||||
SecretFile string
|
||||
ExportDir string
|
||||
NonInteractive bool
|
||||
|
@ -99,6 +101,7 @@ func defaultDataDir() string {
|
|||
var defaultConfigFile = path.Join(defaultDataDir(), "conf.ini")
|
||||
|
||||
func Init() {
|
||||
// TODO: move common flag processing to cmd/utils
|
||||
flag.Usage = func() {
|
||||
fmt.Fprintf(os.Stderr, "%s [options] [filename]:\noptions precedence: default < config file < environment variables < command line\n", os.Args[0])
|
||||
flag.PrintDefaults()
|
||||
|
@ -108,30 +111,51 @@ func Init() {
|
|||
flag.StringVar(&Identifier, "id", "", "Custom client identifier")
|
||||
flag.StringVar(&KeyRing, "keyring", "", "identifier for keyring to use")
|
||||
flag.StringVar(&KeyStore, "keystore", "db", "system to store keyrings: db|file (db)")
|
||||
flag.StringVar(&OutboundPort, "port", "30303", "listening port")
|
||||
flag.BoolVar(&UseUPnP, "upnp", true, "enable UPnP support")
|
||||
flag.IntVar(&MaxPeer, "maxpeer", 30, "maximum desired peers")
|
||||
flag.IntVar(&RpcPort, "rpcport", 8545, "port to start json-rpc server on")
|
||||
flag.IntVar(&WsPort, "wsport", 40404, "port to start websocket rpc server on")
|
||||
flag.BoolVar(&StartRpc, "rpc", true, "start rpc server")
|
||||
flag.BoolVar(&StartWebSockets, "ws", false, "start websocket server")
|
||||
flag.BoolVar(&NonInteractive, "y", false, "non-interactive mode (say yes to confirmations)")
|
||||
flag.StringVar(&SeedNode, "seednode", "poc-8.ethdev.com:30303", "ip:port of seed node to connect to. Set to blank for skip")
|
||||
flag.BoolVar(&GenAddr, "genaddr", false, "create a new priv/pub key")
|
||||
flag.StringVar(&NatType, "nat", "", "NAT support (UPNP|PMP) (none)")
|
||||
flag.StringVar(&SecretFile, "import", "", "imports the file given (hex or mnemonic formats)")
|
||||
flag.StringVar(&ExportDir, "export", "", "exports the session keyring to files in the directory given")
|
||||
flag.StringVar(&LogFile, "logfile", "", "log file (defaults to standard output)")
|
||||
flag.StringVar(&Datadir, "datadir", defaultDataDir(), "specifies the datadir to use")
|
||||
flag.StringVar(&PMPGateway, "pmp", "", "Gateway IP for PMP")
|
||||
flag.StringVar(&ConfigFile, "conf", defaultConfigFile, "config file")
|
||||
flag.StringVar(&DebugFile, "debug", "", "debug file (no debugging if not set)")
|
||||
flag.IntVar(&LogLevel, "loglevel", int(logger.InfoLevel), "loglevel: 0-5: silent,error,warn,info,debug,debug detail)")
|
||||
|
||||
flag.StringVar(&AssetPath, "asset_path", defaultAssetPath(), "absolute path to GUI assets directory")
|
||||
|
||||
// Network stuff
|
||||
var (
|
||||
nodeKeyFile = flag.String("nodekey", "", "network private key file")
|
||||
nodeKeyHex = flag.String("nodekeyhex", "", "network private key (for testing)")
|
||||
natstr = flag.String("nat", "any", "port mapping mechanism (any|none|upnp|pmp|extip:<IP>)")
|
||||
)
|
||||
flag.StringVar(&OutboundPort, "port", "30303", "listening port")
|
||||
flag.StringVar(&BootNodes, "bootnodes", "", "space-separated node URLs for discovery bootstrap")
|
||||
flag.IntVar(&MaxPeer, "maxpeer", 30, "maximum desired peers")
|
||||
|
||||
flag.Parse()
|
||||
|
||||
var err error
|
||||
if NAT, err = nat.Parse(*natstr); err != nil {
|
||||
log.Fatalf("-nat: %v", err)
|
||||
}
|
||||
switch {
|
||||
case *nodeKeyFile != "" && *nodeKeyHex != "":
|
||||
log.Fatal("Options -nodekey and -nodekeyhex are mutually exclusive")
|
||||
case *nodeKeyFile != "":
|
||||
if NodeKey, err = crypto.LoadECDSA(*nodeKeyFile); err != nil {
|
||||
log.Fatalf("-nodekey: %v", err)
|
||||
}
|
||||
case *nodeKeyHex != "":
|
||||
if NodeKey, err = crypto.HexToECDSA(*nodeKeyHex); err != nil {
|
||||
log.Fatalf("-nodekeyhex: %v", err)
|
||||
}
|
||||
}
|
||||
|
||||
if VmType >= int(vm.MaxVmTy) {
|
||||
log.Fatal("Invalid VM type ", VmType)
|
||||
}
|
||||
|
|
|
@ -41,7 +41,6 @@ import (
|
|||
"github.com/ethereum/go-ethereum/ethutil"
|
||||
"github.com/ethereum/go-ethereum/logger"
|
||||
"github.com/ethereum/go-ethereum/miner"
|
||||
"github.com/ethereum/go-ethereum/p2p"
|
||||
"github.com/ethereum/go-ethereum/ui/qt/qwhisper"
|
||||
"github.com/ethereum/go-ethereum/xeth"
|
||||
"github.com/obscuren/qml"
|
||||
|
@ -77,9 +76,8 @@ type Gui struct {
|
|||
|
||||
xeth *xeth.XEth
|
||||
|
||||
Session string
|
||||
clientIdentity *p2p.SimpleClientIdentity
|
||||
config *ethutil.ConfigManager
|
||||
Session string
|
||||
config *ethutil.ConfigManager
|
||||
|
||||
plugins map[string]plugin
|
||||
|
||||
|
@ -87,7 +85,7 @@ type Gui struct {
|
|||
}
|
||||
|
||||
// Create GUI, but doesn't start it
|
||||
func NewWindow(ethereum *eth.Ethereum, config *ethutil.ConfigManager, clientIdentity *p2p.SimpleClientIdentity, session string, logLevel int) *Gui {
|
||||
func NewWindow(ethereum *eth.Ethereum, config *ethutil.ConfigManager, session string, logLevel int) *Gui {
|
||||
db, err := ethdb.NewLDBDatabase("tx_database")
|
||||
if err != nil {
|
||||
panic(err)
|
||||
|
@ -95,15 +93,14 @@ func NewWindow(ethereum *eth.Ethereum, config *ethutil.ConfigManager, clientIden
|
|||
|
||||
xeth := xeth.New(ethereum)
|
||||
gui := &Gui{eth: ethereum,
|
||||
txDb: db,
|
||||
xeth: xeth,
|
||||
logLevel: logger.LogLevel(logLevel),
|
||||
Session: session,
|
||||
open: false,
|
||||
clientIdentity: clientIdentity,
|
||||
config: config,
|
||||
plugins: make(map[string]plugin),
|
||||
serviceEvents: make(chan ServEv, 1),
|
||||
txDb: db,
|
||||
xeth: xeth,
|
||||
logLevel: logger.LogLevel(logLevel),
|
||||
Session: session,
|
||||
open: false,
|
||||
config: config,
|
||||
plugins: make(map[string]plugin),
|
||||
serviceEvents: make(chan ServEv, 1),
|
||||
}
|
||||
data, _ := ethutil.ReadAllFile(path.Join(ethutil.Config.ExecPath, "plugins.json"))
|
||||
json.Unmarshal([]byte(data), &gui.plugins)
|
||||
|
|
|
@ -52,19 +52,18 @@ func run() error {
|
|||
config := utils.InitConfig(VmType, ConfigFile, Datadir, "ETH")
|
||||
|
||||
ethereum, err := eth.New(ð.Config{
|
||||
Name: ClientIdentifier,
|
||||
Version: Version,
|
||||
KeyStore: KeyStore,
|
||||
DataDir: Datadir,
|
||||
LogFile: LogFile,
|
||||
LogLevel: LogLevel,
|
||||
Identifier: Identifier,
|
||||
MaxPeers: MaxPeer,
|
||||
Port: OutboundPort,
|
||||
NATType: PMPGateway,
|
||||
PMPGateway: PMPGateway,
|
||||
KeyRing: KeyRing,
|
||||
Dial: true,
|
||||
Name: p2p.MakeName(ClientIdentifier, Version),
|
||||
KeyStore: KeyStore,
|
||||
DataDir: Datadir,
|
||||
LogFile: LogFile,
|
||||
LogLevel: LogLevel,
|
||||
MaxPeers: MaxPeer,
|
||||
Port: OutboundPort,
|
||||
NAT: NAT,
|
||||
BootNodes: BootNodes,
|
||||
NodeKey: NodeKey,
|
||||
KeyRing: KeyRing,
|
||||
Dial: true,
|
||||
})
|
||||
if err != nil {
|
||||
mainlogger.Fatalln(err)
|
||||
|
@ -79,12 +78,12 @@ func run() error {
|
|||
utils.StartWebSockets(ethereum, WsPort)
|
||||
}
|
||||
|
||||
gui := NewWindow(ethereum, config, ethereum.ClientIdentity().(*p2p.SimpleClientIdentity), KeyRing, LogLevel)
|
||||
gui := NewWindow(ethereum, config, KeyRing, LogLevel)
|
||||
|
||||
utils.RegisterInterrupt(func(os.Signal) {
|
||||
gui.Stop()
|
||||
})
|
||||
go utils.StartEthereum(ethereum, SeedNode)
|
||||
go utils.StartEthereum(ethereum)
|
||||
|
||||
fmt.Println("ETH stack took", time.Since(tstart))
|
||||
|
||||
|
|
|
@ -136,15 +136,15 @@ func (ui *UiLib) Muted(content string) {
|
|||
|
||||
func (ui *UiLib) Connect(button qml.Object) {
|
||||
if !ui.connected {
|
||||
ui.eth.Start(SeedNode)
|
||||
ui.eth.Start()
|
||||
ui.connected = true
|
||||
button.Set("enabled", false)
|
||||
}
|
||||
}
|
||||
|
||||
func (ui *UiLib) ConnectToPeer(addr string) {
|
||||
if err := ui.eth.SuggestPeer(addr); err != nil {
|
||||
guilogger.Infoln(err)
|
||||
func (ui *UiLib) ConnectToPeer(nodeURL string) {
|
||||
if err := ui.eth.SuggestPeer(nodeURL); err != nil {
|
||||
guilogger.Infoln("SuggestPeer error: " + err.Error())
|
||||
}
|
||||
}
|
||||
|
||||
|
|
|
@ -1,58 +0,0 @@
|
|||
/*
|
||||
This file is part of go-ethereum
|
||||
|
||||
go-ethereum is free software: you can redistribute it and/or modify
|
||||
it under the terms of the GNU General Public License as published by
|
||||
the Free Software Foundation, either version 3 of the License, or
|
||||
(at your option) any later version.
|
||||
|
||||
go-ethereum is distributed in the hope that it will be useful,
|
||||
but WITHOUT ANY WARRANTY; without even the implied warranty of
|
||||
MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
|
||||
GNU General Public License for more details.
|
||||
|
||||
You should have received a copy of the GNU General Public License
|
||||
along with go-ethereum. If not, see <http://www.gnu.org/licenses/>.
|
||||
*/
|
||||
package main
|
||||
|
||||
import (
|
||||
"crypto/elliptic"
|
||||
"flag"
|
||||
"log"
|
||||
"os"
|
||||
|
||||
"github.com/ethereum/go-ethereum/crypto"
|
||||
"github.com/ethereum/go-ethereum/logger"
|
||||
"github.com/ethereum/go-ethereum/p2p"
|
||||
)
|
||||
|
||||
var (
|
||||
natType = flag.String("nat", "", "NAT traversal implementation")
|
||||
pmpGateway = flag.String("gateway", "", "gateway address for NAT-PMP")
|
||||
listenAddr = flag.String("addr", ":30301", "listen address")
|
||||
)
|
||||
|
||||
func main() {
|
||||
flag.Parse()
|
||||
nat, err := p2p.ParseNAT(*natType, *pmpGateway)
|
||||
if err != nil {
|
||||
log.Fatal("invalid nat:", err)
|
||||
}
|
||||
|
||||
logger.AddLogSystem(logger.NewStdLogSystem(os.Stdout, log.LstdFlags, logger.InfoLevel))
|
||||
key, _ := crypto.GenerateKey()
|
||||
marshaled := elliptic.Marshal(crypto.S256(), key.PublicKey.X, key.PublicKey.Y)
|
||||
|
||||
srv := p2p.Server{
|
||||
MaxPeers: 100,
|
||||
Identity: p2p.NewSimpleClientIdentity("Ethereum(G)", "0.1", "Peer Server Two", marshaled),
|
||||
ListenAddr: *listenAddr,
|
||||
NAT: nat,
|
||||
NoDial: true,
|
||||
}
|
||||
if err := srv.Start(); err != nil {
|
||||
log.Fatal("could not start server:", err)
|
||||
}
|
||||
select {}
|
||||
}
|
|
@ -121,13 +121,11 @@ func exit(err error) {
|
|||
os.Exit(status)
|
||||
}
|
||||
|
||||
func StartEthereum(ethereum *eth.Ethereum, SeedNode string) {
|
||||
clilogger.Infof("Starting %s", ethereum.ClientIdentity())
|
||||
err := ethereum.Start(SeedNode)
|
||||
if err != nil {
|
||||
func StartEthereum(ethereum *eth.Ethereum) {
|
||||
clilogger.Infoln("Starting ", ethereum.Name())
|
||||
if err := ethereum.Start(); err != nil {
|
||||
exit(err)
|
||||
}
|
||||
|
||||
RegisterInterrupt(func(sig os.Signal) {
|
||||
ethereum.Stop()
|
||||
logger.Flush()
|
||||
|
|
|
@ -9,7 +9,6 @@ import (
|
|||
"github.com/ethereum/go-ethereum/ethdb"
|
||||
"github.com/ethereum/go-ethereum/ethutil"
|
||||
"github.com/ethereum/go-ethereum/event"
|
||||
"github.com/ethereum/go-ethereum/p2p"
|
||||
)
|
||||
|
||||
// Implement our EthTest Manager
|
||||
|
@ -54,13 +53,6 @@ func (tm *TestManager) TxPool() *TxPool {
|
|||
func (tm *TestManager) EventMux() *event.TypeMux {
|
||||
return tm.eventMux
|
||||
}
|
||||
func (tm *TestManager) Broadcast(msgType p2p.Msg, data []interface{}) {
|
||||
fmt.Println("Broadcast not implemented")
|
||||
}
|
||||
|
||||
func (tm *TestManager) ClientIdentity() p2p.ClientIdentity {
|
||||
return nil
|
||||
}
|
||||
func (tm *TestManager) KeyManager() *crypto.KeyManager {
|
||||
return nil
|
||||
}
|
||||
|
|
|
@ -16,7 +16,6 @@ type EthManager interface {
|
|||
IsListening() bool
|
||||
Peers() []*p2p.Peer
|
||||
KeyManager() *crypto.KeyManager
|
||||
ClientIdentity() p2p.ClientIdentity
|
||||
Db() ethutil.Database
|
||||
EventMux() *event.TypeMux
|
||||
}
|
||||
|
|
|
@ -8,6 +8,8 @@ import (
|
|||
"crypto/rand"
|
||||
"crypto/sha256"
|
||||
"fmt"
|
||||
"io"
|
||||
"os"
|
||||
|
||||
"encoding/hex"
|
||||
"encoding/json"
|
||||
|
@ -27,10 +29,11 @@ func init() {
|
|||
ecies.AddParamsForCurve(S256(), ecies.ECIES_AES128_SHA256)
|
||||
}
|
||||
|
||||
func Sha3(data []byte) []byte {
|
||||
func Sha3(data ...[]byte) []byte {
|
||||
d := sha3.NewKeccak256()
|
||||
d.Write(data)
|
||||
|
||||
for _, b := range data {
|
||||
d.Write(b)
|
||||
}
|
||||
return d.Sum(nil)
|
||||
}
|
||||
|
||||
|
@ -98,6 +101,32 @@ func FromECDSAPub(pub *ecdsa.PublicKey) []byte {
|
|||
return elliptic.Marshal(S256(), pub.X, pub.Y)
|
||||
}
|
||||
|
||||
// HexToECDSA parses a secp256k1 private key.
|
||||
func HexToECDSA(hexkey string) (*ecdsa.PrivateKey, error) {
|
||||
b, err := hex.DecodeString(hexkey)
|
||||
if err != nil {
|
||||
return nil, errors.New("invalid hex string")
|
||||
}
|
||||
if len(b) != 32 {
|
||||
return nil, errors.New("invalid length, need 256 bits")
|
||||
}
|
||||
return ToECDSA(b), nil
|
||||
}
|
||||
|
||||
// LoadECDSA loads a secp256k1 private key from the given file.
|
||||
func LoadECDSA(file string) (*ecdsa.PrivateKey, error) {
|
||||
buf := make([]byte, 32)
|
||||
fd, err := os.Open(file)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
defer fd.Close()
|
||||
if _, err := io.ReadFull(fd, buf); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return ToECDSA(buf), nil
|
||||
}
|
||||
|
||||
func GenerateKey() (*ecdsa.PrivateKey, error) {
|
||||
return ecdsa.GenerateKey(S256(), rand.Reader)
|
||||
}
|
||||
|
|
|
@ -18,7 +18,7 @@ import (
|
|||
func TestSha3(t *testing.T) {
|
||||
msg := []byte("abc")
|
||||
exp, _ := hex.DecodeString("4e03657aea45a94fc7d47ba826c8d667c0d1e6e33a64a036ec44f58fa12d6c45")
|
||||
checkhash(t, "Sha3-256", Sha3, msg, exp)
|
||||
checkhash(t, "Sha3-256", func(in []byte) []byte { return Sha3(in) }, msg, exp)
|
||||
}
|
||||
|
||||
func TestSha256(t *testing.T) {
|
||||
|
|
|
@ -25,11 +25,12 @@ package crypto
|
|||
|
||||
import (
|
||||
"bytes"
|
||||
"code.google.com/p/go-uuid/uuid"
|
||||
"crypto/ecdsa"
|
||||
"crypto/elliptic"
|
||||
"encoding/json"
|
||||
"io"
|
||||
|
||||
"code.google.com/p/go-uuid/uuid"
|
||||
)
|
||||
|
||||
type Key struct {
|
||||
|
|
134
eth/backend.go
134
eth/backend.go
|
@ -1,9 +1,9 @@
|
|||
package eth
|
||||
|
||||
import (
|
||||
"crypto/ecdsa"
|
||||
"fmt"
|
||||
"net"
|
||||
"sync"
|
||||
"strings"
|
||||
|
||||
"github.com/ethereum/go-ethereum/core"
|
||||
"github.com/ethereum/go-ethereum/crypto"
|
||||
|
@ -12,35 +12,58 @@ import (
|
|||
"github.com/ethereum/go-ethereum/event"
|
||||
ethlogger "github.com/ethereum/go-ethereum/logger"
|
||||
"github.com/ethereum/go-ethereum/p2p"
|
||||
"github.com/ethereum/go-ethereum/p2p/discover"
|
||||
"github.com/ethereum/go-ethereum/p2p/nat"
|
||||
"github.com/ethereum/go-ethereum/pow/ezp"
|
||||
"github.com/ethereum/go-ethereum/rpc"
|
||||
"github.com/ethereum/go-ethereum/whisper"
|
||||
)
|
||||
|
||||
var logger = ethlogger.NewLogger("SERV")
|
||||
var jsonlogger = ethlogger.NewJsonLogger()
|
||||
|
||||
type Config struct {
|
||||
Name string
|
||||
Version string
|
||||
Identifier string
|
||||
KeyStore string
|
||||
DataDir string
|
||||
LogFile string
|
||||
LogLevel int
|
||||
LogFormat string
|
||||
KeyRing string
|
||||
Name string
|
||||
KeyStore string
|
||||
DataDir string
|
||||
LogFile string
|
||||
LogLevel int
|
||||
KeyRing string
|
||||
LogFormat string
|
||||
|
||||
MaxPeers int
|
||||
Port string
|
||||
NATType string
|
||||
PMPGateway string
|
||||
MaxPeers int
|
||||
Port string
|
||||
|
||||
// This should be a space-separated list of
|
||||
// discovery node URLs.
|
||||
BootNodes string
|
||||
|
||||
// This key is used to identify the node on the network.
|
||||
// If nil, an ephemeral key is used.
|
||||
NodeKey *ecdsa.PrivateKey
|
||||
|
||||
NAT nat.Interface
|
||||
Shh bool
|
||||
Dial bool
|
||||
|
||||
KeyManager *crypto.KeyManager
|
||||
}
|
||||
|
||||
var logger = ethlogger.NewLogger("SERV")
|
||||
var jsonlogger = ethlogger.NewJsonLogger()
|
||||
func (cfg *Config) parseBootNodes() []*discover.Node {
|
||||
var ns []*discover.Node
|
||||
for _, url := range strings.Split(cfg.BootNodes, " ") {
|
||||
if url == "" {
|
||||
continue
|
||||
}
|
||||
n, err := discover.ParseNode(url)
|
||||
if err != nil {
|
||||
logger.Errorf("Bootstrap URL %s: %v\n", url, err)
|
||||
continue
|
||||
}
|
||||
ns = append(ns, n)
|
||||
}
|
||||
return ns
|
||||
}
|
||||
|
||||
type Ethereum struct {
|
||||
// Channel for shutting down the ethereum
|
||||
|
@ -68,11 +91,7 @@ type Ethereum struct {
|
|||
WsServer rpc.RpcServer
|
||||
keyManager *crypto.KeyManager
|
||||
|
||||
clientIdentity p2p.ClientIdentity
|
||||
logger ethlogger.LogSystem
|
||||
|
||||
synclock sync.Mutex
|
||||
syncGroup sync.WaitGroup
|
||||
logger ethlogger.LogSystem
|
||||
|
||||
Mining bool
|
||||
}
|
||||
|
@ -105,21 +124,17 @@ func New(config *Config) (*Ethereum, error) {
|
|||
// Initialise the keyring
|
||||
keyManager.Init(config.KeyRing, 0, false)
|
||||
|
||||
// Create a new client id for this instance. This will help identifying the node on the network
|
||||
clientId := p2p.NewSimpleClientIdentity(config.Name, config.Version, config.Identifier, keyManager.PublicKey())
|
||||
|
||||
saveProtocolVersion(db)
|
||||
//ethutil.Config.Db = db
|
||||
|
||||
eth := &Ethereum{
|
||||
shutdownChan: make(chan bool),
|
||||
quit: make(chan bool),
|
||||
db: db,
|
||||
keyManager: keyManager,
|
||||
clientIdentity: clientId,
|
||||
blacklist: p2p.NewBlacklist(),
|
||||
eventMux: &event.TypeMux{},
|
||||
logger: logger,
|
||||
shutdownChan: make(chan bool),
|
||||
quit: make(chan bool),
|
||||
db: db,
|
||||
keyManager: keyManager,
|
||||
blacklist: p2p.NewBlacklist(),
|
||||
eventMux: &event.TypeMux{},
|
||||
logger: logger,
|
||||
}
|
||||
|
||||
eth.chainManager = core.NewChainManager(db, eth.EventMux())
|
||||
|
@ -134,21 +149,22 @@ func New(config *Config) (*Ethereum, error) {
|
|||
|
||||
ethProto := EthProtocol(eth.txPool, eth.chainManager, eth.blockPool)
|
||||
protocols := []p2p.Protocol{ethProto, eth.whisper.Protocol()}
|
||||
|
||||
nat, err := p2p.ParseNAT(config.NATType, config.PMPGateway)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
netprv := config.NodeKey
|
||||
if netprv == nil {
|
||||
if netprv, err = crypto.GenerateKey(); err != nil {
|
||||
return nil, fmt.Errorf("could not generate server key: %v", err)
|
||||
}
|
||||
}
|
||||
|
||||
eth.net = &p2p.Server{
|
||||
Identity: clientId,
|
||||
MaxPeers: config.MaxPeers,
|
||||
Protocols: protocols,
|
||||
Blacklist: eth.blacklist,
|
||||
NAT: nat,
|
||||
NoDial: !config.Dial,
|
||||
PrivateKey: netprv,
|
||||
Name: config.Name,
|
||||
MaxPeers: config.MaxPeers,
|
||||
Protocols: protocols,
|
||||
Blacklist: eth.blacklist,
|
||||
NAT: config.NAT,
|
||||
NoDial: !config.Dial,
|
||||
BootstrapNodes: config.parseBootNodes(),
|
||||
}
|
||||
|
||||
if len(config.Port) > 0 {
|
||||
eth.net.ListenAddr = ":" + config.Port
|
||||
}
|
||||
|
@ -164,8 +180,8 @@ func (s *Ethereum) Logger() ethlogger.LogSystem {
|
|||
return s.logger
|
||||
}
|
||||
|
||||
func (s *Ethereum) ClientIdentity() p2p.ClientIdentity {
|
||||
return s.clientIdentity
|
||||
func (s *Ethereum) Name() string {
|
||||
return s.net.Name
|
||||
}
|
||||
|
||||
func (s *Ethereum) ChainManager() *core.ChainManager {
|
||||
|
@ -221,12 +237,12 @@ func (s *Ethereum) Coinbase() []byte {
|
|||
}
|
||||
|
||||
// Start the ethereum
|
||||
func (s *Ethereum) Start(seedNode string) error {
|
||||
func (s *Ethereum) Start() error {
|
||||
jsonlogger.LogJson(ðlogger.LogStarting{
|
||||
ClientString: s.ClientIdentity().String(),
|
||||
ClientString: s.net.Name,
|
||||
Coinbase: ethutil.Bytes2Hex(s.KeyManager().Address()),
|
||||
ProtocolVersion: ProtocolVersion,
|
||||
LogEvent: ethlogger.LogEvent{Guid: ethutil.Bytes2Hex(s.ClientIdentity().Pubkey())},
|
||||
LogEvent: ethlogger.LogEvent{Guid: ethutil.Bytes2Hex(crypto.FromECDSAPub(&s.net.PrivateKey.PublicKey))},
|
||||
})
|
||||
|
||||
err := s.net.Start()
|
||||
|
@ -250,26 +266,16 @@ func (s *Ethereum) Start(seedNode string) error {
|
|||
s.blockSub = s.eventMux.Subscribe(core.NewMinedBlockEvent{})
|
||||
go s.blockBroadcastLoop()
|
||||
|
||||
// TODO: read peers here
|
||||
if len(seedNode) > 0 {
|
||||
logger.Infof("Connect to seed node %v", seedNode)
|
||||
if err := s.SuggestPeer(seedNode); err != nil {
|
||||
logger.Infoln(err)
|
||||
}
|
||||
}
|
||||
|
||||
logger.Infoln("Server started")
|
||||
return nil
|
||||
}
|
||||
|
||||
func (self *Ethereum) SuggestPeer(addr string) error {
|
||||
netaddr, err := net.ResolveTCPAddr("tcp", addr)
|
||||
func (self *Ethereum) SuggestPeer(nodeURL string) error {
|
||||
n, err := discover.ParseNode(nodeURL)
|
||||
if err != nil {
|
||||
logger.Errorf("couldn't resolve %s:", addr, err)
|
||||
return err
|
||||
return fmt.Errorf("invalid node URL: %v", err)
|
||||
}
|
||||
|
||||
self.net.SuggestPeer(netaddr.IP, netaddr.Port, nil)
|
||||
self.net.SuggestPeer(n)
|
||||
return nil
|
||||
}
|
||||
|
||||
|
|
|
@ -92,13 +92,14 @@ func EthProtocol(txPool txPool, chainManager chainManager, blockPool blockPool)
|
|||
// the main loop that handles incoming messages
|
||||
// note RemovePeer in the post-disconnect hook
|
||||
func runEthProtocol(txPool txPool, chainManager chainManager, blockPool blockPool, peer *p2p.Peer, rw p2p.MsgReadWriter) (err error) {
|
||||
id := peer.ID()
|
||||
self := ðProtocol{
|
||||
txPool: txPool,
|
||||
chainManager: chainManager,
|
||||
blockPool: blockPool,
|
||||
rw: rw,
|
||||
peer: peer,
|
||||
id: fmt.Sprintf("%x", peer.Identity().Pubkey()[:8]),
|
||||
id: fmt.Sprintf("%x", id[:8]),
|
||||
}
|
||||
err = self.handleStatus()
|
||||
if err == nil {
|
||||
|
|
|
@ -14,6 +14,7 @@ import (
|
|||
"github.com/ethereum/go-ethereum/ethutil"
|
||||
ethlogger "github.com/ethereum/go-ethereum/logger"
|
||||
"github.com/ethereum/go-ethereum/p2p"
|
||||
"github.com/ethereum/go-ethereum/p2p/discover"
|
||||
)
|
||||
|
||||
var sys = ethlogger.NewStdLogSystem(os.Stdout, log.LstdFlags, ethlogger.LogLevel(ethlogger.DebugDetailLevel))
|
||||
|
@ -128,26 +129,11 @@ func (self *testBlockPool) RemovePeer(peerId string) {
|
|||
}
|
||||
}
|
||||
|
||||
// TODO: refactor this into p2p/client_identity
|
||||
type peerId struct {
|
||||
pubkey []byte
|
||||
}
|
||||
|
||||
func (self *peerId) String() string {
|
||||
return "test peer"
|
||||
}
|
||||
|
||||
func (self *peerId) Pubkey() (pubkey []byte) {
|
||||
pubkey = self.pubkey
|
||||
if len(pubkey) == 0 {
|
||||
pubkey = crypto.GenerateNewKeyPair().PublicKey
|
||||
self.pubkey = pubkey
|
||||
}
|
||||
return
|
||||
}
|
||||
|
||||
func testPeer() *p2p.Peer {
|
||||
return p2p.NewPeer(&peerId{}, []p2p.Cap{})
|
||||
var id discover.NodeID
|
||||
pk := crypto.GenerateNewKeyPair().PublicKey
|
||||
copy(id[:], pk)
|
||||
return p2p.NewPeer(id, "test peer", []p2p.Cap{})
|
||||
}
|
||||
|
||||
type ethProtocolTester struct {
|
||||
|
|
|
@ -197,12 +197,13 @@ func (self *JSRE) watch(call otto.FunctionCall) otto.Value {
|
|||
}
|
||||
|
||||
func (self *JSRE) addPeer(call otto.FunctionCall) otto.Value {
|
||||
host, err := call.Argument(0).ToString()
|
||||
nodeURL, err := call.Argument(0).ToString()
|
||||
if err != nil {
|
||||
return otto.FalseValue()
|
||||
}
|
||||
self.ethereum.SuggestPeer(host)
|
||||
|
||||
if err := self.ethereum.SuggestPeer(nodeURL); err != nil {
|
||||
return otto.FalseValue()
|
||||
}
|
||||
return otto.TrueValue()
|
||||
}
|
||||
|
||||
|
|
|
@ -1,63 +0,0 @@
|
|||
package p2p
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"runtime"
|
||||
)
|
||||
|
||||
// ClientIdentity represents the identity of a peer.
|
||||
type ClientIdentity interface {
|
||||
String() string // human readable identity
|
||||
Pubkey() []byte // 512-bit public key
|
||||
}
|
||||
|
||||
type SimpleClientIdentity struct {
|
||||
clientIdentifier string
|
||||
version string
|
||||
customIdentifier string
|
||||
os string
|
||||
implementation string
|
||||
pubkey []byte
|
||||
}
|
||||
|
||||
func NewSimpleClientIdentity(clientIdentifier string, version string, customIdentifier string, pubkey []byte) *SimpleClientIdentity {
|
||||
clientIdentity := &SimpleClientIdentity{
|
||||
clientIdentifier: clientIdentifier,
|
||||
version: version,
|
||||
customIdentifier: customIdentifier,
|
||||
os: runtime.GOOS,
|
||||
implementation: runtime.Version(),
|
||||
pubkey: pubkey,
|
||||
}
|
||||
|
||||
return clientIdentity
|
||||
}
|
||||
|
||||
func (c *SimpleClientIdentity) init() {
|
||||
}
|
||||
|
||||
func (c *SimpleClientIdentity) String() string {
|
||||
var id string
|
||||
if len(c.customIdentifier) > 0 {
|
||||
id = "/" + c.customIdentifier
|
||||
}
|
||||
|
||||
return fmt.Sprintf("%s/v%s%s/%s/%s",
|
||||
c.clientIdentifier,
|
||||
c.version,
|
||||
id,
|
||||
c.os,
|
||||
c.implementation)
|
||||
}
|
||||
|
||||
func (c *SimpleClientIdentity) Pubkey() []byte {
|
||||
return []byte(c.pubkey)
|
||||
}
|
||||
|
||||
func (c *SimpleClientIdentity) SetCustomIdentifier(customIdentifier string) {
|
||||
c.customIdentifier = customIdentifier
|
||||
}
|
||||
|
||||
func (c *SimpleClientIdentity) GetCustomIdentifier() string {
|
||||
return c.customIdentifier
|
||||
}
|
|
@ -1,30 +0,0 @@
|
|||
package p2p
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"runtime"
|
||||
"testing"
|
||||
)
|
||||
|
||||
func TestClientIdentity(t *testing.T) {
|
||||
clientIdentity := NewSimpleClientIdentity("Ethereum(G)", "0.5.16", "test", []byte("pubkey"))
|
||||
clientString := clientIdentity.String()
|
||||
expected := fmt.Sprintf("Ethereum(G)/v0.5.16/test/%s/%s", runtime.GOOS, runtime.Version())
|
||||
if clientString != expected {
|
||||
t.Errorf("Expected clientIdentity to be %v, got %v", expected, clientString)
|
||||
}
|
||||
customIdentifier := clientIdentity.GetCustomIdentifier()
|
||||
if customIdentifier != "test" {
|
||||
t.Errorf("Expected clientIdentity.GetCustomIdentifier() to be 'test', got %v", customIdentifier)
|
||||
}
|
||||
clientIdentity.SetCustomIdentifier("test2")
|
||||
customIdentifier = clientIdentity.GetCustomIdentifier()
|
||||
if customIdentifier != "test2" {
|
||||
t.Errorf("Expected clientIdentity.GetCustomIdentifier() to be 'test2', got %v", customIdentifier)
|
||||
}
|
||||
clientString = clientIdentity.String()
|
||||
expected = fmt.Sprintf("Ethereum(G)/v0.5.16/test2/%s/%s", runtime.GOOS, runtime.Version())
|
||||
if clientString != expected {
|
||||
t.Errorf("Expected clientIdentity to be %v, got %v", expected, clientString)
|
||||
}
|
||||
}
|
|
@ -0,0 +1,363 @@
|
|||
package p2p
|
||||
|
||||
import (
|
||||
// "binary"
|
||||
"crypto/ecdsa"
|
||||
"crypto/rand"
|
||||
"fmt"
|
||||
"io"
|
||||
|
||||
"github.com/ethereum/go-ethereum/crypto"
|
||||
"github.com/ethereum/go-ethereum/crypto/secp256k1"
|
||||
ethlogger "github.com/ethereum/go-ethereum/logger"
|
||||
"github.com/ethereum/go-ethereum/p2p/discover"
|
||||
"github.com/obscuren/ecies"
|
||||
)
|
||||
|
||||
var clogger = ethlogger.NewLogger("CRYPTOID")
|
||||
|
||||
const (
|
||||
sskLen = 16 // ecies.MaxSharedKeyLength(pubKey) / 2
|
||||
sigLen = 65 // elliptic S256
|
||||
pubLen = 64 // 512 bit pubkey in uncompressed representation without format byte
|
||||
shaLen = 32 // hash length (for nonce etc)
|
||||
|
||||
authMsgLen = sigLen + shaLen + pubLen + shaLen + 1
|
||||
authRespLen = pubLen + shaLen + 1
|
||||
|
||||
eciesBytes = 65 + 16 + 32
|
||||
iHSLen = authMsgLen + eciesBytes // size of the final ECIES payload sent as initiator's handshake
|
||||
rHSLen = authRespLen + eciesBytes // size of the final ECIES payload sent as receiver's handshake
|
||||
)
|
||||
|
||||
type hexkey []byte
|
||||
|
||||
func (self hexkey) String() string {
|
||||
return fmt.Sprintf("(%d) %x", len(self), []byte(self))
|
||||
}
|
||||
|
||||
func encHandshake(conn io.ReadWriter, prv *ecdsa.PrivateKey, dial *discover.Node) (
|
||||
remoteID discover.NodeID,
|
||||
sessionToken []byte,
|
||||
err error,
|
||||
) {
|
||||
if dial == nil {
|
||||
var remotePubkey []byte
|
||||
sessionToken, remotePubkey, err = inboundEncHandshake(conn, prv, nil)
|
||||
copy(remoteID[:], remotePubkey)
|
||||
} else {
|
||||
remoteID = dial.ID
|
||||
sessionToken, err = outboundEncHandshake(conn, prv, remoteID[:], nil)
|
||||
}
|
||||
return remoteID, sessionToken, err
|
||||
}
|
||||
|
||||
// outboundEncHandshake negotiates a session token on conn.
|
||||
// it should be called on the dialing side of the connection.
|
||||
//
|
||||
// privateKey is the local client's private key
|
||||
// remotePublicKey is the remote peer's node ID
|
||||
// sessionToken is the token from a previous session with this node.
|
||||
func outboundEncHandshake(conn io.ReadWriter, prvKey *ecdsa.PrivateKey, remotePublicKey []byte, sessionToken []byte) (
|
||||
newSessionToken []byte,
|
||||
err error,
|
||||
) {
|
||||
auth, initNonce, randomPrivKey, err := authMsg(prvKey, remotePublicKey, sessionToken)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
if sessionToken != nil {
|
||||
clogger.Debugf("session-token: %v", hexkey(sessionToken))
|
||||
}
|
||||
|
||||
clogger.Debugf("initiator-nonce: %v", hexkey(initNonce))
|
||||
clogger.Debugf("initiator-random-private-key: %v", hexkey(crypto.FromECDSA(randomPrivKey)))
|
||||
randomPublicKeyS, _ := exportPublicKey(&randomPrivKey.PublicKey)
|
||||
clogger.Debugf("initiator-random-public-key: %v", hexkey(randomPublicKeyS))
|
||||
if _, err = conn.Write(auth); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
clogger.Debugf("initiator handshake: %v", hexkey(auth))
|
||||
|
||||
response := make([]byte, rHSLen)
|
||||
if _, err = io.ReadFull(conn, response); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
recNonce, remoteRandomPubKey, _, err := completeHandshake(response, prvKey)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
clogger.Debugf("receiver-nonce: %v", hexkey(recNonce))
|
||||
remoteRandomPubKeyS, _ := exportPublicKey(remoteRandomPubKey)
|
||||
clogger.Debugf("receiver-random-public-key: %v", hexkey(remoteRandomPubKeyS))
|
||||
return newSession(initNonce, recNonce, randomPrivKey, remoteRandomPubKey)
|
||||
}
|
||||
|
||||
// authMsg creates the initiator handshake.
|
||||
func authMsg(prvKey *ecdsa.PrivateKey, remotePubKeyS, sessionToken []byte) (
|
||||
auth, initNonce []byte,
|
||||
randomPrvKey *ecdsa.PrivateKey,
|
||||
err error,
|
||||
) {
|
||||
// session init, common to both parties
|
||||
remotePubKey, err := importPublicKey(remotePubKeyS)
|
||||
if err != nil {
|
||||
return
|
||||
}
|
||||
|
||||
var tokenFlag byte // = 0x00
|
||||
if sessionToken == nil {
|
||||
// no session token found means we need to generate shared secret.
|
||||
// ecies shared secret is used as initial session token for new peers
|
||||
// generate shared key from prv and remote pubkey
|
||||
if sessionToken, err = ecies.ImportECDSA(prvKey).GenerateShared(ecies.ImportECDSAPublic(remotePubKey), sskLen, sskLen); err != nil {
|
||||
return
|
||||
}
|
||||
// tokenFlag = 0x00 // redundant
|
||||
} else {
|
||||
// for known peers, we use stored token from the previous session
|
||||
tokenFlag = 0x01
|
||||
}
|
||||
|
||||
//E(remote-pubk, S(ecdhe-random, ecdh-shared-secret^nonce) || H(ecdhe-random-pubk) || pubk || nonce || 0x0)
|
||||
// E(remote-pubk, S(ecdhe-random, token^nonce) || H(ecdhe-random-pubk) || pubk || nonce || 0x1)
|
||||
// allocate msgLen long message,
|
||||
var msg []byte = make([]byte, authMsgLen)
|
||||
initNonce = msg[authMsgLen-shaLen-1 : authMsgLen-1]
|
||||
if _, err = rand.Read(initNonce); err != nil {
|
||||
return
|
||||
}
|
||||
// create known message
|
||||
// ecdh-shared-secret^nonce for new peers
|
||||
// token^nonce for old peers
|
||||
var sharedSecret = xor(sessionToken, initNonce)
|
||||
|
||||
// generate random keypair to use for signing
|
||||
if randomPrvKey, err = crypto.GenerateKey(); err != nil {
|
||||
return
|
||||
}
|
||||
// sign shared secret (message known to both parties): shared-secret
|
||||
var signature []byte
|
||||
// signature = sign(ecdhe-random, shared-secret)
|
||||
// uses secp256k1.Sign
|
||||
if signature, err = crypto.Sign(sharedSecret, randomPrvKey); err != nil {
|
||||
return
|
||||
}
|
||||
|
||||
// message
|
||||
// signed-shared-secret || H(ecdhe-random-pubk) || pubk || nonce || 0x0
|
||||
copy(msg, signature) // copy signed-shared-secret
|
||||
// H(ecdhe-random-pubk)
|
||||
var randomPubKey64 []byte
|
||||
if randomPubKey64, err = exportPublicKey(&randomPrvKey.PublicKey); err != nil {
|
||||
return
|
||||
}
|
||||
var pubKey64 []byte
|
||||
if pubKey64, err = exportPublicKey(&prvKey.PublicKey); err != nil {
|
||||
return
|
||||
}
|
||||
copy(msg[sigLen:sigLen+shaLen], crypto.Sha3(randomPubKey64))
|
||||
// pubkey copied to the correct segment.
|
||||
copy(msg[sigLen+shaLen:sigLen+shaLen+pubLen], pubKey64)
|
||||
// nonce is already in the slice
|
||||
// stick tokenFlag byte to the end
|
||||
msg[authMsgLen-1] = tokenFlag
|
||||
|
||||
// encrypt using remote-pubk
|
||||
// auth = eciesEncrypt(remote-pubk, msg)
|
||||
if auth, err = crypto.Encrypt(remotePubKey, msg); err != nil {
|
||||
return
|
||||
}
|
||||
return
|
||||
}
|
||||
|
||||
// completeHandshake is called when the initiator receives an
|
||||
// authentication response (aka receiver handshake). It completes the
|
||||
// handshake by reading off parameters the remote peer provides needed
|
||||
// to set up the secure session.
|
||||
func completeHandshake(auth []byte, prvKey *ecdsa.PrivateKey) (
|
||||
respNonce []byte,
|
||||
remoteRandomPubKey *ecdsa.PublicKey,
|
||||
tokenFlag bool,
|
||||
err error,
|
||||
) {
|
||||
var msg []byte
|
||||
// they prove that msg is meant for me,
|
||||
// I prove I possess private key if i can read it
|
||||
if msg, err = crypto.Decrypt(prvKey, auth); err != nil {
|
||||
return
|
||||
}
|
||||
|
||||
respNonce = msg[pubLen : pubLen+shaLen]
|
||||
var remoteRandomPubKeyS = msg[:pubLen]
|
||||
if remoteRandomPubKey, err = importPublicKey(remoteRandomPubKeyS); err != nil {
|
||||
return
|
||||
}
|
||||
if msg[authRespLen-1] == 0x01 {
|
||||
tokenFlag = true
|
||||
}
|
||||
return
|
||||
}
|
||||
|
||||
// inboundEncHandshake negotiates a session token on conn.
|
||||
// it should be called on the listening side of the connection.
|
||||
//
|
||||
// privateKey is the local client's private key
|
||||
// sessionToken is the token from a previous session with this node.
|
||||
func inboundEncHandshake(conn io.ReadWriter, prvKey *ecdsa.PrivateKey, sessionToken []byte) (
|
||||
token, remotePubKey []byte,
|
||||
err error,
|
||||
) {
|
||||
// we are listening connection. we are responders in the
|
||||
// handshake. Extract info from the authentication. The initiator
|
||||
// starts by sending us a handshake that we need to respond to. so
|
||||
// we read auth message first, then respond.
|
||||
auth := make([]byte, iHSLen)
|
||||
if _, err := io.ReadFull(conn, auth); err != nil {
|
||||
return nil, nil, err
|
||||
}
|
||||
response, recNonce, initNonce, remotePubKey, randomPrivKey, remoteRandomPubKey, err := authResp(auth, sessionToken, prvKey)
|
||||
if err != nil {
|
||||
return nil, nil, err
|
||||
}
|
||||
clogger.Debugf("receiver-nonce: %v", hexkey(recNonce))
|
||||
clogger.Debugf("receiver-random-priv-key: %v", hexkey(crypto.FromECDSA(randomPrivKey)))
|
||||
if _, err = conn.Write(response); err != nil {
|
||||
return nil, nil, err
|
||||
}
|
||||
clogger.Debugf("receiver handshake:\n%v", hexkey(response))
|
||||
token, err = newSession(initNonce, recNonce, randomPrivKey, remoteRandomPubKey)
|
||||
return token, remotePubKey, err
|
||||
}
|
||||
|
||||
// authResp is called by peer if it accepted (but not
|
||||
// initiated) the connection from the remote. It is passed the initiator
|
||||
// handshake received and the session token belonging to the
|
||||
// remote initiator.
|
||||
//
|
||||
// The first return value is the authentication response (aka receiver
|
||||
// handshake) that is to be sent to the remote initiator.
|
||||
func authResp(auth, sessionToken []byte, prvKey *ecdsa.PrivateKey) (
|
||||
authResp, respNonce, initNonce, remotePubKeyS []byte,
|
||||
randomPrivKey *ecdsa.PrivateKey,
|
||||
remoteRandomPubKey *ecdsa.PublicKey,
|
||||
err error,
|
||||
) {
|
||||
// they prove that msg is meant for me,
|
||||
// I prove I possess private key if i can read it
|
||||
msg, err := crypto.Decrypt(prvKey, auth)
|
||||
if err != nil {
|
||||
return
|
||||
}
|
||||
|
||||
remotePubKeyS = msg[sigLen+shaLen : sigLen+shaLen+pubLen]
|
||||
remotePubKey, _ := importPublicKey(remotePubKeyS)
|
||||
|
||||
var tokenFlag byte
|
||||
if sessionToken == nil {
|
||||
// no session token found means we need to generate shared secret.
|
||||
// ecies shared secret is used as initial session token for new peers
|
||||
// generate shared key from prv and remote pubkey
|
||||
if sessionToken, err = ecies.ImportECDSA(prvKey).GenerateShared(ecies.ImportECDSAPublic(remotePubKey), sskLen, sskLen); err != nil {
|
||||
return
|
||||
}
|
||||
// tokenFlag = 0x00 // redundant
|
||||
} else {
|
||||
// for known peers, we use stored token from the previous session
|
||||
tokenFlag = 0x01
|
||||
}
|
||||
|
||||
// the initiator nonce is read off the end of the message
|
||||
initNonce = msg[authMsgLen-shaLen-1 : authMsgLen-1]
|
||||
// I prove that i own prv key (to derive shared secret, and read
|
||||
// nonce off encrypted msg) and that I own shared secret they
|
||||
// prove they own the private key belonging to ecdhe-random-pubk
|
||||
// we can now reconstruct the signed message and recover the peers
|
||||
// pubkey
|
||||
var signedMsg = xor(sessionToken, initNonce)
|
||||
var remoteRandomPubKeyS []byte
|
||||
if remoteRandomPubKeyS, err = secp256k1.RecoverPubkey(signedMsg, msg[:sigLen]); err != nil {
|
||||
return
|
||||
}
|
||||
// convert to ECDSA standard
|
||||
if remoteRandomPubKey, err = importPublicKey(remoteRandomPubKeyS); err != nil {
|
||||
return
|
||||
}
|
||||
|
||||
// now we find ourselves a long task too, fill it random
|
||||
var resp = make([]byte, authRespLen)
|
||||
// generate shaLen long nonce
|
||||
respNonce = resp[pubLen : pubLen+shaLen]
|
||||
if _, err = rand.Read(respNonce); err != nil {
|
||||
return
|
||||
}
|
||||
// generate random keypair for session
|
||||
if randomPrivKey, err = crypto.GenerateKey(); err != nil {
|
||||
return
|
||||
}
|
||||
// responder auth message
|
||||
// E(remote-pubk, ecdhe-random-pubk || nonce || 0x0)
|
||||
var randomPubKeyS []byte
|
||||
if randomPubKeyS, err = exportPublicKey(&randomPrivKey.PublicKey); err != nil {
|
||||
return
|
||||
}
|
||||
copy(resp[:pubLen], randomPubKeyS)
|
||||
// nonce is already in the slice
|
||||
resp[authRespLen-1] = tokenFlag
|
||||
|
||||
// encrypt using remote-pubk
|
||||
// auth = eciesEncrypt(remote-pubk, msg)
|
||||
// why not encrypt with ecdhe-random-remote
|
||||
if authResp, err = crypto.Encrypt(remotePubKey, resp); err != nil {
|
||||
return
|
||||
}
|
||||
return
|
||||
}
|
||||
|
||||
// newSession is called after the handshake is completed. The
|
||||
// arguments are values negotiated in the handshake. The return value
|
||||
// is a new session Token to be remembered for the next time we
|
||||
// connect with this peer.
|
||||
func newSession(initNonce, respNonce []byte, privKey *ecdsa.PrivateKey, remoteRandomPubKey *ecdsa.PublicKey) ([]byte, error) {
|
||||
// 3) Now we can trust ecdhe-random-pubk to derive new keys
|
||||
//ecdhe-shared-secret = ecdh.agree(ecdhe-random, remote-ecdhe-random-pubk)
|
||||
pubKey := ecies.ImportECDSAPublic(remoteRandomPubKey)
|
||||
dhSharedSecret, err := ecies.ImportECDSA(privKey).GenerateShared(pubKey, sskLen, sskLen)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
sharedSecret := crypto.Sha3(dhSharedSecret, crypto.Sha3(respNonce, initNonce))
|
||||
sessionToken := crypto.Sha3(sharedSecret)
|
||||
return sessionToken, nil
|
||||
}
|
||||
|
||||
// importPublicKey unmarshals 512 bit public keys.
|
||||
func importPublicKey(pubKey []byte) (pubKeyEC *ecdsa.PublicKey, err error) {
|
||||
var pubKey65 []byte
|
||||
switch len(pubKey) {
|
||||
case 64:
|
||||
// add 'uncompressed key' flag
|
||||
pubKey65 = append([]byte{0x04}, pubKey...)
|
||||
case 65:
|
||||
pubKey65 = pubKey
|
||||
default:
|
||||
return nil, fmt.Errorf("invalid public key length %v (expect 64/65)", len(pubKey))
|
||||
}
|
||||
return crypto.ToECDSAPub(pubKey65), nil
|
||||
}
|
||||
|
||||
func exportPublicKey(pubKeyEC *ecdsa.PublicKey) (pubKey []byte, err error) {
|
||||
if pubKeyEC == nil {
|
||||
return nil, fmt.Errorf("no ECDSA public key given")
|
||||
}
|
||||
return crypto.FromECDSAPub(pubKeyEC)[1:], nil
|
||||
}
|
||||
|
||||
func xor(one, other []byte) (xor []byte) {
|
||||
xor = make([]byte, len(one))
|
||||
for i := 0; i < len(one); i++ {
|
||||
xor[i] = one[i] ^ other[i]
|
||||
}
|
||||
return xor
|
||||
}
|
|
@ -0,0 +1,167 @@
|
|||
package p2p
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"crypto/ecdsa"
|
||||
"crypto/rand"
|
||||
"net"
|
||||
"testing"
|
||||
|
||||
"github.com/ethereum/go-ethereum/crypto"
|
||||
"github.com/obscuren/ecies"
|
||||
)
|
||||
|
||||
func TestPublicKeyEncoding(t *testing.T) {
|
||||
prv0, _ := crypto.GenerateKey() // = ecdsa.GenerateKey(crypto.S256(), rand.Reader)
|
||||
pub0 := &prv0.PublicKey
|
||||
pub0s := crypto.FromECDSAPub(pub0)
|
||||
pub1, err := importPublicKey(pub0s)
|
||||
if err != nil {
|
||||
t.Errorf("%v", err)
|
||||
}
|
||||
eciesPub1 := ecies.ImportECDSAPublic(pub1)
|
||||
if eciesPub1 == nil {
|
||||
t.Errorf("invalid ecdsa public key")
|
||||
}
|
||||
pub1s, err := exportPublicKey(pub1)
|
||||
if err != nil {
|
||||
t.Errorf("%v", err)
|
||||
}
|
||||
if len(pub1s) != 64 {
|
||||
t.Errorf("wrong length expect 64, got", len(pub1s))
|
||||
}
|
||||
pub2, err := importPublicKey(pub1s)
|
||||
if err != nil {
|
||||
t.Errorf("%v", err)
|
||||
}
|
||||
pub2s, err := exportPublicKey(pub2)
|
||||
if err != nil {
|
||||
t.Errorf("%v", err)
|
||||
}
|
||||
if !bytes.Equal(pub1s, pub2s) {
|
||||
t.Errorf("exports dont match")
|
||||
}
|
||||
pub2sEC := crypto.FromECDSAPub(pub2)
|
||||
if !bytes.Equal(pub0s, pub2sEC) {
|
||||
t.Errorf("exports dont match")
|
||||
}
|
||||
}
|
||||
|
||||
func TestSharedSecret(t *testing.T) {
|
||||
prv0, _ := crypto.GenerateKey() // = ecdsa.GenerateKey(crypto.S256(), rand.Reader)
|
||||
pub0 := &prv0.PublicKey
|
||||
prv1, _ := crypto.GenerateKey()
|
||||
pub1 := &prv1.PublicKey
|
||||
|
||||
ss0, err := ecies.ImportECDSA(prv0).GenerateShared(ecies.ImportECDSAPublic(pub1), sskLen, sskLen)
|
||||
if err != nil {
|
||||
return
|
||||
}
|
||||
ss1, err := ecies.ImportECDSA(prv1).GenerateShared(ecies.ImportECDSAPublic(pub0), sskLen, sskLen)
|
||||
if err != nil {
|
||||
return
|
||||
}
|
||||
t.Logf("Secret:\n%v %x\n%v %x", len(ss0), ss0, len(ss0), ss1)
|
||||
if !bytes.Equal(ss0, ss1) {
|
||||
t.Errorf("dont match :(")
|
||||
}
|
||||
}
|
||||
|
||||
func TestCryptoHandshake(t *testing.T) {
|
||||
testCryptoHandshake(newkey(), newkey(), nil, t)
|
||||
}
|
||||
|
||||
func TestCryptoHandshakeWithToken(t *testing.T) {
|
||||
sessionToken := make([]byte, shaLen)
|
||||
rand.Read(sessionToken)
|
||||
testCryptoHandshake(newkey(), newkey(), sessionToken, t)
|
||||
}
|
||||
|
||||
func testCryptoHandshake(prv0, prv1 *ecdsa.PrivateKey, sessionToken []byte, t *testing.T) {
|
||||
var err error
|
||||
// pub0 := &prv0.PublicKey
|
||||
pub1 := &prv1.PublicKey
|
||||
|
||||
// pub0s := crypto.FromECDSAPub(pub0)
|
||||
pub1s := crypto.FromECDSAPub(pub1)
|
||||
|
||||
// simulate handshake by feeding output to input
|
||||
// initiator sends handshake 'auth'
|
||||
auth, initNonce, randomPrivKey, err := authMsg(prv0, pub1s, sessionToken)
|
||||
if err != nil {
|
||||
t.Errorf("%v", err)
|
||||
}
|
||||
t.Logf("-> %v", hexkey(auth))
|
||||
|
||||
// receiver reads auth and responds with response
|
||||
response, remoteRecNonce, remoteInitNonce, _, remoteRandomPrivKey, remoteInitRandomPubKey, err := authResp(auth, sessionToken, prv1)
|
||||
if err != nil {
|
||||
t.Errorf("%v", err)
|
||||
}
|
||||
t.Logf("<- %v\n", hexkey(response))
|
||||
|
||||
// initiator reads receiver's response and the key exchange completes
|
||||
recNonce, remoteRandomPubKey, _, err := completeHandshake(response, prv0)
|
||||
if err != nil {
|
||||
t.Errorf("completeHandshake error: %v", err)
|
||||
}
|
||||
|
||||
// now both parties should have the same session parameters
|
||||
initSessionToken, err := newSession(initNonce, recNonce, randomPrivKey, remoteRandomPubKey)
|
||||
if err != nil {
|
||||
t.Errorf("newSession error: %v", err)
|
||||
}
|
||||
|
||||
recSessionToken, err := newSession(remoteInitNonce, remoteRecNonce, remoteRandomPrivKey, remoteInitRandomPubKey)
|
||||
if err != nil {
|
||||
t.Errorf("newSession error: %v", err)
|
||||
}
|
||||
|
||||
// fmt.Printf("\nauth (%v) %x\n\nresp (%v) %x\n\n", len(auth), auth, len(response), response)
|
||||
|
||||
// fmt.Printf("\nauth %x\ninitNonce %x\nresponse%x\nremoteRecNonce %x\nremoteInitNonce %x\nremoteRandomPubKey %x\nrecNonce %x\nremoteInitRandomPubKey %x\ninitSessionToken %x\n\n", auth, initNonce, response, remoteRecNonce, remoteInitNonce, remoteRandomPubKey, recNonce, remoteInitRandomPubKey, initSessionToken)
|
||||
|
||||
if !bytes.Equal(initNonce, remoteInitNonce) {
|
||||
t.Errorf("nonces do not match")
|
||||
}
|
||||
if !bytes.Equal(recNonce, remoteRecNonce) {
|
||||
t.Errorf("receiver nonces do not match")
|
||||
}
|
||||
if !bytes.Equal(initSessionToken, recSessionToken) {
|
||||
t.Errorf("session tokens do not match")
|
||||
}
|
||||
}
|
||||
|
||||
func TestHandshake(t *testing.T) {
|
||||
defer testlog(t).detach()
|
||||
|
||||
prv0, _ := crypto.GenerateKey()
|
||||
prv1, _ := crypto.GenerateKey()
|
||||
pub0s, _ := exportPublicKey(&prv0.PublicKey)
|
||||
pub1s, _ := exportPublicKey(&prv1.PublicKey)
|
||||
rw0, rw1 := net.Pipe()
|
||||
tokens := make(chan []byte)
|
||||
|
||||
go func() {
|
||||
token, err := outboundEncHandshake(rw0, prv0, pub1s, nil)
|
||||
if err != nil {
|
||||
t.Errorf("outbound side error: %v", err)
|
||||
}
|
||||
tokens <- token
|
||||
}()
|
||||
go func() {
|
||||
token, remotePubkey, err := inboundEncHandshake(rw1, prv1, nil)
|
||||
if err != nil {
|
||||
t.Errorf("inbound side error: %v", err)
|
||||
}
|
||||
if !bytes.Equal(remotePubkey, pub0s) {
|
||||
t.Errorf("inbound side returned wrong remote pubkey\n got: %x\n want: %x", remotePubkey, pub0s)
|
||||
}
|
||||
tokens <- token
|
||||
}()
|
||||
|
||||
t1, t2 := <-tokens, <-tokens
|
||||
if !bytes.Equal(t1, t2) {
|
||||
t.Error("session token mismatch")
|
||||
}
|
||||
}
|
|
@ -0,0 +1,291 @@
|
|||
package discover
|
||||
|
||||
import (
|
||||
"crypto/ecdsa"
|
||||
"crypto/elliptic"
|
||||
"encoding/hex"
|
||||
"errors"
|
||||
"fmt"
|
||||
"io"
|
||||
"math/rand"
|
||||
"net"
|
||||
"net/url"
|
||||
"strconv"
|
||||
"strings"
|
||||
"time"
|
||||
|
||||
"github.com/ethereum/go-ethereum/crypto/secp256k1"
|
||||
"github.com/ethereum/go-ethereum/rlp"
|
||||
)
|
||||
|
||||
const nodeIDBits = 512
|
||||
|
||||
// Node represents a host on the network.
|
||||
type Node struct {
|
||||
ID NodeID
|
||||
IP net.IP
|
||||
|
||||
DiscPort int // UDP listening port for discovery protocol
|
||||
TCPPort int // TCP listening port for RLPx
|
||||
|
||||
active time.Time
|
||||
}
|
||||
|
||||
func newNode(id NodeID, addr *net.UDPAddr) *Node {
|
||||
return &Node{
|
||||
ID: id,
|
||||
IP: addr.IP,
|
||||
DiscPort: addr.Port,
|
||||
TCPPort: addr.Port,
|
||||
active: time.Now(),
|
||||
}
|
||||
}
|
||||
|
||||
func (n *Node) isValid() bool {
|
||||
// TODO: don't accept localhost, LAN addresses from internet hosts
|
||||
return !n.IP.IsMulticast() && !n.IP.IsUnspecified() && n.TCPPort != 0 && n.DiscPort != 0
|
||||
}
|
||||
|
||||
// The string representation of a Node is a URL.
|
||||
// Please see ParseNode for a description of the format.
|
||||
func (n *Node) String() string {
|
||||
addr := net.TCPAddr{IP: n.IP, Port: n.TCPPort}
|
||||
u := url.URL{
|
||||
Scheme: "enode",
|
||||
User: url.User(fmt.Sprintf("%x", n.ID[:])),
|
||||
Host: addr.String(),
|
||||
}
|
||||
if n.DiscPort != n.TCPPort {
|
||||
u.RawQuery = "discport=" + strconv.Itoa(n.DiscPort)
|
||||
}
|
||||
return u.String()
|
||||
}
|
||||
|
||||
// ParseNode parses a node URL.
|
||||
//
|
||||
// A node URL has scheme "enode".
|
||||
//
|
||||
// The hexadecimal node ID is encoded in the username portion of the
|
||||
// URL, separated from the host by an @ sign. The hostname can only be
|
||||
// given as an IP address, DNS domain names are not allowed. The port
|
||||
// in the host name section is the TCP listening port. If the TCP and
|
||||
// UDP (discovery) ports differ, the UDP port is specified as query
|
||||
// parameter "discport".
|
||||
//
|
||||
// In the following example, the node URL describes
|
||||
// a node with IP address 10.3.58.6, TCP listening port 30303
|
||||
// and UDP discovery port 30301.
|
||||
//
|
||||
// enode://<hex node id>@10.3.58.6:30303?discport=30301
|
||||
func ParseNode(rawurl string) (*Node, error) {
|
||||
var n Node
|
||||
u, err := url.Parse(rawurl)
|
||||
if u.Scheme != "enode" {
|
||||
return nil, errors.New("invalid URL scheme, want \"enode\"")
|
||||
}
|
||||
if u.User == nil {
|
||||
return nil, errors.New("does not contain node ID")
|
||||
}
|
||||
if n.ID, err = HexID(u.User.String()); err != nil {
|
||||
return nil, fmt.Errorf("invalid node ID (%v)", err)
|
||||
}
|
||||
ip, port, err := net.SplitHostPort(u.Host)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("invalid host: %v", err)
|
||||
}
|
||||
if n.IP = net.ParseIP(ip); n.IP == nil {
|
||||
return nil, errors.New("invalid IP address")
|
||||
}
|
||||
if n.TCPPort, err = strconv.Atoi(port); err != nil {
|
||||
return nil, errors.New("invalid port")
|
||||
}
|
||||
qv := u.Query()
|
||||
if qv.Get("discport") == "" {
|
||||
n.DiscPort = n.TCPPort
|
||||
} else {
|
||||
if n.DiscPort, err = strconv.Atoi(qv.Get("discport")); err != nil {
|
||||
return nil, errors.New("invalid discport in query")
|
||||
}
|
||||
}
|
||||
return &n, nil
|
||||
}
|
||||
|
||||
// MustParseNode parses a node URL. It panics if the URL is not valid.
|
||||
func MustParseNode(rawurl string) *Node {
|
||||
n, err := ParseNode(rawurl)
|
||||
if err != nil {
|
||||
panic("invalid node URL: " + err.Error())
|
||||
}
|
||||
return n
|
||||
}
|
||||
|
||||
func (n Node) EncodeRLP(w io.Writer) error {
|
||||
return rlp.Encode(w, rpcNode{IP: n.IP.String(), Port: uint16(n.TCPPort), ID: n.ID})
|
||||
}
|
||||
func (n *Node) DecodeRLP(s *rlp.Stream) (err error) {
|
||||
var ext rpcNode
|
||||
if err = s.Decode(&ext); err == nil {
|
||||
n.TCPPort = int(ext.Port)
|
||||
n.DiscPort = int(ext.Port)
|
||||
n.ID = ext.ID
|
||||
if n.IP = net.ParseIP(ext.IP); n.IP == nil {
|
||||
return errors.New("invalid IP string")
|
||||
}
|
||||
}
|
||||
return err
|
||||
}
|
||||
|
||||
// NodeID is a unique identifier for each node.
|
||||
// The node identifier is a marshaled elliptic curve public key.
|
||||
type NodeID [nodeIDBits / 8]byte
|
||||
|
||||
// NodeID prints as a long hexadecimal number.
|
||||
func (n NodeID) String() string {
|
||||
return fmt.Sprintf("%#x", n[:])
|
||||
}
|
||||
|
||||
// The Go syntax representation of a NodeID is a call to HexID.
|
||||
func (n NodeID) GoString() string {
|
||||
return fmt.Sprintf("discover.HexID(\"%#x\")", n[:])
|
||||
}
|
||||
|
||||
// HexID converts a hex string to a NodeID.
|
||||
// The string may be prefixed with 0x.
|
||||
func HexID(in string) (NodeID, error) {
|
||||
if strings.HasPrefix(in, "0x") {
|
||||
in = in[2:]
|
||||
}
|
||||
var id NodeID
|
||||
b, err := hex.DecodeString(in)
|
||||
if err != nil {
|
||||
return id, err
|
||||
} else if len(b) != len(id) {
|
||||
return id, fmt.Errorf("wrong length, need %d hex bytes", len(id))
|
||||
}
|
||||
copy(id[:], b)
|
||||
return id, nil
|
||||
}
|
||||
|
||||
// MustHexID converts a hex string to a NodeID.
|
||||
// It panics if the string is not a valid NodeID.
|
||||
func MustHexID(in string) NodeID {
|
||||
id, err := HexID(in)
|
||||
if err != nil {
|
||||
panic(err)
|
||||
}
|
||||
return id
|
||||
}
|
||||
|
||||
// PubkeyID returns a marshaled representation of the given public key.
|
||||
func PubkeyID(pub *ecdsa.PublicKey) NodeID {
|
||||
var id NodeID
|
||||
pbytes := elliptic.Marshal(pub.Curve, pub.X, pub.Y)
|
||||
if len(pbytes)-1 != len(id) {
|
||||
panic(fmt.Errorf("need %d bit pubkey, got %d bits", (len(id)+1)*8, len(pbytes)))
|
||||
}
|
||||
copy(id[:], pbytes[1:])
|
||||
return id
|
||||
}
|
||||
|
||||
// recoverNodeID computes the public key used to sign the
|
||||
// given hash from the signature.
|
||||
func recoverNodeID(hash, sig []byte) (id NodeID, err error) {
|
||||
pubkey, err := secp256k1.RecoverPubkey(hash, sig)
|
||||
if err != nil {
|
||||
return id, err
|
||||
}
|
||||
if len(pubkey)-1 != len(id) {
|
||||
return id, fmt.Errorf("recovered pubkey has %d bits, want %d bits", len(pubkey)*8, (len(id)+1)*8)
|
||||
}
|
||||
for i := range id {
|
||||
id[i] = pubkey[i+1]
|
||||
}
|
||||
return id, nil
|
||||
}
|
||||
|
||||
// distcmp compares the distances a->target and b->target.
|
||||
// Returns -1 if a is closer to target, 1 if b is closer to target
|
||||
// and 0 if they are equal.
|
||||
func distcmp(target, a, b NodeID) int {
|
||||
for i := range target {
|
||||
da := a[i] ^ target[i]
|
||||
db := b[i] ^ target[i]
|
||||
if da > db {
|
||||
return 1
|
||||
} else if da < db {
|
||||
return -1
|
||||
}
|
||||
}
|
||||
return 0
|
||||
}
|
||||
|
||||
// table of leading zero counts for bytes [0..255]
|
||||
var lzcount = [256]int{
|
||||
8, 7, 6, 6, 5, 5, 5, 5,
|
||||
4, 4, 4, 4, 4, 4, 4, 4,
|
||||
3, 3, 3, 3, 3, 3, 3, 3,
|
||||
3, 3, 3, 3, 3, 3, 3, 3,
|
||||
2, 2, 2, 2, 2, 2, 2, 2,
|
||||
2, 2, 2, 2, 2, 2, 2, 2,
|
||||
2, 2, 2, 2, 2, 2, 2, 2,
|
||||
2, 2, 2, 2, 2, 2, 2, 2,
|
||||
1, 1, 1, 1, 1, 1, 1, 1,
|
||||
1, 1, 1, 1, 1, 1, 1, 1,
|
||||
1, 1, 1, 1, 1, 1, 1, 1,
|
||||
1, 1, 1, 1, 1, 1, 1, 1,
|
||||
1, 1, 1, 1, 1, 1, 1, 1,
|
||||
1, 1, 1, 1, 1, 1, 1, 1,
|
||||
1, 1, 1, 1, 1, 1, 1, 1,
|
||||
1, 1, 1, 1, 1, 1, 1, 1,
|
||||
0, 0, 0, 0, 0, 0, 0, 0,
|
||||
0, 0, 0, 0, 0, 0, 0, 0,
|
||||
0, 0, 0, 0, 0, 0, 0, 0,
|
||||
0, 0, 0, 0, 0, 0, 0, 0,
|
||||
0, 0, 0, 0, 0, 0, 0, 0,
|
||||
0, 0, 0, 0, 0, 0, 0, 0,
|
||||
0, 0, 0, 0, 0, 0, 0, 0,
|
||||
0, 0, 0, 0, 0, 0, 0, 0,
|
||||
0, 0, 0, 0, 0, 0, 0, 0,
|
||||
0, 0, 0, 0, 0, 0, 0, 0,
|
||||
0, 0, 0, 0, 0, 0, 0, 0,
|
||||
0, 0, 0, 0, 0, 0, 0, 0,
|
||||
0, 0, 0, 0, 0, 0, 0, 0,
|
||||
0, 0, 0, 0, 0, 0, 0, 0,
|
||||
0, 0, 0, 0, 0, 0, 0, 0,
|
||||
0, 0, 0, 0, 0, 0, 0, 0,
|
||||
}
|
||||
|
||||
// logdist returns the logarithmic distance between a and b, log2(a ^ b).
|
||||
func logdist(a, b NodeID) int {
|
||||
lz := 0
|
||||
for i := range a {
|
||||
x := a[i] ^ b[i]
|
||||
if x == 0 {
|
||||
lz += 8
|
||||
} else {
|
||||
lz += lzcount[x]
|
||||
break
|
||||
}
|
||||
}
|
||||
return len(a)*8 - lz
|
||||
}
|
||||
|
||||
// randomID returns a random NodeID such that logdist(a, b) == n
|
||||
func randomID(a NodeID, n int) (b NodeID) {
|
||||
if n == 0 {
|
||||
return a
|
||||
}
|
||||
// flip bit at position n, fill the rest with random bits
|
||||
b = a
|
||||
pos := len(a) - n/8 - 1
|
||||
bit := byte(0x01) << (byte(n%8) - 1)
|
||||
if bit == 0 {
|
||||
pos++
|
||||
bit = 0x80
|
||||
}
|
||||
b[pos] = a[pos]&^bit | ^a[pos]&bit // TODO: randomize end bits
|
||||
for i := pos + 1; i < len(a); i++ {
|
||||
b[i] = byte(rand.Intn(255))
|
||||
}
|
||||
return b
|
||||
}
|
|
@ -0,0 +1,201 @@
|
|||
package discover
|
||||
|
||||
import (
|
||||
"math/big"
|
||||
"math/rand"
|
||||
"net"
|
||||
"reflect"
|
||||
"testing"
|
||||
"testing/quick"
|
||||
"time"
|
||||
|
||||
"github.com/ethereum/go-ethereum/crypto"
|
||||
)
|
||||
|
||||
var (
|
||||
quickrand = rand.New(rand.NewSource(time.Now().Unix()))
|
||||
quickcfg = &quick.Config{MaxCount: 5000, Rand: quickrand}
|
||||
)
|
||||
|
||||
var parseNodeTests = []struct {
|
||||
rawurl string
|
||||
wantError string
|
||||
wantResult *Node
|
||||
}{
|
||||
{
|
||||
rawurl: "http://foobar",
|
||||
wantError: `invalid URL scheme, want "enode"`,
|
||||
},
|
||||
{
|
||||
rawurl: "enode://foobar",
|
||||
wantError: `does not contain node ID`,
|
||||
},
|
||||
{
|
||||
rawurl: "enode://01010101@123.124.125.126:3",
|
||||
wantError: `invalid node ID (wrong length, need 64 hex bytes)`,
|
||||
},
|
||||
{
|
||||
rawurl: "enode://1dd9d65c4552b5eb43d5ad55a2ee3f56c6cbc1c64a5c8d659f51fcd51bace24351232b8d7821617d2b29b54b81cdefb9b3e9c37d7fd5f63270bcc9e1a6f6a439@hostname:3",
|
||||
wantError: `invalid IP address`,
|
||||
},
|
||||
{
|
||||
rawurl: "enode://1dd9d65c4552b5eb43d5ad55a2ee3f56c6cbc1c64a5c8d659f51fcd51bace24351232b8d7821617d2b29b54b81cdefb9b3e9c37d7fd5f63270bcc9e1a6f6a439@127.0.0.1:foo",
|
||||
wantError: `invalid port`,
|
||||
},
|
||||
{
|
||||
rawurl: "enode://1dd9d65c4552b5eb43d5ad55a2ee3f56c6cbc1c64a5c8d659f51fcd51bace24351232b8d7821617d2b29b54b81cdefb9b3e9c37d7fd5f63270bcc9e1a6f6a439@127.0.0.1:3?discport=foo",
|
||||
wantError: `invalid discport in query`,
|
||||
},
|
||||
{
|
||||
rawurl: "enode://1dd9d65c4552b5eb43d5ad55a2ee3f56c6cbc1c64a5c8d659f51fcd51bace24351232b8d7821617d2b29b54b81cdefb9b3e9c37d7fd5f63270bcc9e1a6f6a439@127.0.0.1:52150",
|
||||
wantResult: &Node{
|
||||
ID: MustHexID("0x1dd9d65c4552b5eb43d5ad55a2ee3f56c6cbc1c64a5c8d659f51fcd51bace24351232b8d7821617d2b29b54b81cdefb9b3e9c37d7fd5f63270bcc9e1a6f6a439"),
|
||||
IP: net.ParseIP("127.0.0.1"),
|
||||
DiscPort: 52150,
|
||||
TCPPort: 52150,
|
||||
},
|
||||
},
|
||||
{
|
||||
rawurl: "enode://1dd9d65c4552b5eb43d5ad55a2ee3f56c6cbc1c64a5c8d659f51fcd51bace24351232b8d7821617d2b29b54b81cdefb9b3e9c37d7fd5f63270bcc9e1a6f6a439@[::]:52150",
|
||||
wantResult: &Node{
|
||||
ID: MustHexID("0x1dd9d65c4552b5eb43d5ad55a2ee3f56c6cbc1c64a5c8d659f51fcd51bace24351232b8d7821617d2b29b54b81cdefb9b3e9c37d7fd5f63270bcc9e1a6f6a439"),
|
||||
IP: net.ParseIP("::"),
|
||||
DiscPort: 52150,
|
||||
TCPPort: 52150,
|
||||
},
|
||||
},
|
||||
{
|
||||
rawurl: "enode://1dd9d65c4552b5eb43d5ad55a2ee3f56c6cbc1c64a5c8d659f51fcd51bace24351232b8d7821617d2b29b54b81cdefb9b3e9c37d7fd5f63270bcc9e1a6f6a439@127.0.0.1:52150?discport=223344",
|
||||
wantResult: &Node{
|
||||
ID: MustHexID("0x1dd9d65c4552b5eb43d5ad55a2ee3f56c6cbc1c64a5c8d659f51fcd51bace24351232b8d7821617d2b29b54b81cdefb9b3e9c37d7fd5f63270bcc9e1a6f6a439"),
|
||||
IP: net.ParseIP("127.0.0.1"),
|
||||
DiscPort: 223344,
|
||||
TCPPort: 52150,
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
func TestParseNode(t *testing.T) {
|
||||
for i, test := range parseNodeTests {
|
||||
n, err := ParseNode(test.rawurl)
|
||||
if err == nil && test.wantError != "" {
|
||||
t.Errorf("test %d: got nil error, expected %#q", i, test.wantError)
|
||||
continue
|
||||
}
|
||||
if err != nil && err.Error() != test.wantError {
|
||||
t.Errorf("test %d: got error %#q, expected %#q", i, err.Error(), test.wantError)
|
||||
continue
|
||||
}
|
||||
if !reflect.DeepEqual(n, test.wantResult) {
|
||||
t.Errorf("test %d: result mismatch:\ngot: %#v, want: %#v", i, n, test.wantResult)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func TestNodeString(t *testing.T) {
|
||||
for i, test := range parseNodeTests {
|
||||
if test.wantError != "" {
|
||||
continue
|
||||
}
|
||||
str := test.wantResult.String()
|
||||
if str != test.rawurl {
|
||||
t.Errorf("test %d: Node.String() mismatch:\ngot: %s\nwant: %s", i, str, test.rawurl)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func TestHexID(t *testing.T) {
|
||||
ref := NodeID{0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 128, 106, 217, 182, 31, 165, 174, 1, 67, 7, 235, 220, 150, 66, 83, 173, 205, 159, 44, 10, 57, 42, 161, 26, 188}
|
||||
id1 := MustHexID("0x000000000000000000000000000000000000000000000000000000000000000000000000000000806ad9b61fa5ae014307ebdc964253adcd9f2c0a392aa11abc")
|
||||
id2 := MustHexID("000000000000000000000000000000000000000000000000000000000000000000000000000000806ad9b61fa5ae014307ebdc964253adcd9f2c0a392aa11abc")
|
||||
|
||||
if id1 != ref {
|
||||
t.Errorf("wrong id1\ngot %v\nwant %v", id1[:], ref[:])
|
||||
}
|
||||
if id2 != ref {
|
||||
t.Errorf("wrong id2\ngot %v\nwant %v", id2[:], ref[:])
|
||||
}
|
||||
}
|
||||
|
||||
func TestNodeID_recover(t *testing.T) {
|
||||
prv := newkey()
|
||||
hash := make([]byte, 32)
|
||||
sig, err := crypto.Sign(hash, prv)
|
||||
if err != nil {
|
||||
t.Fatalf("signing error: %v", err)
|
||||
}
|
||||
|
||||
pub := PubkeyID(&prv.PublicKey)
|
||||
recpub, err := recoverNodeID(hash, sig)
|
||||
if err != nil {
|
||||
t.Fatalf("recovery error: %v", err)
|
||||
}
|
||||
if pub != recpub {
|
||||
t.Errorf("recovered wrong pubkey:\ngot: %v\nwant: %v", recpub, pub)
|
||||
}
|
||||
}
|
||||
|
||||
func TestNodeID_distcmp(t *testing.T) {
|
||||
distcmpBig := func(target, a, b NodeID) int {
|
||||
tbig := new(big.Int).SetBytes(target[:])
|
||||
abig := new(big.Int).SetBytes(a[:])
|
||||
bbig := new(big.Int).SetBytes(b[:])
|
||||
return new(big.Int).Xor(tbig, abig).Cmp(new(big.Int).Xor(tbig, bbig))
|
||||
}
|
||||
if err := quick.CheckEqual(distcmp, distcmpBig, quickcfg); err != nil {
|
||||
t.Error(err)
|
||||
}
|
||||
}
|
||||
|
||||
// the random tests is likely to miss the case where they're equal.
|
||||
func TestNodeID_distcmpEqual(t *testing.T) {
|
||||
base := NodeID{0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15}
|
||||
x := NodeID{15, 14, 13, 12, 11, 10, 9, 8, 7, 6, 5, 4, 3, 2, 1, 0}
|
||||
if distcmp(base, x, x) != 0 {
|
||||
t.Errorf("distcmp(base, x, x) != 0")
|
||||
}
|
||||
}
|
||||
|
||||
func TestNodeID_logdist(t *testing.T) {
|
||||
logdistBig := func(a, b NodeID) int {
|
||||
abig, bbig := new(big.Int).SetBytes(a[:]), new(big.Int).SetBytes(b[:])
|
||||
return new(big.Int).Xor(abig, bbig).BitLen()
|
||||
}
|
||||
if err := quick.CheckEqual(logdist, logdistBig, quickcfg); err != nil {
|
||||
t.Error(err)
|
||||
}
|
||||
}
|
||||
|
||||
// the random tests is likely to miss the case where they're equal.
|
||||
func TestNodeID_logdistEqual(t *testing.T) {
|
||||
x := NodeID{0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15}
|
||||
if logdist(x, x) != 0 {
|
||||
t.Errorf("logdist(x, x) != 0")
|
||||
}
|
||||
}
|
||||
|
||||
func TestNodeID_randomID(t *testing.T) {
|
||||
// we don't use quick.Check here because its output isn't
|
||||
// very helpful when the test fails.
|
||||
for i := 0; i < quickcfg.MaxCount; i++ {
|
||||
a := gen(NodeID{}, quickrand).(NodeID)
|
||||
dist := quickrand.Intn(len(NodeID{}) * 8)
|
||||
result := randomID(a, dist)
|
||||
actualdist := logdist(result, a)
|
||||
|
||||
if dist != actualdist {
|
||||
t.Log("a: ", a)
|
||||
t.Log("result:", result)
|
||||
t.Fatalf("#%d: distance of result is %d, want %d", i, actualdist, dist)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func (NodeID) Generate(rand *rand.Rand, size int) reflect.Value {
|
||||
var id NodeID
|
||||
m := rand.Intn(len(id))
|
||||
for i := len(id) - 1; i > m; i-- {
|
||||
id[i] = byte(rand.Uint32())
|
||||
}
|
||||
return reflect.ValueOf(id)
|
||||
}
|
|
@ -0,0 +1,280 @@
|
|||
// Package discover implements the Node Discovery Protocol.
|
||||
//
|
||||
// The Node Discovery protocol provides a way to find RLPx nodes that
|
||||
// can be connected to. It uses a Kademlia-like protocol to maintain a
|
||||
// distributed database of the IDs and endpoints of all listening
|
||||
// nodes.
|
||||
package discover
|
||||
|
||||
import (
|
||||
"net"
|
||||
"sort"
|
||||
"sync"
|
||||
"time"
|
||||
)
|
||||
|
||||
const (
|
||||
alpha = 3 // Kademlia concurrency factor
|
||||
bucketSize = 16 // Kademlia bucket size
|
||||
nBuckets = nodeIDBits + 1 // Number of buckets
|
||||
)
|
||||
|
||||
type Table struct {
|
||||
mutex sync.Mutex // protects buckets, their content, and nursery
|
||||
buckets [nBuckets]*bucket // index of known nodes by distance
|
||||
nursery []*Node // bootstrap nodes
|
||||
|
||||
net transport
|
||||
self *Node // metadata of the local node
|
||||
}
|
||||
|
||||
// transport is implemented by the UDP transport.
|
||||
// it is an interface so we can test without opening lots of UDP
|
||||
// sockets and without generating a private key.
|
||||
type transport interface {
|
||||
ping(*Node) error
|
||||
findnode(e *Node, target NodeID) ([]*Node, error)
|
||||
close()
|
||||
}
|
||||
|
||||
// bucket contains nodes, ordered by their last activity.
|
||||
type bucket struct {
|
||||
lastLookup time.Time
|
||||
entries []*Node
|
||||
}
|
||||
|
||||
func newTable(t transport, ourID NodeID, ourAddr *net.UDPAddr) *Table {
|
||||
tab := &Table{net: t, self: newNode(ourID, ourAddr)}
|
||||
for i := range tab.buckets {
|
||||
tab.buckets[i] = new(bucket)
|
||||
}
|
||||
return tab
|
||||
}
|
||||
|
||||
// Self returns the local node ID.
|
||||
func (tab *Table) Self() NodeID {
|
||||
return tab.self.ID
|
||||
}
|
||||
|
||||
// Close terminates the network listener.
|
||||
func (tab *Table) Close() {
|
||||
tab.net.close()
|
||||
}
|
||||
|
||||
// Bootstrap sets the bootstrap nodes. These nodes are used to connect
|
||||
// to the network if the table is empty. Bootstrap will also attempt to
|
||||
// fill the table by performing random lookup operations on the
|
||||
// network.
|
||||
func (tab *Table) Bootstrap(nodes []*Node) {
|
||||
tab.mutex.Lock()
|
||||
// TODO: maybe filter nodes with bad fields (nil, etc.) to avoid strange crashes
|
||||
tab.nursery = make([]*Node, 0, len(nodes))
|
||||
for _, n := range nodes {
|
||||
cpy := *n
|
||||
tab.nursery = append(tab.nursery, &cpy)
|
||||
}
|
||||
tab.mutex.Unlock()
|
||||
tab.refresh()
|
||||
}
|
||||
|
||||
// Lookup performs a network search for nodes close
|
||||
// to the given target. It approaches the target by querying
|
||||
// nodes that are closer to it on each iteration.
|
||||
func (tab *Table) Lookup(target NodeID) []*Node {
|
||||
var (
|
||||
asked = make(map[NodeID]bool)
|
||||
seen = make(map[NodeID]bool)
|
||||
reply = make(chan []*Node, alpha)
|
||||
pendingQueries = 0
|
||||
)
|
||||
// don't query further if we hit the target or ourself.
|
||||
// unlikely to happen often in practice.
|
||||
asked[target] = true
|
||||
asked[tab.self.ID] = true
|
||||
|
||||
tab.mutex.Lock()
|
||||
// update last lookup stamp (for refresh logic)
|
||||
tab.buckets[logdist(tab.self.ID, target)].lastLookup = time.Now()
|
||||
// generate initial result set
|
||||
result := tab.closest(target, bucketSize)
|
||||
tab.mutex.Unlock()
|
||||
|
||||
for {
|
||||
// ask the alpha closest nodes that we haven't asked yet
|
||||
for i := 0; i < len(result.entries) && pendingQueries < alpha; i++ {
|
||||
n := result.entries[i]
|
||||
if !asked[n.ID] {
|
||||
asked[n.ID] = true
|
||||
pendingQueries++
|
||||
go func() {
|
||||
result, _ := tab.net.findnode(n, target)
|
||||
reply <- result
|
||||
}()
|
||||
}
|
||||
}
|
||||
if pendingQueries == 0 {
|
||||
// we have asked all closest nodes, stop the search
|
||||
break
|
||||
}
|
||||
|
||||
// wait for the next reply
|
||||
for _, n := range <-reply {
|
||||
cn := n
|
||||
if !seen[n.ID] {
|
||||
seen[n.ID] = true
|
||||
result.push(cn, bucketSize)
|
||||
}
|
||||
}
|
||||
pendingQueries--
|
||||
}
|
||||
return result.entries
|
||||
}
|
||||
|
||||
// refresh performs a lookup for a random target to keep buckets full.
|
||||
func (tab *Table) refresh() {
|
||||
ld := -1 // logdist of chosen bucket
|
||||
tab.mutex.Lock()
|
||||
for i, b := range tab.buckets {
|
||||
if i > 0 && b.lastLookup.Before(time.Now().Add(-1*time.Hour)) {
|
||||
ld = i
|
||||
break
|
||||
}
|
||||
}
|
||||
tab.mutex.Unlock()
|
||||
|
||||
result := tab.Lookup(randomID(tab.self.ID, ld))
|
||||
if len(result) == 0 {
|
||||
// bootstrap the table with a self lookup
|
||||
tab.mutex.Lock()
|
||||
tab.add(tab.nursery)
|
||||
tab.mutex.Unlock()
|
||||
tab.Lookup(tab.self.ID)
|
||||
// TODO: the Kademlia paper says that we're supposed to perform
|
||||
// random lookups in all buckets further away than our closest neighbor.
|
||||
}
|
||||
}
|
||||
|
||||
// closest returns the n nodes in the table that are closest to the
|
||||
// given id. The caller must hold tab.mutex.
|
||||
func (tab *Table) closest(target NodeID, nresults int) *nodesByDistance {
|
||||
// This is a very wasteful way to find the closest nodes but
|
||||
// obviously correct. I believe that tree-based buckets would make
|
||||
// this easier to implement efficiently.
|
||||
close := &nodesByDistance{target: target}
|
||||
for _, b := range tab.buckets {
|
||||
for _, n := range b.entries {
|
||||
close.push(n, nresults)
|
||||
}
|
||||
}
|
||||
return close
|
||||
}
|
||||
|
||||
func (tab *Table) len() (n int) {
|
||||
for _, b := range tab.buckets {
|
||||
n += len(b.entries)
|
||||
}
|
||||
return n
|
||||
}
|
||||
|
||||
// bumpOrAdd updates the activity timestamp for the given node and
|
||||
// attempts to insert the node into a bucket. The returned Node might
|
||||
// not be part of the table. The caller must hold tab.mutex.
|
||||
func (tab *Table) bumpOrAdd(node NodeID, from *net.UDPAddr) (n *Node) {
|
||||
b := tab.buckets[logdist(tab.self.ID, node)]
|
||||
if n = b.bump(node); n == nil {
|
||||
n = newNode(node, from)
|
||||
if len(b.entries) == bucketSize {
|
||||
tab.pingReplace(n, b)
|
||||
} else {
|
||||
b.entries = append(b.entries, n)
|
||||
}
|
||||
}
|
||||
return n
|
||||
}
|
||||
|
||||
func (tab *Table) pingReplace(n *Node, b *bucket) {
|
||||
old := b.entries[bucketSize-1]
|
||||
go func() {
|
||||
if err := tab.net.ping(old); err == nil {
|
||||
// it responded, we don't need to replace it.
|
||||
return
|
||||
}
|
||||
// it didn't respond, replace the node if it is still the oldest node.
|
||||
tab.mutex.Lock()
|
||||
if len(b.entries) > 0 && b.entries[len(b.entries)-1] == old {
|
||||
// slide down other entries and put the new one in front.
|
||||
// TODO: insert in correct position to keep the order
|
||||
copy(b.entries[1:], b.entries)
|
||||
b.entries[0] = n
|
||||
}
|
||||
tab.mutex.Unlock()
|
||||
}()
|
||||
}
|
||||
|
||||
// bump updates the activity timestamp for the given node.
|
||||
// The caller must hold tab.mutex.
|
||||
func (tab *Table) bump(node NodeID) {
|
||||
tab.buckets[logdist(tab.self.ID, node)].bump(node)
|
||||
}
|
||||
|
||||
// add puts the entries into the table if their corresponding
|
||||
// bucket is not full. The caller must hold tab.mutex.
|
||||
func (tab *Table) add(entries []*Node) {
|
||||
outer:
|
||||
for _, n := range entries {
|
||||
if n == nil || n.ID == tab.self.ID {
|
||||
// skip bad entries. The RLP decoder returns nil for empty
|
||||
// input lists.
|
||||
continue
|
||||
}
|
||||
bucket := tab.buckets[logdist(tab.self.ID, n.ID)]
|
||||
for i := range bucket.entries {
|
||||
if bucket.entries[i].ID == n.ID {
|
||||
// already in bucket
|
||||
continue outer
|
||||
}
|
||||
}
|
||||
if len(bucket.entries) < bucketSize {
|
||||
bucket.entries = append(bucket.entries, n)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func (b *bucket) bump(id NodeID) *Node {
|
||||
for i, n := range b.entries {
|
||||
if n.ID == id {
|
||||
n.active = time.Now()
|
||||
// move it to the front
|
||||
copy(b.entries[1:], b.entries[:i+1])
|
||||
b.entries[0] = n
|
||||
return n
|
||||
}
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
// nodesByDistance is a list of nodes, ordered by
|
||||
// distance to target.
|
||||
type nodesByDistance struct {
|
||||
entries []*Node
|
||||
target NodeID
|
||||
}
|
||||
|
||||
// push adds the given node to the list, keeping the total size below maxElems.
|
||||
func (h *nodesByDistance) push(n *Node, maxElems int) {
|
||||
ix := sort.Search(len(h.entries), func(i int) bool {
|
||||
return distcmp(h.target, h.entries[i].ID, n.ID) > 0
|
||||
})
|
||||
if len(h.entries) < maxElems {
|
||||
h.entries = append(h.entries, n)
|
||||
}
|
||||
if ix == len(h.entries) {
|
||||
// farther away than all nodes we already have.
|
||||
// if there was room for it, the node is now the last element.
|
||||
} else {
|
||||
// slide existing entries down to make room
|
||||
// this will overwrite the entry we just appended.
|
||||
copy(h.entries[ix+1:], h.entries[ix:])
|
||||
h.entries[ix] = n
|
||||
}
|
||||
}
|
|
@ -0,0 +1,311 @@
|
|||
package discover
|
||||
|
||||
import (
|
||||
"crypto/ecdsa"
|
||||
"errors"
|
||||
"fmt"
|
||||
"math/rand"
|
||||
"net"
|
||||
"reflect"
|
||||
"testing"
|
||||
"testing/quick"
|
||||
"time"
|
||||
|
||||
"github.com/ethereum/go-ethereum/crypto"
|
||||
)
|
||||
|
||||
func TestTable_bumpOrAddBucketAssign(t *testing.T) {
|
||||
tab := newTable(nil, NodeID{}, &net.UDPAddr{})
|
||||
for i := 1; i < len(tab.buckets); i++ {
|
||||
tab.bumpOrAdd(randomID(tab.self.ID, i), &net.UDPAddr{})
|
||||
}
|
||||
for i, b := range tab.buckets {
|
||||
if i > 0 && len(b.entries) != 1 {
|
||||
t.Errorf("bucket %d has %d entries, want 1", i, len(b.entries))
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func TestTable_bumpOrAddPingReplace(t *testing.T) {
|
||||
pingC := make(pingC)
|
||||
tab := newTable(pingC, NodeID{}, &net.UDPAddr{})
|
||||
last := fillBucket(tab, 200)
|
||||
|
||||
// this bumpOrAdd should not replace the last node
|
||||
// because the node replies to ping.
|
||||
new := tab.bumpOrAdd(randomID(tab.self.ID, 200), &net.UDPAddr{})
|
||||
|
||||
pinged := <-pingC
|
||||
if pinged != last.ID {
|
||||
t.Fatalf("pinged wrong node: %v\nwant %v", pinged, last.ID)
|
||||
}
|
||||
|
||||
tab.mutex.Lock()
|
||||
defer tab.mutex.Unlock()
|
||||
if l := len(tab.buckets[200].entries); l != bucketSize {
|
||||
t.Errorf("wrong bucket size after bumpOrAdd: got %d, want %d", bucketSize, l)
|
||||
}
|
||||
if !contains(tab.buckets[200].entries, last.ID) {
|
||||
t.Error("last entry was removed")
|
||||
}
|
||||
if contains(tab.buckets[200].entries, new.ID) {
|
||||
t.Error("new entry was added")
|
||||
}
|
||||
}
|
||||
|
||||
func TestTable_bumpOrAddPingTimeout(t *testing.T) {
|
||||
tab := newTable(pingC(nil), NodeID{}, &net.UDPAddr{})
|
||||
last := fillBucket(tab, 200)
|
||||
|
||||
// this bumpOrAdd should replace the last node
|
||||
// because the node does not reply to ping.
|
||||
new := tab.bumpOrAdd(randomID(tab.self.ID, 200), &net.UDPAddr{})
|
||||
|
||||
// wait for async bucket update. damn. this needs to go away.
|
||||
time.Sleep(2 * time.Millisecond)
|
||||
|
||||
tab.mutex.Lock()
|
||||
defer tab.mutex.Unlock()
|
||||
if l := len(tab.buckets[200].entries); l != bucketSize {
|
||||
t.Errorf("wrong bucket size after bumpOrAdd: got %d, want %d", bucketSize, l)
|
||||
}
|
||||
if contains(tab.buckets[200].entries, last.ID) {
|
||||
t.Error("last entry was not removed")
|
||||
}
|
||||
if !contains(tab.buckets[200].entries, new.ID) {
|
||||
t.Error("new entry was not added")
|
||||
}
|
||||
}
|
||||
|
||||
func fillBucket(tab *Table, ld int) (last *Node) {
|
||||
b := tab.buckets[ld]
|
||||
for len(b.entries) < bucketSize {
|
||||
b.entries = append(b.entries, &Node{ID: randomID(tab.self.ID, ld)})
|
||||
}
|
||||
return b.entries[bucketSize-1]
|
||||
}
|
||||
|
||||
type pingC chan NodeID
|
||||
|
||||
func (t pingC) findnode(n *Node, target NodeID) ([]*Node, error) {
|
||||
panic("findnode called on pingRecorder")
|
||||
}
|
||||
func (t pingC) close() {
|
||||
panic("close called on pingRecorder")
|
||||
}
|
||||
func (t pingC) ping(n *Node) error {
|
||||
if t == nil {
|
||||
return errTimeout
|
||||
}
|
||||
t <- n.ID
|
||||
return nil
|
||||
}
|
||||
|
||||
func TestTable_bump(t *testing.T) {
|
||||
tab := newTable(nil, NodeID{}, &net.UDPAddr{})
|
||||
|
||||
// add an old entry and two recent ones
|
||||
oldactive := time.Now().Add(-2 * time.Minute)
|
||||
old := &Node{ID: randomID(tab.self.ID, 200), active: oldactive}
|
||||
others := []*Node{
|
||||
&Node{ID: randomID(tab.self.ID, 200), active: time.Now()},
|
||||
&Node{ID: randomID(tab.self.ID, 200), active: time.Now()},
|
||||
}
|
||||
tab.add(append(others, old))
|
||||
if tab.buckets[200].entries[0] == old {
|
||||
t.Fatal("old entry is at front of bucket")
|
||||
}
|
||||
|
||||
// bumping the old entry should move it to the front
|
||||
tab.bump(old.ID)
|
||||
if old.active == oldactive {
|
||||
t.Error("activity timestamp not updated")
|
||||
}
|
||||
if tab.buckets[200].entries[0] != old {
|
||||
t.Errorf("bumped entry did not move to the front of bucket")
|
||||
}
|
||||
}
|
||||
|
||||
func TestTable_closest(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
test := func(test *closeTest) bool {
|
||||
// for any node table, Target and N
|
||||
tab := newTable(nil, test.Self, &net.UDPAddr{})
|
||||
tab.add(test.All)
|
||||
|
||||
// check that doClosest(Target, N) returns nodes
|
||||
result := tab.closest(test.Target, test.N).entries
|
||||
if hasDuplicates(result) {
|
||||
t.Errorf("result contains duplicates")
|
||||
return false
|
||||
}
|
||||
if !sortedByDistanceTo(test.Target, result) {
|
||||
t.Errorf("result is not sorted by distance to target")
|
||||
return false
|
||||
}
|
||||
|
||||
// check that the number of results is min(N, tablen)
|
||||
wantN := test.N
|
||||
if tlen := tab.len(); tlen < test.N {
|
||||
wantN = tlen
|
||||
}
|
||||
if len(result) != wantN {
|
||||
t.Errorf("wrong number of nodes: got %d, want %d", len(result), wantN)
|
||||
return false
|
||||
} else if len(result) == 0 {
|
||||
return true // no need to check distance
|
||||
}
|
||||
|
||||
// check that the result nodes have minimum distance to target.
|
||||
for _, b := range tab.buckets {
|
||||
for _, n := range b.entries {
|
||||
if contains(result, n.ID) {
|
||||
continue // don't run the check below for nodes in result
|
||||
}
|
||||
farthestResult := result[len(result)-1].ID
|
||||
if distcmp(test.Target, n.ID, farthestResult) < 0 {
|
||||
t.Errorf("table contains node that is closer to target but it's not in result")
|
||||
t.Logf(" Target: %v", test.Target)
|
||||
t.Logf(" Farthest Result: %v", farthestResult)
|
||||
t.Logf(" ID: %v", n.ID)
|
||||
return false
|
||||
}
|
||||
}
|
||||
}
|
||||
return true
|
||||
}
|
||||
if err := quick.Check(test, quickcfg); err != nil {
|
||||
t.Error(err)
|
||||
}
|
||||
}
|
||||
|
||||
type closeTest struct {
|
||||
Self NodeID
|
||||
Target NodeID
|
||||
All []*Node
|
||||
N int
|
||||
}
|
||||
|
||||
func (*closeTest) Generate(rand *rand.Rand, size int) reflect.Value {
|
||||
t := &closeTest{
|
||||
Self: gen(NodeID{}, rand).(NodeID),
|
||||
Target: gen(NodeID{}, rand).(NodeID),
|
||||
N: rand.Intn(bucketSize),
|
||||
}
|
||||
for _, id := range gen([]NodeID{}, rand).([]NodeID) {
|
||||
t.All = append(t.All, &Node{ID: id})
|
||||
}
|
||||
return reflect.ValueOf(t)
|
||||
}
|
||||
|
||||
func TestTable_Lookup(t *testing.T) {
|
||||
self := gen(NodeID{}, quickrand).(NodeID)
|
||||
target := randomID(self, 200)
|
||||
transport := findnodeOracle{t, target}
|
||||
tab := newTable(transport, self, &net.UDPAddr{})
|
||||
|
||||
// lookup on empty table returns no nodes
|
||||
if results := tab.Lookup(target); len(results) > 0 {
|
||||
t.Fatalf("lookup on empty table returned %d results: %#v", len(results), results)
|
||||
}
|
||||
// seed table with initial node (otherwise lookup will terminate immediately)
|
||||
tab.bumpOrAdd(randomID(target, 200), &net.UDPAddr{Port: 200})
|
||||
|
||||
results := tab.Lookup(target)
|
||||
t.Logf("results:")
|
||||
for _, e := range results {
|
||||
t.Logf(" ld=%d, %v", logdist(target, e.ID), e.ID)
|
||||
}
|
||||
if len(results) != bucketSize {
|
||||
t.Errorf("wrong number of results: got %d, want %d", len(results), bucketSize)
|
||||
}
|
||||
if hasDuplicates(results) {
|
||||
t.Errorf("result set contains duplicate entries")
|
||||
}
|
||||
if !sortedByDistanceTo(target, results) {
|
||||
t.Errorf("result set not sorted by distance to target")
|
||||
}
|
||||
if !contains(results, target) {
|
||||
t.Errorf("result set does not contain target")
|
||||
}
|
||||
}
|
||||
|
||||
// findnode on this transport always returns at least one node
|
||||
// that is one bucket closer to the target.
|
||||
type findnodeOracle struct {
|
||||
t *testing.T
|
||||
target NodeID
|
||||
}
|
||||
|
||||
func (t findnodeOracle) findnode(n *Node, target NodeID) ([]*Node, error) {
|
||||
t.t.Logf("findnode query at dist %d", n.DiscPort)
|
||||
// current log distance is encoded in port number
|
||||
var result []*Node
|
||||
switch n.DiscPort {
|
||||
case 0:
|
||||
panic("query to node at distance 0")
|
||||
default:
|
||||
// TODO: add more randomness to distances
|
||||
next := n.DiscPort - 1
|
||||
for i := 0; i < bucketSize; i++ {
|
||||
result = append(result, &Node{ID: randomID(t.target, next), DiscPort: next})
|
||||
}
|
||||
}
|
||||
return result, nil
|
||||
}
|
||||
|
||||
func (t findnodeOracle) close() {}
|
||||
|
||||
func (t findnodeOracle) ping(n *Node) error {
|
||||
return errors.New("ping is not supported by this transport")
|
||||
}
|
||||
|
||||
func hasDuplicates(slice []*Node) bool {
|
||||
seen := make(map[NodeID]bool)
|
||||
for _, e := range slice {
|
||||
if seen[e.ID] {
|
||||
return true
|
||||
}
|
||||
seen[e.ID] = true
|
||||
}
|
||||
return false
|
||||
}
|
||||
|
||||
func sortedByDistanceTo(distbase NodeID, slice []*Node) bool {
|
||||
var last NodeID
|
||||
for i, e := range slice {
|
||||
if i > 0 && distcmp(distbase, e.ID, last) < 0 {
|
||||
return false
|
||||
}
|
||||
last = e.ID
|
||||
}
|
||||
return true
|
||||
}
|
||||
|
||||
func contains(ns []*Node, id NodeID) bool {
|
||||
for _, n := range ns {
|
||||
if n.ID == id {
|
||||
return true
|
||||
}
|
||||
}
|
||||
return false
|
||||
}
|
||||
|
||||
// gen wraps quick.Value so it's easier to use.
|
||||
// it generates a random value of the given value's type.
|
||||
func gen(typ interface{}, rand *rand.Rand) interface{} {
|
||||
v, ok := quick.Value(reflect.TypeOf(typ), rand)
|
||||
if !ok {
|
||||
panic(fmt.Sprintf("couldn't generate random value of type %T", typ))
|
||||
}
|
||||
return v.Interface()
|
||||
}
|
||||
|
||||
func newkey() *ecdsa.PrivateKey {
|
||||
key, err := crypto.GenerateKey()
|
||||
if err != nil {
|
||||
panic("couldn't generate key: " + err.Error())
|
||||
}
|
||||
return key
|
||||
}
|
|
@ -0,0 +1,431 @@
|
|||
package discover
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"crypto/ecdsa"
|
||||
"errors"
|
||||
"fmt"
|
||||
"net"
|
||||
"time"
|
||||
|
||||
"github.com/ethereum/go-ethereum/crypto"
|
||||
"github.com/ethereum/go-ethereum/logger"
|
||||
"github.com/ethereum/go-ethereum/p2p/nat"
|
||||
"github.com/ethereum/go-ethereum/rlp"
|
||||
)
|
||||
|
||||
var log = logger.NewLogger("P2P Discovery")
|
||||
|
||||
// Errors
|
||||
var (
|
||||
errPacketTooSmall = errors.New("too small")
|
||||
errBadHash = errors.New("bad hash")
|
||||
errExpired = errors.New("expired")
|
||||
errTimeout = errors.New("RPC timeout")
|
||||
errClosed = errors.New("socket closed")
|
||||
)
|
||||
|
||||
// Timeouts
|
||||
const (
|
||||
respTimeout = 300 * time.Millisecond
|
||||
sendTimeout = 300 * time.Millisecond
|
||||
expiration = 20 * time.Second
|
||||
|
||||
refreshInterval = 1 * time.Hour
|
||||
)
|
||||
|
||||
// RPC packet types
|
||||
const (
|
||||
pingPacket = iota + 1 // zero is 'reserved'
|
||||
pongPacket
|
||||
findnodePacket
|
||||
neighborsPacket
|
||||
)
|
||||
|
||||
// RPC request structures
|
||||
type (
|
||||
ping struct {
|
||||
IP string // our IP
|
||||
Port uint16 // our port
|
||||
Expiration uint64
|
||||
}
|
||||
|
||||
// reply to Ping
|
||||
pong struct {
|
||||
ReplyTok []byte
|
||||
Expiration uint64
|
||||
}
|
||||
|
||||
findnode struct {
|
||||
// Id to look up. The responding node will send back nodes
|
||||
// closest to the target.
|
||||
Target NodeID
|
||||
Expiration uint64
|
||||
}
|
||||
|
||||
// reply to findnode
|
||||
neighbors struct {
|
||||
Nodes []*Node
|
||||
Expiration uint64
|
||||
}
|
||||
)
|
||||
|
||||
type rpcNode struct {
|
||||
IP string
|
||||
Port uint16
|
||||
ID NodeID
|
||||
}
|
||||
|
||||
// udp implements the RPC protocol.
|
||||
type udp struct {
|
||||
conn *net.UDPConn
|
||||
priv *ecdsa.PrivateKey
|
||||
addpending chan *pending
|
||||
replies chan reply
|
||||
closing chan struct{}
|
||||
nat nat.Interface
|
||||
|
||||
*Table
|
||||
}
|
||||
|
||||
// pending represents a pending reply.
|
||||
//
|
||||
// some implementations of the protocol wish to send more than one
|
||||
// reply packet to findnode. in general, any neighbors packet cannot
|
||||
// be matched up with a specific findnode packet.
|
||||
//
|
||||
// our implementation handles this by storing a callback function for
|
||||
// each pending reply. incoming packets from a node are dispatched
|
||||
// to all the callback functions for that node.
|
||||
type pending struct {
|
||||
// these fields must match in the reply.
|
||||
from NodeID
|
||||
ptype byte
|
||||
|
||||
// time when the request must complete
|
||||
deadline time.Time
|
||||
|
||||
// callback is called when a matching reply arrives. if it returns
|
||||
// true, the callback is removed from the pending reply queue.
|
||||
// if it returns false, the reply is considered incomplete and
|
||||
// the callback will be invoked again for the next matching reply.
|
||||
callback func(resp interface{}) (done bool)
|
||||
|
||||
// errc receives nil when the callback indicates completion or an
|
||||
// error if no further reply is received within the timeout.
|
||||
errc chan<- error
|
||||
}
|
||||
|
||||
type reply struct {
|
||||
from NodeID
|
||||
ptype byte
|
||||
data interface{}
|
||||
}
|
||||
|
||||
// ListenUDP returns a new table that listens for UDP packets on laddr.
|
||||
func ListenUDP(priv *ecdsa.PrivateKey, laddr string, natm nat.Interface) (*Table, error) {
|
||||
addr, err := net.ResolveUDPAddr("udp", laddr)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
conn, err := net.ListenUDP("udp", addr)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
udp := &udp{
|
||||
conn: conn,
|
||||
priv: priv,
|
||||
closing: make(chan struct{}),
|
||||
addpending: make(chan *pending),
|
||||
replies: make(chan reply),
|
||||
}
|
||||
|
||||
realaddr := conn.LocalAddr().(*net.UDPAddr)
|
||||
if natm != nil {
|
||||
if !realaddr.IP.IsLoopback() {
|
||||
go nat.Map(natm, udp.closing, "udp", realaddr.Port, realaddr.Port, "ethereum discovery")
|
||||
}
|
||||
// TODO: react to external IP changes over time.
|
||||
if ext, err := natm.ExternalIP(); err == nil {
|
||||
realaddr = &net.UDPAddr{IP: ext, Port: realaddr.Port}
|
||||
}
|
||||
}
|
||||
udp.Table = newTable(udp, PubkeyID(&priv.PublicKey), realaddr)
|
||||
|
||||
go udp.loop()
|
||||
go udp.readLoop()
|
||||
log.Infoln("Listening, ", udp.self)
|
||||
return udp.Table, nil
|
||||
}
|
||||
|
||||
func (t *udp) close() {
|
||||
close(t.closing)
|
||||
t.conn.Close()
|
||||
// TODO: wait for the loops to end.
|
||||
}
|
||||
|
||||
// ping sends a ping message to the given node and waits for a reply.
|
||||
func (t *udp) ping(e *Node) error {
|
||||
// TODO: maybe check for ReplyTo field in callback to measure RTT
|
||||
errc := t.pending(e.ID, pongPacket, func(interface{}) bool { return true })
|
||||
t.send(e, pingPacket, ping{
|
||||
IP: t.self.IP.String(),
|
||||
Port: uint16(t.self.TCPPort),
|
||||
Expiration: uint64(time.Now().Add(expiration).Unix()),
|
||||
})
|
||||
return <-errc
|
||||
}
|
||||
|
||||
// findnode sends a findnode request to the given node and waits until
|
||||
// the node has sent up to k neighbors.
|
||||
func (t *udp) findnode(to *Node, target NodeID) ([]*Node, error) {
|
||||
nodes := make([]*Node, 0, bucketSize)
|
||||
nreceived := 0
|
||||
errc := t.pending(to.ID, neighborsPacket, func(r interface{}) bool {
|
||||
reply := r.(*neighbors)
|
||||
for _, n := range reply.Nodes {
|
||||
nreceived++
|
||||
if n.isValid() {
|
||||
nodes = append(nodes, n)
|
||||
}
|
||||
}
|
||||
return nreceived >= bucketSize
|
||||
})
|
||||
|
||||
t.send(to, findnodePacket, findnode{
|
||||
Target: target,
|
||||
Expiration: uint64(time.Now().Add(expiration).Unix()),
|
||||
})
|
||||
err := <-errc
|
||||
return nodes, err
|
||||
}
|
||||
|
||||
// pending adds a reply callback to the pending reply queue.
|
||||
// see the documentation of type pending for a detailed explanation.
|
||||
func (t *udp) pending(id NodeID, ptype byte, callback func(interface{}) bool) <-chan error {
|
||||
ch := make(chan error, 1)
|
||||
p := &pending{from: id, ptype: ptype, callback: callback, errc: ch}
|
||||
select {
|
||||
case t.addpending <- p:
|
||||
// loop will handle it
|
||||
case <-t.closing:
|
||||
ch <- errClosed
|
||||
}
|
||||
return ch
|
||||
}
|
||||
|
||||
// loop runs in its own goroutin. it keeps track of
|
||||
// the refresh timer and the pending reply queue.
|
||||
func (t *udp) loop() {
|
||||
var (
|
||||
pending []*pending
|
||||
nextDeadline time.Time
|
||||
timeout = time.NewTimer(0)
|
||||
refresh = time.NewTicker(refreshInterval)
|
||||
)
|
||||
<-timeout.C // ignore first timeout
|
||||
defer refresh.Stop()
|
||||
defer timeout.Stop()
|
||||
|
||||
rearmTimeout := func() {
|
||||
if len(pending) == 0 || nextDeadline == pending[0].deadline {
|
||||
return
|
||||
}
|
||||
nextDeadline = pending[0].deadline
|
||||
timeout.Reset(nextDeadline.Sub(time.Now()))
|
||||
}
|
||||
|
||||
for {
|
||||
select {
|
||||
case <-refresh.C:
|
||||
go t.refresh()
|
||||
|
||||
case <-t.closing:
|
||||
for _, p := range pending {
|
||||
p.errc <- errClosed
|
||||
}
|
||||
return
|
||||
|
||||
case p := <-t.addpending:
|
||||
p.deadline = time.Now().Add(respTimeout)
|
||||
pending = append(pending, p)
|
||||
rearmTimeout()
|
||||
|
||||
case reply := <-t.replies:
|
||||
// run matching callbacks, remove if they return false.
|
||||
for i, p := range pending {
|
||||
if reply.from == p.from && reply.ptype == p.ptype && p.callback(reply.data) {
|
||||
p.errc <- nil
|
||||
copy(pending[i:], pending[i+1:])
|
||||
pending = pending[:len(pending)-1]
|
||||
i--
|
||||
}
|
||||
}
|
||||
rearmTimeout()
|
||||
|
||||
case now := <-timeout.C:
|
||||
// notify and remove callbacks whose deadline is in the past.
|
||||
i := 0
|
||||
for ; i < len(pending) && now.After(pending[i].deadline); i++ {
|
||||
pending[i].errc <- errTimeout
|
||||
}
|
||||
if i > 0 {
|
||||
copy(pending, pending[i:])
|
||||
pending = pending[:len(pending)-i]
|
||||
}
|
||||
rearmTimeout()
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
const (
|
||||
macSize = 256 / 8
|
||||
sigSize = 520 / 8
|
||||
headSize = macSize + sigSize // space of packet frame data
|
||||
)
|
||||
|
||||
var headSpace = make([]byte, headSize)
|
||||
|
||||
func (t *udp) send(to *Node, ptype byte, req interface{}) error {
|
||||
b := new(bytes.Buffer)
|
||||
b.Write(headSpace)
|
||||
b.WriteByte(ptype)
|
||||
if err := rlp.Encode(b, req); err != nil {
|
||||
log.Errorln("error encoding packet:", err)
|
||||
return err
|
||||
}
|
||||
|
||||
packet := b.Bytes()
|
||||
sig, err := crypto.Sign(crypto.Sha3(packet[headSize:]), t.priv)
|
||||
if err != nil {
|
||||
log.Errorln("could not sign packet:", err)
|
||||
return err
|
||||
}
|
||||
copy(packet[macSize:], sig)
|
||||
// add the hash to the front. Note: this doesn't protect the
|
||||
// packet in any way. Our public key will be part of this hash in
|
||||
// the future.
|
||||
copy(packet, crypto.Sha3(packet[macSize:]))
|
||||
|
||||
toaddr := &net.UDPAddr{IP: to.IP, Port: to.DiscPort}
|
||||
log.DebugDetailf(">>> %v %T %v\n", toaddr, req, req)
|
||||
if _, err = t.conn.WriteToUDP(packet, toaddr); err != nil {
|
||||
log.DebugDetailln("UDP send failed:", err)
|
||||
}
|
||||
return err
|
||||
}
|
||||
|
||||
// readLoop runs in its own goroutine. it handles incoming UDP packets.
|
||||
func (t *udp) readLoop() {
|
||||
defer t.conn.Close()
|
||||
buf := make([]byte, 4096) // TODO: good buffer size
|
||||
for {
|
||||
nbytes, from, err := t.conn.ReadFromUDP(buf)
|
||||
if err != nil {
|
||||
return
|
||||
}
|
||||
if err := t.packetIn(from, buf[:nbytes]); err != nil {
|
||||
log.Debugf("Bad packet from %v: %v\n", from, err)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func (t *udp) packetIn(from *net.UDPAddr, buf []byte) error {
|
||||
if len(buf) < headSize+1 {
|
||||
return errPacketTooSmall
|
||||
}
|
||||
hash, sig, sigdata := buf[:macSize], buf[macSize:headSize], buf[headSize:]
|
||||
shouldhash := crypto.Sha3(buf[macSize:])
|
||||
if !bytes.Equal(hash, shouldhash) {
|
||||
return errBadHash
|
||||
}
|
||||
fromID, err := recoverNodeID(crypto.Sha3(buf[headSize:]), sig)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
var req interface {
|
||||
handle(t *udp, from *net.UDPAddr, fromID NodeID, mac []byte) error
|
||||
}
|
||||
switch ptype := sigdata[0]; ptype {
|
||||
case pingPacket:
|
||||
req = new(ping)
|
||||
case pongPacket:
|
||||
req = new(pong)
|
||||
case findnodePacket:
|
||||
req = new(findnode)
|
||||
case neighborsPacket:
|
||||
req = new(neighbors)
|
||||
default:
|
||||
return fmt.Errorf("unknown type: %d", ptype)
|
||||
}
|
||||
if err := rlp.Decode(bytes.NewReader(sigdata[1:]), req); err != nil {
|
||||
return err
|
||||
}
|
||||
log.DebugDetailf("<<< %v %T %v\n", from, req, req)
|
||||
return req.handle(t, from, fromID, hash)
|
||||
}
|
||||
|
||||
func (req *ping) handle(t *udp, from *net.UDPAddr, fromID NodeID, mac []byte) error {
|
||||
if expired(req.Expiration) {
|
||||
return errExpired
|
||||
}
|
||||
t.mutex.Lock()
|
||||
// Note: we're ignoring the provided IP address right now
|
||||
n := t.bumpOrAdd(fromID, from)
|
||||
if req.Port != 0 {
|
||||
n.TCPPort = int(req.Port)
|
||||
}
|
||||
t.mutex.Unlock()
|
||||
|
||||
t.send(n, pongPacket, pong{
|
||||
ReplyTok: mac,
|
||||
Expiration: uint64(time.Now().Add(expiration).Unix()),
|
||||
})
|
||||
return nil
|
||||
}
|
||||
|
||||
func (req *pong) handle(t *udp, from *net.UDPAddr, fromID NodeID, mac []byte) error {
|
||||
if expired(req.Expiration) {
|
||||
return errExpired
|
||||
}
|
||||
t.mutex.Lock()
|
||||
t.bump(fromID)
|
||||
t.mutex.Unlock()
|
||||
|
||||
t.replies <- reply{fromID, pongPacket, req}
|
||||
return nil
|
||||
}
|
||||
|
||||
func (req *findnode) handle(t *udp, from *net.UDPAddr, fromID NodeID, mac []byte) error {
|
||||
if expired(req.Expiration) {
|
||||
return errExpired
|
||||
}
|
||||
t.mutex.Lock()
|
||||
e := t.bumpOrAdd(fromID, from)
|
||||
closest := t.closest(req.Target, bucketSize).entries
|
||||
t.mutex.Unlock()
|
||||
|
||||
t.send(e, neighborsPacket, neighbors{
|
||||
Nodes: closest,
|
||||
Expiration: uint64(time.Now().Add(expiration).Unix()),
|
||||
})
|
||||
return nil
|
||||
}
|
||||
|
||||
func (req *neighbors) handle(t *udp, from *net.UDPAddr, fromID NodeID, mac []byte) error {
|
||||
if expired(req.Expiration) {
|
||||
return errExpired
|
||||
}
|
||||
t.mutex.Lock()
|
||||
t.bump(fromID)
|
||||
t.add(req.Nodes)
|
||||
t.mutex.Unlock()
|
||||
|
||||
t.replies <- reply{fromID, neighborsPacket, req}
|
||||
return nil
|
||||
}
|
||||
|
||||
func expired(ts uint64) bool {
|
||||
return time.Unix(int64(ts), 0).Before(time.Now())
|
||||
}
|
|
@ -0,0 +1,211 @@
|
|||
package discover
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
logpkg "log"
|
||||
"net"
|
||||
"os"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"github.com/ethereum/go-ethereum/logger"
|
||||
)
|
||||
|
||||
func init() {
|
||||
logger.AddLogSystem(logger.NewStdLogSystem(os.Stdout, logpkg.LstdFlags, logger.ErrorLevel))
|
||||
}
|
||||
|
||||
func TestUDP_ping(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
n1, _ := ListenUDP(newkey(), "127.0.0.1:0", nil)
|
||||
n2, _ := ListenUDP(newkey(), "127.0.0.1:0", nil)
|
||||
defer n1.Close()
|
||||
defer n2.Close()
|
||||
|
||||
if err := n1.net.ping(n2.self); err != nil {
|
||||
t.Fatalf("ping error: %v", err)
|
||||
}
|
||||
if find(n2, n1.self.ID) == nil {
|
||||
t.Errorf("node 2 does not contain id of node 1")
|
||||
}
|
||||
if e := find(n1, n2.self.ID); e != nil {
|
||||
t.Errorf("node 1 does contains id of node 2: %v", e)
|
||||
}
|
||||
}
|
||||
|
||||
func find(tab *Table, id NodeID) *Node {
|
||||
for _, b := range tab.buckets {
|
||||
for _, e := range b.entries {
|
||||
if e.ID == id {
|
||||
return e
|
||||
}
|
||||
}
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func TestUDP_findnode(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
n1, _ := ListenUDP(newkey(), "127.0.0.1:0", nil)
|
||||
n2, _ := ListenUDP(newkey(), "127.0.0.1:0", nil)
|
||||
defer n1.Close()
|
||||
defer n2.Close()
|
||||
|
||||
// put a few nodes into n2. the exact distribution shouldn't
|
||||
// matter much, altough we need to take care not to overflow
|
||||
// any bucket.
|
||||
target := randomID(n1.self.ID, 100)
|
||||
nodes := &nodesByDistance{target: target}
|
||||
for i := 0; i < bucketSize; i++ {
|
||||
n2.add([]*Node{&Node{
|
||||
IP: net.IP{1, 2, 3, byte(i)},
|
||||
DiscPort: i + 2,
|
||||
TCPPort: i + 2,
|
||||
ID: randomID(n2.self.ID, i+2),
|
||||
}})
|
||||
}
|
||||
n2.add(nodes.entries)
|
||||
n2.bumpOrAdd(n1.self.ID, &net.UDPAddr{IP: n1.self.IP, Port: n1.self.DiscPort})
|
||||
expected := n2.closest(target, bucketSize)
|
||||
|
||||
err := runUDP(10, func() error {
|
||||
result, _ := n1.net.findnode(n2.self, target)
|
||||
if len(result) != bucketSize {
|
||||
return fmt.Errorf("wrong number of results: got %d, want %d", len(result), bucketSize)
|
||||
}
|
||||
for i := range result {
|
||||
if result[i].ID != expected.entries[i].ID {
|
||||
return fmt.Errorf("result mismatch at %d:\n got: %v\n want: %v", i, result[i], expected.entries[i])
|
||||
}
|
||||
}
|
||||
return nil
|
||||
})
|
||||
if err != nil {
|
||||
t.Error(err)
|
||||
}
|
||||
}
|
||||
|
||||
func TestUDP_replytimeout(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
// reserve a port so we don't talk to an existing service by accident
|
||||
addr, _ := net.ResolveUDPAddr("udp", "127.0.0.1:0")
|
||||
fd, err := net.ListenUDP("udp", addr)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
defer fd.Close()
|
||||
|
||||
n1, _ := ListenUDP(newkey(), "127.0.0.1:0", nil)
|
||||
defer n1.Close()
|
||||
n2 := n1.bumpOrAdd(randomID(n1.self.ID, 10), fd.LocalAddr().(*net.UDPAddr))
|
||||
|
||||
if err := n1.net.ping(n2); err != errTimeout {
|
||||
t.Error("expected timeout error, got", err)
|
||||
}
|
||||
|
||||
if result, err := n1.net.findnode(n2, n1.self.ID); err != errTimeout {
|
||||
t.Error("expected timeout error, got", err)
|
||||
} else if len(result) > 0 {
|
||||
t.Error("expected empty result, got", result)
|
||||
}
|
||||
}
|
||||
|
||||
func TestUDP_findnodeMultiReply(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
n1, _ := ListenUDP(newkey(), "127.0.0.1:0", nil)
|
||||
n2, _ := ListenUDP(newkey(), "127.0.0.1:0", nil)
|
||||
udp2 := n2.net.(*udp)
|
||||
defer n1.Close()
|
||||
defer n2.Close()
|
||||
|
||||
err := runUDP(10, func() error {
|
||||
nodes := make([]*Node, bucketSize)
|
||||
for i := range nodes {
|
||||
nodes[i] = &Node{
|
||||
IP: net.IP{1, 2, 3, 4},
|
||||
DiscPort: i + 1,
|
||||
TCPPort: i + 1,
|
||||
ID: randomID(n2.self.ID, i+1),
|
||||
}
|
||||
}
|
||||
|
||||
// ask N2 for neighbors. it will send an empty reply back.
|
||||
// the request will wait for up to bucketSize replies.
|
||||
resultc := make(chan []*Node)
|
||||
errc := make(chan error)
|
||||
go func() {
|
||||
ns, err := n1.net.findnode(n2.self, n1.self.ID)
|
||||
if err != nil {
|
||||
errc <- err
|
||||
} else {
|
||||
resultc <- ns
|
||||
}
|
||||
}()
|
||||
|
||||
// send a few more neighbors packets to N1.
|
||||
// it should collect those.
|
||||
for end := 0; end < len(nodes); {
|
||||
off := end
|
||||
if end = end + 5; end > len(nodes) {
|
||||
end = len(nodes)
|
||||
}
|
||||
udp2.send(n1.self, neighborsPacket, neighbors{
|
||||
Nodes: nodes[off:end],
|
||||
Expiration: uint64(time.Now().Add(10 * time.Second).Unix()),
|
||||
})
|
||||
}
|
||||
|
||||
// check that they are all returned. we cannot just check for
|
||||
// equality because they might not be returned in the order they
|
||||
// were sent.
|
||||
var result []*Node
|
||||
select {
|
||||
case result = <-resultc:
|
||||
case err := <-errc:
|
||||
return err
|
||||
}
|
||||
if hasDuplicates(result) {
|
||||
return fmt.Errorf("result slice contains duplicates")
|
||||
}
|
||||
if len(result) != len(nodes) {
|
||||
return fmt.Errorf("wrong number of nodes returned: got %d, want %d", len(result), len(nodes))
|
||||
}
|
||||
matched := make(map[NodeID]bool)
|
||||
for _, n := range result {
|
||||
for _, expn := range nodes {
|
||||
if n.ID == expn.ID { // && bytes.Equal(n.Addr.IP, expn.Addr.IP) && n.Addr.Port == expn.Addr.Port {
|
||||
matched[n.ID] = true
|
||||
}
|
||||
}
|
||||
}
|
||||
if len(matched) != len(nodes) {
|
||||
return fmt.Errorf("wrong number of matching nodes: got %d, want %d", len(matched), len(nodes))
|
||||
}
|
||||
return nil
|
||||
})
|
||||
if err != nil {
|
||||
t.Error(err)
|
||||
}
|
||||
}
|
||||
|
||||
// runUDP runs a test n times and returns an error if the test failed
|
||||
// in all n runs. This is necessary because UDP is unreliable even for
|
||||
// connections on the local machine, causing test failures.
|
||||
func runUDP(n int, test func() error) error {
|
||||
errcount := 0
|
||||
errors := ""
|
||||
for i := 0; i < n; i++ {
|
||||
if err := test(); err != nil {
|
||||
errors += fmt.Sprintf("\n#%d: %v", i, err)
|
||||
errcount++
|
||||
}
|
||||
}
|
||||
if errcount == n {
|
||||
return fmt.Errorf("failed on all %d iterations:%s", n, errors)
|
||||
}
|
||||
return nil
|
||||
}
|
143
p2p/message.go
143
p2p/message.go
|
@ -1,6 +1,7 @@
|
|||
package p2p
|
||||
|
||||
import (
|
||||
"bufio"
|
||||
"bytes"
|
||||
"encoding/binary"
|
||||
"errors"
|
||||
|
@ -8,12 +9,37 @@ import (
|
|||
"io"
|
||||
"io/ioutil"
|
||||
"math/big"
|
||||
"net"
|
||||
"sync"
|
||||
"sync/atomic"
|
||||
"time"
|
||||
|
||||
"github.com/ethereum/go-ethereum/ethutil"
|
||||
"github.com/ethereum/go-ethereum/rlp"
|
||||
)
|
||||
|
||||
// parameters for frameRW
|
||||
const (
|
||||
// maximum time allowed for reading a message header.
|
||||
// this is effectively the amount of time a connection can be idle.
|
||||
frameReadTimeout = 1 * time.Minute
|
||||
|
||||
// maximum time allowed for reading the payload data of a message.
|
||||
// this is shorter than (and distinct from) frameReadTimeout because
|
||||
// the connection is not considered idle while a message is transferred.
|
||||
// this also limits the payload size of messages to how much the connection
|
||||
// can transfer within the timeout.
|
||||
payloadReadTimeout = 5 * time.Second
|
||||
|
||||
// maximum amount of time allowed for writing a complete message.
|
||||
msgWriteTimeout = 5 * time.Second
|
||||
|
||||
// messages smaller than this many bytes will be read at
|
||||
// once before passing them to a protocol. this increases
|
||||
// concurrency in the processing.
|
||||
wholePayloadSize = 64 * 1024
|
||||
)
|
||||
|
||||
// Msg defines the structure of a p2p message.
|
||||
//
|
||||
// Note that a Msg can only be sent once since the Payload reader is
|
||||
|
@ -74,11 +100,14 @@ type MsgWriter interface {
|
|||
// WriteMsg sends a message. It will block until the message's
|
||||
// Payload has been consumed by the other end.
|
||||
//
|
||||
// Note that messages can be sent only once.
|
||||
// Note that messages can be sent only once because their
|
||||
// payload reader is drained.
|
||||
WriteMsg(Msg) error
|
||||
}
|
||||
|
||||
// MsgReadWriter provides reading and writing of encoded messages.
|
||||
// Implementations should ensure that ReadMsg and WriteMsg can be
|
||||
// called simultaneously from multiple goroutines.
|
||||
type MsgReadWriter interface {
|
||||
MsgReader
|
||||
MsgWriter
|
||||
|
@ -90,8 +119,45 @@ func EncodeMsg(w MsgWriter, code uint64, data ...interface{}) error {
|
|||
return w.WriteMsg(NewMsg(code, data...))
|
||||
}
|
||||
|
||||
// frameRW is a MsgReadWriter that reads and writes devp2p message frames.
|
||||
// As required by the interface, ReadMsg and WriteMsg can be called from
|
||||
// multiple goroutines.
|
||||
type frameRW struct {
|
||||
net.Conn // make Conn methods available. be careful.
|
||||
bufconn *bufio.ReadWriter
|
||||
|
||||
// this channel is used to 'lend' bufconn to a caller of ReadMsg
|
||||
// until the message payload has been consumed. the channel
|
||||
// receives a value when EOF is reached on the payload, unblocking
|
||||
// a pending call to ReadMsg.
|
||||
rsync chan struct{}
|
||||
|
||||
// this mutex guards writes to bufconn.
|
||||
writeMu sync.Mutex
|
||||
}
|
||||
|
||||
func newFrameRW(conn net.Conn, timeout time.Duration) *frameRW {
|
||||
rsync := make(chan struct{}, 1)
|
||||
rsync <- struct{}{}
|
||||
return &frameRW{
|
||||
Conn: conn,
|
||||
bufconn: bufio.NewReadWriter(bufio.NewReader(conn), bufio.NewWriter(conn)),
|
||||
rsync: rsync,
|
||||
}
|
||||
}
|
||||
|
||||
var magicToken = []byte{34, 64, 8, 145}
|
||||
|
||||
func (rw *frameRW) WriteMsg(msg Msg) error {
|
||||
rw.writeMu.Lock()
|
||||
defer rw.writeMu.Unlock()
|
||||
rw.SetWriteDeadline(time.Now().Add(msgWriteTimeout))
|
||||
if err := writeMsg(rw.bufconn, msg); err != nil {
|
||||
return err
|
||||
}
|
||||
return rw.bufconn.Flush()
|
||||
}
|
||||
|
||||
func writeMsg(w io.Writer, msg Msg) error {
|
||||
// TODO: handle case when Size + len(code) + len(listhdr) overflows uint32
|
||||
code := ethutil.Encode(uint32(msg.Code))
|
||||
|
@ -120,31 +186,51 @@ func makeListHeader(length uint32) []byte {
|
|||
return append([]byte{lenb}, enc...)
|
||||
}
|
||||
|
||||
// readMsg reads a message header from r.
|
||||
// It takes an rlp.ByteReader to ensure that the decoding doesn't buffer.
|
||||
func readMsg(r rlp.ByteReader) (msg Msg, err error) {
|
||||
func (rw *frameRW) ReadMsg() (msg Msg, err error) {
|
||||
<-rw.rsync // wait until bufconn is ours
|
||||
|
||||
rw.SetReadDeadline(time.Now().Add(frameReadTimeout))
|
||||
|
||||
// read magic and payload size
|
||||
start := make([]byte, 8)
|
||||
if _, err = io.ReadFull(r, start); err != nil {
|
||||
return msg, newPeerError(errRead, "%v", err)
|
||||
if _, err = io.ReadFull(rw.bufconn, start); err != nil {
|
||||
return msg, err
|
||||
}
|
||||
if !bytes.HasPrefix(start, magicToken) {
|
||||
return msg, newPeerError(errMagicTokenMismatch, "got %x, want %x", start[:4], magicToken)
|
||||
return msg, fmt.Errorf("bad magic token %x", start[:4], magicToken)
|
||||
}
|
||||
size := binary.BigEndian.Uint32(start[4:])
|
||||
|
||||
// decode start of RLP message to get the message code
|
||||
posr := &postrack{r, 0}
|
||||
posr := &postrack{rw.bufconn, 0}
|
||||
s := rlp.NewStream(posr)
|
||||
if _, err := s.List(); err != nil {
|
||||
return msg, err
|
||||
}
|
||||
code, err := s.Uint()
|
||||
msg.Code, err = s.Uint()
|
||||
if err != nil {
|
||||
return msg, err
|
||||
}
|
||||
payloadsize := size - posr.p
|
||||
return Msg{code, payloadsize, io.LimitReader(r, int64(payloadsize))}, nil
|
||||
msg.Size = size - posr.p
|
||||
|
||||
rw.SetReadDeadline(time.Now().Add(payloadReadTimeout))
|
||||
|
||||
if msg.Size <= wholePayloadSize {
|
||||
// msg is small, read all of it and move on to the next message.
|
||||
pbuf := make([]byte, msg.Size)
|
||||
if _, err := io.ReadFull(rw.bufconn, pbuf); err != nil {
|
||||
return msg, err
|
||||
}
|
||||
rw.rsync <- struct{}{} // bufconn is available again
|
||||
msg.Payload = bytes.NewReader(pbuf)
|
||||
} else {
|
||||
// lend bufconn to the caller until it has
|
||||
// consumed the payload. eofSignal will send a value
|
||||
// on rw.rsync when EOF is reached.
|
||||
pr := &eofSignal{rw.bufconn, msg.Size, rw.rsync}
|
||||
msg.Payload = pr
|
||||
}
|
||||
return msg, nil
|
||||
}
|
||||
|
||||
// postrack wraps an rlp.ByteReader with a position counter.
|
||||
|
@ -167,6 +253,39 @@ func (r *postrack) ReadByte() (byte, error) {
|
|||
return b, err
|
||||
}
|
||||
|
||||
// eofSignal wraps a reader with eof signaling. the eof channel is
|
||||
// closed when the wrapped reader returns an error or when count bytes
|
||||
// have been read.
|
||||
type eofSignal struct {
|
||||
wrapped io.Reader
|
||||
count uint32 // number of bytes left
|
||||
eof chan<- struct{}
|
||||
}
|
||||
|
||||
// note: when using eofSignal to detect whether a message payload
|
||||
// has been read, Read might not be called for zero sized messages.
|
||||
func (r *eofSignal) Read(buf []byte) (int, error) {
|
||||
if r.count == 0 {
|
||||
if r.eof != nil {
|
||||
r.eof <- struct{}{}
|
||||
r.eof = nil
|
||||
}
|
||||
return 0, io.EOF
|
||||
}
|
||||
|
||||
max := len(buf)
|
||||
if int(r.count) < len(buf) {
|
||||
max = int(r.count)
|
||||
}
|
||||
n, err := r.wrapped.Read(buf[:max])
|
||||
r.count -= uint32(n)
|
||||
if (err != nil || r.count == 0) && r.eof != nil {
|
||||
r.eof <- struct{}{} // tell Peer that msg has been consumed
|
||||
r.eof = nil
|
||||
}
|
||||
return n, err
|
||||
}
|
||||
|
||||
// MsgPipe creates a message pipe. Reads on one end are matched
|
||||
// with writes on the other. The pipe is full-duplex, both ends
|
||||
// implement MsgReadWriter.
|
||||
|
@ -198,7 +317,7 @@ type MsgPipeRW struct {
|
|||
func (p *MsgPipeRW) WriteMsg(msg Msg) error {
|
||||
if atomic.LoadInt32(p.closed) == 0 {
|
||||
consumed := make(chan struct{}, 1)
|
||||
msg.Payload = &eofSignal{msg.Payload, int64(msg.Size), consumed}
|
||||
msg.Payload = &eofSignal{msg.Payload, msg.Size, consumed}
|
||||
select {
|
||||
case p.w <- msg:
|
||||
if msg.Size > 0 {
|
||||
|
|
|
@ -3,12 +3,11 @@ package p2p
|
|||
import (
|
||||
"bytes"
|
||||
"fmt"
|
||||
"io"
|
||||
"io/ioutil"
|
||||
"runtime"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"github.com/ethereum/go-ethereum/ethutil"
|
||||
)
|
||||
|
||||
func TestNewMsg(t *testing.T) {
|
||||
|
@ -26,51 +25,51 @@ func TestNewMsg(t *testing.T) {
|
|||
}
|
||||
}
|
||||
|
||||
func TestEncodeDecodeMsg(t *testing.T) {
|
||||
msg := NewMsg(3, 1, "000")
|
||||
buf := new(bytes.Buffer)
|
||||
if err := writeMsg(buf, msg); err != nil {
|
||||
t.Fatalf("encodeMsg error: %v", err)
|
||||
}
|
||||
// t.Logf("encoded: %x", buf.Bytes())
|
||||
// func TestEncodeDecodeMsg(t *testing.T) {
|
||||
// msg := NewMsg(3, 1, "000")
|
||||
// buf := new(bytes.Buffer)
|
||||
// if err := writeMsg(buf, msg); err != nil {
|
||||
// t.Fatalf("encodeMsg error: %v", err)
|
||||
// }
|
||||
// // t.Logf("encoded: %x", buf.Bytes())
|
||||
|
||||
decmsg, err := readMsg(buf)
|
||||
if err != nil {
|
||||
t.Fatalf("readMsg error: %v", err)
|
||||
}
|
||||
if decmsg.Code != 3 {
|
||||
t.Errorf("incorrect code %d, want %d", decmsg.Code, 3)
|
||||
}
|
||||
if decmsg.Size != 5 {
|
||||
t.Errorf("incorrect size %d, want %d", decmsg.Size, 5)
|
||||
}
|
||||
// decmsg, err := readMsg(buf)
|
||||
// if err != nil {
|
||||
// t.Fatalf("readMsg error: %v", err)
|
||||
// }
|
||||
// if decmsg.Code != 3 {
|
||||
// t.Errorf("incorrect code %d, want %d", decmsg.Code, 3)
|
||||
// }
|
||||
// if decmsg.Size != 5 {
|
||||
// t.Errorf("incorrect size %d, want %d", decmsg.Size, 5)
|
||||
// }
|
||||
|
||||
var data struct {
|
||||
I uint
|
||||
S string
|
||||
}
|
||||
if err := decmsg.Decode(&data); err != nil {
|
||||
t.Fatalf("Decode error: %v", err)
|
||||
}
|
||||
if data.I != 1 {
|
||||
t.Errorf("incorrect data.I: got %v, expected %d", data.I, 1)
|
||||
}
|
||||
if data.S != "000" {
|
||||
t.Errorf("incorrect data.S: got %q, expected %q", data.S, "000")
|
||||
}
|
||||
}
|
||||
// var data struct {
|
||||
// I uint
|
||||
// S string
|
||||
// }
|
||||
// if err := decmsg.Decode(&data); err != nil {
|
||||
// t.Fatalf("Decode error: %v", err)
|
||||
// }
|
||||
// if data.I != 1 {
|
||||
// t.Errorf("incorrect data.I: got %v, expected %d", data.I, 1)
|
||||
// }
|
||||
// if data.S != "000" {
|
||||
// t.Errorf("incorrect data.S: got %q, expected %q", data.S, "000")
|
||||
// }
|
||||
// }
|
||||
|
||||
func TestDecodeRealMsg(t *testing.T) {
|
||||
data := ethutil.Hex2Bytes("2240089100000080f87e8002b5457468657265756d282b2b292f5065657220536572766572204f6e652f76302e372e382f52656c656173652f4c696e75782f672b2bc082765fb84086dd80b7aefd6a6d2e3b93f4f300a86bfb6ef7bdc97cb03f793db6bb")
|
||||
msg, err := readMsg(bytes.NewReader(data))
|
||||
if err != nil {
|
||||
t.Fatalf("unexpected error: %v", err)
|
||||
}
|
||||
// func TestDecodeRealMsg(t *testing.T) {
|
||||
// data := ethutil.Hex2Bytes("2240089100000080f87e8002b5457468657265756d282b2b292f5065657220536572766572204f6e652f76302e372e382f52656c656173652f4c696e75782f672b2bc082765fb84086dd80b7aefd6a6d2e3b93f4f300a86bfb6ef7bdc97cb03f793db6bb")
|
||||
// msg, err := readMsg(bytes.NewReader(data))
|
||||
// if err != nil {
|
||||
// t.Fatalf("unexpected error: %v", err)
|
||||
// }
|
||||
|
||||
if msg.Code != 0 {
|
||||
t.Errorf("incorrect code %d, want %d", msg.Code, 0)
|
||||
}
|
||||
}
|
||||
// if msg.Code != 0 {
|
||||
// t.Errorf("incorrect code %d, want %d", msg.Code, 0)
|
||||
// }
|
||||
// }
|
||||
|
||||
func ExampleMsgPipe() {
|
||||
rw1, rw2 := MsgPipe()
|
||||
|
@ -131,3 +130,58 @@ func TestMsgPipeConcurrentClose(t *testing.T) {
|
|||
go rw1.Close()
|
||||
}
|
||||
}
|
||||
|
||||
func TestEOFSignal(t *testing.T) {
|
||||
rb := make([]byte, 10)
|
||||
|
||||
// empty reader
|
||||
eof := make(chan struct{}, 1)
|
||||
sig := &eofSignal{new(bytes.Buffer), 0, eof}
|
||||
if n, err := sig.Read(rb); n != 0 || err != io.EOF {
|
||||
t.Errorf("Read returned unexpected values: (%v, %v)", n, err)
|
||||
}
|
||||
select {
|
||||
case <-eof:
|
||||
default:
|
||||
t.Error("EOF chan not signaled")
|
||||
}
|
||||
|
||||
// count before error
|
||||
eof = make(chan struct{}, 1)
|
||||
sig = &eofSignal{bytes.NewBufferString("aaaaaaaa"), 4, eof}
|
||||
if n, err := sig.Read(rb); n != 4 || err != nil {
|
||||
t.Errorf("Read returned unexpected values: (%v, %v)", n, err)
|
||||
}
|
||||
select {
|
||||
case <-eof:
|
||||
default:
|
||||
t.Error("EOF chan not signaled")
|
||||
}
|
||||
|
||||
// error before count
|
||||
eof = make(chan struct{}, 1)
|
||||
sig = &eofSignal{bytes.NewBufferString("aaaa"), 999, eof}
|
||||
if n, err := sig.Read(rb); n != 4 || err != nil {
|
||||
t.Errorf("Read returned unexpected values: (%v, %v)", n, err)
|
||||
}
|
||||
if n, err := sig.Read(rb); n != 0 || err != io.EOF {
|
||||
t.Errorf("Read returned unexpected values: (%v, %v)", n, err)
|
||||
}
|
||||
select {
|
||||
case <-eof:
|
||||
default:
|
||||
t.Error("EOF chan not signaled")
|
||||
}
|
||||
|
||||
// no signal if neither occurs
|
||||
eof = make(chan struct{}, 1)
|
||||
sig = &eofSignal{bytes.NewBufferString("aaaaaaaaaaaaaaaaaaaaa"), 999, eof}
|
||||
if n, err := sig.Read(rb); n != 10 || err != nil {
|
||||
t.Errorf("Read returned unexpected values: (%v, %v)", n, err)
|
||||
}
|
||||
select {
|
||||
case <-eof:
|
||||
t.Error("unexpected EOF signal")
|
||||
default:
|
||||
}
|
||||
}
|
||||
|
|
23
p2p/nat.go
23
p2p/nat.go
|
@ -1,23 +0,0 @@
|
|||
package p2p
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"net"
|
||||
)
|
||||
|
||||
func ParseNAT(natType string, gateway string) (nat NAT, err error) {
|
||||
switch natType {
|
||||
case "UPNP":
|
||||
nat = UPNP()
|
||||
case "PMP":
|
||||
ip := net.ParseIP(gateway)
|
||||
if ip == nil {
|
||||
return nil, fmt.Errorf("cannot resolve PMP gateway IP %s", gateway)
|
||||
}
|
||||
nat = PMP(ip)
|
||||
case "":
|
||||
default:
|
||||
return nil, fmt.Errorf("unrecognised NAT type '%s'", natType)
|
||||
}
|
||||
return
|
||||
}
|
|
@ -0,0 +1,235 @@
|
|||
// Package nat provides access to common port mapping protocols.
|
||||
package nat
|
||||
|
||||
import (
|
||||
"errors"
|
||||
"fmt"
|
||||
"net"
|
||||
"strings"
|
||||
"sync"
|
||||
"time"
|
||||
|
||||
"github.com/ethereum/go-ethereum/logger"
|
||||
"github.com/jackpal/go-nat-pmp"
|
||||
)
|
||||
|
||||
var log = logger.NewLogger("P2P NAT")
|
||||
|
||||
// An implementation of nat.Interface can map local ports to ports
|
||||
// accessible from the Internet.
|
||||
type Interface interface {
|
||||
// These methods manage a mapping between a port on the local
|
||||
// machine to a port that can be connected to from the internet.
|
||||
//
|
||||
// protocol is "UDP" or "TCP". Some implementations allow setting
|
||||
// a display name for the mapping. The mapping may be removed by
|
||||
// the gateway when its lifetime ends.
|
||||
AddMapping(protocol string, extport, intport int, name string, lifetime time.Duration) error
|
||||
DeleteMapping(protocol string, extport, intport int) error
|
||||
|
||||
// This method should return the external (Internet-facing)
|
||||
// address of the gateway device.
|
||||
ExternalIP() (net.IP, error)
|
||||
|
||||
// Should return name of the method. This is used for logging.
|
||||
String() string
|
||||
}
|
||||
|
||||
// Parse parses a NAT interface description.
|
||||
// The following formats are currently accepted.
|
||||
// Note that mechanism names are not case-sensitive.
|
||||
//
|
||||
// "" or "none" return nil
|
||||
// "extip:77.12.33.4" will assume the local machine is reachable on the given IP
|
||||
// "any" uses the first auto-detected mechanism
|
||||
// "upnp" uses the Universal Plug and Play protocol
|
||||
// "pmp" uses NAT-PMP with an auto-detected gateway address
|
||||
// "pmp:192.168.0.1" uses NAT-PMP with the given gateway address
|
||||
func Parse(spec string) (Interface, error) {
|
||||
var (
|
||||
parts = strings.SplitN(spec, ":", 2)
|
||||
mech = strings.ToLower(parts[0])
|
||||
ip net.IP
|
||||
)
|
||||
if len(parts) > 1 {
|
||||
ip = net.ParseIP(parts[1])
|
||||
if ip == nil {
|
||||
return nil, errors.New("invalid IP address")
|
||||
}
|
||||
}
|
||||
switch mech {
|
||||
case "", "none", "off":
|
||||
return nil, nil
|
||||
case "any", "auto", "on":
|
||||
return Any(), nil
|
||||
case "extip", "ip":
|
||||
if ip == nil {
|
||||
return nil, errors.New("missing IP address")
|
||||
}
|
||||
return ExtIP(ip), nil
|
||||
case "upnp":
|
||||
return UPnP(), nil
|
||||
case "pmp", "natpmp", "nat-pmp":
|
||||
return PMP(ip), nil
|
||||
default:
|
||||
return nil, fmt.Errorf("unknown mechanism %q", parts[0])
|
||||
}
|
||||
}
|
||||
|
||||
const (
|
||||
mapTimeout = 20 * time.Minute
|
||||
mapUpdateInterval = 15 * time.Minute
|
||||
)
|
||||
|
||||
// Map adds a port mapping on m and keeps it alive until c is closed.
|
||||
// This function is typically invoked in its own goroutine.
|
||||
func Map(m Interface, c chan struct{}, protocol string, extport, intport int, name string) {
|
||||
refresh := time.NewTimer(mapUpdateInterval)
|
||||
defer func() {
|
||||
refresh.Stop()
|
||||
log.Debugf("Deleting port mapping: %s %d -> %d (%s) using %s\n", protocol, extport, intport, name, m)
|
||||
m.DeleteMapping(protocol, extport, intport)
|
||||
}()
|
||||
log.Debugf("add mapping: %s %d -> %d (%s) using %s\n", protocol, extport, intport, name, m)
|
||||
if err := m.AddMapping(protocol, intport, extport, name, mapTimeout); err != nil {
|
||||
log.Errorf("mapping error: %v\n", err)
|
||||
}
|
||||
for {
|
||||
select {
|
||||
case _, ok := <-c:
|
||||
if !ok {
|
||||
return
|
||||
}
|
||||
case <-refresh.C:
|
||||
log.DebugDetailf("refresh mapping: %s %d -> %d (%s) using %s\n", protocol, extport, intport, name, m)
|
||||
if err := m.AddMapping(protocol, intport, extport, name, mapTimeout); err != nil {
|
||||
log.Errorf("mapping error: %v\n", err)
|
||||
}
|
||||
refresh.Reset(mapUpdateInterval)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// ExtIP assumes that the local machine is reachable on the given
|
||||
// external IP address, and that any required ports were mapped manually.
|
||||
// Mapping operations will not return an error but won't actually do anything.
|
||||
func ExtIP(ip net.IP) Interface {
|
||||
if ip == nil {
|
||||
panic("IP must not be nil")
|
||||
}
|
||||
return extIP(ip)
|
||||
}
|
||||
|
||||
type extIP net.IP
|
||||
|
||||
func (n extIP) ExternalIP() (net.IP, error) { return net.IP(n), nil }
|
||||
func (n extIP) String() string { return fmt.Sprintf("ExtIP(%v)", net.IP(n)) }
|
||||
|
||||
// These do nothing.
|
||||
func (extIP) AddMapping(string, int, int, string, time.Duration) error { return nil }
|
||||
func (extIP) DeleteMapping(string, int, int) error { return nil }
|
||||
|
||||
// Any returns a port mapper that tries to discover any supported
|
||||
// mechanism on the local network.
|
||||
func Any() Interface {
|
||||
// TODO: attempt to discover whether the local machine has an
|
||||
// Internet-class address. Return ExtIP in this case.
|
||||
return startautodisc("UPnP or NAT-PMP", func() Interface {
|
||||
found := make(chan Interface, 2)
|
||||
go func() { found <- discoverUPnP() }()
|
||||
go func() { found <- discoverPMP() }()
|
||||
for i := 0; i < cap(found); i++ {
|
||||
if c := <-found; c != nil {
|
||||
return c
|
||||
}
|
||||
}
|
||||
return nil
|
||||
})
|
||||
}
|
||||
|
||||
// UPnP returns a port mapper that uses UPnP. It will attempt to
|
||||
// discover the address of your router using UDP broadcasts.
|
||||
func UPnP() Interface {
|
||||
return startautodisc("UPnP", discoverUPnP)
|
||||
}
|
||||
|
||||
// PMP returns a port mapper that uses NAT-PMP. The provided gateway
|
||||
// address should be the IP of your router. If the given gateway
|
||||
// address is nil, PMP will attempt to auto-discover the router.
|
||||
func PMP(gateway net.IP) Interface {
|
||||
if gateway != nil {
|
||||
return &pmp{gw: gateway, c: natpmp.NewClient(gateway)}
|
||||
}
|
||||
return startautodisc("NAT-PMP", discoverPMP)
|
||||
}
|
||||
|
||||
// autodisc represents a port mapping mechanism that is still being
|
||||
// auto-discovered. Calls to the Interface methods on this type will
|
||||
// wait until the discovery is done and then call the method on the
|
||||
// discovered mechanism.
|
||||
//
|
||||
// This type is useful because discovery can take a while but we
|
||||
// want return an Interface value from UPnP, PMP and Auto immediately.
|
||||
type autodisc struct {
|
||||
what string
|
||||
done <-chan Interface
|
||||
|
||||
mu sync.Mutex
|
||||
found Interface
|
||||
}
|
||||
|
||||
func startautodisc(what string, doit func() Interface) Interface {
|
||||
// TODO: monitor network configuration and rerun doit when it changes.
|
||||
done := make(chan Interface)
|
||||
ad := &autodisc{what: what, done: done}
|
||||
go func() { done <- doit(); close(done) }()
|
||||
return ad
|
||||
}
|
||||
|
||||
func (n *autodisc) AddMapping(protocol string, extport, intport int, name string, lifetime time.Duration) error {
|
||||
if err := n.wait(); err != nil {
|
||||
return err
|
||||
}
|
||||
return n.found.AddMapping(protocol, extport, intport, name, lifetime)
|
||||
}
|
||||
|
||||
func (n *autodisc) DeleteMapping(protocol string, extport, intport int) error {
|
||||
if err := n.wait(); err != nil {
|
||||
return err
|
||||
}
|
||||
return n.found.DeleteMapping(protocol, extport, intport)
|
||||
}
|
||||
|
||||
func (n *autodisc) ExternalIP() (net.IP, error) {
|
||||
if err := n.wait(); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return n.found.ExternalIP()
|
||||
}
|
||||
|
||||
func (n *autodisc) String() string {
|
||||
n.mu.Lock()
|
||||
defer n.mu.Unlock()
|
||||
if n.found == nil {
|
||||
return n.what
|
||||
} else {
|
||||
return n.found.String()
|
||||
}
|
||||
}
|
||||
|
||||
func (n *autodisc) wait() error {
|
||||
n.mu.Lock()
|
||||
found := n.found
|
||||
n.mu.Unlock()
|
||||
if found != nil {
|
||||
// already discovered
|
||||
return nil
|
||||
}
|
||||
if found = <-n.done; found == nil {
|
||||
return errors.New("no devices discovered")
|
||||
}
|
||||
n.mu.Lock()
|
||||
n.found = found
|
||||
n.mu.Unlock()
|
||||
return nil
|
||||
}
|
|
@ -0,0 +1,115 @@
|
|||
package nat
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"net"
|
||||
"strings"
|
||||
"time"
|
||||
|
||||
"github.com/jackpal/go-nat-pmp"
|
||||
)
|
||||
|
||||
// natPMPClient adapts the NAT-PMP protocol implementation so it conforms to
|
||||
// the common interface.
|
||||
type pmp struct {
|
||||
gw net.IP
|
||||
c *natpmp.Client
|
||||
}
|
||||
|
||||
func (n *pmp) String() string {
|
||||
return fmt.Sprintf("NAT-PMP(%v)", n.gw)
|
||||
}
|
||||
|
||||
func (n *pmp) ExternalIP() (net.IP, error) {
|
||||
response, err := n.c.GetExternalAddress()
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return response.ExternalIPAddress[:], nil
|
||||
}
|
||||
|
||||
func (n *pmp) AddMapping(protocol string, extport, intport int, name string, lifetime time.Duration) error {
|
||||
if lifetime <= 0 {
|
||||
return fmt.Errorf("lifetime must not be <= 0")
|
||||
}
|
||||
// Note order of port arguments is switched between our
|
||||
// AddMapping and the client's AddPortMapping.
|
||||
_, err := n.c.AddPortMapping(strings.ToLower(protocol), intport, extport, int(lifetime/time.Second))
|
||||
return err
|
||||
}
|
||||
|
||||
func (n *pmp) DeleteMapping(protocol string, extport, intport int) (err error) {
|
||||
// To destroy a mapping, send an add-port with an internalPort of
|
||||
// the internal port to destroy, an external port of zero and a
|
||||
// time of zero.
|
||||
_, err = n.c.AddPortMapping(strings.ToLower(protocol), intport, 0, 0)
|
||||
return err
|
||||
}
|
||||
|
||||
func discoverPMP() Interface {
|
||||
// run external address lookups on all potential gateways
|
||||
gws := potentialGateways()
|
||||
found := make(chan *pmp, len(gws))
|
||||
for i := range gws {
|
||||
gw := gws[i]
|
||||
go func() {
|
||||
c := natpmp.NewClient(gw)
|
||||
if _, err := c.GetExternalAddress(); err != nil {
|
||||
found <- nil
|
||||
} else {
|
||||
found <- &pmp{gw, c}
|
||||
}
|
||||
}()
|
||||
}
|
||||
// return the one that responds first.
|
||||
// discovery needs to be quick, so we stop caring about
|
||||
// any responses after a very short timeout.
|
||||
timeout := time.NewTimer(1 * time.Second)
|
||||
defer timeout.Stop()
|
||||
for _ = range gws {
|
||||
select {
|
||||
case c := <-found:
|
||||
if c != nil {
|
||||
return c
|
||||
}
|
||||
case <-timeout.C:
|
||||
return nil
|
||||
}
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
var (
|
||||
// LAN IP ranges
|
||||
_, lan10, _ = net.ParseCIDR("10.0.0.0/8")
|
||||
_, lan176, _ = net.ParseCIDR("172.16.0.0/12")
|
||||
_, lan192, _ = net.ParseCIDR("192.168.0.0/16")
|
||||
)
|
||||
|
||||
// TODO: improve this. We currently assume that (on most networks)
|
||||
// the router is X.X.X.1 in a local LAN range.
|
||||
func potentialGateways() (gws []net.IP) {
|
||||
ifaces, err := net.Interfaces()
|
||||
if err != nil {
|
||||
return nil
|
||||
}
|
||||
for _, iface := range ifaces {
|
||||
ifaddrs, err := iface.Addrs()
|
||||
if err != nil {
|
||||
return gws
|
||||
}
|
||||
for _, addr := range ifaddrs {
|
||||
switch x := addr.(type) {
|
||||
case *net.IPNet:
|
||||
if lan10.Contains(x.IP) || lan176.Contains(x.IP) || lan192.Contains(x.IP) {
|
||||
ip := x.IP.Mask(x.Mask).To4()
|
||||
if ip != nil {
|
||||
ip[3] = ip[3] | 0x01
|
||||
gws = append(gws, ip)
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
return gws
|
||||
}
|
|
@ -0,0 +1,149 @@
|
|||
package nat
|
||||
|
||||
import (
|
||||
"errors"
|
||||
"fmt"
|
||||
"net"
|
||||
"strings"
|
||||
"time"
|
||||
|
||||
"github.com/fjl/goupnp"
|
||||
"github.com/fjl/goupnp/dcps/internetgateway1"
|
||||
"github.com/fjl/goupnp/dcps/internetgateway2"
|
||||
)
|
||||
|
||||
type upnp struct {
|
||||
dev *goupnp.RootDevice
|
||||
service string
|
||||
client upnpClient
|
||||
}
|
||||
|
||||
type upnpClient interface {
|
||||
GetExternalIPAddress() (string, error)
|
||||
AddPortMapping(string, uint16, string, uint16, string, bool, string, uint32) error
|
||||
DeletePortMapping(string, uint16, string) error
|
||||
GetNATRSIPStatus() (sip bool, nat bool, err error)
|
||||
}
|
||||
|
||||
func (n *upnp) ExternalIP() (addr net.IP, err error) {
|
||||
ipString, err := n.client.GetExternalIPAddress()
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
ip := net.ParseIP(ipString)
|
||||
if ip == nil {
|
||||
return nil, errors.New("bad IP in response")
|
||||
}
|
||||
return ip, nil
|
||||
}
|
||||
|
||||
func (n *upnp) AddMapping(protocol string, extport, intport int, desc string, lifetime time.Duration) error {
|
||||
ip, err := n.internalAddress()
|
||||
if err != nil {
|
||||
return nil
|
||||
}
|
||||
protocol = strings.ToUpper(protocol)
|
||||
lifetimeS := uint32(lifetime / time.Second)
|
||||
return n.client.AddPortMapping("", uint16(extport), protocol, uint16(intport), ip.String(), true, desc, lifetimeS)
|
||||
}
|
||||
|
||||
func (n *upnp) internalAddress() (net.IP, error) {
|
||||
devaddr, err := net.ResolveUDPAddr("udp4", n.dev.URLBase.Host)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
ifaces, err := net.Interfaces()
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
for _, iface := range ifaces {
|
||||
addrs, err := iface.Addrs()
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
for _, addr := range addrs {
|
||||
switch x := addr.(type) {
|
||||
case *net.IPNet:
|
||||
if x.Contains(devaddr.IP) {
|
||||
return x.IP, nil
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
return nil, fmt.Errorf("could not find local address in same net as %v", devaddr)
|
||||
}
|
||||
|
||||
func (n *upnp) DeleteMapping(protocol string, extport, intport int) error {
|
||||
return n.client.DeletePortMapping("", uint16(extport), strings.ToUpper(protocol))
|
||||
}
|
||||
|
||||
func (n *upnp) String() string {
|
||||
return "UPNP " + n.service
|
||||
}
|
||||
|
||||
// discoverUPnP searches for Internet Gateway Devices
|
||||
// and returns the first one it can find on the local network.
|
||||
func discoverUPnP() Interface {
|
||||
found := make(chan *upnp, 2)
|
||||
// IGDv1
|
||||
go discover(found, internetgateway1.URN_WANConnectionDevice_1, func(dev *goupnp.RootDevice, sc goupnp.ServiceClient) *upnp {
|
||||
switch sc.Service.ServiceType {
|
||||
case internetgateway1.URN_WANIPConnection_1:
|
||||
return &upnp{dev, "IGDv1-IP1", &internetgateway1.WANIPConnection1{sc}}
|
||||
case internetgateway1.URN_WANPPPConnection_1:
|
||||
return &upnp{dev, "IGDv1-PPP1", &internetgateway1.WANPPPConnection1{sc}}
|
||||
}
|
||||
return nil
|
||||
})
|
||||
// IGDv2
|
||||
go discover(found, internetgateway2.URN_WANConnectionDevice_2, func(dev *goupnp.RootDevice, sc goupnp.ServiceClient) *upnp {
|
||||
switch sc.Service.ServiceType {
|
||||
case internetgateway2.URN_WANIPConnection_1:
|
||||
return &upnp{dev, "IGDv2-IP1", &internetgateway2.WANIPConnection1{sc}}
|
||||
case internetgateway2.URN_WANIPConnection_2:
|
||||
return &upnp{dev, "IGDv2-IP2", &internetgateway2.WANIPConnection2{sc}}
|
||||
case internetgateway2.URN_WANPPPConnection_1:
|
||||
return &upnp{dev, "IGDv2-PPP1", &internetgateway2.WANPPPConnection1{sc}}
|
||||
}
|
||||
return nil
|
||||
})
|
||||
for i := 0; i < cap(found); i++ {
|
||||
if c := <-found; c != nil {
|
||||
return c
|
||||
}
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func discover(out chan<- *upnp, target string, matcher func(*goupnp.RootDevice, goupnp.ServiceClient) *upnp) {
|
||||
devs, err := goupnp.DiscoverDevices(target)
|
||||
if err != nil {
|
||||
return
|
||||
}
|
||||
found := false
|
||||
for i := 0; i < len(devs) && !found; i++ {
|
||||
if devs[i].Root == nil {
|
||||
continue
|
||||
}
|
||||
devs[i].Root.Device.VisitServices(func(service *goupnp.Service) {
|
||||
if found {
|
||||
return
|
||||
}
|
||||
// check for a matching IGD service
|
||||
sc := goupnp.ServiceClient{service.NewSOAPClient(), devs[i].Root, service}
|
||||
upnp := matcher(devs[i].Root, sc)
|
||||
if upnp == nil {
|
||||
return
|
||||
}
|
||||
// check whether port mapping is enabled
|
||||
if _, nat, err := upnp.client.GetNATRSIPStatus(); err != nil || !nat {
|
||||
return
|
||||
}
|
||||
out <- upnp
|
||||
found = true
|
||||
})
|
||||
}
|
||||
if !found {
|
||||
out <- nil
|
||||
}
|
||||
}
|
|
@ -1,55 +0,0 @@
|
|||
package p2p
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"net"
|
||||
"time"
|
||||
|
||||
natpmp "github.com/jackpal/go-nat-pmp"
|
||||
)
|
||||
|
||||
// Adapt the NAT-PMP protocol to the NAT interface
|
||||
|
||||
// TODO:
|
||||
// + Register for changes to the external address.
|
||||
// + Re-register port mapping when router reboots.
|
||||
// + A mechanism for keeping a port mapping registered.
|
||||
// + Discover gateway address automatically.
|
||||
|
||||
type natPMPClient struct {
|
||||
client *natpmp.Client
|
||||
}
|
||||
|
||||
// PMP returns a NAT traverser that uses NAT-PMP. The provided gateway
|
||||
// address should be the IP of your router.
|
||||
func PMP(gateway net.IP) (nat NAT) {
|
||||
return &natPMPClient{natpmp.NewClient(gateway)}
|
||||
}
|
||||
|
||||
func (*natPMPClient) String() string {
|
||||
return "NAT-PMP"
|
||||
}
|
||||
|
||||
func (n *natPMPClient) GetExternalAddress() (net.IP, error) {
|
||||
response, err := n.client.GetExternalAddress()
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return response.ExternalIPAddress[:], nil
|
||||
}
|
||||
|
||||
func (n *natPMPClient) AddPortMapping(protocol string, extport, intport int, name string, lifetime time.Duration) error {
|
||||
if lifetime <= 0 {
|
||||
return fmt.Errorf("lifetime must not be <= 0")
|
||||
}
|
||||
// Note order of port arguments is switched between our AddPortMapping and the client's AddPortMapping.
|
||||
_, err := n.client.AddPortMapping(protocol, intport, extport, int(lifetime/time.Second))
|
||||
return err
|
||||
}
|
||||
|
||||
func (n *natPMPClient) DeletePortMapping(protocol string, externalPort, internalPort int) (err error) {
|
||||
// To destroy a mapping, send an add-port with
|
||||
// an internalPort of the internal port to destroy, an external port of zero and a time of zero.
|
||||
_, err = n.client.AddPortMapping(protocol, internalPort, 0, 0)
|
||||
return
|
||||
}
|
341
p2p/natupnp.go
341
p2p/natupnp.go
|
@ -1,341 +0,0 @@
|
|||
package p2p
|
||||
|
||||
// Just enough UPnP to be able to forward ports
|
||||
//
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"encoding/xml"
|
||||
"errors"
|
||||
"fmt"
|
||||
"net"
|
||||
"net/http"
|
||||
"os"
|
||||
"strconv"
|
||||
"strings"
|
||||
"time"
|
||||
)
|
||||
|
||||
const (
|
||||
upnpDiscoverAttempts = 3
|
||||
upnpDiscoverTimeout = 5 * time.Second
|
||||
)
|
||||
|
||||
// UPNP returns a NAT port mapper that uses UPnP. It will attempt to
|
||||
// discover the address of your router using UDP broadcasts.
|
||||
func UPNP() NAT {
|
||||
return &upnpNAT{}
|
||||
}
|
||||
|
||||
type upnpNAT struct {
|
||||
serviceURL string
|
||||
ourIP string
|
||||
}
|
||||
|
||||
func (n *upnpNAT) String() string {
|
||||
return "UPNP"
|
||||
}
|
||||
|
||||
func (n *upnpNAT) discover() error {
|
||||
if n.serviceURL != "" {
|
||||
// already discovered
|
||||
return nil
|
||||
}
|
||||
|
||||
ssdp, err := net.ResolveUDPAddr("udp4", "239.255.255.250:1900")
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
// TODO: try on all network interfaces simultaneously.
|
||||
// Broadcasting on 0.0.0.0 could select a random interface
|
||||
// to send on (platform specific).
|
||||
conn, err := net.ListenPacket("udp4", ":0")
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
defer conn.Close()
|
||||
|
||||
conn.SetDeadline(time.Now().Add(10 * time.Second))
|
||||
st := "ST: urn:schemas-upnp-org:device:InternetGatewayDevice:1\r\n"
|
||||
buf := bytes.NewBufferString(
|
||||
"M-SEARCH * HTTP/1.1\r\n" +
|
||||
"HOST: 239.255.255.250:1900\r\n" +
|
||||
st +
|
||||
"MAN: \"ssdp:discover\"\r\n" +
|
||||
"MX: 2\r\n\r\n")
|
||||
message := buf.Bytes()
|
||||
answerBytes := make([]byte, 1024)
|
||||
for i := 0; i < upnpDiscoverAttempts; i++ {
|
||||
_, err = conn.WriteTo(message, ssdp)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
nn, _, err := conn.ReadFrom(answerBytes)
|
||||
if err != nil {
|
||||
continue
|
||||
}
|
||||
answer := string(answerBytes[0:nn])
|
||||
if strings.Index(answer, "\r\n"+st) < 0 {
|
||||
continue
|
||||
}
|
||||
// HTTP header field names are case-insensitive.
|
||||
// http://www.w3.org/Protocols/rfc2616/rfc2616-sec4.html#sec4.2
|
||||
locString := "\r\nlocation: "
|
||||
answer = strings.ToLower(answer)
|
||||
locIndex := strings.Index(answer, locString)
|
||||
if locIndex < 0 {
|
||||
continue
|
||||
}
|
||||
loc := answer[locIndex+len(locString):]
|
||||
endIndex := strings.Index(loc, "\r\n")
|
||||
if endIndex < 0 {
|
||||
continue
|
||||
}
|
||||
locURL := loc[0:endIndex]
|
||||
var serviceURL string
|
||||
serviceURL, err = getServiceURL(locURL)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
var ourIP string
|
||||
ourIP, err = getOurIP()
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
n.serviceURL = serviceURL
|
||||
n.ourIP = ourIP
|
||||
return nil
|
||||
}
|
||||
return errors.New("UPnP port discovery failed.")
|
||||
}
|
||||
|
||||
func (n *upnpNAT) GetExternalAddress() (addr net.IP, err error) {
|
||||
if err := n.discover(); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
info, err := n.getStatusInfo()
|
||||
return net.ParseIP(info.externalIpAddress), err
|
||||
}
|
||||
|
||||
func (n *upnpNAT) AddPortMapping(protocol string, extport, intport int, description string, lifetime time.Duration) error {
|
||||
if err := n.discover(); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
// A single concatenation would break ARM compilation.
|
||||
message := "<u:AddPortMapping xmlns:u=\"urn:schemas-upnp-org:service:WANIPConnection:1\">\r\n" +
|
||||
"<NewRemoteHost></NewRemoteHost><NewExternalPort>" + strconv.Itoa(extport)
|
||||
message += "</NewExternalPort><NewProtocol>" + protocol + "</NewProtocol>"
|
||||
message += "<NewInternalPort>" + strconv.Itoa(extport) + "</NewInternalPort>" +
|
||||
"<NewInternalClient>" + n.ourIP + "</NewInternalClient>" +
|
||||
"<NewEnabled>1</NewEnabled><NewPortMappingDescription>"
|
||||
message += description +
|
||||
"</NewPortMappingDescription><NewLeaseDuration>" + fmt.Sprint(lifetime/time.Second) +
|
||||
"</NewLeaseDuration></u:AddPortMapping>"
|
||||
|
||||
// TODO: check response to see if the port was forwarded
|
||||
_, err := soapRequest(n.serviceURL, "AddPortMapping", message)
|
||||
return err
|
||||
}
|
||||
|
||||
func (n *upnpNAT) DeletePortMapping(protocol string, externalPort, internalPort int) error {
|
||||
if err := n.discover(); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
message := "<u:DeletePortMapping xmlns:u=\"urn:schemas-upnp-org:service:WANIPConnection:1\">\r\n" +
|
||||
"<NewRemoteHost></NewRemoteHost><NewExternalPort>" + strconv.Itoa(externalPort) +
|
||||
"</NewExternalPort><NewProtocol>" + protocol + "</NewProtocol>" +
|
||||
"</u:DeletePortMapping>"
|
||||
|
||||
// TODO: check response to see if the port was deleted
|
||||
_, err := soapRequest(n.serviceURL, "DeletePortMapping", message)
|
||||
return err
|
||||
}
|
||||
|
||||
type statusInfo struct {
|
||||
externalIpAddress string
|
||||
}
|
||||
|
||||
func (n *upnpNAT) getStatusInfo() (info statusInfo, err error) {
|
||||
message := "<u:GetStatusInfo xmlns:u=\"urn:schemas-upnp-org:service:WANIPConnection:1\">\r\n" +
|
||||
"</u:GetStatusInfo>"
|
||||
|
||||
var response *http.Response
|
||||
response, err = soapRequest(n.serviceURL, "GetStatusInfo", message)
|
||||
if err != nil {
|
||||
return
|
||||
}
|
||||
|
||||
// TODO: Write a soap reply parser. It has to eat the Body and envelope tags...
|
||||
|
||||
response.Body.Close()
|
||||
return
|
||||
}
|
||||
|
||||
// service represents the Service type in an UPnP xml description.
|
||||
// Only the parts we care about are present and thus the xml may have more
|
||||
// fields than present in the structure.
|
||||
type service struct {
|
||||
ServiceType string `xml:"serviceType"`
|
||||
ControlURL string `xml:"controlURL"`
|
||||
}
|
||||
|
||||
// deviceList represents the deviceList type in an UPnP xml description.
|
||||
// Only the parts we care about are present and thus the xml may have more
|
||||
// fields than present in the structure.
|
||||
type deviceList struct {
|
||||
XMLName xml.Name `xml:"deviceList"`
|
||||
Device []device `xml:"device"`
|
||||
}
|
||||
|
||||
// serviceList represents the serviceList type in an UPnP xml description.
|
||||
// Only the parts we care about are present and thus the xml may have more
|
||||
// fields than present in the structure.
|
||||
type serviceList struct {
|
||||
XMLName xml.Name `xml:"serviceList"`
|
||||
Service []service `xml:"service"`
|
||||
}
|
||||
|
||||
// device represents the device type in an UPnP xml description.
|
||||
// Only the parts we care about are present and thus the xml may have more
|
||||
// fields than present in the structure.
|
||||
type device struct {
|
||||
XMLName xml.Name `xml:"device"`
|
||||
DeviceType string `xml:"deviceType"`
|
||||
DeviceList deviceList `xml:"deviceList"`
|
||||
ServiceList serviceList `xml:"serviceList"`
|
||||
}
|
||||
|
||||
// specVersion represents the specVersion in a UPnP xml description.
|
||||
// Only the parts we care about are present and thus the xml may have more
|
||||
// fields than present in the structure.
|
||||
type specVersion struct {
|
||||
XMLName xml.Name `xml:"specVersion"`
|
||||
Major int `xml:"major"`
|
||||
Minor int `xml:"minor"`
|
||||
}
|
||||
|
||||
// root represents the Root document for a UPnP xml description.
|
||||
// Only the parts we care about are present and thus the xml may have more
|
||||
// fields than present in the structure.
|
||||
type root struct {
|
||||
XMLName xml.Name `xml:"root"`
|
||||
SpecVersion specVersion
|
||||
Device device
|
||||
}
|
||||
|
||||
func getChildDevice(d *device, deviceType string) *device {
|
||||
dl := d.DeviceList.Device
|
||||
for i := 0; i < len(dl); i++ {
|
||||
if dl[i].DeviceType == deviceType {
|
||||
return &dl[i]
|
||||
}
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func getChildService(d *device, serviceType string) *service {
|
||||
sl := d.ServiceList.Service
|
||||
for i := 0; i < len(sl); i++ {
|
||||
if sl[i].ServiceType == serviceType {
|
||||
return &sl[i]
|
||||
}
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func getOurIP() (ip string, err error) {
|
||||
hostname, err := os.Hostname()
|
||||
if err != nil {
|
||||
return
|
||||
}
|
||||
p, err := net.LookupIP(hostname)
|
||||
if err != nil && len(p) > 0 {
|
||||
return
|
||||
}
|
||||
return p[0].String(), nil
|
||||
}
|
||||
|
||||
func getServiceURL(rootURL string) (url string, err error) {
|
||||
r, err := http.Get(rootURL)
|
||||
if err != nil {
|
||||
return
|
||||
}
|
||||
defer r.Body.Close()
|
||||
if r.StatusCode >= 400 {
|
||||
err = errors.New(string(r.StatusCode))
|
||||
return
|
||||
}
|
||||
var root root
|
||||
err = xml.NewDecoder(r.Body).Decode(&root)
|
||||
|
||||
if err != nil {
|
||||
return
|
||||
}
|
||||
a := &root.Device
|
||||
if a.DeviceType != "urn:schemas-upnp-org:device:InternetGatewayDevice:1" {
|
||||
err = errors.New("No InternetGatewayDevice")
|
||||
return
|
||||
}
|
||||
b := getChildDevice(a, "urn:schemas-upnp-org:device:WANDevice:1")
|
||||
if b == nil {
|
||||
err = errors.New("No WANDevice")
|
||||
return
|
||||
}
|
||||
c := getChildDevice(b, "urn:schemas-upnp-org:device:WANConnectionDevice:1")
|
||||
if c == nil {
|
||||
err = errors.New("No WANConnectionDevice")
|
||||
return
|
||||
}
|
||||
d := getChildService(c, "urn:schemas-upnp-org:service:WANIPConnection:1")
|
||||
if d == nil {
|
||||
err = errors.New("No WANIPConnection")
|
||||
return
|
||||
}
|
||||
url = combineURL(rootURL, d.ControlURL)
|
||||
return
|
||||
}
|
||||
|
||||
func combineURL(rootURL, subURL string) string {
|
||||
protocolEnd := "://"
|
||||
protoEndIndex := strings.Index(rootURL, protocolEnd)
|
||||
a := rootURL[protoEndIndex+len(protocolEnd):]
|
||||
rootIndex := strings.Index(a, "/")
|
||||
return rootURL[0:protoEndIndex+len(protocolEnd)+rootIndex] + subURL
|
||||
}
|
||||
|
||||
func soapRequest(url, function, message string) (r *http.Response, err error) {
|
||||
fullMessage := "<?xml version=\"1.0\" ?>" +
|
||||
"<s:Envelope xmlns:s=\"http://schemas.xmlsoap.org/soap/envelope/\" s:encodingStyle=\"http://schemas.xmlsoap.org/soap/encoding/\">\r\n" +
|
||||
"<s:Body>" + message + "</s:Body></s:Envelope>"
|
||||
|
||||
req, err := http.NewRequest("POST", url, strings.NewReader(fullMessage))
|
||||
if err != nil {
|
||||
return
|
||||
}
|
||||
req.Header.Set("Content-Type", "text/xml ; charset=\"utf-8\"")
|
||||
req.Header.Set("User-Agent", "Darwin/10.0.0, UPnP/1.0, MiniUPnPc/1.3")
|
||||
//req.Header.Set("Transfer-Encoding", "chunked")
|
||||
req.Header.Set("SOAPAction", "\"urn:schemas-upnp-org:service:WANIPConnection:1#"+function+"\"")
|
||||
req.Header.Set("Connection", "Close")
|
||||
req.Header.Set("Cache-Control", "no-cache")
|
||||
req.Header.Set("Pragma", "no-cache")
|
||||
|
||||
r, err = http.DefaultClient.Do(req)
|
||||
if err != nil {
|
||||
return
|
||||
}
|
||||
|
||||
if r.Body != nil {
|
||||
defer r.Body.Close()
|
||||
}
|
||||
|
||||
if r.StatusCode >= 400 {
|
||||
// log.Stderr(function, r.StatusCode)
|
||||
err = errors.New("Error " + strconv.Itoa(r.StatusCode) + " for " + function)
|
||||
r = nil
|
||||
return
|
||||
}
|
||||
return
|
||||
}
|
480
p2p/peer.go
480
p2p/peer.go
|
@ -1,8 +1,7 @@
|
|||
package p2p
|
||||
|
||||
import (
|
||||
"bufio"
|
||||
"bytes"
|
||||
"errors"
|
||||
"fmt"
|
||||
"io"
|
||||
"io/ioutil"
|
||||
|
@ -11,159 +10,109 @@ import (
|
|||
"sync"
|
||||
"time"
|
||||
|
||||
"github.com/ethereum/go-ethereum/event"
|
||||
"github.com/ethereum/go-ethereum/logger"
|
||||
"github.com/ethereum/go-ethereum/p2p/discover"
|
||||
"github.com/ethereum/go-ethereum/rlp"
|
||||
)
|
||||
|
||||
// peerAddr is the structure of a peer list element.
|
||||
// It is also a valid net.Addr.
|
||||
type peerAddr struct {
|
||||
IP net.IP
|
||||
Port uint64
|
||||
Pubkey []byte // optional
|
||||
const (
|
||||
baseProtocolVersion = 3
|
||||
baseProtocolLength = uint64(16)
|
||||
baseProtocolMaxMsgSize = 10 * 1024 * 1024
|
||||
|
||||
disconnectGracePeriod = 2 * time.Second
|
||||
)
|
||||
|
||||
const (
|
||||
// devp2p message codes
|
||||
handshakeMsg = 0x00
|
||||
discMsg = 0x01
|
||||
pingMsg = 0x02
|
||||
pongMsg = 0x03
|
||||
getPeersMsg = 0x04
|
||||
peersMsg = 0x05
|
||||
)
|
||||
|
||||
// handshake is the RLP structure of the protocol handshake.
|
||||
type handshake struct {
|
||||
Version uint64
|
||||
Name string
|
||||
Caps []Cap
|
||||
ListenPort uint64
|
||||
NodeID discover.NodeID
|
||||
}
|
||||
|
||||
func newPeerAddr(addr net.Addr, pubkey []byte) *peerAddr {
|
||||
n := addr.Network()
|
||||
if n != "tcp" && n != "tcp4" && n != "tcp6" {
|
||||
// for testing with non-TCP
|
||||
return &peerAddr{net.ParseIP("127.0.0.1"), 30303, pubkey}
|
||||
}
|
||||
ta := addr.(*net.TCPAddr)
|
||||
return &peerAddr{ta.IP, uint64(ta.Port), pubkey}
|
||||
}
|
||||
|
||||
func (d peerAddr) Network() string {
|
||||
if d.IP.To4() != nil {
|
||||
return "tcp4"
|
||||
} else {
|
||||
return "tcp6"
|
||||
}
|
||||
}
|
||||
|
||||
func (d peerAddr) String() string {
|
||||
return fmt.Sprintf("%v:%d", d.IP, d.Port)
|
||||
}
|
||||
|
||||
func (d *peerAddr) RlpData() interface{} {
|
||||
return []interface{}{string(d.IP), d.Port, d.Pubkey}
|
||||
}
|
||||
|
||||
// Peer represents a remote peer.
|
||||
// Peer represents a connected remote node.
|
||||
type Peer struct {
|
||||
// Peers have all the log methods.
|
||||
// Use them to display messages related to the peer.
|
||||
*logger.Logger
|
||||
|
||||
infolock sync.Mutex
|
||||
identity ClientIdentity
|
||||
caps []Cap
|
||||
listenAddr *peerAddr // what remote peer is listening on
|
||||
dialAddr *peerAddr // non-nil if dialing
|
||||
infoMu sync.Mutex
|
||||
name string
|
||||
caps []Cap
|
||||
|
||||
// The mutex protects the connection
|
||||
// so only one protocol can write at a time.
|
||||
writeMu sync.Mutex
|
||||
conn net.Conn
|
||||
bufconn *bufio.ReadWriter
|
||||
ourID, remoteID *discover.NodeID
|
||||
ourName string
|
||||
|
||||
rw *frameRW
|
||||
|
||||
// These fields maintain the running protocols.
|
||||
protocols []Protocol
|
||||
runBaseProtocol bool // for testing
|
||||
protocols []Protocol
|
||||
runlock sync.RWMutex // protects running
|
||||
running map[string]*proto
|
||||
|
||||
runlock sync.RWMutex // protects running
|
||||
running map[string]*proto
|
||||
// disables protocol handshake, for testing
|
||||
noHandshake bool
|
||||
|
||||
protoWG sync.WaitGroup
|
||||
protoErr chan error
|
||||
closed chan struct{}
|
||||
disc chan DiscReason
|
||||
|
||||
activity event.TypeMux // for activity events
|
||||
|
||||
slot int // index into Server peer list
|
||||
|
||||
// These fields are kept so base protocol can access them.
|
||||
// TODO: this should be one or more interfaces
|
||||
ourID ClientIdentity // client id of the Server
|
||||
ourListenAddr *peerAddr // listen addr of Server, nil if not listening
|
||||
newPeerAddr chan<- *peerAddr // tell server about received peers
|
||||
otherPeers func() []*Peer // should return the list of all peers
|
||||
pubkeyHook func(*peerAddr) error // called at end of handshake to validate pubkey
|
||||
}
|
||||
|
||||
// NewPeer returns a peer for testing purposes.
|
||||
func NewPeer(id ClientIdentity, caps []Cap) *Peer {
|
||||
func NewPeer(id discover.NodeID, name string, caps []Cap) *Peer {
|
||||
conn, _ := net.Pipe()
|
||||
peer := newPeer(conn, nil, nil)
|
||||
peer.setHandshakeInfo(id, nil, caps)
|
||||
close(peer.closed)
|
||||
peer := newPeer(conn, nil, "", nil, &id)
|
||||
peer.setHandshakeInfo(name, caps)
|
||||
close(peer.closed) // ensures Disconnect doesn't block
|
||||
return peer
|
||||
}
|
||||
|
||||
func newServerPeer(server *Server, conn net.Conn, dialAddr *peerAddr) *Peer {
|
||||
p := newPeer(conn, server.Protocols, dialAddr)
|
||||
p.ourID = server.Identity
|
||||
p.newPeerAddr = server.peerConnect
|
||||
p.otherPeers = server.Peers
|
||||
p.pubkeyHook = server.verifyPeer
|
||||
p.runBaseProtocol = true
|
||||
|
||||
// laddr can be updated concurrently by NAT traversal.
|
||||
// newServerPeer must be called with the server lock held.
|
||||
if server.laddr != nil {
|
||||
p.ourListenAddr = newPeerAddr(server.laddr, server.Identity.Pubkey())
|
||||
}
|
||||
return p
|
||||
// ID returns the node's public key.
|
||||
func (p *Peer) ID() discover.NodeID {
|
||||
return *p.remoteID
|
||||
}
|
||||
|
||||
func newPeer(conn net.Conn, protocols []Protocol, dialAddr *peerAddr) *Peer {
|
||||
p := &Peer{
|
||||
Logger: logger.NewLogger("P2P " + conn.RemoteAddr().String()),
|
||||
conn: conn,
|
||||
dialAddr: dialAddr,
|
||||
bufconn: bufio.NewReadWriter(bufio.NewReader(conn), bufio.NewWriter(conn)),
|
||||
protocols: protocols,
|
||||
running: make(map[string]*proto),
|
||||
disc: make(chan DiscReason),
|
||||
protoErr: make(chan error),
|
||||
closed: make(chan struct{}),
|
||||
}
|
||||
return p
|
||||
}
|
||||
|
||||
// Identity returns the client identity of the remote peer. The
|
||||
// identity can be nil if the peer has not yet completed the
|
||||
// handshake.
|
||||
func (p *Peer) Identity() ClientIdentity {
|
||||
p.infolock.Lock()
|
||||
defer p.infolock.Unlock()
|
||||
return p.identity
|
||||
// Name returns the node name that the remote node advertised.
|
||||
func (p *Peer) Name() string {
|
||||
// this needs a lock because the information is part of the
|
||||
// protocol handshake.
|
||||
p.infoMu.Lock()
|
||||
name := p.name
|
||||
p.infoMu.Unlock()
|
||||
return name
|
||||
}
|
||||
|
||||
// Caps returns the capabilities (supported subprotocols) of the remote peer.
|
||||
func (p *Peer) Caps() []Cap {
|
||||
p.infolock.Lock()
|
||||
defer p.infolock.Unlock()
|
||||
return p.caps
|
||||
}
|
||||
|
||||
func (p *Peer) setHandshakeInfo(id ClientIdentity, laddr *peerAddr, caps []Cap) {
|
||||
p.infolock.Lock()
|
||||
p.identity = id
|
||||
p.listenAddr = laddr
|
||||
p.caps = caps
|
||||
p.infolock.Unlock()
|
||||
// this needs a lock because the information is part of the
|
||||
// protocol handshake.
|
||||
p.infoMu.Lock()
|
||||
caps := p.caps
|
||||
p.infoMu.Unlock()
|
||||
return caps
|
||||
}
|
||||
|
||||
// RemoteAddr returns the remote address of the network connection.
|
||||
func (p *Peer) RemoteAddr() net.Addr {
|
||||
return p.conn.RemoteAddr()
|
||||
return p.rw.RemoteAddr()
|
||||
}
|
||||
|
||||
// LocalAddr returns the local address of the network connection.
|
||||
func (p *Peer) LocalAddr() net.Addr {
|
||||
return p.conn.LocalAddr()
|
||||
return p.rw.LocalAddr()
|
||||
}
|
||||
|
||||
// Disconnect terminates the peer connection with the given reason.
|
||||
|
@ -177,149 +126,177 @@ func (p *Peer) Disconnect(reason DiscReason) {
|
|||
|
||||
// String implements fmt.Stringer.
|
||||
func (p *Peer) String() string {
|
||||
kind := "inbound"
|
||||
p.infolock.Lock()
|
||||
if p.dialAddr != nil {
|
||||
kind = "outbound"
|
||||
}
|
||||
p.infolock.Unlock()
|
||||
return fmt.Sprintf("Peer(%p %v %s)", p, p.conn.RemoteAddr(), kind)
|
||||
return fmt.Sprintf("Peer %.8x %v", p.remoteID[:], p.RemoteAddr())
|
||||
}
|
||||
|
||||
const (
|
||||
// maximum amount of time allowed for reading a message
|
||||
msgReadTimeout = 5 * time.Second
|
||||
// maximum amount of time allowed for writing a message
|
||||
msgWriteTimeout = 5 * time.Second
|
||||
// messages smaller than this many bytes will be read at
|
||||
// once before passing them to a protocol.
|
||||
wholePayloadSize = 64 * 1024
|
||||
)
|
||||
func newPeer(conn net.Conn, protocols []Protocol, ourName string, ourID, remoteID *discover.NodeID) *Peer {
|
||||
logtag := fmt.Sprintf("Peer %.8x %v", remoteID[:], conn.RemoteAddr())
|
||||
return &Peer{
|
||||
Logger: logger.NewLogger(logtag),
|
||||
rw: newFrameRW(conn, msgWriteTimeout),
|
||||
ourID: ourID,
|
||||
ourName: ourName,
|
||||
remoteID: remoteID,
|
||||
protocols: protocols,
|
||||
running: make(map[string]*proto),
|
||||
disc: make(chan DiscReason),
|
||||
protoErr: make(chan error),
|
||||
closed: make(chan struct{}),
|
||||
}
|
||||
}
|
||||
|
||||
var (
|
||||
inactivityTimeout = 2 * time.Second
|
||||
disconnectGracePeriod = 2 * time.Second
|
||||
)
|
||||
func (p *Peer) setHandshakeInfo(name string, caps []Cap) {
|
||||
p.infoMu.Lock()
|
||||
p.name = name
|
||||
p.caps = caps
|
||||
p.infoMu.Unlock()
|
||||
}
|
||||
|
||||
func (p *Peer) loop() (reason DiscReason, err error) {
|
||||
defer p.activity.Stop()
|
||||
func (p *Peer) run() DiscReason {
|
||||
var readErr = make(chan error, 1)
|
||||
defer p.closeProtocols()
|
||||
defer close(p.closed)
|
||||
defer p.conn.Close()
|
||||
|
||||
// read loop
|
||||
readMsg := make(chan Msg)
|
||||
readErr := make(chan error)
|
||||
readNext := make(chan bool, 1)
|
||||
protoDone := make(chan struct{}, 1)
|
||||
go p.readLoop(readMsg, readErr, readNext)
|
||||
readNext <- true
|
||||
go func() { readErr <- p.readLoop() }()
|
||||
|
||||
if p.runBaseProtocol {
|
||||
p.startBaseProtocol()
|
||||
}
|
||||
|
||||
loop:
|
||||
for {
|
||||
select {
|
||||
case msg := <-readMsg:
|
||||
// a new message has arrived.
|
||||
var wait bool
|
||||
if wait, err = p.dispatch(msg, protoDone); err != nil {
|
||||
p.Errorf("msg dispatch error: %v\n", err)
|
||||
reason = discReasonForError(err)
|
||||
break loop
|
||||
}
|
||||
if !wait {
|
||||
// Msg has already been read completely, continue with next message.
|
||||
readNext <- true
|
||||
}
|
||||
p.activity.Post(time.Now())
|
||||
case <-protoDone:
|
||||
// protocol has consumed the message payload,
|
||||
// we can continue reading from the socket.
|
||||
readNext <- true
|
||||
|
||||
case err := <-readErr:
|
||||
// read failed. there is no need to run the
|
||||
// polite disconnect sequence because the connection
|
||||
// is probably dead anyway.
|
||||
// TODO: handle write errors as well
|
||||
return DiscNetworkError, err
|
||||
case err = <-p.protoErr:
|
||||
reason = discReasonForError(err)
|
||||
break loop
|
||||
case reason = <-p.disc:
|
||||
break loop
|
||||
if !p.noHandshake {
|
||||
if err := writeProtocolHandshake(p.rw, p.ourName, *p.ourID, p.protocols); err != nil {
|
||||
p.DebugDetailf("Protocol handshake error: %v\n", err)
|
||||
p.rw.Close()
|
||||
return DiscProtocolError
|
||||
}
|
||||
}
|
||||
|
||||
// wait for read loop to return.
|
||||
close(readNext)
|
||||
// Wait for an error or disconnect.
|
||||
var reason DiscReason
|
||||
select {
|
||||
case err := <-readErr:
|
||||
// We rely on protocols to abort if there is a write error. It
|
||||
// might be more robust to handle them here as well.
|
||||
p.DebugDetailf("Read error: %v\n", err)
|
||||
p.rw.Close()
|
||||
return DiscNetworkError
|
||||
|
||||
case err := <-p.protoErr:
|
||||
reason = discReasonForError(err)
|
||||
case reason = <-p.disc:
|
||||
}
|
||||
p.politeDisconnect(reason)
|
||||
|
||||
// Wait for readLoop. It will end because conn is now closed.
|
||||
<-readErr
|
||||
// tell the remote end to disconnect
|
||||
p.Debugf("Disconnected: %v\n", reason)
|
||||
return reason
|
||||
}
|
||||
|
||||
func (p *Peer) politeDisconnect(reason DiscReason) {
|
||||
done := make(chan struct{})
|
||||
go func() {
|
||||
p.conn.SetDeadline(time.Now().Add(disconnectGracePeriod))
|
||||
p.writeMsg(NewMsg(discMsg, reason), disconnectGracePeriod)
|
||||
io.Copy(ioutil.Discard, p.conn)
|
||||
EncodeMsg(p.rw, discMsg, uint(reason))
|
||||
// Wait for the other side to close the connection.
|
||||
// Discard any data that they send until then.
|
||||
io.Copy(ioutil.Discard, p.rw)
|
||||
close(done)
|
||||
}()
|
||||
select {
|
||||
case <-done:
|
||||
case <-time.After(disconnectGracePeriod):
|
||||
}
|
||||
return reason, err
|
||||
p.rw.Close()
|
||||
}
|
||||
|
||||
func (p *Peer) readLoop(msgc chan<- Msg, errc chan<- error, unblock <-chan bool) {
|
||||
for _ = range unblock {
|
||||
p.conn.SetReadDeadline(time.Now().Add(msgReadTimeout))
|
||||
if msg, err := readMsg(p.bufconn); err != nil {
|
||||
errc <- err
|
||||
} else {
|
||||
msgc <- msg
|
||||
func (p *Peer) readLoop() error {
|
||||
if !p.noHandshake {
|
||||
if err := readProtocolHandshake(p, p.rw); err != nil {
|
||||
return err
|
||||
}
|
||||
}
|
||||
close(errc)
|
||||
}
|
||||
|
||||
func (p *Peer) dispatch(msg Msg, protoDone chan struct{}) (wait bool, err error) {
|
||||
proto, err := p.getProto(msg.Code)
|
||||
if err != nil {
|
||||
return false, err
|
||||
}
|
||||
if msg.Size <= wholePayloadSize {
|
||||
// optimization: msg is small enough, read all
|
||||
// of it and move on to the next message
|
||||
buf, err := ioutil.ReadAll(msg.Payload)
|
||||
for {
|
||||
msg, err := p.rw.ReadMsg()
|
||||
if err != nil {
|
||||
return false, err
|
||||
return err
|
||||
}
|
||||
if err = p.handle(msg); err != nil {
|
||||
return err
|
||||
}
|
||||
msg.Payload = bytes.NewReader(buf)
|
||||
proto.in <- msg
|
||||
} else {
|
||||
wait = true
|
||||
pr := &eofSignal{msg.Payload, int64(msg.Size), protoDone}
|
||||
msg.Payload = pr
|
||||
proto.in <- msg
|
||||
}
|
||||
return wait, nil
|
||||
return nil
|
||||
}
|
||||
|
||||
func (p *Peer) startBaseProtocol() {
|
||||
p.runlock.Lock()
|
||||
defer p.runlock.Unlock()
|
||||
p.running[""] = p.startProto(0, Protocol{
|
||||
Length: baseProtocolLength,
|
||||
Run: runBaseProtocol,
|
||||
})
|
||||
func (p *Peer) handle(msg Msg) error {
|
||||
switch {
|
||||
case msg.Code == pingMsg:
|
||||
msg.Discard()
|
||||
go EncodeMsg(p.rw, pongMsg)
|
||||
case msg.Code == discMsg:
|
||||
var reason DiscReason
|
||||
// no need to discard or for error checking, we'll close the
|
||||
// connection after this.
|
||||
rlp.Decode(msg.Payload, &reason)
|
||||
p.Disconnect(DiscRequested)
|
||||
return discRequestedError(reason)
|
||||
case msg.Code < baseProtocolLength:
|
||||
// ignore other base protocol messages
|
||||
return msg.Discard()
|
||||
default:
|
||||
// it's a subprotocol message
|
||||
proto, err := p.getProto(msg.Code)
|
||||
if err != nil {
|
||||
return fmt.Errorf("msg code out of range: %v", msg.Code)
|
||||
}
|
||||
proto.in <- msg
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func readProtocolHandshake(p *Peer, rw MsgReadWriter) error {
|
||||
// read and handle remote handshake
|
||||
msg, err := rw.ReadMsg()
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
if msg.Code == discMsg {
|
||||
// disconnect before protocol handshake is valid according to the
|
||||
// spec and we send it ourself if Server.addPeer fails.
|
||||
var reason DiscReason
|
||||
rlp.Decode(msg.Payload, &reason)
|
||||
return discRequestedError(reason)
|
||||
}
|
||||
if msg.Code != handshakeMsg {
|
||||
return newPeerError(errProtocolBreach, "expected handshake, got %x", msg.Code)
|
||||
}
|
||||
if msg.Size > baseProtocolMaxMsgSize {
|
||||
return newPeerError(errInvalidMsg, "message too big")
|
||||
}
|
||||
var hs handshake
|
||||
if err := msg.Decode(&hs); err != nil {
|
||||
return err
|
||||
}
|
||||
// validate handshake info
|
||||
if hs.Version != baseProtocolVersion {
|
||||
return newPeerError(errP2PVersionMismatch, "required version %d, received %d\n",
|
||||
baseProtocolVersion, hs.Version)
|
||||
}
|
||||
if hs.NodeID == *p.remoteID {
|
||||
return newPeerError(errPubkeyForbidden, "node ID mismatch")
|
||||
}
|
||||
// TODO: remove Caps with empty name
|
||||
p.setHandshakeInfo(hs.Name, hs.Caps)
|
||||
p.startSubprotocols(hs.Caps)
|
||||
return nil
|
||||
}
|
||||
|
||||
func writeProtocolHandshake(w MsgWriter, name string, id discover.NodeID, ps []Protocol) error {
|
||||
var caps []interface{}
|
||||
for _, proto := range ps {
|
||||
caps = append(caps, proto.cap())
|
||||
}
|
||||
return EncodeMsg(w, handshakeMsg, baseProtocolVersion, name, caps, 0, id)
|
||||
}
|
||||
|
||||
// startProtocols starts matching named subprotocols.
|
||||
func (p *Peer) startSubprotocols(caps []Cap) {
|
||||
sort.Sort(capsByName(caps))
|
||||
|
||||
p.runlock.Lock()
|
||||
defer p.runlock.Unlock()
|
||||
offset := baseProtocolLength
|
||||
|
@ -338,20 +315,22 @@ outer:
|
|||
}
|
||||
|
||||
func (p *Peer) startProto(offset uint64, impl Protocol) *proto {
|
||||
p.DebugDetailf("Starting protocol %s/%d\n", impl.Name, impl.Version)
|
||||
rw := &proto{
|
||||
name: impl.Name,
|
||||
in: make(chan Msg),
|
||||
offset: offset,
|
||||
maxcode: impl.Length,
|
||||
peer: p,
|
||||
w: p.rw,
|
||||
}
|
||||
p.protoWG.Add(1)
|
||||
go func() {
|
||||
err := impl.Run(p, rw)
|
||||
if err == nil {
|
||||
p.Infof("protocol %q returned", impl.Name)
|
||||
err = newPeerError(errMisc, "protocol returned")
|
||||
p.DebugDetailf("Protocol %s/%d returned\n", impl.Name, impl.Version)
|
||||
err = errors.New("protocol returned")
|
||||
} else {
|
||||
p.Errorf("protocol %q error: %v\n", impl.Name, err)
|
||||
p.DebugDetailf("Protocol %s/%d error: %v\n", impl.Name, impl.Version, err)
|
||||
}
|
||||
select {
|
||||
case p.protoErr <- err:
|
||||
|
@ -385,6 +364,7 @@ func (p *Peer) closeProtocols() {
|
|||
}
|
||||
|
||||
// writeProtoMsg sends the given message on behalf of the given named protocol.
|
||||
// this exists because of Server.Broadcast.
|
||||
func (p *Peer) writeProtoMsg(protoName string, msg Msg) error {
|
||||
p.runlock.RLock()
|
||||
proto, ok := p.running[protoName]
|
||||
|
@ -396,25 +376,14 @@ func (p *Peer) writeProtoMsg(protoName string, msg Msg) error {
|
|||
return newPeerError(errInvalidMsgCode, "code %x is out of range for protocol %q", msg.Code, protoName)
|
||||
}
|
||||
msg.Code += proto.offset
|
||||
return p.writeMsg(msg, msgWriteTimeout)
|
||||
}
|
||||
|
||||
// writeMsg writes a message to the connection.
|
||||
func (p *Peer) writeMsg(msg Msg, timeout time.Duration) error {
|
||||
p.writeMu.Lock()
|
||||
defer p.writeMu.Unlock()
|
||||
p.conn.SetWriteDeadline(time.Now().Add(timeout))
|
||||
if err := writeMsg(p.bufconn, msg); err != nil {
|
||||
return newPeerError(errWrite, "%v", err)
|
||||
}
|
||||
return p.bufconn.Flush()
|
||||
return p.rw.WriteMsg(msg)
|
||||
}
|
||||
|
||||
type proto struct {
|
||||
name string
|
||||
in chan Msg
|
||||
maxcode, offset uint64
|
||||
peer *Peer
|
||||
w MsgWriter
|
||||
}
|
||||
|
||||
func (rw *proto) WriteMsg(msg Msg) error {
|
||||
|
@ -422,11 +391,7 @@ func (rw *proto) WriteMsg(msg Msg) error {
|
|||
return newPeerError(errInvalidMsgCode, "not handled")
|
||||
}
|
||||
msg.Code += rw.offset
|
||||
return rw.peer.writeMsg(msg, msgWriteTimeout)
|
||||
}
|
||||
|
||||
func (rw *proto) EncodeMsg(code uint64, data ...interface{}) error {
|
||||
return rw.WriteMsg(NewMsg(code, data...))
|
||||
return rw.w.WriteMsg(msg)
|
||||
}
|
||||
|
||||
func (rw *proto) ReadMsg() (Msg, error) {
|
||||
|
@ -437,26 +402,3 @@ func (rw *proto) ReadMsg() (Msg, error) {
|
|||
msg.Code -= rw.offset
|
||||
return msg, nil
|
||||
}
|
||||
|
||||
// eofSignal wraps a reader with eof signaling. the eof channel is
|
||||
// closed when the wrapped reader returns an error or when count bytes
|
||||
// have been read.
|
||||
//
|
||||
type eofSignal struct {
|
||||
wrapped io.Reader
|
||||
count int64
|
||||
eof chan<- struct{}
|
||||
}
|
||||
|
||||
// note: when using eofSignal to detect whether a message payload
|
||||
// has been read, Read might not be called for zero sized messages.
|
||||
|
||||
func (r *eofSignal) Read(buf []byte) (int, error) {
|
||||
n, err := r.wrapped.Read(buf)
|
||||
r.count -= int64(n)
|
||||
if (err != nil || r.count <= 0) && r.eof != nil {
|
||||
r.eof <- struct{}{} // tell Peer that msg has been consumed
|
||||
r.eof = nil
|
||||
}
|
||||
return n, err
|
||||
}
|
||||
|
|
|
@ -12,7 +12,6 @@ const (
|
|||
errInvalidMsgCode
|
||||
errInvalidMsg
|
||||
errP2PVersionMismatch
|
||||
errPubkeyMissing
|
||||
errPubkeyInvalid
|
||||
errPubkeyForbidden
|
||||
errProtocolBreach
|
||||
|
@ -22,20 +21,19 @@ const (
|
|||
)
|
||||
|
||||
var errorToString = map[int]string{
|
||||
errMagicTokenMismatch: "Magic token mismatch",
|
||||
errRead: "Read error",
|
||||
errWrite: "Write error",
|
||||
errMisc: "Misc error",
|
||||
errInvalidMsgCode: "Invalid message code",
|
||||
errInvalidMsg: "Invalid message",
|
||||
errMagicTokenMismatch: "magic token mismatch",
|
||||
errRead: "read error",
|
||||
errWrite: "write error",
|
||||
errMisc: "misc error",
|
||||
errInvalidMsgCode: "invalid message code",
|
||||
errInvalidMsg: "invalid message",
|
||||
errP2PVersionMismatch: "P2P Version Mismatch",
|
||||
errPubkeyMissing: "Public key missing",
|
||||
errPubkeyInvalid: "Public key invalid",
|
||||
errPubkeyForbidden: "Public key forbidden",
|
||||
errProtocolBreach: "Protocol Breach",
|
||||
errPingTimeout: "Ping timeout",
|
||||
errInvalidNetworkId: "Invalid network id",
|
||||
errInvalidProtocolVersion: "Invalid protocol version",
|
||||
errPubkeyInvalid: "public key invalid",
|
||||
errPubkeyForbidden: "public key forbidden",
|
||||
errProtocolBreach: "protocol Breach",
|
||||
errPingTimeout: "ping timeout",
|
||||
errInvalidNetworkId: "invalid network id",
|
||||
errInvalidProtocolVersion: "invalid protocol version",
|
||||
}
|
||||
|
||||
type peerError struct {
|
||||
|
@ -62,22 +60,22 @@ func (self *peerError) Error() string {
|
|||
type DiscReason byte
|
||||
|
||||
const (
|
||||
DiscRequested DiscReason = 0x00
|
||||
DiscNetworkError = 0x01
|
||||
DiscProtocolError = 0x02
|
||||
DiscUselessPeer = 0x03
|
||||
DiscTooManyPeers = 0x04
|
||||
DiscAlreadyConnected = 0x05
|
||||
DiscIncompatibleVersion = 0x06
|
||||
DiscInvalidIdentity = 0x07
|
||||
DiscQuitting = 0x08
|
||||
DiscUnexpectedIdentity = 0x09
|
||||
DiscSelf = 0x0a
|
||||
DiscReadTimeout = 0x0b
|
||||
DiscSubprotocolError = 0x10
|
||||
DiscRequested DiscReason = iota
|
||||
DiscNetworkError
|
||||
DiscProtocolError
|
||||
DiscUselessPeer
|
||||
DiscTooManyPeers
|
||||
DiscAlreadyConnected
|
||||
DiscIncompatibleVersion
|
||||
DiscInvalidIdentity
|
||||
DiscQuitting
|
||||
DiscUnexpectedIdentity
|
||||
DiscSelf
|
||||
DiscReadTimeout
|
||||
DiscSubprotocolError
|
||||
)
|
||||
|
||||
var discReasonToString = [DiscSubprotocolError + 1]string{
|
||||
var discReasonToString = [...]string{
|
||||
DiscRequested: "Disconnect requested",
|
||||
DiscNetworkError: "Network error",
|
||||
DiscProtocolError: "Breach of protocol",
|
||||
|
@ -117,7 +115,7 @@ func discReasonForError(err error) DiscReason {
|
|||
switch peerError.Code {
|
||||
case errP2PVersionMismatch:
|
||||
return DiscIncompatibleVersion
|
||||
case errPubkeyMissing, errPubkeyInvalid:
|
||||
case errPubkeyInvalid:
|
||||
return DiscInvalidIdentity
|
||||
case errPubkeyForbidden:
|
||||
return DiscUselessPeer
|
||||
|
@ -125,7 +123,7 @@ func discReasonForError(err error) DiscReason {
|
|||
return DiscProtocolError
|
||||
case errPingTimeout:
|
||||
return DiscReadTimeout
|
||||
case errRead, errWrite, errMisc:
|
||||
case errRead, errWrite:
|
||||
return DiscNetworkError
|
||||
default:
|
||||
return DiscSubprotocolError
|
||||
|
|
307
p2p/peer_test.go
307
p2p/peer_test.go
|
@ -1,15 +1,17 @@
|
|||
package p2p
|
||||
|
||||
import (
|
||||
"bufio"
|
||||
"bytes"
|
||||
"encoding/hex"
|
||||
"io"
|
||||
"fmt"
|
||||
"io/ioutil"
|
||||
"net"
|
||||
"reflect"
|
||||
"sort"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"github.com/ethereum/go-ethereum/p2p/discover"
|
||||
"github.com/ethereum/go-ethereum/rlp"
|
||||
)
|
||||
|
||||
var discard = Protocol{
|
||||
|
@ -28,17 +30,13 @@ var discard = Protocol{
|
|||
},
|
||||
}
|
||||
|
||||
func testPeer(protos []Protocol) (net.Conn, *Peer, <-chan error) {
|
||||
func testPeer(noHandshake bool, protos []Protocol) (*frameRW, *Peer, <-chan DiscReason) {
|
||||
conn1, conn2 := net.Pipe()
|
||||
peer := newPeer(conn1, protos, nil)
|
||||
peer.ourID = &peerId{}
|
||||
peer.pubkeyHook = func(*peerAddr) error { return nil }
|
||||
errc := make(chan error, 1)
|
||||
go func() {
|
||||
_, err := peer.loop()
|
||||
errc <- err
|
||||
}()
|
||||
return conn2, peer, errc
|
||||
peer := newPeer(conn1, protos, "name", &discover.NodeID{}, &discover.NodeID{})
|
||||
peer.noHandshake = noHandshake
|
||||
errc := make(chan DiscReason, 1)
|
||||
go func() { errc <- peer.run() }()
|
||||
return newFrameRW(conn2, msgWriteTimeout), peer, errc
|
||||
}
|
||||
|
||||
func TestPeerProtoReadMsg(t *testing.T) {
|
||||
|
@ -49,31 +47,28 @@ func TestPeerProtoReadMsg(t *testing.T) {
|
|||
Name: "a",
|
||||
Length: 5,
|
||||
Run: func(peer *Peer, rw MsgReadWriter) error {
|
||||
msg, err := rw.ReadMsg()
|
||||
if err != nil {
|
||||
t.Errorf("read error: %v", err)
|
||||
if err := expectMsg(rw, 2, []uint{1}); err != nil {
|
||||
t.Error(err)
|
||||
}
|
||||
if msg.Code != 2 {
|
||||
t.Errorf("incorrect msg code %d relayed to protocol", msg.Code)
|
||||
if err := expectMsg(rw, 3, []uint{2}); err != nil {
|
||||
t.Error(err)
|
||||
}
|
||||
data, err := ioutil.ReadAll(msg.Payload)
|
||||
if err != nil {
|
||||
t.Errorf("payload read error: %v", err)
|
||||
}
|
||||
expdata, _ := hex.DecodeString("0183303030")
|
||||
if !bytes.Equal(expdata, data) {
|
||||
t.Errorf("incorrect msg data %x", data)
|
||||
if err := expectMsg(rw, 4, []uint{3}); err != nil {
|
||||
t.Error(err)
|
||||
}
|
||||
close(done)
|
||||
return nil
|
||||
},
|
||||
}
|
||||
|
||||
net, peer, errc := testPeer([]Protocol{proto})
|
||||
defer net.Close()
|
||||
rw, peer, errc := testPeer(true, []Protocol{proto})
|
||||
defer rw.Close()
|
||||
peer.startSubprotocols([]Cap{proto.cap()})
|
||||
|
||||
writeMsg(net, NewMsg(18, 1, "000"))
|
||||
EncodeMsg(rw, baseProtocolLength+2, 1)
|
||||
EncodeMsg(rw, baseProtocolLength+3, 2)
|
||||
EncodeMsg(rw, baseProtocolLength+4, 3)
|
||||
|
||||
select {
|
||||
case <-done:
|
||||
case err := <-errc:
|
||||
|
@ -105,11 +100,11 @@ func TestPeerProtoReadLargeMsg(t *testing.T) {
|
|||
},
|
||||
}
|
||||
|
||||
net, peer, errc := testPeer([]Protocol{proto})
|
||||
defer net.Close()
|
||||
rw, peer, errc := testPeer(true, []Protocol{proto})
|
||||
defer rw.Close()
|
||||
peer.startSubprotocols([]Cap{proto.cap()})
|
||||
|
||||
writeMsg(net, NewMsg(18, make([]byte, msgsize)))
|
||||
EncodeMsg(rw, 18, make([]byte, msgsize))
|
||||
select {
|
||||
case <-done:
|
||||
case err := <-errc:
|
||||
|
@ -135,32 +130,20 @@ func TestPeerProtoEncodeMsg(t *testing.T) {
|
|||
return nil
|
||||
},
|
||||
}
|
||||
net, peer, _ := testPeer([]Protocol{proto})
|
||||
defer net.Close()
|
||||
rw, peer, _ := testPeer(true, []Protocol{proto})
|
||||
defer rw.Close()
|
||||
peer.startSubprotocols([]Cap{proto.cap()})
|
||||
|
||||
bufr := bufio.NewReader(net)
|
||||
msg, err := readMsg(bufr)
|
||||
if err != nil {
|
||||
t.Errorf("read error: %v", err)
|
||||
}
|
||||
if msg.Code != 17 {
|
||||
t.Errorf("incorrect message code: got %d, expected %d", msg.Code, 17)
|
||||
}
|
||||
var data []string
|
||||
if err := msg.Decode(&data); err != nil {
|
||||
t.Errorf("payload decode error: %v", err)
|
||||
}
|
||||
if !reflect.DeepEqual(data, []string{"foo", "bar"}) {
|
||||
t.Errorf("payload RLP mismatch, got %#v, want %#v", data, []string{"foo", "bar"})
|
||||
if err := expectMsg(rw, 17, []string{"foo", "bar"}); err != nil {
|
||||
t.Error(err)
|
||||
}
|
||||
}
|
||||
|
||||
func TestPeerWrite(t *testing.T) {
|
||||
func TestPeerWriteForBroadcast(t *testing.T) {
|
||||
defer testlog(t).detach()
|
||||
|
||||
net, peer, peerErr := testPeer([]Protocol{discard})
|
||||
defer net.Close()
|
||||
rw, peer, peerErr := testPeer(true, []Protocol{discard})
|
||||
defer rw.Close()
|
||||
peer.startSubprotocols([]Cap{discard.cap()})
|
||||
|
||||
// test write errors
|
||||
|
@ -176,18 +159,13 @@ func TestPeerWrite(t *testing.T) {
|
|||
// setup for reading the message on the other end
|
||||
read := make(chan struct{})
|
||||
go func() {
|
||||
bufr := bufio.NewReader(net)
|
||||
msg, err := readMsg(bufr)
|
||||
if err != nil {
|
||||
t.Errorf("read error: %v", err)
|
||||
} else if msg.Code != 16 {
|
||||
t.Errorf("wrong code, got %d, expected %d", msg.Code, 16)
|
||||
if err := expectMsg(rw, 16, nil); err != nil {
|
||||
t.Error()
|
||||
}
|
||||
msg.Discard()
|
||||
close(read)
|
||||
}()
|
||||
|
||||
// test succcessful write
|
||||
// test successful write
|
||||
if err := peer.writeProtoMsg("discard", NewMsg(0)); err != nil {
|
||||
t.Errorf("expect no error for known protocol: %v", err)
|
||||
}
|
||||
|
@ -198,104 +176,153 @@ func TestPeerWrite(t *testing.T) {
|
|||
}
|
||||
}
|
||||
|
||||
func TestPeerActivity(t *testing.T) {
|
||||
// shorten inactivityTimeout while this test is running
|
||||
oldT := inactivityTimeout
|
||||
defer func() { inactivityTimeout = oldT }()
|
||||
inactivityTimeout = 20 * time.Millisecond
|
||||
func TestPeerPing(t *testing.T) {
|
||||
defer testlog(t).detach()
|
||||
|
||||
net, peer, peerErr := testPeer([]Protocol{discard})
|
||||
defer net.Close()
|
||||
peer.startSubprotocols([]Cap{discard.cap()})
|
||||
rw, _, _ := testPeer(true, nil)
|
||||
defer rw.Close()
|
||||
if err := EncodeMsg(rw, pingMsg); err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
if err := expectMsg(rw, pongMsg, nil); err != nil {
|
||||
t.Error(err)
|
||||
}
|
||||
}
|
||||
|
||||
sub := peer.activity.Subscribe(time.Time{})
|
||||
defer sub.Unsubscribe()
|
||||
func TestPeerDisconnect(t *testing.T) {
|
||||
defer testlog(t).detach()
|
||||
|
||||
for i := 0; i < 6; i++ {
|
||||
writeMsg(net, NewMsg(16))
|
||||
rw, _, disc := testPeer(true, nil)
|
||||
defer rw.Close()
|
||||
if err := EncodeMsg(rw, discMsg, DiscQuitting); err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
if err := expectMsg(rw, discMsg, []interface{}{DiscRequested}); err != nil {
|
||||
t.Error(err)
|
||||
}
|
||||
rw.Close() // make test end faster
|
||||
if reason := <-disc; reason != DiscRequested {
|
||||
t.Errorf("run returned wrong reason: got %v, want %v", reason, DiscRequested)
|
||||
}
|
||||
}
|
||||
|
||||
func TestPeerHandshake(t *testing.T) {
|
||||
defer testlog(t).detach()
|
||||
|
||||
// remote has two matching protocols: a and c
|
||||
remote := NewPeer(randomID(), "", []Cap{{"a", 1}, {"b", 999}, {"c", 3}})
|
||||
remoteID := randomID()
|
||||
remote.ourID = &remoteID
|
||||
remote.ourName = "remote peer"
|
||||
|
||||
start := make(chan string)
|
||||
stop := make(chan struct{})
|
||||
run := func(p *Peer, rw MsgReadWriter) error {
|
||||
name := rw.(*proto).name
|
||||
if name != "a" && name != "c" {
|
||||
t.Errorf("protocol %q should not be started", name)
|
||||
} else {
|
||||
start <- name
|
||||
}
|
||||
<-stop
|
||||
return nil
|
||||
}
|
||||
protocols := []Protocol{
|
||||
{Name: "a", Version: 1, Length: 1, Run: run},
|
||||
{Name: "b", Version: 2, Length: 1, Run: run},
|
||||
{Name: "c", Version: 3, Length: 1, Run: run},
|
||||
{Name: "d", Version: 4, Length: 1, Run: run},
|
||||
}
|
||||
rw, p, disc := testPeer(false, protocols)
|
||||
p.remoteID = remote.ourID
|
||||
defer rw.Close()
|
||||
|
||||
// run the handshake
|
||||
remoteProtocols := []Protocol{protocols[0], protocols[2]}
|
||||
if err := writeProtocolHandshake(rw, "remote peer", remoteID, remoteProtocols); err != nil {
|
||||
t.Fatalf("handshake write error: %v", err)
|
||||
}
|
||||
if err := readProtocolHandshake(remote, rw); err != nil {
|
||||
t.Fatalf("handshake read error: %v", err)
|
||||
}
|
||||
|
||||
// check that all protocols have been started
|
||||
var started []string
|
||||
for i := 0; i < 2; i++ {
|
||||
select {
|
||||
case <-sub.Chan():
|
||||
case <-time.After(inactivityTimeout / 2):
|
||||
t.Fatal("no event within ", inactivityTimeout/2)
|
||||
case err := <-peerErr:
|
||||
t.Fatal("peer error", err)
|
||||
case name := <-start:
|
||||
started = append(started, name)
|
||||
case <-time.After(100 * time.Millisecond):
|
||||
}
|
||||
}
|
||||
|
||||
select {
|
||||
case <-time.After(inactivityTimeout * 2):
|
||||
case <-sub.Chan():
|
||||
t.Fatal("got activity event while connection was inactive")
|
||||
case err := <-peerErr:
|
||||
t.Fatal("peer error", err)
|
||||
sort.Strings(started)
|
||||
if !reflect.DeepEqual(started, []string{"a", "c"}) {
|
||||
t.Errorf("wrong protocols started: %v", started)
|
||||
}
|
||||
|
||||
// check that metadata has been set
|
||||
if p.ID() != remoteID {
|
||||
t.Errorf("peer has wrong node ID: got %v, want %v", p.ID(), remoteID)
|
||||
}
|
||||
if p.Name() != remote.ourName {
|
||||
t.Errorf("peer has wrong node name: got %q, want %q", p.Name(), remote.ourName)
|
||||
}
|
||||
|
||||
close(stop)
|
||||
expectMsg(rw, discMsg, nil)
|
||||
t.Logf("disc reason: %v", <-disc)
|
||||
}
|
||||
|
||||
func TestNewPeer(t *testing.T) {
|
||||
name := "nodename"
|
||||
caps := []Cap{{"foo", 2}, {"bar", 3}}
|
||||
id := &peerId{}
|
||||
p := NewPeer(id, caps)
|
||||
id := randomID()
|
||||
p := NewPeer(id, name, caps)
|
||||
if p.ID() != id {
|
||||
t.Errorf("ID mismatch: got %v, expected %v", p.ID(), id)
|
||||
}
|
||||
if p.Name() != name {
|
||||
t.Errorf("Name mismatch: got %v, expected %v", p.Name(), name)
|
||||
}
|
||||
if !reflect.DeepEqual(p.Caps(), caps) {
|
||||
t.Errorf("Caps mismatch: got %v, expected %v", p.Caps(), caps)
|
||||
}
|
||||
if p.Identity() != id {
|
||||
t.Errorf("Identity mismatch: got %v, expected %v", p.Identity(), id)
|
||||
}
|
||||
// Should not hang.
|
||||
p.Disconnect(DiscAlreadyConnected)
|
||||
|
||||
p.Disconnect(DiscAlreadyConnected) // Should not hang
|
||||
}
|
||||
|
||||
func TestEOFSignal(t *testing.T) {
|
||||
rb := make([]byte, 10)
|
||||
// expectMsg reads a message from r and verifies that its
|
||||
// code and encoded RLP content match the provided values.
|
||||
// If content is nil, the payload is discarded and not verified.
|
||||
func expectMsg(r MsgReader, code uint64, content interface{}) error {
|
||||
msg, err := r.ReadMsg()
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
if msg.Code != code {
|
||||
return fmt.Errorf("message code mismatch: got %d, expected %d", msg.Code, code)
|
||||
}
|
||||
if content == nil {
|
||||
return msg.Discard()
|
||||
} else {
|
||||
contentEnc, err := rlp.EncodeToBytes(content)
|
||||
if err != nil {
|
||||
panic("content encode error: " + err.Error())
|
||||
}
|
||||
// skip over list header in encoded value. this is temporary.
|
||||
contentEncR := bytes.NewReader(contentEnc)
|
||||
if k, _, err := rlp.NewStream(contentEncR).Kind(); k != rlp.List || err != nil {
|
||||
panic("content must encode as RLP list")
|
||||
}
|
||||
contentEnc = contentEnc[len(contentEnc)-contentEncR.Len():]
|
||||
|
||||
// empty reader
|
||||
eof := make(chan struct{}, 1)
|
||||
sig := &eofSignal{new(bytes.Buffer), 0, eof}
|
||||
if n, err := sig.Read(rb); n != 0 || err != io.EOF {
|
||||
t.Errorf("Read returned unexpected values: (%v, %v)", n, err)
|
||||
}
|
||||
select {
|
||||
case <-eof:
|
||||
default:
|
||||
t.Error("EOF chan not signaled")
|
||||
}
|
||||
|
||||
// count before error
|
||||
eof = make(chan struct{}, 1)
|
||||
sig = &eofSignal{bytes.NewBufferString("aaaaaaaa"), 4, eof}
|
||||
if n, err := sig.Read(rb); n != 8 || err != nil {
|
||||
t.Errorf("Read returned unexpected values: (%v, %v)", n, err)
|
||||
}
|
||||
select {
|
||||
case <-eof:
|
||||
default:
|
||||
t.Error("EOF chan not signaled")
|
||||
}
|
||||
|
||||
// error before count
|
||||
eof = make(chan struct{}, 1)
|
||||
sig = &eofSignal{bytes.NewBufferString("aaaa"), 999, eof}
|
||||
if n, err := sig.Read(rb); n != 4 || err != nil {
|
||||
t.Errorf("Read returned unexpected values: (%v, %v)", n, err)
|
||||
}
|
||||
if n, err := sig.Read(rb); n != 0 || err != io.EOF {
|
||||
t.Errorf("Read returned unexpected values: (%v, %v)", n, err)
|
||||
}
|
||||
select {
|
||||
case <-eof:
|
||||
default:
|
||||
t.Error("EOF chan not signaled")
|
||||
}
|
||||
|
||||
// no signal if neither occurs
|
||||
eof = make(chan struct{}, 1)
|
||||
sig = &eofSignal{bytes.NewBufferString("aaaaaaaaaaaaaaaaaaaaa"), 999, eof}
|
||||
if n, err := sig.Read(rb); n != 10 || err != nil {
|
||||
t.Errorf("Read returned unexpected values: (%v, %v)", n, err)
|
||||
}
|
||||
select {
|
||||
case <-eof:
|
||||
t.Error("unexpected EOF signal")
|
||||
default:
|
||||
actualContent, err := ioutil.ReadAll(msg.Payload)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
if !bytes.Equal(actualContent, contentEnc) {
|
||||
return fmt.Errorf("message payload mismatch:\ngot: %x\nwant: %x", actualContent, contentEnc)
|
||||
}
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
|
244
p2p/protocol.go
244
p2p/protocol.go
|
@ -1,10 +1,5 @@
|
|||
package p2p
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"time"
|
||||
)
|
||||
|
||||
// Protocol represents a P2P subprotocol implementation.
|
||||
type Protocol struct {
|
||||
// Name should contain the official protocol name,
|
||||
|
@ -32,38 +27,6 @@ func (p Protocol) cap() Cap {
|
|||
return Cap{p.Name, p.Version}
|
||||
}
|
||||
|
||||
const (
|
||||
baseProtocolVersion = 2
|
||||
baseProtocolLength = uint64(16)
|
||||
baseProtocolMaxMsgSize = 10 * 1024 * 1024
|
||||
)
|
||||
|
||||
const (
|
||||
// devp2p message codes
|
||||
handshakeMsg = 0x00
|
||||
discMsg = 0x01
|
||||
pingMsg = 0x02
|
||||
pongMsg = 0x03
|
||||
getPeersMsg = 0x04
|
||||
peersMsg = 0x05
|
||||
)
|
||||
|
||||
// handshake is the structure of a handshake list.
|
||||
type handshake struct {
|
||||
Version uint64
|
||||
ID string
|
||||
Caps []Cap
|
||||
ListenPort uint64
|
||||
NodeID []byte
|
||||
}
|
||||
|
||||
func (h *handshake) String() string {
|
||||
return h.ID
|
||||
}
|
||||
func (h *handshake) Pubkey() []byte {
|
||||
return h.NodeID
|
||||
}
|
||||
|
||||
// Cap is the structure of a peer capability.
|
||||
type Cap struct {
|
||||
Name string
|
||||
|
@ -79,210 +42,3 @@ type capsByName []Cap
|
|||
func (cs capsByName) Len() int { return len(cs) }
|
||||
func (cs capsByName) Less(i, j int) bool { return cs[i].Name < cs[j].Name }
|
||||
func (cs capsByName) Swap(i, j int) { cs[i], cs[j] = cs[j], cs[i] }
|
||||
|
||||
type baseProtocol struct {
|
||||
rw MsgReadWriter
|
||||
peer *Peer
|
||||
}
|
||||
|
||||
func runBaseProtocol(peer *Peer, rw MsgReadWriter) error {
|
||||
bp := &baseProtocol{rw, peer}
|
||||
errc := make(chan error, 1)
|
||||
go func() { errc <- rw.WriteMsg(bp.handshakeMsg()) }()
|
||||
if err := bp.readHandshake(); err != nil {
|
||||
return err
|
||||
}
|
||||
// handle write error
|
||||
if err := <-errc; err != nil {
|
||||
return err
|
||||
}
|
||||
// run main loop
|
||||
go func() {
|
||||
for {
|
||||
if err := bp.handle(rw); err != nil {
|
||||
errc <- err
|
||||
break
|
||||
}
|
||||
}
|
||||
}()
|
||||
return bp.loop(errc)
|
||||
}
|
||||
|
||||
var pingTimeout = 2 * time.Second
|
||||
|
||||
func (bp *baseProtocol) loop(quit <-chan error) error {
|
||||
ping := time.NewTimer(pingTimeout)
|
||||
activity := bp.peer.activity.Subscribe(time.Time{})
|
||||
lastActive := time.Time{}
|
||||
defer ping.Stop()
|
||||
defer activity.Unsubscribe()
|
||||
|
||||
getPeersTick := time.NewTicker(10 * time.Second)
|
||||
defer getPeersTick.Stop()
|
||||
err := EncodeMsg(bp.rw, getPeersMsg)
|
||||
|
||||
for err == nil {
|
||||
select {
|
||||
case err = <-quit:
|
||||
return err
|
||||
case <-getPeersTick.C:
|
||||
err = EncodeMsg(bp.rw, getPeersMsg)
|
||||
case event := <-activity.Chan():
|
||||
ping.Reset(pingTimeout)
|
||||
lastActive = event.(time.Time)
|
||||
case t := <-ping.C:
|
||||
if lastActive.Add(pingTimeout * 2).Before(t) {
|
||||
err = newPeerError(errPingTimeout, "")
|
||||
} else if lastActive.Add(pingTimeout).Before(t) {
|
||||
err = EncodeMsg(bp.rw, pingMsg)
|
||||
}
|
||||
}
|
||||
}
|
||||
return err
|
||||
}
|
||||
|
||||
func (bp *baseProtocol) handle(rw MsgReadWriter) error {
|
||||
msg, err := rw.ReadMsg()
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
if msg.Size > baseProtocolMaxMsgSize {
|
||||
return newPeerError(errMisc, "message too big")
|
||||
}
|
||||
// make sure that the payload has been fully consumed
|
||||
defer msg.Discard()
|
||||
|
||||
switch msg.Code {
|
||||
case handshakeMsg:
|
||||
return newPeerError(errProtocolBreach, "extra handshake received")
|
||||
|
||||
case discMsg:
|
||||
var reason [1]DiscReason
|
||||
if err := msg.Decode(&reason); err != nil {
|
||||
return err
|
||||
}
|
||||
return discRequestedError(reason[0])
|
||||
|
||||
case pingMsg:
|
||||
return EncodeMsg(bp.rw, pongMsg)
|
||||
|
||||
case pongMsg:
|
||||
|
||||
case getPeersMsg:
|
||||
peers := bp.peerList()
|
||||
// this is dangerous. the spec says that we should _delay_
|
||||
// sending the response if no new information is available.
|
||||
// this means that would need to send a response later when
|
||||
// new peers become available.
|
||||
//
|
||||
// TODO: add event mechanism to notify baseProtocol for new peers
|
||||
if len(peers) > 0 {
|
||||
return EncodeMsg(bp.rw, peersMsg, peers...)
|
||||
}
|
||||
|
||||
case peersMsg:
|
||||
var peers []*peerAddr
|
||||
if err := msg.Decode(&peers); err != nil {
|
||||
return err
|
||||
}
|
||||
for _, addr := range peers {
|
||||
bp.peer.Debugf("received peer suggestion: %v", addr)
|
||||
bp.peer.newPeerAddr <- addr
|
||||
}
|
||||
|
||||
default:
|
||||
return newPeerError(errInvalidMsgCode, "unknown message code %v", msg.Code)
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func (bp *baseProtocol) readHandshake() error {
|
||||
// read and handle remote handshake
|
||||
msg, err := bp.rw.ReadMsg()
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
if msg.Code != handshakeMsg {
|
||||
return newPeerError(errProtocolBreach, "first message must be handshake, got %x", msg.Code)
|
||||
}
|
||||
if msg.Size > baseProtocolMaxMsgSize {
|
||||
return newPeerError(errMisc, "message too big")
|
||||
}
|
||||
var hs handshake
|
||||
if err := msg.Decode(&hs); err != nil {
|
||||
return err
|
||||
}
|
||||
// validate handshake info
|
||||
if hs.Version != baseProtocolVersion {
|
||||
return newPeerError(errP2PVersionMismatch, "Require protocol %d, received %d\n",
|
||||
baseProtocolVersion, hs.Version)
|
||||
}
|
||||
if len(hs.NodeID) == 0 {
|
||||
return newPeerError(errPubkeyMissing, "")
|
||||
}
|
||||
if len(hs.NodeID) != 64 {
|
||||
return newPeerError(errPubkeyInvalid, "require 512 bit, got %v", len(hs.NodeID)*8)
|
||||
}
|
||||
if da := bp.peer.dialAddr; da != nil {
|
||||
// verify that the peer we wanted to connect to
|
||||
// actually holds the target public key.
|
||||
if da.Pubkey != nil && !bytes.Equal(da.Pubkey, hs.NodeID) {
|
||||
return newPeerError(errPubkeyForbidden, "dial address pubkey mismatch")
|
||||
}
|
||||
}
|
||||
pa := newPeerAddr(bp.peer.conn.RemoteAddr(), hs.NodeID)
|
||||
if err := bp.peer.pubkeyHook(pa); err != nil {
|
||||
return newPeerError(errPubkeyForbidden, "%v", err)
|
||||
}
|
||||
// TODO: remove Caps with empty name
|
||||
var addr *peerAddr
|
||||
if hs.ListenPort != 0 {
|
||||
addr = newPeerAddr(bp.peer.conn.RemoteAddr(), hs.NodeID)
|
||||
addr.Port = hs.ListenPort
|
||||
}
|
||||
bp.peer.setHandshakeInfo(&hs, addr, hs.Caps)
|
||||
bp.peer.startSubprotocols(hs.Caps)
|
||||
return nil
|
||||
}
|
||||
|
||||
func (bp *baseProtocol) handshakeMsg() Msg {
|
||||
var (
|
||||
port uint64
|
||||
caps []interface{}
|
||||
)
|
||||
if bp.peer.ourListenAddr != nil {
|
||||
port = bp.peer.ourListenAddr.Port
|
||||
}
|
||||
for _, proto := range bp.peer.protocols {
|
||||
caps = append(caps, proto.cap())
|
||||
}
|
||||
return NewMsg(handshakeMsg,
|
||||
baseProtocolVersion,
|
||||
bp.peer.ourID.String(),
|
||||
caps,
|
||||
port,
|
||||
bp.peer.ourID.Pubkey()[1:],
|
||||
)
|
||||
}
|
||||
|
||||
func (bp *baseProtocol) peerList() []interface{} {
|
||||
peers := bp.peer.otherPeers()
|
||||
ds := make([]interface{}, 0, len(peers))
|
||||
for _, p := range peers {
|
||||
p.infolock.Lock()
|
||||
addr := p.listenAddr
|
||||
p.infolock.Unlock()
|
||||
// filter out this peer and peers that are not listening or
|
||||
// have not completed the handshake.
|
||||
// TODO: track previously sent peers and exclude them as well.
|
||||
if p == bp.peer || addr == nil {
|
||||
continue
|
||||
}
|
||||
ds = append(ds, addr)
|
||||
}
|
||||
ourAddr := bp.peer.ourListenAddr
|
||||
if ourAddr != nil && !ourAddr.IP.IsLoopback() && !ourAddr.IP.IsUnspecified() {
|
||||
ds = append(ds, ourAddr)
|
||||
}
|
||||
return ds
|
||||
}
|
||||
|
|
|
@ -1,158 +0,0 @@
|
|||
package p2p
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"net"
|
||||
"reflect"
|
||||
"sync"
|
||||
"testing"
|
||||
|
||||
"github.com/ethereum/go-ethereum/crypto"
|
||||
)
|
||||
|
||||
type peerId struct {
|
||||
pubkey []byte
|
||||
}
|
||||
|
||||
func (self *peerId) String() string {
|
||||
return fmt.Sprintf("test peer %x", self.Pubkey()[:4])
|
||||
}
|
||||
|
||||
func (self *peerId) Pubkey() (pubkey []byte) {
|
||||
pubkey = self.pubkey
|
||||
if len(pubkey) == 0 {
|
||||
pubkey = crypto.GenerateNewKeyPair().PublicKey
|
||||
self.pubkey = pubkey
|
||||
}
|
||||
return
|
||||
}
|
||||
|
||||
func newTestPeer() (peer *Peer) {
|
||||
peer = NewPeer(&peerId{}, []Cap{})
|
||||
peer.pubkeyHook = func(*peerAddr) error { return nil }
|
||||
peer.ourID = &peerId{}
|
||||
peer.listenAddr = &peerAddr{}
|
||||
peer.otherPeers = func() []*Peer { return nil }
|
||||
return
|
||||
}
|
||||
|
||||
func TestBaseProtocolPeers(t *testing.T) {
|
||||
peerList := []*peerAddr{
|
||||
{IP: net.ParseIP("1.2.3.4"), Port: 2222, Pubkey: []byte{}},
|
||||
{IP: net.ParseIP("5.6.7.8"), Port: 3333, Pubkey: []byte{}},
|
||||
}
|
||||
listenAddr := &peerAddr{IP: net.ParseIP("1.3.5.7"), Port: 1111, Pubkey: []byte{}}
|
||||
rw1, rw2 := MsgPipe()
|
||||
defer rw1.Close()
|
||||
wg := new(sync.WaitGroup)
|
||||
|
||||
// run matcher, close pipe when addresses have arrived
|
||||
numPeers := len(peerList) + 1
|
||||
addrChan := make(chan *peerAddr)
|
||||
wg.Add(1)
|
||||
go func() {
|
||||
i := 0
|
||||
for got := range addrChan {
|
||||
var want *peerAddr
|
||||
switch {
|
||||
case i < len(peerList):
|
||||
want = peerList[i]
|
||||
case i == len(peerList):
|
||||
want = listenAddr // listenAddr should be the last thing sent
|
||||
}
|
||||
t.Logf("got peer %d/%d: %v", i+1, numPeers, got)
|
||||
if !reflect.DeepEqual(want, got) {
|
||||
t.Errorf("mismatch: got %+v, want %+v", got, want)
|
||||
}
|
||||
i++
|
||||
if i == numPeers {
|
||||
break
|
||||
}
|
||||
}
|
||||
if i != numPeers {
|
||||
t.Errorf("wrong number of peers received: got %d, want %d", i, numPeers)
|
||||
}
|
||||
rw1.Close()
|
||||
wg.Done()
|
||||
}()
|
||||
|
||||
// run first peer (in background)
|
||||
peer1 := newTestPeer()
|
||||
peer1.ourListenAddr = listenAddr
|
||||
peer1.otherPeers = func() []*Peer {
|
||||
pl := make([]*Peer, len(peerList))
|
||||
for i, addr := range peerList {
|
||||
pl[i] = &Peer{listenAddr: addr}
|
||||
}
|
||||
return pl
|
||||
}
|
||||
wg.Add(1)
|
||||
go func() {
|
||||
runBaseProtocol(peer1, rw1)
|
||||
wg.Done()
|
||||
}()
|
||||
|
||||
// run second peer
|
||||
peer2 := newTestPeer()
|
||||
peer2.newPeerAddr = addrChan // feed peer suggestions into matcher
|
||||
if err := runBaseProtocol(peer2, rw2); err != ErrPipeClosed {
|
||||
t.Errorf("peer2 terminated with unexpected error: %v", err)
|
||||
}
|
||||
|
||||
// terminate matcher
|
||||
close(addrChan)
|
||||
wg.Wait()
|
||||
}
|
||||
|
||||
func TestBaseProtocolDisconnect(t *testing.T) {
|
||||
peer := NewPeer(&peerId{}, nil)
|
||||
peer.ourID = &peerId{}
|
||||
peer.pubkeyHook = func(*peerAddr) error { return nil }
|
||||
|
||||
rw1, rw2 := MsgPipe()
|
||||
done := make(chan struct{})
|
||||
go func() {
|
||||
if err := expectMsg(rw2, handshakeMsg); err != nil {
|
||||
t.Error(err)
|
||||
}
|
||||
err := EncodeMsg(rw2, handshakeMsg,
|
||||
baseProtocolVersion,
|
||||
"",
|
||||
[]interface{}{},
|
||||
0,
|
||||
make([]byte, 64),
|
||||
)
|
||||
if err != nil {
|
||||
t.Error(err)
|
||||
}
|
||||
if err := expectMsg(rw2, getPeersMsg); err != nil {
|
||||
t.Error(err)
|
||||
}
|
||||
if err := EncodeMsg(rw2, discMsg, DiscQuitting); err != nil {
|
||||
t.Error(err)
|
||||
}
|
||||
|
||||
close(done)
|
||||
}()
|
||||
|
||||
if err := runBaseProtocol(peer, rw1); err == nil {
|
||||
t.Errorf("base protocol returned without error")
|
||||
} else if reason, ok := err.(discRequestedError); !ok || reason != DiscQuitting {
|
||||
t.Errorf("base protocol returned wrong error: %v", err)
|
||||
}
|
||||
<-done
|
||||
}
|
||||
|
||||
func expectMsg(r MsgReader, code uint64) error {
|
||||
msg, err := r.ReadMsg()
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
if err := msg.Discard(); err != nil {
|
||||
return err
|
||||
}
|
||||
if msg.Code != code {
|
||||
return fmt.Errorf("wrong message code: got %d, expected %d", msg.Code, code)
|
||||
}
|
||||
return nil
|
||||
}
|
420
p2p/server.go
420
p2p/server.go
|
@ -2,37 +2,56 @@ package p2p
|
|||
|
||||
import (
|
||||
"bytes"
|
||||
"crypto/ecdsa"
|
||||
"errors"
|
||||
"fmt"
|
||||
"io"
|
||||
"net"
|
||||
"runtime"
|
||||
"sync"
|
||||
"time"
|
||||
|
||||
"github.com/ethereum/go-ethereum/logger"
|
||||
"github.com/ethereum/go-ethereum/p2p/discover"
|
||||
"github.com/ethereum/go-ethereum/p2p/nat"
|
||||
)
|
||||
|
||||
const (
|
||||
outboundAddressPoolSize = 500
|
||||
defaultDialTimeout = 10 * time.Second
|
||||
portMappingUpdateInterval = 15 * time.Minute
|
||||
portMappingTimeout = 20 * time.Minute
|
||||
handshakeTimeout = 5 * time.Second
|
||||
defaultDialTimeout = 10 * time.Second
|
||||
refreshPeersInterval = 30 * time.Second
|
||||
)
|
||||
|
||||
var srvlog = logger.NewLogger("P2P Server")
|
||||
|
||||
// MakeName creates a node name that follows the ethereum convention
|
||||
// for such names. It adds the operation system name and Go runtime version
|
||||
// the name.
|
||||
func MakeName(name, version string) string {
|
||||
return fmt.Sprintf("%s/v%s/%s/%s", name, version, runtime.GOOS, runtime.Version())
|
||||
}
|
||||
|
||||
// Server manages all peer connections.
|
||||
//
|
||||
// The fields of Server are used as configuration parameters.
|
||||
// You should set them before starting the Server. Fields may not be
|
||||
// modified while the server is running.
|
||||
type Server struct {
|
||||
// This field must be set to a valid client identity.
|
||||
Identity ClientIdentity
|
||||
// This field must be set to a valid secp256k1 private key.
|
||||
PrivateKey *ecdsa.PrivateKey
|
||||
|
||||
// MaxPeers is the maximum number of peers that can be
|
||||
// connected. It must be greater than zero.
|
||||
MaxPeers int
|
||||
|
||||
// Name sets the node name of this server.
|
||||
// Use MakeName to create a name that follows existing conventions.
|
||||
Name string
|
||||
|
||||
// Bootstrap nodes are used to establish connectivity
|
||||
// with the rest of the network.
|
||||
BootstrapNodes []*discover.Node
|
||||
|
||||
// Protocols should contain the protocols supported
|
||||
// by the server. Matching protocols are launched for
|
||||
// each peer.
|
||||
|
@ -53,7 +72,7 @@ type Server struct {
|
|||
// If set to a non-nil value, the given NAT port mapper
|
||||
// is used to make the listening port available to the
|
||||
// Internet.
|
||||
NAT NAT
|
||||
NAT nat.Interface
|
||||
|
||||
// If Dialer is set to a non-nil value, the given Dialer
|
||||
// is used to dial outbound peer connections.
|
||||
|
@ -62,35 +81,26 @@ type Server struct {
|
|||
// If NoDial is true, the server will not dial any peers.
|
||||
NoDial bool
|
||||
|
||||
// Hook for testing. This is useful because we can inhibit
|
||||
// Hooks for testing. These are useful because we can inhibit
|
||||
// the whole protocol stack.
|
||||
newPeerFunc peerFunc
|
||||
handshakeFunc
|
||||
newPeerHook
|
||||
|
||||
lock sync.RWMutex
|
||||
running bool
|
||||
listener net.Listener
|
||||
laddr *net.TCPAddr // real listen addr
|
||||
peers []*Peer
|
||||
peerSlots chan int
|
||||
peerCount int
|
||||
lock sync.RWMutex
|
||||
running bool
|
||||
listener net.Listener
|
||||
peers map[discover.NodeID]*Peer
|
||||
|
||||
quit chan struct{}
|
||||
wg sync.WaitGroup
|
||||
peerConnect chan *peerAddr
|
||||
peerDisconnect chan *Peer
|
||||
ntab *discover.Table
|
||||
|
||||
quit chan struct{}
|
||||
loopWG sync.WaitGroup // {dial,listen,nat}Loop
|
||||
peerWG sync.WaitGroup // active peer goroutines
|
||||
peerConnect chan *discover.Node
|
||||
}
|
||||
|
||||
// NAT is implemented by NAT traversal methods.
|
||||
type NAT interface {
|
||||
GetExternalAddress() (net.IP, error)
|
||||
AddPortMapping(protocol string, extport, intport int, name string, lifetime time.Duration) error
|
||||
DeletePortMapping(protocol string, extport, intport int) error
|
||||
|
||||
// Should return name of the method.
|
||||
String() string
|
||||
}
|
||||
|
||||
type peerFunc func(srv *Server, c net.Conn, dialAddr *peerAddr) *Peer
|
||||
type handshakeFunc func(io.ReadWriter, *ecdsa.PrivateKey, *discover.Node) (discover.NodeID, []byte, error)
|
||||
type newPeerHook func(*Peer)
|
||||
|
||||
// Peers returns all connected peers.
|
||||
func (srv *Server) Peers() (peers []*Peer) {
|
||||
|
@ -107,18 +117,15 @@ func (srv *Server) Peers() (peers []*Peer) {
|
|||
// PeerCount returns the number of connected peers.
|
||||
func (srv *Server) PeerCount() int {
|
||||
srv.lock.RLock()
|
||||
defer srv.lock.RUnlock()
|
||||
return srv.peerCount
|
||||
n := len(srv.peers)
|
||||
srv.lock.RUnlock()
|
||||
return n
|
||||
}
|
||||
|
||||
// SuggestPeer injects an address into the outbound address pool.
|
||||
func (srv *Server) SuggestPeer(ip net.IP, port int, nodeID []byte) {
|
||||
addr := &peerAddr{ip, uint64(port), nodeID}
|
||||
select {
|
||||
case srv.peerConnect <- addr:
|
||||
default: // don't block
|
||||
srvlog.Warnf("peer suggestion %v ignored", addr)
|
||||
}
|
||||
// SuggestPeer creates a connection to the given Node if it
|
||||
// is not already connected.
|
||||
func (srv *Server) SuggestPeer(n *discover.Node) {
|
||||
srv.peerConnect <- n
|
||||
}
|
||||
|
||||
// Broadcast sends an RLP-encoded message to all connected peers.
|
||||
|
@ -152,47 +159,46 @@ func (srv *Server) Start() (err error) {
|
|||
}
|
||||
srvlog.Infoln("Starting Server")
|
||||
|
||||
// initialize fields
|
||||
if srv.Identity == nil {
|
||||
return fmt.Errorf("Server.Identity must be set to a non-nil identity")
|
||||
// initialize all the fields
|
||||
if srv.PrivateKey == nil {
|
||||
return fmt.Errorf("Server.PrivateKey must be set to a non-nil key")
|
||||
}
|
||||
if srv.MaxPeers <= 0 {
|
||||
return fmt.Errorf("Server.MaxPeers must be > 0")
|
||||
}
|
||||
srv.quit = make(chan struct{})
|
||||
srv.peers = make([]*Peer, srv.MaxPeers)
|
||||
srv.peerSlots = make(chan int, srv.MaxPeers)
|
||||
srv.peerConnect = make(chan *peerAddr, outboundAddressPoolSize)
|
||||
srv.peerDisconnect = make(chan *Peer)
|
||||
if srv.newPeerFunc == nil {
|
||||
srv.newPeerFunc = newServerPeer
|
||||
srv.peers = make(map[discover.NodeID]*Peer)
|
||||
srv.peerConnect = make(chan *discover.Node)
|
||||
|
||||
if srv.handshakeFunc == nil {
|
||||
srv.handshakeFunc = encHandshake
|
||||
}
|
||||
if srv.Blacklist == nil {
|
||||
srv.Blacklist = NewBlacklist()
|
||||
}
|
||||
if srv.Dialer == nil {
|
||||
srv.Dialer = &net.Dialer{Timeout: defaultDialTimeout}
|
||||
}
|
||||
|
||||
if srv.ListenAddr != "" {
|
||||
if err := srv.startListening(); err != nil {
|
||||
return err
|
||||
}
|
||||
}
|
||||
|
||||
// dial stuff
|
||||
dt, err := discover.ListenUDP(srv.PrivateKey, srv.ListenAddr, srv.NAT)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
srv.ntab = dt
|
||||
if srv.Dialer == nil {
|
||||
srv.Dialer = &net.Dialer{Timeout: defaultDialTimeout}
|
||||
}
|
||||
if !srv.NoDial {
|
||||
srv.wg.Add(1)
|
||||
srv.loopWG.Add(1)
|
||||
go srv.dialLoop()
|
||||
}
|
||||
if srv.NoDial && srv.ListenAddr == "" {
|
||||
srvlog.Warnln("I will be kind-of useless, neither dialing nor listening.")
|
||||
}
|
||||
|
||||
// make all slots available
|
||||
for i := range srv.peers {
|
||||
srv.peerSlots <- i
|
||||
}
|
||||
// note: discLoop is not part of WaitGroup
|
||||
go srv.discLoop()
|
||||
srv.running = true
|
||||
return nil
|
||||
}
|
||||
|
@ -202,14 +208,17 @@ func (srv *Server) startListening() error {
|
|||
if err != nil {
|
||||
return err
|
||||
}
|
||||
srv.ListenAddr = listener.Addr().String()
|
||||
srv.laddr = listener.Addr().(*net.TCPAddr)
|
||||
laddr := listener.Addr().(*net.TCPAddr)
|
||||
srv.ListenAddr = laddr.String()
|
||||
srv.listener = listener
|
||||
srv.wg.Add(1)
|
||||
srv.loopWG.Add(1)
|
||||
go srv.listenLoop()
|
||||
if !srv.laddr.IP.IsLoopback() && srv.NAT != nil {
|
||||
srv.wg.Add(1)
|
||||
go srv.natLoop(srv.laddr.Port)
|
||||
if !laddr.IP.IsLoopback() && srv.NAT != nil {
|
||||
srv.loopWG.Add(1)
|
||||
go func() {
|
||||
nat.Map(srv.NAT, srv.quit, "tcp", laddr.Port, laddr.Port, "ethereum p2p")
|
||||
srv.loopWG.Done()
|
||||
}()
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
@ -225,200 +234,171 @@ func (srv *Server) Stop() {
|
|||
srv.running = false
|
||||
srv.lock.Unlock()
|
||||
|
||||
srvlog.Infoln("Stopping server")
|
||||
srvlog.Infoln("Stopping Server")
|
||||
srv.ntab.Close()
|
||||
if srv.listener != nil {
|
||||
// this unblocks listener Accept
|
||||
srv.listener.Close()
|
||||
}
|
||||
close(srv.quit)
|
||||
for _, peer := range srv.Peers() {
|
||||
srv.loopWG.Wait()
|
||||
|
||||
// No new peers can be added at this point because dialLoop and
|
||||
// listenLoop are down. It is safe to call peerWG.Wait because
|
||||
// peerWG.Add is not called outside of those loops.
|
||||
for _, peer := range srv.peers {
|
||||
peer.Disconnect(DiscQuitting)
|
||||
}
|
||||
srv.wg.Wait()
|
||||
|
||||
// wait till they actually disconnect
|
||||
// this is checked by claiming all peerSlots.
|
||||
// slots become available as the peers disconnect.
|
||||
for i := 0; i < cap(srv.peerSlots); i++ {
|
||||
<-srv.peerSlots
|
||||
}
|
||||
// terminate discLoop
|
||||
close(srv.peerDisconnect)
|
||||
}
|
||||
|
||||
func (srv *Server) discLoop() {
|
||||
for peer := range srv.peerDisconnect {
|
||||
srv.removePeer(peer)
|
||||
}
|
||||
srv.peerWG.Wait()
|
||||
}
|
||||
|
||||
// main loop for adding connections via listening
|
||||
func (srv *Server) listenLoop() {
|
||||
defer srv.wg.Done()
|
||||
|
||||
defer srv.loopWG.Done()
|
||||
srvlog.Infoln("Listening on", srv.listener.Addr())
|
||||
for {
|
||||
select {
|
||||
case slot := <-srv.peerSlots:
|
||||
srvlog.Debugf("grabbed slot %v for listening", slot)
|
||||
conn, err := srv.listener.Accept()
|
||||
if err != nil {
|
||||
srv.peerSlots <- slot
|
||||
return
|
||||
}
|
||||
srvlog.Debugf("Accepted conn %v (slot %d)\n", conn.RemoteAddr(), slot)
|
||||
srv.addPeer(conn, nil, slot)
|
||||
case <-srv.quit:
|
||||
conn, err := srv.listener.Accept()
|
||||
if err != nil {
|
||||
return
|
||||
}
|
||||
srvlog.Debugf("Accepted conn %v\n", conn.RemoteAddr())
|
||||
srv.peerWG.Add(1)
|
||||
go srv.startPeer(conn, nil)
|
||||
}
|
||||
}
|
||||
|
||||
func (srv *Server) natLoop(port int) {
|
||||
defer srv.wg.Done()
|
||||
for {
|
||||
srv.updatePortMapping(port)
|
||||
select {
|
||||
case <-time.After(portMappingUpdateInterval):
|
||||
// one more round
|
||||
case <-srv.quit:
|
||||
srv.removePortMapping(port)
|
||||
return
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func (srv *Server) updatePortMapping(port int) {
|
||||
srvlog.Infoln("Attempting to map port", port, "with", srv.NAT)
|
||||
err := srv.NAT.AddPortMapping("tcp", port, port, "ethereum p2p", portMappingTimeout)
|
||||
if err != nil {
|
||||
srvlog.Errorln("Port mapping error:", err)
|
||||
return
|
||||
}
|
||||
extip, err := srv.NAT.GetExternalAddress()
|
||||
if err != nil {
|
||||
srvlog.Errorln("Error getting external IP:", err)
|
||||
return
|
||||
}
|
||||
srv.lock.Lock()
|
||||
extaddr := *(srv.listener.Addr().(*net.TCPAddr))
|
||||
extaddr.IP = extip
|
||||
srvlog.Infoln("Mapped port, external addr is", &extaddr)
|
||||
srv.laddr = &extaddr
|
||||
srv.lock.Unlock()
|
||||
}
|
||||
|
||||
func (srv *Server) removePortMapping(port int) {
|
||||
srvlog.Infoln("Removing port mapping for", port, "with", srv.NAT)
|
||||
srv.NAT.DeletePortMapping("tcp", port, port)
|
||||
}
|
||||
|
||||
func (srv *Server) dialLoop() {
|
||||
defer srv.wg.Done()
|
||||
var (
|
||||
suggest chan *peerAddr
|
||||
slot *int
|
||||
slots = srv.peerSlots
|
||||
)
|
||||
defer srv.loopWG.Done()
|
||||
refresh := time.NewTicker(refreshPeersInterval)
|
||||
defer refresh.Stop()
|
||||
|
||||
srv.ntab.Bootstrap(srv.BootstrapNodes)
|
||||
go srv.findPeers()
|
||||
|
||||
dialed := make(chan *discover.Node)
|
||||
dialing := make(map[discover.NodeID]bool)
|
||||
|
||||
// TODO: limit number of active dials
|
||||
// TODO: ensure only one findPeers goroutine is running
|
||||
// TODO: pause findPeers when we're at capacity
|
||||
|
||||
for {
|
||||
select {
|
||||
case i := <-slots:
|
||||
// we need a peer in slot i, slot reserved
|
||||
slot = &i
|
||||
// now we can watch for candidate peers in the next loop
|
||||
suggest = srv.peerConnect
|
||||
// do not consume more until candidate peer is found
|
||||
slots = nil
|
||||
case <-refresh.C:
|
||||
|
||||
case desc := <-suggest:
|
||||
// candidate peer found, will dial out asyncronously
|
||||
// if connection fails slot will be released
|
||||
srvlog.DebugDetailf("dial %v (%v)", desc, *slot)
|
||||
go srv.dialPeer(desc, *slot)
|
||||
// we can watch if more peers needed in the next loop
|
||||
slots = srv.peerSlots
|
||||
// until then we dont care about candidate peers
|
||||
suggest = nil
|
||||
go srv.findPeers()
|
||||
|
||||
case dest := <-srv.peerConnect:
|
||||
// avoid dialing nodes that are already connected.
|
||||
// there is another check for this in addPeer,
|
||||
// which runs after the handshake.
|
||||
srv.lock.Lock()
|
||||
_, isconnected := srv.peers[dest.ID]
|
||||
srv.lock.Unlock()
|
||||
if isconnected || dialing[dest.ID] || dest.ID == srv.ntab.Self() {
|
||||
continue
|
||||
}
|
||||
|
||||
dialing[dest.ID] = true
|
||||
srv.peerWG.Add(1)
|
||||
go func() {
|
||||
srv.dialNode(dest)
|
||||
// at this point, the peer has been added
|
||||
// or discarded. either way, we're not dialing it anymore.
|
||||
dialed <- dest
|
||||
}()
|
||||
|
||||
case dest := <-dialed:
|
||||
delete(dialing, dest.ID)
|
||||
|
||||
case <-srv.quit:
|
||||
// give back the currently reserved slot
|
||||
if slot != nil {
|
||||
srv.peerSlots <- *slot
|
||||
}
|
||||
// TODO: maybe wait for active dials
|
||||
return
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// connect to peer via dial out
|
||||
func (srv *Server) dialPeer(desc *peerAddr, slot int) {
|
||||
srvlog.Debugf("Dialing %v (slot %d)\n", desc, slot)
|
||||
conn, err := srv.Dialer.Dial(desc.Network(), desc.String())
|
||||
func (srv *Server) dialNode(dest *discover.Node) {
|
||||
addr := &net.TCPAddr{IP: dest.IP, Port: dest.TCPPort}
|
||||
srvlog.Debugf("Dialing %v\n", dest)
|
||||
conn, err := srv.Dialer.Dial("tcp", addr.String())
|
||||
if err != nil {
|
||||
srvlog.DebugDetailf("dial error: %v", err)
|
||||
srv.peerSlots <- slot
|
||||
return
|
||||
}
|
||||
go srv.addPeer(conn, desc, slot)
|
||||
srv.startPeer(conn, dest)
|
||||
}
|
||||
|
||||
// creates the new peer object and inserts it into its slot
|
||||
func (srv *Server) addPeer(conn net.Conn, desc *peerAddr, slot int) *Peer {
|
||||
srv.lock.Lock()
|
||||
defer srv.lock.Unlock()
|
||||
if !srv.running {
|
||||
conn.Close()
|
||||
srv.peerSlots <- slot // release slot
|
||||
return nil
|
||||
func (srv *Server) findPeers() {
|
||||
far := srv.ntab.Self()
|
||||
for i := range far {
|
||||
far[i] = ^far[i]
|
||||
}
|
||||
peer := srv.newPeerFunc(srv, conn, desc)
|
||||
peer.slot = slot
|
||||
srv.peers[slot] = peer
|
||||
srv.peerCount++
|
||||
go func() {
|
||||
peer.loop()
|
||||
srv.peerDisconnect <- peer
|
||||
}()
|
||||
return peer
|
||||
}
|
||||
closeToSelf := srv.ntab.Lookup(srv.ntab.Self())
|
||||
farFromSelf := srv.ntab.Lookup(far)
|
||||
|
||||
// removes peer: sending disconnect msg, stop peer, remove rom list/table, release slot
|
||||
func (srv *Server) removePeer(peer *Peer) {
|
||||
srv.lock.Lock()
|
||||
defer srv.lock.Unlock()
|
||||
srvlog.Debugf("Removing %v (slot %v)\n", peer, peer.slot)
|
||||
if srv.peers[peer.slot] != peer {
|
||||
srvlog.Warnln("Invalid peer to remove:", peer)
|
||||
return
|
||||
}
|
||||
// remove from list and index
|
||||
srv.peerCount--
|
||||
srv.peers[peer.slot] = nil
|
||||
// release slot to signal need for a new peer, last!
|
||||
srv.peerSlots <- peer.slot
|
||||
}
|
||||
|
||||
func (srv *Server) verifyPeer(addr *peerAddr) error {
|
||||
if srv.Blacklist.Exists(addr.Pubkey) {
|
||||
return errors.New("blacklisted")
|
||||
}
|
||||
if bytes.Equal(srv.Identity.Pubkey()[1:], addr.Pubkey) {
|
||||
return newPeerError(errPubkeyForbidden, "not allowed to connect to srv")
|
||||
}
|
||||
srv.lock.RLock()
|
||||
defer srv.lock.RUnlock()
|
||||
for _, peer := range srv.peers {
|
||||
if peer != nil {
|
||||
id := peer.Identity()
|
||||
if id != nil && bytes.Equal(id.Pubkey(), addr.Pubkey) {
|
||||
return errors.New("already connected")
|
||||
}
|
||||
for i := 0; i < len(closeToSelf) || i < len(farFromSelf); i++ {
|
||||
if i < len(closeToSelf) {
|
||||
srv.peerConnect <- closeToSelf[i]
|
||||
}
|
||||
if i < len(farFromSelf) {
|
||||
srv.peerConnect <- farFromSelf[i]
|
||||
}
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
// TODO replace with "Set"
|
||||
func (srv *Server) startPeer(conn net.Conn, dest *discover.Node) {
|
||||
// TODO: handle/store session token
|
||||
conn.SetDeadline(time.Now().Add(handshakeTimeout))
|
||||
remoteID, _, err := srv.handshakeFunc(conn, srv.PrivateKey, dest)
|
||||
if err != nil {
|
||||
conn.Close()
|
||||
srvlog.Debugf("Encryption Handshake with %v failed: %v", conn.RemoteAddr(), err)
|
||||
return
|
||||
}
|
||||
ourID := srv.ntab.Self()
|
||||
p := newPeer(conn, srv.Protocols, srv.Name, &ourID, &remoteID)
|
||||
if ok, reason := srv.addPeer(remoteID, p); !ok {
|
||||
srvlog.DebugDetailf("Not adding %v (%v)\n", p, reason)
|
||||
p.politeDisconnect(reason)
|
||||
return
|
||||
}
|
||||
srvlog.Debugf("Added %v\n", p)
|
||||
|
||||
if srv.newPeerHook != nil {
|
||||
srv.newPeerHook(p)
|
||||
}
|
||||
discreason := p.run()
|
||||
srv.removePeer(p)
|
||||
srvlog.Debugf("Removed %v (%v)\n", p, discreason)
|
||||
}
|
||||
|
||||
func (srv *Server) addPeer(id discover.NodeID, p *Peer) (bool, DiscReason) {
|
||||
srv.lock.Lock()
|
||||
defer srv.lock.Unlock()
|
||||
switch {
|
||||
case !srv.running:
|
||||
return false, DiscQuitting
|
||||
case len(srv.peers) >= srv.MaxPeers:
|
||||
return false, DiscTooManyPeers
|
||||
case srv.peers[id] != nil:
|
||||
return false, DiscAlreadyConnected
|
||||
case srv.Blacklist.Exists(id[:]):
|
||||
return false, DiscUselessPeer
|
||||
case id == srv.ntab.Self():
|
||||
return false, DiscSelf
|
||||
}
|
||||
srv.peers[id] = p
|
||||
return true, 0
|
||||
}
|
||||
|
||||
func (srv *Server) removePeer(p *Peer) {
|
||||
srv.lock.Lock()
|
||||
delete(srv.peers, *p.remoteID)
|
||||
srv.lock.Unlock()
|
||||
srv.peerWG.Done()
|
||||
}
|
||||
|
||||
type Blacklist interface {
|
||||
Get([]byte) (bool, error)
|
||||
Put([]byte) error
|
||||
|
|
|
@ -2,19 +2,28 @@ package p2p
|
|||
|
||||
import (
|
||||
"bytes"
|
||||
"crypto/ecdsa"
|
||||
"io"
|
||||
"math/rand"
|
||||
"net"
|
||||
"sync"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"github.com/ethereum/go-ethereum/crypto"
|
||||
"github.com/ethereum/go-ethereum/p2p/discover"
|
||||
)
|
||||
|
||||
func startTestServer(t *testing.T, pf peerFunc) *Server {
|
||||
func startTestServer(t *testing.T, pf newPeerHook) *Server {
|
||||
server := &Server{
|
||||
Identity: &peerId{},
|
||||
Name: "test",
|
||||
MaxPeers: 10,
|
||||
ListenAddr: "127.0.0.1:0",
|
||||
newPeerFunc: pf,
|
||||
PrivateKey: newkey(),
|
||||
newPeerHook: pf,
|
||||
handshakeFunc: func(io.ReadWriter, *ecdsa.PrivateKey, *discover.Node) (id discover.NodeID, st []byte, err error) {
|
||||
return randomID(), nil, err
|
||||
},
|
||||
}
|
||||
if err := server.Start(); err != nil {
|
||||
t.Fatalf("Could not start server: %v", err)
|
||||
|
@ -27,16 +36,11 @@ func TestServerListen(t *testing.T) {
|
|||
|
||||
// start the test server
|
||||
connected := make(chan *Peer)
|
||||
srv := startTestServer(t, func(srv *Server, conn net.Conn, dialAddr *peerAddr) *Peer {
|
||||
if conn == nil {
|
||||
srv := startTestServer(t, func(p *Peer) {
|
||||
if p == nil {
|
||||
t.Error("peer func called with nil conn")
|
||||
}
|
||||
if dialAddr != nil {
|
||||
t.Error("peer func called with non-nil dialAddr")
|
||||
}
|
||||
peer := newPeer(conn, nil, dialAddr)
|
||||
connected <- peer
|
||||
return peer
|
||||
connected <- p
|
||||
})
|
||||
defer close(connected)
|
||||
defer srv.Stop()
|
||||
|
@ -50,9 +54,9 @@ func TestServerListen(t *testing.T) {
|
|||
|
||||
select {
|
||||
case peer := <-connected:
|
||||
if peer.conn.LocalAddr().String() != conn.RemoteAddr().String() {
|
||||
if peer.LocalAddr().String() != conn.RemoteAddr().String() {
|
||||
t.Errorf("peer started with wrong conn: got %v, want %v",
|
||||
peer.conn.LocalAddr(), conn.RemoteAddr())
|
||||
peer.LocalAddr(), conn.RemoteAddr())
|
||||
}
|
||||
case <-time.After(1 * time.Second):
|
||||
t.Error("server did not accept within one second")
|
||||
|
@ -62,7 +66,7 @@ func TestServerListen(t *testing.T) {
|
|||
func TestServerDial(t *testing.T) {
|
||||
defer testlog(t).detach()
|
||||
|
||||
// run a fake TCP server to handle the connection.
|
||||
// run a one-shot TCP server to handle the connection.
|
||||
listener, err := net.Listen("tcp", "127.0.0.1:0")
|
||||
if err != nil {
|
||||
t.Fatalf("could not setup listener: %v")
|
||||
|
@ -72,41 +76,32 @@ func TestServerDial(t *testing.T) {
|
|||
go func() {
|
||||
conn, err := listener.Accept()
|
||||
if err != nil {
|
||||
t.Error("acccept error:", err)
|
||||
t.Error("accept error:", err)
|
||||
return
|
||||
}
|
||||
conn.Close()
|
||||
accepted <- conn
|
||||
}()
|
||||
|
||||
// start the test server
|
||||
// start the server
|
||||
connected := make(chan *Peer)
|
||||
srv := startTestServer(t, func(srv *Server, conn net.Conn, dialAddr *peerAddr) *Peer {
|
||||
if conn == nil {
|
||||
t.Error("peer func called with nil conn")
|
||||
}
|
||||
peer := newPeer(conn, nil, dialAddr)
|
||||
connected <- peer
|
||||
return peer
|
||||
})
|
||||
srv := startTestServer(t, func(p *Peer) { connected <- p })
|
||||
defer close(connected)
|
||||
defer srv.Stop()
|
||||
|
||||
// tell the server to connect.
|
||||
connAddr := newPeerAddr(listener.Addr(), nil)
|
||||
srv.peerConnect <- connAddr
|
||||
// tell the server to connect
|
||||
tcpAddr := listener.Addr().(*net.TCPAddr)
|
||||
srv.SuggestPeer(&discover.Node{IP: tcpAddr.IP, TCPPort: tcpAddr.Port})
|
||||
|
||||
select {
|
||||
case conn := <-accepted:
|
||||
select {
|
||||
case peer := <-connected:
|
||||
if peer.conn.RemoteAddr().String() != conn.LocalAddr().String() {
|
||||
if peer.RemoteAddr().String() != conn.LocalAddr().String() {
|
||||
t.Errorf("peer started with wrong conn: got %v, want %v",
|
||||
peer.conn.RemoteAddr(), conn.LocalAddr())
|
||||
}
|
||||
if peer.dialAddr != connAddr {
|
||||
t.Errorf("peer started with wrong dialAddr: got %v, want %v",
|
||||
peer.dialAddr, connAddr)
|
||||
peer.RemoteAddr(), conn.LocalAddr())
|
||||
}
|
||||
// TODO: validate more fields
|
||||
case <-time.After(1 * time.Second):
|
||||
t.Error("server did not launch peer within one second")
|
||||
}
|
||||
|
@ -118,16 +113,17 @@ func TestServerDial(t *testing.T) {
|
|||
|
||||
func TestServerBroadcast(t *testing.T) {
|
||||
defer testlog(t).detach()
|
||||
|
||||
var connected sync.WaitGroup
|
||||
srv := startTestServer(t, func(srv *Server, c net.Conn, dialAddr *peerAddr) *Peer {
|
||||
peer := newPeer(c, []Protocol{discard}, dialAddr)
|
||||
peer.startSubprotocols([]Cap{discard.cap()})
|
||||
srv := startTestServer(t, func(p *Peer) {
|
||||
p.protocols = []Protocol{discard}
|
||||
p.startSubprotocols([]Cap{discard.cap()})
|
||||
p.noHandshake = true
|
||||
connected.Done()
|
||||
return peer
|
||||
})
|
||||
defer srv.Stop()
|
||||
|
||||
// dial a bunch of conns
|
||||
// create a few peers
|
||||
var conns = make([]net.Conn, 8)
|
||||
connected.Add(len(conns))
|
||||
deadline := time.Now().Add(3 * time.Second)
|
||||
|
@ -159,3 +155,18 @@ func TestServerBroadcast(t *testing.T) {
|
|||
}
|
||||
}
|
||||
}
|
||||
|
||||
func newkey() *ecdsa.PrivateKey {
|
||||
key, err := crypto.GenerateKey()
|
||||
if err != nil {
|
||||
panic("couldn't generate key: " + err.Error())
|
||||
}
|
||||
return key
|
||||
}
|
||||
|
||||
func randomID() (id discover.NodeID) {
|
||||
for i := range id {
|
||||
id[i] = byte(rand.Intn(255))
|
||||
}
|
||||
return id
|
||||
}
|
||||
|
|
|
@ -15,7 +15,7 @@ func testlog(t *testing.T) testLogger {
|
|||
return l
|
||||
}
|
||||
|
||||
func (testLogger) GetLogLevel() logger.LogLevel { return logger.DebugLevel }
|
||||
func (testLogger) GetLogLevel() logger.LogLevel { return logger.DebugDetailLevel }
|
||||
func (testLogger) SetLogLevel(logger.LogLevel) {}
|
||||
|
||||
func (l testLogger) LogPrint(level logger.LogLevel, msg string) {
|
||||
|
|
|
@ -1,40 +0,0 @@
|
|||
// +build none
|
||||
|
||||
package main
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"log"
|
||||
"net"
|
||||
"os"
|
||||
|
||||
"github.com/ethereum/go-ethereum/crypto/secp256k1"
|
||||
"github.com/ethereum/go-ethereum/logger"
|
||||
"github.com/ethereum/go-ethereum/p2p"
|
||||
)
|
||||
|
||||
func main() {
|
||||
logger.AddLogSystem(logger.NewStdLogSystem(os.Stdout, log.LstdFlags, logger.DebugLevel))
|
||||
|
||||
pub, _ := secp256k1.GenerateKeyPair()
|
||||
srv := p2p.Server{
|
||||
MaxPeers: 10,
|
||||
Identity: p2p.NewSimpleClientIdentity("test", "1.0", "", string(pub)),
|
||||
ListenAddr: ":30303",
|
||||
NAT: p2p.PMP(net.ParseIP("10.0.0.1")),
|
||||
}
|
||||
if err := srv.Start(); err != nil {
|
||||
fmt.Println("could not start server:", err)
|
||||
os.Exit(1)
|
||||
}
|
||||
|
||||
// add seed peers
|
||||
seed, err := net.ResolveTCPAddr("tcp", "poc-7.ethdev.com:30303")
|
||||
if err != nil {
|
||||
fmt.Println("couldn't resolve:", err)
|
||||
os.Exit(1)
|
||||
}
|
||||
srv.SuggestPeer(seed.IP, seed.Port, nil)
|
||||
|
||||
select {}
|
||||
}
|
|
@ -350,8 +350,10 @@ func makeWriter(typ reflect.Type) (writer, error) {
|
|||
return writeUint, nil
|
||||
case kind == reflect.String:
|
||||
return writeString, nil
|
||||
case kind == reflect.Slice && typ.Elem().Kind() == reflect.Uint8 && !typ.Elem().Implements(encoderInterface):
|
||||
case kind == reflect.Slice && isByte(typ.Elem()):
|
||||
return writeBytes, nil
|
||||
case kind == reflect.Array && isByte(typ.Elem()):
|
||||
return writeByteArray, nil
|
||||
case kind == reflect.Slice || kind == reflect.Array:
|
||||
return makeSliceWriter(typ)
|
||||
case kind == reflect.Struct:
|
||||
|
@ -363,6 +365,10 @@ func makeWriter(typ reflect.Type) (writer, error) {
|
|||
}
|
||||
}
|
||||
|
||||
func isByte(typ reflect.Type) bool {
|
||||
return typ.Kind() == reflect.Uint8 && !typ.Implements(encoderInterface)
|
||||
}
|
||||
|
||||
func writeUint(val reflect.Value, w *encbuf) error {
|
||||
i := val.Uint()
|
||||
if i == 0 {
|
||||
|
@ -407,6 +413,20 @@ func writeBytes(val reflect.Value, w *encbuf) error {
|
|||
return nil
|
||||
}
|
||||
|
||||
func writeByteArray(val reflect.Value, w *encbuf) error {
|
||||
if !val.CanAddr() {
|
||||
// Slice requires the value to be addressable.
|
||||
// Make it addressable by copying.
|
||||
copy := reflect.New(val.Type()).Elem()
|
||||
copy.Set(val)
|
||||
val = copy
|
||||
}
|
||||
size := val.Len()
|
||||
slice := val.Slice(0, size).Bytes()
|
||||
w.encodeString(slice)
|
||||
return nil
|
||||
}
|
||||
|
||||
func writeString(val reflect.Value, w *encbuf) error {
|
||||
s := val.String()
|
||||
w.encodeStringHeader(len(s))
|
||||
|
|
|
@ -40,6 +40,8 @@ func (e *encodableReader) Read(b []byte) (int, error) {
|
|||
panic("called")
|
||||
}
|
||||
|
||||
type namedByteType byte
|
||||
|
||||
var (
|
||||
_ = Encoder(&testEncoder{})
|
||||
_ = Encoder(byteEncoder(0))
|
||||
|
@ -102,6 +104,10 @@ var encTests = []encTest{
|
|||
// byte slices, strings
|
||||
{val: []byte{}, output: "80"},
|
||||
{val: []byte{1, 2, 3}, output: "83010203"},
|
||||
|
||||
{val: []namedByteType{1, 2, 3}, output: "83010203"},
|
||||
{val: [...]namedByteType{1, 2, 3}, output: "83010203"},
|
||||
|
||||
{val: "", output: "80"},
|
||||
{val: "dog", output: "83646F67"},
|
||||
{
|
||||
|
|
|
@ -215,7 +215,7 @@ func NewPeer(peer *p2p.Peer) *Peer {
|
|||
return &Peer{
|
||||
ref: peer,
|
||||
Ip: fmt.Sprintf("%v", peer.RemoteAddr()),
|
||||
Version: fmt.Sprintf("%v", peer.Identity()),
|
||||
Version: fmt.Sprintf("%v", peer.ID()),
|
||||
Caps: fmt.Sprintf("%v", caps),
|
||||
}
|
||||
}
|
||||
|
|
|
@ -31,7 +31,6 @@ type Backend interface {
|
|||
IsListening() bool
|
||||
Peers() []*p2p.Peer
|
||||
KeyManager() *crypto.KeyManager
|
||||
ClientIdentity() p2p.ClientIdentity
|
||||
Db() ethutil.Database
|
||||
EventMux() *event.TypeMux
|
||||
Whisper() *whisper.Whisper
|
||||
|
|
Loading…
Reference in New Issue