mirror of https://github.com/status-im/op-geth.git
Merge branch 'develop'
This commit is contained in:
commit
2f2dd80e48
|
@ -15,6 +15,10 @@
|
||||||
"Comment": "1.2.0-95-g9b2bd2b",
|
"Comment": "1.2.0-95-g9b2bd2b",
|
||||||
"Rev": "9b2bd2b3489748d4d0a204fa4eb2ee9e89e0ebc6"
|
"Rev": "9b2bd2b3489748d4d0a204fa4eb2ee9e89e0ebc6"
|
||||||
},
|
},
|
||||||
|
{
|
||||||
|
"ImportPath": "github.com/davecgh/go-spew/spew",
|
||||||
|
"Rev": "3e6e67c4dcea3ac2f25fd4731abc0e1deaf36216"
|
||||||
|
},
|
||||||
{
|
{
|
||||||
"ImportPath": "github.com/ethereum/ethash",
|
"ImportPath": "github.com/ethereum/ethash",
|
||||||
"Comment": "v23.1-206-gf0e6321",
|
"Comment": "v23.1-206-gf0e6321",
|
||||||
|
|
|
@ -0,0 +1,450 @@
|
||||||
|
/*
|
||||||
|
* Copyright (c) 2013 Dave Collins <dave@davec.name>
|
||||||
|
*
|
||||||
|
* Permission to use, copy, modify, and distribute this software for any
|
||||||
|
* purpose with or without fee is hereby granted, provided that the above
|
||||||
|
* copyright notice and this permission notice appear in all copies.
|
||||||
|
*
|
||||||
|
* THE SOFTWARE IS PROVIDED "AS IS" AND THE AUTHOR DISCLAIMS ALL WARRANTIES
|
||||||
|
* WITH REGARD TO THIS SOFTWARE INCLUDING ALL IMPLIED WARRANTIES OF
|
||||||
|
* MERCHANTABILITY AND FITNESS. IN NO EVENT SHALL THE AUTHOR BE LIABLE FOR
|
||||||
|
* ANY SPECIAL, DIRECT, INDIRECT, OR CONSEQUENTIAL DAMAGES OR ANY DAMAGES
|
||||||
|
* WHATSOEVER RESULTING FROM LOSS OF USE, DATA OR PROFITS, WHETHER IN AN
|
||||||
|
* ACTION OF CONTRACT, NEGLIGENCE OR OTHER TORTIOUS ACTION, ARISING OUT OF
|
||||||
|
* OR IN CONNECTION WITH THE USE OR PERFORMANCE OF THIS SOFTWARE.
|
||||||
|
*/
|
||||||
|
|
||||||
|
package spew
|
||||||
|
|
||||||
|
import (
|
||||||
|
"bytes"
|
||||||
|
"fmt"
|
||||||
|
"io"
|
||||||
|
"reflect"
|
||||||
|
"sort"
|
||||||
|
"strconv"
|
||||||
|
"unsafe"
|
||||||
|
)
|
||||||
|
|
||||||
|
const (
|
||||||
|
// ptrSize is the size of a pointer on the current arch.
|
||||||
|
ptrSize = unsafe.Sizeof((*byte)(nil))
|
||||||
|
)
|
||||||
|
|
||||||
|
var (
|
||||||
|
// offsetPtr, offsetScalar, and offsetFlag are the offsets for the
|
||||||
|
// internal reflect.Value fields. These values are valid before golang
|
||||||
|
// commit ecccf07e7f9d which changed the format. The are also valid
|
||||||
|
// after commit 82f48826c6c7 which changed the format again to mirror
|
||||||
|
// the original format. Code in the init function updates these offsets
|
||||||
|
// as necessary.
|
||||||
|
offsetPtr = uintptr(ptrSize)
|
||||||
|
offsetScalar = uintptr(0)
|
||||||
|
offsetFlag = uintptr(ptrSize * 2)
|
||||||
|
|
||||||
|
// flagKindWidth and flagKindShift indicate various bits that the
|
||||||
|
// reflect package uses internally to track kind information.
|
||||||
|
//
|
||||||
|
// flagRO indicates whether or not the value field of a reflect.Value is
|
||||||
|
// read-only.
|
||||||
|
//
|
||||||
|
// flagIndir indicates whether the value field of a reflect.Value is
|
||||||
|
// the actual data or a pointer to the data.
|
||||||
|
//
|
||||||
|
// These values are valid before golang commit 90a7c3c86944 which
|
||||||
|
// changed their positions. Code in the init function updates these
|
||||||
|
// flags as necessary.
|
||||||
|
flagKindWidth = uintptr(5)
|
||||||
|
flagKindShift = uintptr(flagKindWidth - 1)
|
||||||
|
flagRO = uintptr(1 << 0)
|
||||||
|
flagIndir = uintptr(1 << 1)
|
||||||
|
)
|
||||||
|
|
||||||
|
func init() {
|
||||||
|
// Older versions of reflect.Value stored small integers directly in the
|
||||||
|
// ptr field (which is named val in the older versions). Versions
|
||||||
|
// between commits ecccf07e7f9d and 82f48826c6c7 added a new field named
|
||||||
|
// scalar for this purpose which unfortunately came before the flag
|
||||||
|
// field, so the offset of the flag field is different for those
|
||||||
|
// versions.
|
||||||
|
//
|
||||||
|
// This code constructs a new reflect.Value from a known small integer
|
||||||
|
// and checks if the size of the reflect.Value struct indicates it has
|
||||||
|
// the scalar field. When it does, the offsets are updated accordingly.
|
||||||
|
vv := reflect.ValueOf(0xf00)
|
||||||
|
if unsafe.Sizeof(vv) == (ptrSize * 4) {
|
||||||
|
offsetScalar = ptrSize * 2
|
||||||
|
offsetFlag = ptrSize * 3
|
||||||
|
}
|
||||||
|
|
||||||
|
// Commit 90a7c3c86944 changed the flag positions such that the low
|
||||||
|
// order bits are the kind. This code extracts the kind from the flags
|
||||||
|
// field and ensures it's the correct type. When it's not, the flag
|
||||||
|
// order has been changed to the newer format, so the flags are updated
|
||||||
|
// accordingly.
|
||||||
|
upf := unsafe.Pointer(uintptr(unsafe.Pointer(&vv)) + offsetFlag)
|
||||||
|
upfv := *(*uintptr)(upf)
|
||||||
|
flagKindMask := uintptr((1<<flagKindWidth - 1) << flagKindShift)
|
||||||
|
if (upfv&flagKindMask)>>flagKindShift != uintptr(reflect.Int) {
|
||||||
|
flagKindShift = 0
|
||||||
|
flagRO = 1 << 5
|
||||||
|
flagIndir = 1 << 6
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// unsafeReflectValue converts the passed reflect.Value into a one that bypasses
|
||||||
|
// the typical safety restrictions preventing access to unaddressable and
|
||||||
|
// unexported data. It works by digging the raw pointer to the underlying
|
||||||
|
// value out of the protected value and generating a new unprotected (unsafe)
|
||||||
|
// reflect.Value to it.
|
||||||
|
//
|
||||||
|
// This allows us to check for implementations of the Stringer and error
|
||||||
|
// interfaces to be used for pretty printing ordinarily unaddressable and
|
||||||
|
// inaccessible values such as unexported struct fields.
|
||||||
|
func unsafeReflectValue(v reflect.Value) (rv reflect.Value) {
|
||||||
|
indirects := 1
|
||||||
|
vt := v.Type()
|
||||||
|
upv := unsafe.Pointer(uintptr(unsafe.Pointer(&v)) + offsetPtr)
|
||||||
|
rvf := *(*uintptr)(unsafe.Pointer(uintptr(unsafe.Pointer(&v)) + offsetFlag))
|
||||||
|
if rvf&flagIndir != 0 {
|
||||||
|
vt = reflect.PtrTo(v.Type())
|
||||||
|
indirects++
|
||||||
|
} else if offsetScalar != 0 {
|
||||||
|
// The value is in the scalar field when it's not one of the
|
||||||
|
// reference types.
|
||||||
|
switch vt.Kind() {
|
||||||
|
case reflect.Uintptr:
|
||||||
|
case reflect.Chan:
|
||||||
|
case reflect.Func:
|
||||||
|
case reflect.Map:
|
||||||
|
case reflect.Ptr:
|
||||||
|
case reflect.UnsafePointer:
|
||||||
|
default:
|
||||||
|
upv = unsafe.Pointer(uintptr(unsafe.Pointer(&v)) +
|
||||||
|
offsetScalar)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
pv := reflect.NewAt(vt, upv)
|
||||||
|
rv = pv
|
||||||
|
for i := 0; i < indirects; i++ {
|
||||||
|
rv = rv.Elem()
|
||||||
|
}
|
||||||
|
return rv
|
||||||
|
}
|
||||||
|
|
||||||
|
// Some constants in the form of bytes to avoid string overhead. This mirrors
|
||||||
|
// the technique used in the fmt package.
|
||||||
|
var (
|
||||||
|
panicBytes = []byte("(PANIC=")
|
||||||
|
plusBytes = []byte("+")
|
||||||
|
iBytes = []byte("i")
|
||||||
|
trueBytes = []byte("true")
|
||||||
|
falseBytes = []byte("false")
|
||||||
|
interfaceBytes = []byte("(interface {})")
|
||||||
|
commaNewlineBytes = []byte(",\n")
|
||||||
|
newlineBytes = []byte("\n")
|
||||||
|
openBraceBytes = []byte("{")
|
||||||
|
openBraceNewlineBytes = []byte("{\n")
|
||||||
|
closeBraceBytes = []byte("}")
|
||||||
|
asteriskBytes = []byte("*")
|
||||||
|
colonBytes = []byte(":")
|
||||||
|
colonSpaceBytes = []byte(": ")
|
||||||
|
openParenBytes = []byte("(")
|
||||||
|
closeParenBytes = []byte(")")
|
||||||
|
spaceBytes = []byte(" ")
|
||||||
|
pointerChainBytes = []byte("->")
|
||||||
|
nilAngleBytes = []byte("<nil>")
|
||||||
|
maxNewlineBytes = []byte("<max depth reached>\n")
|
||||||
|
maxShortBytes = []byte("<max>")
|
||||||
|
circularBytes = []byte("<already shown>")
|
||||||
|
circularShortBytes = []byte("<shown>")
|
||||||
|
invalidAngleBytes = []byte("<invalid>")
|
||||||
|
openBracketBytes = []byte("[")
|
||||||
|
closeBracketBytes = []byte("]")
|
||||||
|
percentBytes = []byte("%")
|
||||||
|
precisionBytes = []byte(".")
|
||||||
|
openAngleBytes = []byte("<")
|
||||||
|
closeAngleBytes = []byte(">")
|
||||||
|
openMapBytes = []byte("map[")
|
||||||
|
closeMapBytes = []byte("]")
|
||||||
|
lenEqualsBytes = []byte("len=")
|
||||||
|
capEqualsBytes = []byte("cap=")
|
||||||
|
)
|
||||||
|
|
||||||
|
// hexDigits is used to map a decimal value to a hex digit.
|
||||||
|
var hexDigits = "0123456789abcdef"
|
||||||
|
|
||||||
|
// catchPanic handles any panics that might occur during the handleMethods
|
||||||
|
// calls.
|
||||||
|
func catchPanic(w io.Writer, v reflect.Value) {
|
||||||
|
if err := recover(); err != nil {
|
||||||
|
w.Write(panicBytes)
|
||||||
|
fmt.Fprintf(w, "%v", err)
|
||||||
|
w.Write(closeParenBytes)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// handleMethods attempts to call the Error and String methods on the underlying
|
||||||
|
// type the passed reflect.Value represents and outputes the result to Writer w.
|
||||||
|
//
|
||||||
|
// It handles panics in any called methods by catching and displaying the error
|
||||||
|
// as the formatted value.
|
||||||
|
func handleMethods(cs *ConfigState, w io.Writer, v reflect.Value) (handled bool) {
|
||||||
|
// We need an interface to check if the type implements the error or
|
||||||
|
// Stringer interface. However, the reflect package won't give us an
|
||||||
|
// interface on certain things like unexported struct fields in order
|
||||||
|
// to enforce visibility rules. We use unsafe to bypass these restrictions
|
||||||
|
// since this package does not mutate the values.
|
||||||
|
if !v.CanInterface() {
|
||||||
|
v = unsafeReflectValue(v)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Choose whether or not to do error and Stringer interface lookups against
|
||||||
|
// the base type or a pointer to the base type depending on settings.
|
||||||
|
// Technically calling one of these methods with a pointer receiver can
|
||||||
|
// mutate the value, however, types which choose to satisify an error or
|
||||||
|
// Stringer interface with a pointer receiver should not be mutating their
|
||||||
|
// state inside these interface methods.
|
||||||
|
var viface interface{}
|
||||||
|
if !cs.DisablePointerMethods {
|
||||||
|
if !v.CanAddr() {
|
||||||
|
v = unsafeReflectValue(v)
|
||||||
|
}
|
||||||
|
viface = v.Addr().Interface()
|
||||||
|
} else {
|
||||||
|
if v.CanAddr() {
|
||||||
|
v = v.Addr()
|
||||||
|
}
|
||||||
|
viface = v.Interface()
|
||||||
|
}
|
||||||
|
|
||||||
|
// Is it an error or Stringer?
|
||||||
|
switch iface := viface.(type) {
|
||||||
|
case error:
|
||||||
|
defer catchPanic(w, v)
|
||||||
|
if cs.ContinueOnMethod {
|
||||||
|
w.Write(openParenBytes)
|
||||||
|
w.Write([]byte(iface.Error()))
|
||||||
|
w.Write(closeParenBytes)
|
||||||
|
w.Write(spaceBytes)
|
||||||
|
return false
|
||||||
|
}
|
||||||
|
|
||||||
|
w.Write([]byte(iface.Error()))
|
||||||
|
return true
|
||||||
|
|
||||||
|
case fmt.Stringer:
|
||||||
|
defer catchPanic(w, v)
|
||||||
|
if cs.ContinueOnMethod {
|
||||||
|
w.Write(openParenBytes)
|
||||||
|
w.Write([]byte(iface.String()))
|
||||||
|
w.Write(closeParenBytes)
|
||||||
|
w.Write(spaceBytes)
|
||||||
|
return false
|
||||||
|
}
|
||||||
|
w.Write([]byte(iface.String()))
|
||||||
|
return true
|
||||||
|
}
|
||||||
|
return false
|
||||||
|
}
|
||||||
|
|
||||||
|
// printBool outputs a boolean value as true or false to Writer w.
|
||||||
|
func printBool(w io.Writer, val bool) {
|
||||||
|
if val {
|
||||||
|
w.Write(trueBytes)
|
||||||
|
} else {
|
||||||
|
w.Write(falseBytes)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// printInt outputs a signed integer value to Writer w.
|
||||||
|
func printInt(w io.Writer, val int64, base int) {
|
||||||
|
w.Write([]byte(strconv.FormatInt(val, base)))
|
||||||
|
}
|
||||||
|
|
||||||
|
// printUint outputs an unsigned integer value to Writer w.
|
||||||
|
func printUint(w io.Writer, val uint64, base int) {
|
||||||
|
w.Write([]byte(strconv.FormatUint(val, base)))
|
||||||
|
}
|
||||||
|
|
||||||
|
// printFloat outputs a floating point value using the specified precision,
|
||||||
|
// which is expected to be 32 or 64bit, to Writer w.
|
||||||
|
func printFloat(w io.Writer, val float64, precision int) {
|
||||||
|
w.Write([]byte(strconv.FormatFloat(val, 'g', -1, precision)))
|
||||||
|
}
|
||||||
|
|
||||||
|
// printComplex outputs a complex value using the specified float precision
|
||||||
|
// for the real and imaginary parts to Writer w.
|
||||||
|
func printComplex(w io.Writer, c complex128, floatPrecision int) {
|
||||||
|
r := real(c)
|
||||||
|
w.Write(openParenBytes)
|
||||||
|
w.Write([]byte(strconv.FormatFloat(r, 'g', -1, floatPrecision)))
|
||||||
|
i := imag(c)
|
||||||
|
if i >= 0 {
|
||||||
|
w.Write(plusBytes)
|
||||||
|
}
|
||||||
|
w.Write([]byte(strconv.FormatFloat(i, 'g', -1, floatPrecision)))
|
||||||
|
w.Write(iBytes)
|
||||||
|
w.Write(closeParenBytes)
|
||||||
|
}
|
||||||
|
|
||||||
|
// printHexPtr outputs a uintptr formatted as hexidecimal with a leading '0x'
|
||||||
|
// prefix to Writer w.
|
||||||
|
func printHexPtr(w io.Writer, p uintptr) {
|
||||||
|
// Null pointer.
|
||||||
|
num := uint64(p)
|
||||||
|
if num == 0 {
|
||||||
|
w.Write(nilAngleBytes)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
// Max uint64 is 16 bytes in hex + 2 bytes for '0x' prefix
|
||||||
|
buf := make([]byte, 18)
|
||||||
|
|
||||||
|
// It's simpler to construct the hex string right to left.
|
||||||
|
base := uint64(16)
|
||||||
|
i := len(buf) - 1
|
||||||
|
for num >= base {
|
||||||
|
buf[i] = hexDigits[num%base]
|
||||||
|
num /= base
|
||||||
|
i--
|
||||||
|
}
|
||||||
|
buf[i] = hexDigits[num]
|
||||||
|
|
||||||
|
// Add '0x' prefix.
|
||||||
|
i--
|
||||||
|
buf[i] = 'x'
|
||||||
|
i--
|
||||||
|
buf[i] = '0'
|
||||||
|
|
||||||
|
// Strip unused leading bytes.
|
||||||
|
buf = buf[i:]
|
||||||
|
w.Write(buf)
|
||||||
|
}
|
||||||
|
|
||||||
|
// valuesSorter implements sort.Interface to allow a slice of reflect.Value
|
||||||
|
// elements to be sorted.
|
||||||
|
type valuesSorter struct {
|
||||||
|
values []reflect.Value
|
||||||
|
strings []string // either nil or same len and values
|
||||||
|
cs *ConfigState
|
||||||
|
}
|
||||||
|
|
||||||
|
// newValuesSorter initializes a valuesSorter instance, which holds a set of
|
||||||
|
// surrogate keys on which the data should be sorted. It uses flags in
|
||||||
|
// ConfigState to decide if and how to populate those surrogate keys.
|
||||||
|
func newValuesSorter(values []reflect.Value, cs *ConfigState) sort.Interface {
|
||||||
|
vs := &valuesSorter{values: values, cs: cs}
|
||||||
|
if canSortSimply(vs.values[0].Kind()) {
|
||||||
|
return vs
|
||||||
|
}
|
||||||
|
if !cs.DisableMethods {
|
||||||
|
vs.strings = make([]string, len(values))
|
||||||
|
for i := range vs.values {
|
||||||
|
b := bytes.Buffer{}
|
||||||
|
if !handleMethods(cs, &b, vs.values[i]) {
|
||||||
|
vs.strings = nil
|
||||||
|
break
|
||||||
|
}
|
||||||
|
vs.strings[i] = b.String()
|
||||||
|
}
|
||||||
|
}
|
||||||
|
if vs.strings == nil && cs.SpewKeys {
|
||||||
|
vs.strings = make([]string, len(values))
|
||||||
|
for i := range vs.values {
|
||||||
|
vs.strings[i] = Sprintf("%#v", vs.values[i].Interface())
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return vs
|
||||||
|
}
|
||||||
|
|
||||||
|
// canSortSimply tests whether a reflect.Kind is a primitive that can be sorted
|
||||||
|
// directly, or whether it should be considered for sorting by surrogate keys
|
||||||
|
// (if the ConfigState allows it).
|
||||||
|
func canSortSimply(kind reflect.Kind) bool {
|
||||||
|
// This switch parallels valueSortLess, except for the default case.
|
||||||
|
switch kind {
|
||||||
|
case reflect.Bool:
|
||||||
|
return true
|
||||||
|
case reflect.Int8, reflect.Int16, reflect.Int32, reflect.Int64, reflect.Int:
|
||||||
|
return true
|
||||||
|
case reflect.Uint8, reflect.Uint16, reflect.Uint32, reflect.Uint64, reflect.Uint:
|
||||||
|
return true
|
||||||
|
case reflect.Float32, reflect.Float64:
|
||||||
|
return true
|
||||||
|
case reflect.String:
|
||||||
|
return true
|
||||||
|
case reflect.Uintptr:
|
||||||
|
return true
|
||||||
|
case reflect.Array:
|
||||||
|
return true
|
||||||
|
}
|
||||||
|
return false
|
||||||
|
}
|
||||||
|
|
||||||
|
// Len returns the number of values in the slice. It is part of the
|
||||||
|
// sort.Interface implementation.
|
||||||
|
func (s *valuesSorter) Len() int {
|
||||||
|
return len(s.values)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Swap swaps the values at the passed indices. It is part of the
|
||||||
|
// sort.Interface implementation.
|
||||||
|
func (s *valuesSorter) Swap(i, j int) {
|
||||||
|
s.values[i], s.values[j] = s.values[j], s.values[i]
|
||||||
|
if s.strings != nil {
|
||||||
|
s.strings[i], s.strings[j] = s.strings[j], s.strings[i]
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// valueSortLess returns whether the first value should sort before the second
|
||||||
|
// value. It is used by valueSorter.Less as part of the sort.Interface
|
||||||
|
// implementation.
|
||||||
|
func valueSortLess(a, b reflect.Value) bool {
|
||||||
|
switch a.Kind() {
|
||||||
|
case reflect.Bool:
|
||||||
|
return !a.Bool() && b.Bool()
|
||||||
|
case reflect.Int8, reflect.Int16, reflect.Int32, reflect.Int64, reflect.Int:
|
||||||
|
return a.Int() < b.Int()
|
||||||
|
case reflect.Uint8, reflect.Uint16, reflect.Uint32, reflect.Uint64, reflect.Uint:
|
||||||
|
return a.Uint() < b.Uint()
|
||||||
|
case reflect.Float32, reflect.Float64:
|
||||||
|
return a.Float() < b.Float()
|
||||||
|
case reflect.String:
|
||||||
|
return a.String() < b.String()
|
||||||
|
case reflect.Uintptr:
|
||||||
|
return a.Uint() < b.Uint()
|
||||||
|
case reflect.Array:
|
||||||
|
// Compare the contents of both arrays.
|
||||||
|
l := a.Len()
|
||||||
|
for i := 0; i < l; i++ {
|
||||||
|
av := a.Index(i)
|
||||||
|
bv := b.Index(i)
|
||||||
|
if av.Interface() == bv.Interface() {
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
return valueSortLess(av, bv)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return a.String() < b.String()
|
||||||
|
}
|
||||||
|
|
||||||
|
// Less returns whether the value at index i should sort before the
|
||||||
|
// value at index j. It is part of the sort.Interface implementation.
|
||||||
|
func (s *valuesSorter) Less(i, j int) bool {
|
||||||
|
if s.strings == nil {
|
||||||
|
return valueSortLess(s.values[i], s.values[j])
|
||||||
|
}
|
||||||
|
return s.strings[i] < s.strings[j]
|
||||||
|
}
|
||||||
|
|
||||||
|
// sortValues is a sort function that handles both native types and any type that
|
||||||
|
// can be converted to error or Stringer. Other inputs are sorted according to
|
||||||
|
// their Value.String() value to ensure display stability.
|
||||||
|
func sortValues(values []reflect.Value, cs *ConfigState) {
|
||||||
|
if len(values) == 0 {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
sort.Sort(newValuesSorter(values, cs))
|
||||||
|
}
|
298
Godeps/_workspace/src/github.com/davecgh/go-spew/spew/common_test.go
generated
vendored
Normal file
298
Godeps/_workspace/src/github.com/davecgh/go-spew/spew/common_test.go
generated
vendored
Normal file
|
@ -0,0 +1,298 @@
|
||||||
|
/*
|
||||||
|
* Copyright (c) 2013 Dave Collins <dave@davec.name>
|
||||||
|
*
|
||||||
|
* Permission to use, copy, modify, and distribute this software for any
|
||||||
|
* purpose with or without fee is hereby granted, provided that the above
|
||||||
|
* copyright notice and this permission notice appear in all copies.
|
||||||
|
*
|
||||||
|
* THE SOFTWARE IS PROVIDED "AS IS" AND THE AUTHOR DISCLAIMS ALL WARRANTIES
|
||||||
|
* WITH REGARD TO THIS SOFTWARE INCLUDING ALL IMPLIED WARRANTIES OF
|
||||||
|
* MERCHANTABILITY AND FITNESS. IN NO EVENT SHALL THE AUTHOR BE LIABLE FOR
|
||||||
|
* ANY SPECIAL, DIRECT, INDIRECT, OR CONSEQUENTIAL DAMAGES OR ANY DAMAGES
|
||||||
|
* WHATSOEVER RESULTING FROM LOSS OF USE, DATA OR PROFITS, WHETHER IN AN
|
||||||
|
* ACTION OF CONTRACT, NEGLIGENCE OR OTHER TORTIOUS ACTION, ARISING OUT OF
|
||||||
|
* OR IN CONNECTION WITH THE USE OR PERFORMANCE OF THIS SOFTWARE.
|
||||||
|
*/
|
||||||
|
|
||||||
|
package spew_test
|
||||||
|
|
||||||
|
import (
|
||||||
|
"fmt"
|
||||||
|
"reflect"
|
||||||
|
"testing"
|
||||||
|
|
||||||
|
"github.com/davecgh/go-spew/spew"
|
||||||
|
)
|
||||||
|
|
||||||
|
// custom type to test Stinger interface on non-pointer receiver.
|
||||||
|
type stringer string
|
||||||
|
|
||||||
|
// String implements the Stringer interface for testing invocation of custom
|
||||||
|
// stringers on types with non-pointer receivers.
|
||||||
|
func (s stringer) String() string {
|
||||||
|
return "stringer " + string(s)
|
||||||
|
}
|
||||||
|
|
||||||
|
// custom type to test Stinger interface on pointer receiver.
|
||||||
|
type pstringer string
|
||||||
|
|
||||||
|
// String implements the Stringer interface for testing invocation of custom
|
||||||
|
// stringers on types with only pointer receivers.
|
||||||
|
func (s *pstringer) String() string {
|
||||||
|
return "stringer " + string(*s)
|
||||||
|
}
|
||||||
|
|
||||||
|
// xref1 and xref2 are cross referencing structs for testing circular reference
|
||||||
|
// detection.
|
||||||
|
type xref1 struct {
|
||||||
|
ps2 *xref2
|
||||||
|
}
|
||||||
|
type xref2 struct {
|
||||||
|
ps1 *xref1
|
||||||
|
}
|
||||||
|
|
||||||
|
// indirCir1, indirCir2, and indirCir3 are used to generate an indirect circular
|
||||||
|
// reference for testing detection.
|
||||||
|
type indirCir1 struct {
|
||||||
|
ps2 *indirCir2
|
||||||
|
}
|
||||||
|
type indirCir2 struct {
|
||||||
|
ps3 *indirCir3
|
||||||
|
}
|
||||||
|
type indirCir3 struct {
|
||||||
|
ps1 *indirCir1
|
||||||
|
}
|
||||||
|
|
||||||
|
// embed is used to test embedded structures.
|
||||||
|
type embed struct {
|
||||||
|
a string
|
||||||
|
}
|
||||||
|
|
||||||
|
// embedwrap is used to test embedded structures.
|
||||||
|
type embedwrap struct {
|
||||||
|
*embed
|
||||||
|
e *embed
|
||||||
|
}
|
||||||
|
|
||||||
|
// panicer is used to intentionally cause a panic for testing spew properly
|
||||||
|
// handles them
|
||||||
|
type panicer int
|
||||||
|
|
||||||
|
func (p panicer) String() string {
|
||||||
|
panic("test panic")
|
||||||
|
}
|
||||||
|
|
||||||
|
// customError is used to test custom error interface invocation.
|
||||||
|
type customError int
|
||||||
|
|
||||||
|
func (e customError) Error() string {
|
||||||
|
return fmt.Sprintf("error: %d", int(e))
|
||||||
|
}
|
||||||
|
|
||||||
|
// stringizeWants converts a slice of wanted test output into a format suitable
|
||||||
|
// for a test error message.
|
||||||
|
func stringizeWants(wants []string) string {
|
||||||
|
s := ""
|
||||||
|
for i, want := range wants {
|
||||||
|
if i > 0 {
|
||||||
|
s += fmt.Sprintf("want%d: %s", i+1, want)
|
||||||
|
} else {
|
||||||
|
s += "want: " + want
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return s
|
||||||
|
}
|
||||||
|
|
||||||
|
// testFailed returns whether or not a test failed by checking if the result
|
||||||
|
// of the test is in the slice of wanted strings.
|
||||||
|
func testFailed(result string, wants []string) bool {
|
||||||
|
for _, want := range wants {
|
||||||
|
if result == want {
|
||||||
|
return false
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return true
|
||||||
|
}
|
||||||
|
|
||||||
|
type sortableStruct struct {
|
||||||
|
x int
|
||||||
|
}
|
||||||
|
|
||||||
|
func (ss sortableStruct) String() string {
|
||||||
|
return fmt.Sprintf("ss.%d", ss.x)
|
||||||
|
}
|
||||||
|
|
||||||
|
type unsortableStruct struct {
|
||||||
|
x int
|
||||||
|
}
|
||||||
|
|
||||||
|
type sortTestCase struct {
|
||||||
|
input []reflect.Value
|
||||||
|
expected []reflect.Value
|
||||||
|
}
|
||||||
|
|
||||||
|
func helpTestSortValues(tests []sortTestCase, cs *spew.ConfigState, t *testing.T) {
|
||||||
|
getInterfaces := func(values []reflect.Value) []interface{} {
|
||||||
|
interfaces := []interface{}{}
|
||||||
|
for _, v := range values {
|
||||||
|
interfaces = append(interfaces, v.Interface())
|
||||||
|
}
|
||||||
|
return interfaces
|
||||||
|
}
|
||||||
|
|
||||||
|
for _, test := range tests {
|
||||||
|
spew.SortValues(test.input, cs)
|
||||||
|
// reflect.DeepEqual cannot really make sense of reflect.Value,
|
||||||
|
// probably because of all the pointer tricks. For instance,
|
||||||
|
// v(2.0) != v(2.0) on a 32-bits system. Turn them into interface{}
|
||||||
|
// instead.
|
||||||
|
input := getInterfaces(test.input)
|
||||||
|
expected := getInterfaces(test.expected)
|
||||||
|
if !reflect.DeepEqual(input, expected) {
|
||||||
|
t.Errorf("Sort mismatch:\n %v != %v", input, expected)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// TestSortValues ensures the sort functionality for relect.Value based sorting
|
||||||
|
// works as intended.
|
||||||
|
func TestSortValues(t *testing.T) {
|
||||||
|
v := reflect.ValueOf
|
||||||
|
|
||||||
|
a := v("a")
|
||||||
|
b := v("b")
|
||||||
|
c := v("c")
|
||||||
|
embedA := v(embed{"a"})
|
||||||
|
embedB := v(embed{"b"})
|
||||||
|
embedC := v(embed{"c"})
|
||||||
|
tests := []sortTestCase{
|
||||||
|
// No values.
|
||||||
|
{
|
||||||
|
[]reflect.Value{},
|
||||||
|
[]reflect.Value{},
|
||||||
|
},
|
||||||
|
// Bools.
|
||||||
|
{
|
||||||
|
[]reflect.Value{v(false), v(true), v(false)},
|
||||||
|
[]reflect.Value{v(false), v(false), v(true)},
|
||||||
|
},
|
||||||
|
// Ints.
|
||||||
|
{
|
||||||
|
[]reflect.Value{v(2), v(1), v(3)},
|
||||||
|
[]reflect.Value{v(1), v(2), v(3)},
|
||||||
|
},
|
||||||
|
// Uints.
|
||||||
|
{
|
||||||
|
[]reflect.Value{v(uint8(2)), v(uint8(1)), v(uint8(3))},
|
||||||
|
[]reflect.Value{v(uint8(1)), v(uint8(2)), v(uint8(3))},
|
||||||
|
},
|
||||||
|
// Floats.
|
||||||
|
{
|
||||||
|
[]reflect.Value{v(2.0), v(1.0), v(3.0)},
|
||||||
|
[]reflect.Value{v(1.0), v(2.0), v(3.0)},
|
||||||
|
},
|
||||||
|
// Strings.
|
||||||
|
{
|
||||||
|
[]reflect.Value{b, a, c},
|
||||||
|
[]reflect.Value{a, b, c},
|
||||||
|
},
|
||||||
|
// Array
|
||||||
|
{
|
||||||
|
[]reflect.Value{v([3]int{3, 2, 1}), v([3]int{1, 3, 2}), v([3]int{1, 2, 3})},
|
||||||
|
[]reflect.Value{v([3]int{1, 2, 3}), v([3]int{1, 3, 2}), v([3]int{3, 2, 1})},
|
||||||
|
},
|
||||||
|
// Uintptrs.
|
||||||
|
{
|
||||||
|
[]reflect.Value{v(uintptr(2)), v(uintptr(1)), v(uintptr(3))},
|
||||||
|
[]reflect.Value{v(uintptr(1)), v(uintptr(2)), v(uintptr(3))},
|
||||||
|
},
|
||||||
|
// SortableStructs.
|
||||||
|
{
|
||||||
|
// Note: not sorted - DisableMethods is set.
|
||||||
|
[]reflect.Value{v(sortableStruct{2}), v(sortableStruct{1}), v(sortableStruct{3})},
|
||||||
|
[]reflect.Value{v(sortableStruct{2}), v(sortableStruct{1}), v(sortableStruct{3})},
|
||||||
|
},
|
||||||
|
// UnsortableStructs.
|
||||||
|
{
|
||||||
|
// Note: not sorted - SpewKeys is false.
|
||||||
|
[]reflect.Value{v(unsortableStruct{2}), v(unsortableStruct{1}), v(unsortableStruct{3})},
|
||||||
|
[]reflect.Value{v(unsortableStruct{2}), v(unsortableStruct{1}), v(unsortableStruct{3})},
|
||||||
|
},
|
||||||
|
// Invalid.
|
||||||
|
{
|
||||||
|
[]reflect.Value{embedB, embedA, embedC},
|
||||||
|
[]reflect.Value{embedB, embedA, embedC},
|
||||||
|
},
|
||||||
|
}
|
||||||
|
cs := spew.ConfigState{DisableMethods: true, SpewKeys: false}
|
||||||
|
helpTestSortValues(tests, &cs, t)
|
||||||
|
}
|
||||||
|
|
||||||
|
// TestSortValuesWithMethods ensures the sort functionality for relect.Value
|
||||||
|
// based sorting works as intended when using string methods.
|
||||||
|
func TestSortValuesWithMethods(t *testing.T) {
|
||||||
|
v := reflect.ValueOf
|
||||||
|
|
||||||
|
a := v("a")
|
||||||
|
b := v("b")
|
||||||
|
c := v("c")
|
||||||
|
tests := []sortTestCase{
|
||||||
|
// Ints.
|
||||||
|
{
|
||||||
|
[]reflect.Value{v(2), v(1), v(3)},
|
||||||
|
[]reflect.Value{v(1), v(2), v(3)},
|
||||||
|
},
|
||||||
|
// Strings.
|
||||||
|
{
|
||||||
|
[]reflect.Value{b, a, c},
|
||||||
|
[]reflect.Value{a, b, c},
|
||||||
|
},
|
||||||
|
// SortableStructs.
|
||||||
|
{
|
||||||
|
[]reflect.Value{v(sortableStruct{2}), v(sortableStruct{1}), v(sortableStruct{3})},
|
||||||
|
[]reflect.Value{v(sortableStruct{1}), v(sortableStruct{2}), v(sortableStruct{3})},
|
||||||
|
},
|
||||||
|
// UnsortableStructs.
|
||||||
|
{
|
||||||
|
// Note: not sorted - SpewKeys is false.
|
||||||
|
[]reflect.Value{v(unsortableStruct{2}), v(unsortableStruct{1}), v(unsortableStruct{3})},
|
||||||
|
[]reflect.Value{v(unsortableStruct{2}), v(unsortableStruct{1}), v(unsortableStruct{3})},
|
||||||
|
},
|
||||||
|
}
|
||||||
|
cs := spew.ConfigState{DisableMethods: false, SpewKeys: false}
|
||||||
|
helpTestSortValues(tests, &cs, t)
|
||||||
|
}
|
||||||
|
|
||||||
|
// TestSortValuesWithSpew ensures the sort functionality for relect.Value
|
||||||
|
// based sorting works as intended when using spew to stringify keys.
|
||||||
|
func TestSortValuesWithSpew(t *testing.T) {
|
||||||
|
v := reflect.ValueOf
|
||||||
|
|
||||||
|
a := v("a")
|
||||||
|
b := v("b")
|
||||||
|
c := v("c")
|
||||||
|
tests := []sortTestCase{
|
||||||
|
// Ints.
|
||||||
|
{
|
||||||
|
[]reflect.Value{v(2), v(1), v(3)},
|
||||||
|
[]reflect.Value{v(1), v(2), v(3)},
|
||||||
|
},
|
||||||
|
// Strings.
|
||||||
|
{
|
||||||
|
[]reflect.Value{b, a, c},
|
||||||
|
[]reflect.Value{a, b, c},
|
||||||
|
},
|
||||||
|
// SortableStructs.
|
||||||
|
{
|
||||||
|
[]reflect.Value{v(sortableStruct{2}), v(sortableStruct{1}), v(sortableStruct{3})},
|
||||||
|
[]reflect.Value{v(sortableStruct{1}), v(sortableStruct{2}), v(sortableStruct{3})},
|
||||||
|
},
|
||||||
|
// UnsortableStructs.
|
||||||
|
{
|
||||||
|
[]reflect.Value{v(unsortableStruct{2}), v(unsortableStruct{1}), v(unsortableStruct{3})},
|
||||||
|
[]reflect.Value{v(unsortableStruct{1}), v(unsortableStruct{2}), v(unsortableStruct{3})},
|
||||||
|
},
|
||||||
|
}
|
||||||
|
cs := spew.ConfigState{DisableMethods: true, SpewKeys: true}
|
||||||
|
helpTestSortValues(tests, &cs, t)
|
||||||
|
}
|
|
@ -0,0 +1,294 @@
|
||||||
|
/*
|
||||||
|
* Copyright (c) 2013 Dave Collins <dave@davec.name>
|
||||||
|
*
|
||||||
|
* Permission to use, copy, modify, and distribute this software for any
|
||||||
|
* purpose with or without fee is hereby granted, provided that the above
|
||||||
|
* copyright notice and this permission notice appear in all copies.
|
||||||
|
*
|
||||||
|
* THE SOFTWARE IS PROVIDED "AS IS" AND THE AUTHOR DISCLAIMS ALL WARRANTIES
|
||||||
|
* WITH REGARD TO THIS SOFTWARE INCLUDING ALL IMPLIED WARRANTIES OF
|
||||||
|
* MERCHANTABILITY AND FITNESS. IN NO EVENT SHALL THE AUTHOR BE LIABLE FOR
|
||||||
|
* ANY SPECIAL, DIRECT, INDIRECT, OR CONSEQUENTIAL DAMAGES OR ANY DAMAGES
|
||||||
|
* WHATSOEVER RESULTING FROM LOSS OF USE, DATA OR PROFITS, WHETHER IN AN
|
||||||
|
* ACTION OF CONTRACT, NEGLIGENCE OR OTHER TORTIOUS ACTION, ARISING OUT OF
|
||||||
|
* OR IN CONNECTION WITH THE USE OR PERFORMANCE OF THIS SOFTWARE.
|
||||||
|
*/
|
||||||
|
|
||||||
|
package spew
|
||||||
|
|
||||||
|
import (
|
||||||
|
"bytes"
|
||||||
|
"fmt"
|
||||||
|
"io"
|
||||||
|
"os"
|
||||||
|
)
|
||||||
|
|
||||||
|
// ConfigState houses the configuration options used by spew to format and
|
||||||
|
// display values. There is a global instance, Config, that is used to control
|
||||||
|
// all top-level Formatter and Dump functionality. Each ConfigState instance
|
||||||
|
// provides methods equivalent to the top-level functions.
|
||||||
|
//
|
||||||
|
// The zero value for ConfigState provides no indentation. You would typically
|
||||||
|
// want to set it to a space or a tab.
|
||||||
|
//
|
||||||
|
// Alternatively, you can use NewDefaultConfig to get a ConfigState instance
|
||||||
|
// with default settings. See the documentation of NewDefaultConfig for default
|
||||||
|
// values.
|
||||||
|
type ConfigState struct {
|
||||||
|
// Indent specifies the string to use for each indentation level. The
|
||||||
|
// global config instance that all top-level functions use set this to a
|
||||||
|
// single space by default. If you would like more indentation, you might
|
||||||
|
// set this to a tab with "\t" or perhaps two spaces with " ".
|
||||||
|
Indent string
|
||||||
|
|
||||||
|
// MaxDepth controls the maximum number of levels to descend into nested
|
||||||
|
// data structures. The default, 0, means there is no limit.
|
||||||
|
//
|
||||||
|
// NOTE: Circular data structures are properly detected, so it is not
|
||||||
|
// necessary to set this value unless you specifically want to limit deeply
|
||||||
|
// nested data structures.
|
||||||
|
MaxDepth int
|
||||||
|
|
||||||
|
// DisableMethods specifies whether or not error and Stringer interfaces are
|
||||||
|
// invoked for types that implement them.
|
||||||
|
DisableMethods bool
|
||||||
|
|
||||||
|
// DisablePointerMethods specifies whether or not to check for and invoke
|
||||||
|
// error and Stringer interfaces on types which only accept a pointer
|
||||||
|
// receiver when the current type is not a pointer.
|
||||||
|
//
|
||||||
|
// NOTE: This might be an unsafe action since calling one of these methods
|
||||||
|
// with a pointer receiver could technically mutate the value, however,
|
||||||
|
// in practice, types which choose to satisify an error or Stringer
|
||||||
|
// interface with a pointer receiver should not be mutating their state
|
||||||
|
// inside these interface methods.
|
||||||
|
DisablePointerMethods bool
|
||||||
|
|
||||||
|
// ContinueOnMethod specifies whether or not recursion should continue once
|
||||||
|
// a custom error or Stringer interface is invoked. The default, false,
|
||||||
|
// means it will print the results of invoking the custom error or Stringer
|
||||||
|
// interface and return immediately instead of continuing to recurse into
|
||||||
|
// the internals of the data type.
|
||||||
|
//
|
||||||
|
// NOTE: This flag does not have any effect if method invocation is disabled
|
||||||
|
// via the DisableMethods or DisablePointerMethods options.
|
||||||
|
ContinueOnMethod bool
|
||||||
|
|
||||||
|
// SortKeys specifies map keys should be sorted before being printed. Use
|
||||||
|
// this to have a more deterministic, diffable output. Note that only
|
||||||
|
// native types (bool, int, uint, floats, uintptr and string) and types
|
||||||
|
// that support the error or Stringer interfaces (if methods are
|
||||||
|
// enabled) are supported, with other types sorted according to the
|
||||||
|
// reflect.Value.String() output which guarantees display stability.
|
||||||
|
SortKeys bool
|
||||||
|
|
||||||
|
// SpewKeys specifies that, as a last resort attempt, map keys should
|
||||||
|
// be spewed to strings and sorted by those strings. This is only
|
||||||
|
// considered if SortKeys is true.
|
||||||
|
SpewKeys bool
|
||||||
|
}
|
||||||
|
|
||||||
|
// Config is the active configuration of the top-level functions.
|
||||||
|
// The configuration can be changed by modifying the contents of spew.Config.
|
||||||
|
var Config = ConfigState{Indent: " "}
|
||||||
|
|
||||||
|
// Errorf is a wrapper for fmt.Errorf that treats each argument as if it were
|
||||||
|
// passed with a Formatter interface returned by c.NewFormatter. It returns
|
||||||
|
// the formatted string as a value that satisfies error. See NewFormatter
|
||||||
|
// for formatting details.
|
||||||
|
//
|
||||||
|
// This function is shorthand for the following syntax:
|
||||||
|
//
|
||||||
|
// fmt.Errorf(format, c.NewFormatter(a), c.NewFormatter(b))
|
||||||
|
func (c *ConfigState) Errorf(format string, a ...interface{}) (err error) {
|
||||||
|
return fmt.Errorf(format, c.convertArgs(a)...)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Fprint is a wrapper for fmt.Fprint that treats each argument as if it were
|
||||||
|
// passed with a Formatter interface returned by c.NewFormatter. It returns
|
||||||
|
// the number of bytes written and any write error encountered. See
|
||||||
|
// NewFormatter for formatting details.
|
||||||
|
//
|
||||||
|
// This function is shorthand for the following syntax:
|
||||||
|
//
|
||||||
|
// fmt.Fprint(w, c.NewFormatter(a), c.NewFormatter(b))
|
||||||
|
func (c *ConfigState) Fprint(w io.Writer, a ...interface{}) (n int, err error) {
|
||||||
|
return fmt.Fprint(w, c.convertArgs(a)...)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Fprintf is a wrapper for fmt.Fprintf that treats each argument as if it were
|
||||||
|
// passed with a Formatter interface returned by c.NewFormatter. It returns
|
||||||
|
// the number of bytes written and any write error encountered. See
|
||||||
|
// NewFormatter for formatting details.
|
||||||
|
//
|
||||||
|
// This function is shorthand for the following syntax:
|
||||||
|
//
|
||||||
|
// fmt.Fprintf(w, format, c.NewFormatter(a), c.NewFormatter(b))
|
||||||
|
func (c *ConfigState) Fprintf(w io.Writer, format string, a ...interface{}) (n int, err error) {
|
||||||
|
return fmt.Fprintf(w, format, c.convertArgs(a)...)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Fprintln is a wrapper for fmt.Fprintln that treats each argument as if it
|
||||||
|
// passed with a Formatter interface returned by c.NewFormatter. See
|
||||||
|
// NewFormatter for formatting details.
|
||||||
|
//
|
||||||
|
// This function is shorthand for the following syntax:
|
||||||
|
//
|
||||||
|
// fmt.Fprintln(w, c.NewFormatter(a), c.NewFormatter(b))
|
||||||
|
func (c *ConfigState) Fprintln(w io.Writer, a ...interface{}) (n int, err error) {
|
||||||
|
return fmt.Fprintln(w, c.convertArgs(a)...)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Print is a wrapper for fmt.Print that treats each argument as if it were
|
||||||
|
// passed with a Formatter interface returned by c.NewFormatter. It returns
|
||||||
|
// the number of bytes written and any write error encountered. See
|
||||||
|
// NewFormatter for formatting details.
|
||||||
|
//
|
||||||
|
// This function is shorthand for the following syntax:
|
||||||
|
//
|
||||||
|
// fmt.Print(c.NewFormatter(a), c.NewFormatter(b))
|
||||||
|
func (c *ConfigState) Print(a ...interface{}) (n int, err error) {
|
||||||
|
return fmt.Print(c.convertArgs(a)...)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Printf is a wrapper for fmt.Printf that treats each argument as if it were
|
||||||
|
// passed with a Formatter interface returned by c.NewFormatter. It returns
|
||||||
|
// the number of bytes written and any write error encountered. See
|
||||||
|
// NewFormatter for formatting details.
|
||||||
|
//
|
||||||
|
// This function is shorthand for the following syntax:
|
||||||
|
//
|
||||||
|
// fmt.Printf(format, c.NewFormatter(a), c.NewFormatter(b))
|
||||||
|
func (c *ConfigState) Printf(format string, a ...interface{}) (n int, err error) {
|
||||||
|
return fmt.Printf(format, c.convertArgs(a)...)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Println is a wrapper for fmt.Println that treats each argument as if it were
|
||||||
|
// passed with a Formatter interface returned by c.NewFormatter. It returns
|
||||||
|
// the number of bytes written and any write error encountered. See
|
||||||
|
// NewFormatter for formatting details.
|
||||||
|
//
|
||||||
|
// This function is shorthand for the following syntax:
|
||||||
|
//
|
||||||
|
// fmt.Println(c.NewFormatter(a), c.NewFormatter(b))
|
||||||
|
func (c *ConfigState) Println(a ...interface{}) (n int, err error) {
|
||||||
|
return fmt.Println(c.convertArgs(a)...)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Sprint is a wrapper for fmt.Sprint that treats each argument as if it were
|
||||||
|
// passed with a Formatter interface returned by c.NewFormatter. It returns
|
||||||
|
// the resulting string. See NewFormatter for formatting details.
|
||||||
|
//
|
||||||
|
// This function is shorthand for the following syntax:
|
||||||
|
//
|
||||||
|
// fmt.Sprint(c.NewFormatter(a), c.NewFormatter(b))
|
||||||
|
func (c *ConfigState) Sprint(a ...interface{}) string {
|
||||||
|
return fmt.Sprint(c.convertArgs(a)...)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Sprintf is a wrapper for fmt.Sprintf that treats each argument as if it were
|
||||||
|
// passed with a Formatter interface returned by c.NewFormatter. It returns
|
||||||
|
// the resulting string. See NewFormatter for formatting details.
|
||||||
|
//
|
||||||
|
// This function is shorthand for the following syntax:
|
||||||
|
//
|
||||||
|
// fmt.Sprintf(format, c.NewFormatter(a), c.NewFormatter(b))
|
||||||
|
func (c *ConfigState) Sprintf(format string, a ...interface{}) string {
|
||||||
|
return fmt.Sprintf(format, c.convertArgs(a)...)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Sprintln is a wrapper for fmt.Sprintln that treats each argument as if it
|
||||||
|
// were passed with a Formatter interface returned by c.NewFormatter. It
|
||||||
|
// returns the resulting string. See NewFormatter for formatting details.
|
||||||
|
//
|
||||||
|
// This function is shorthand for the following syntax:
|
||||||
|
//
|
||||||
|
// fmt.Sprintln(c.NewFormatter(a), c.NewFormatter(b))
|
||||||
|
func (c *ConfigState) Sprintln(a ...interface{}) string {
|
||||||
|
return fmt.Sprintln(c.convertArgs(a)...)
|
||||||
|
}
|
||||||
|
|
||||||
|
/*
|
||||||
|
NewFormatter returns a custom formatter that satisfies the fmt.Formatter
|
||||||
|
interface. As a result, it integrates cleanly with standard fmt package
|
||||||
|
printing functions. The formatter is useful for inline printing of smaller data
|
||||||
|
types similar to the standard %v format specifier.
|
||||||
|
|
||||||
|
The custom formatter only responds to the %v (most compact), %+v (adds pointer
|
||||||
|
addresses), %#v (adds types), and %#+v (adds types and pointer addresses) verb
|
||||||
|
combinations. Any other verbs such as %x and %q will be sent to the the
|
||||||
|
standard fmt package for formatting. In addition, the custom formatter ignores
|
||||||
|
the width and precision arguments (however they will still work on the format
|
||||||
|
specifiers not handled by the custom formatter).
|
||||||
|
|
||||||
|
Typically this function shouldn't be called directly. It is much easier to make
|
||||||
|
use of the custom formatter by calling one of the convenience functions such as
|
||||||
|
c.Printf, c.Println, or c.Printf.
|
||||||
|
*/
|
||||||
|
func (c *ConfigState) NewFormatter(v interface{}) fmt.Formatter {
|
||||||
|
return newFormatter(c, v)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Fdump formats and displays the passed arguments to io.Writer w. It formats
|
||||||
|
// exactly the same as Dump.
|
||||||
|
func (c *ConfigState) Fdump(w io.Writer, a ...interface{}) {
|
||||||
|
fdump(c, w, a...)
|
||||||
|
}
|
||||||
|
|
||||||
|
/*
|
||||||
|
Dump displays the passed parameters to standard out with newlines, customizable
|
||||||
|
indentation, and additional debug information such as complete types and all
|
||||||
|
pointer addresses used to indirect to the final value. It provides the
|
||||||
|
following features over the built-in printing facilities provided by the fmt
|
||||||
|
package:
|
||||||
|
|
||||||
|
* Pointers are dereferenced and followed
|
||||||
|
* Circular data structures are detected and handled properly
|
||||||
|
* Custom Stringer/error interfaces are optionally invoked, including
|
||||||
|
on unexported types
|
||||||
|
* Custom types which only implement the Stringer/error interfaces via
|
||||||
|
a pointer receiver are optionally invoked when passing non-pointer
|
||||||
|
variables
|
||||||
|
* Byte arrays and slices are dumped like the hexdump -C command which
|
||||||
|
includes offsets, byte values in hex, and ASCII output
|
||||||
|
|
||||||
|
The configuration options are controlled by modifying the public members
|
||||||
|
of c. See ConfigState for options documentation.
|
||||||
|
|
||||||
|
See Fdump if you would prefer dumping to an arbitrary io.Writer or Sdump to
|
||||||
|
get the formatted result as a string.
|
||||||
|
*/
|
||||||
|
func (c *ConfigState) Dump(a ...interface{}) {
|
||||||
|
fdump(c, os.Stdout, a...)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Sdump returns a string with the passed arguments formatted exactly the same
|
||||||
|
// as Dump.
|
||||||
|
func (c *ConfigState) Sdump(a ...interface{}) string {
|
||||||
|
var buf bytes.Buffer
|
||||||
|
fdump(c, &buf, a...)
|
||||||
|
return buf.String()
|
||||||
|
}
|
||||||
|
|
||||||
|
// convertArgs accepts a slice of arguments and returns a slice of the same
|
||||||
|
// length with each argument converted to a spew Formatter interface using
|
||||||
|
// the ConfigState associated with s.
|
||||||
|
func (c *ConfigState) convertArgs(args []interface{}) (formatters []interface{}) {
|
||||||
|
formatters = make([]interface{}, len(args))
|
||||||
|
for index, arg := range args {
|
||||||
|
formatters[index] = newFormatter(c, arg)
|
||||||
|
}
|
||||||
|
return formatters
|
||||||
|
}
|
||||||
|
|
||||||
|
// NewDefaultConfig returns a ConfigState with the following default settings.
|
||||||
|
//
|
||||||
|
// Indent: " "
|
||||||
|
// MaxDepth: 0
|
||||||
|
// DisableMethods: false
|
||||||
|
// DisablePointerMethods: false
|
||||||
|
// ContinueOnMethod: false
|
||||||
|
// SortKeys: false
|
||||||
|
func NewDefaultConfig() *ConfigState {
|
||||||
|
return &ConfigState{Indent: " "}
|
||||||
|
}
|
|
@ -0,0 +1,202 @@
|
||||||
|
/*
|
||||||
|
* Copyright (c) 2013 Dave Collins <dave@davec.name>
|
||||||
|
*
|
||||||
|
* Permission to use, copy, modify, and distribute this software for any
|
||||||
|
* purpose with or without fee is hereby granted, provided that the above
|
||||||
|
* copyright notice and this permission notice appear in all copies.
|
||||||
|
*
|
||||||
|
* THE SOFTWARE IS PROVIDED "AS IS" AND THE AUTHOR DISCLAIMS ALL WARRANTIES
|
||||||
|
* WITH REGARD TO THIS SOFTWARE INCLUDING ALL IMPLIED WARRANTIES OF
|
||||||
|
* MERCHANTABILITY AND FITNESS. IN NO EVENT SHALL THE AUTHOR BE LIABLE FOR
|
||||||
|
* ANY SPECIAL, DIRECT, INDIRECT, OR CONSEQUENTIAL DAMAGES OR ANY DAMAGES
|
||||||
|
* WHATSOEVER RESULTING FROM LOSS OF USE, DATA OR PROFITS, WHETHER IN AN
|
||||||
|
* ACTION OF CONTRACT, NEGLIGENCE OR OTHER TORTIOUS ACTION, ARISING OUT OF
|
||||||
|
* OR IN CONNECTION WITH THE USE OR PERFORMANCE OF THIS SOFTWARE.
|
||||||
|
*/
|
||||||
|
|
||||||
|
/*
|
||||||
|
Package spew implements a deep pretty printer for Go data structures to aid in
|
||||||
|
debugging.
|
||||||
|
|
||||||
|
A quick overview of the additional features spew provides over the built-in
|
||||||
|
printing facilities for Go data types are as follows:
|
||||||
|
|
||||||
|
* Pointers are dereferenced and followed
|
||||||
|
* Circular data structures are detected and handled properly
|
||||||
|
* Custom Stringer/error interfaces are optionally invoked, including
|
||||||
|
on unexported types
|
||||||
|
* Custom types which only implement the Stringer/error interfaces via
|
||||||
|
a pointer receiver are optionally invoked when passing non-pointer
|
||||||
|
variables
|
||||||
|
* Byte arrays and slices are dumped like the hexdump -C command which
|
||||||
|
includes offsets, byte values in hex, and ASCII output (only when using
|
||||||
|
Dump style)
|
||||||
|
|
||||||
|
There are two different approaches spew allows for dumping Go data structures:
|
||||||
|
|
||||||
|
* Dump style which prints with newlines, customizable indentation,
|
||||||
|
and additional debug information such as types and all pointer addresses
|
||||||
|
used to indirect to the final value
|
||||||
|
* A custom Formatter interface that integrates cleanly with the standard fmt
|
||||||
|
package and replaces %v, %+v, %#v, and %#+v to provide inline printing
|
||||||
|
similar to the default %v while providing the additional functionality
|
||||||
|
outlined above and passing unsupported format verbs such as %x and %q
|
||||||
|
along to fmt
|
||||||
|
|
||||||
|
Quick Start
|
||||||
|
|
||||||
|
This section demonstrates how to quickly get started with spew. See the
|
||||||
|
sections below for further details on formatting and configuration options.
|
||||||
|
|
||||||
|
To dump a variable with full newlines, indentation, type, and pointer
|
||||||
|
information use Dump, Fdump, or Sdump:
|
||||||
|
spew.Dump(myVar1, myVar2, ...)
|
||||||
|
spew.Fdump(someWriter, myVar1, myVar2, ...)
|
||||||
|
str := spew.Sdump(myVar1, myVar2, ...)
|
||||||
|
|
||||||
|
Alternatively, if you would prefer to use format strings with a compacted inline
|
||||||
|
printing style, use the convenience wrappers Printf, Fprintf, etc with
|
||||||
|
%v (most compact), %+v (adds pointer addresses), %#v (adds types), or
|
||||||
|
%#+v (adds types and pointer addresses):
|
||||||
|
spew.Printf("myVar1: %v -- myVar2: %+v", myVar1, myVar2)
|
||||||
|
spew.Printf("myVar3: %#v -- myVar4: %#+v", myVar3, myVar4)
|
||||||
|
spew.Fprintf(someWriter, "myVar1: %v -- myVar2: %+v", myVar1, myVar2)
|
||||||
|
spew.Fprintf(someWriter, "myVar3: %#v -- myVar4: %#+v", myVar3, myVar4)
|
||||||
|
|
||||||
|
Configuration Options
|
||||||
|
|
||||||
|
Configuration of spew is handled by fields in the ConfigState type. For
|
||||||
|
convenience, all of the top-level functions use a global state available
|
||||||
|
via the spew.Config global.
|
||||||
|
|
||||||
|
It is also possible to create a ConfigState instance that provides methods
|
||||||
|
equivalent to the top-level functions. This allows concurrent configuration
|
||||||
|
options. See the ConfigState documentation for more details.
|
||||||
|
|
||||||
|
The following configuration options are available:
|
||||||
|
* Indent
|
||||||
|
String to use for each indentation level for Dump functions.
|
||||||
|
It is a single space by default. A popular alternative is "\t".
|
||||||
|
|
||||||
|
* MaxDepth
|
||||||
|
Maximum number of levels to descend into nested data structures.
|
||||||
|
There is no limit by default.
|
||||||
|
|
||||||
|
* DisableMethods
|
||||||
|
Disables invocation of error and Stringer interface methods.
|
||||||
|
Method invocation is enabled by default.
|
||||||
|
|
||||||
|
* DisablePointerMethods
|
||||||
|
Disables invocation of error and Stringer interface methods on types
|
||||||
|
which only accept pointer receivers from non-pointer variables.
|
||||||
|
Pointer method invocation is enabled by default.
|
||||||
|
|
||||||
|
* ContinueOnMethod
|
||||||
|
Enables recursion into types after invoking error and Stringer interface
|
||||||
|
methods. Recursion after method invocation is disabled by default.
|
||||||
|
|
||||||
|
* SortKeys
|
||||||
|
Specifies map keys should be sorted before being printed. Use
|
||||||
|
this to have a more deterministic, diffable output. Note that
|
||||||
|
only native types (bool, int, uint, floats, uintptr and string)
|
||||||
|
and types which implement error or Stringer interfaces are
|
||||||
|
supported with other types sorted according to the
|
||||||
|
reflect.Value.String() output which guarantees display
|
||||||
|
stability. Natural map order is used by default.
|
||||||
|
|
||||||
|
* SpewKeys
|
||||||
|
Specifies that, as a last resort attempt, map keys should be
|
||||||
|
spewed to strings and sorted by those strings. This is only
|
||||||
|
considered if SortKeys is true.
|
||||||
|
|
||||||
|
Dump Usage
|
||||||
|
|
||||||
|
Simply call spew.Dump with a list of variables you want to dump:
|
||||||
|
|
||||||
|
spew.Dump(myVar1, myVar2, ...)
|
||||||
|
|
||||||
|
You may also call spew.Fdump if you would prefer to output to an arbitrary
|
||||||
|
io.Writer. For example, to dump to standard error:
|
||||||
|
|
||||||
|
spew.Fdump(os.Stderr, myVar1, myVar2, ...)
|
||||||
|
|
||||||
|
A third option is to call spew.Sdump to get the formatted output as a string:
|
||||||
|
|
||||||
|
str := spew.Sdump(myVar1, myVar2, ...)
|
||||||
|
|
||||||
|
Sample Dump Output
|
||||||
|
|
||||||
|
See the Dump example for details on the setup of the types and variables being
|
||||||
|
shown here.
|
||||||
|
|
||||||
|
(main.Foo) {
|
||||||
|
unexportedField: (*main.Bar)(0xf84002e210)({
|
||||||
|
flag: (main.Flag) flagTwo,
|
||||||
|
data: (uintptr) <nil>
|
||||||
|
}),
|
||||||
|
ExportedField: (map[interface {}]interface {}) (len=1) {
|
||||||
|
(string) (len=3) "one": (bool) true
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
Byte (and uint8) arrays and slices are displayed uniquely like the hexdump -C
|
||||||
|
command as shown.
|
||||||
|
([]uint8) (len=32 cap=32) {
|
||||||
|
00000000 11 12 13 14 15 16 17 18 19 1a 1b 1c 1d 1e 1f 20 |............... |
|
||||||
|
00000010 21 22 23 24 25 26 27 28 29 2a 2b 2c 2d 2e 2f 30 |!"#$%&'()*+,-./0|
|
||||||
|
00000020 31 32 |12|
|
||||||
|
}
|
||||||
|
|
||||||
|
Custom Formatter
|
||||||
|
|
||||||
|
Spew provides a custom formatter that implements the fmt.Formatter interface
|
||||||
|
so that it integrates cleanly with standard fmt package printing functions. The
|
||||||
|
formatter is useful for inline printing of smaller data types similar to the
|
||||||
|
standard %v format specifier.
|
||||||
|
|
||||||
|
The custom formatter only responds to the %v (most compact), %+v (adds pointer
|
||||||
|
addresses), %#v (adds types), or %#+v (adds types and pointer addresses) verb
|
||||||
|
combinations. Any other verbs such as %x and %q will be sent to the the
|
||||||
|
standard fmt package for formatting. In addition, the custom formatter ignores
|
||||||
|
the width and precision arguments (however they will still work on the format
|
||||||
|
specifiers not handled by the custom formatter).
|
||||||
|
|
||||||
|
Custom Formatter Usage
|
||||||
|
|
||||||
|
The simplest way to make use of the spew custom formatter is to call one of the
|
||||||
|
convenience functions such as spew.Printf, spew.Println, or spew.Printf. The
|
||||||
|
functions have syntax you are most likely already familiar with:
|
||||||
|
|
||||||
|
spew.Printf("myVar1: %v -- myVar2: %+v", myVar1, myVar2)
|
||||||
|
spew.Printf("myVar3: %#v -- myVar4: %#+v", myVar3, myVar4)
|
||||||
|
spew.Println(myVar, myVar2)
|
||||||
|
spew.Fprintf(os.Stderr, "myVar1: %v -- myVar2: %+v", myVar1, myVar2)
|
||||||
|
spew.Fprintf(os.Stderr, "myVar3: %#v -- myVar4: %#+v", myVar3, myVar4)
|
||||||
|
|
||||||
|
See the Index for the full list convenience functions.
|
||||||
|
|
||||||
|
Sample Formatter Output
|
||||||
|
|
||||||
|
Double pointer to a uint8:
|
||||||
|
%v: <**>5
|
||||||
|
%+v: <**>(0xf8400420d0->0xf8400420c8)5
|
||||||
|
%#v: (**uint8)5
|
||||||
|
%#+v: (**uint8)(0xf8400420d0->0xf8400420c8)5
|
||||||
|
|
||||||
|
Pointer to circular struct with a uint8 field and a pointer to itself:
|
||||||
|
%v: <*>{1 <*><shown>}
|
||||||
|
%+v: <*>(0xf84003e260){ui8:1 c:<*>(0xf84003e260)<shown>}
|
||||||
|
%#v: (*main.circular){ui8:(uint8)1 c:(*main.circular)<shown>}
|
||||||
|
%#+v: (*main.circular)(0xf84003e260){ui8:(uint8)1 c:(*main.circular)(0xf84003e260)<shown>}
|
||||||
|
|
||||||
|
See the Printf example for details on the setup of variables being shown
|
||||||
|
here.
|
||||||
|
|
||||||
|
Errors
|
||||||
|
|
||||||
|
Since it is possible for custom Stringer/error interfaces to panic, spew
|
||||||
|
detects them and handles them internally by printing the panic information
|
||||||
|
inline with the output. Since spew is intended to provide deep pretty printing
|
||||||
|
capabilities on structures, it intentionally does not return any errors.
|
||||||
|
*/
|
||||||
|
package spew
|
|
@ -0,0 +1,506 @@
|
||||||
|
/*
|
||||||
|
* Copyright (c) 2013 Dave Collins <dave@davec.name>
|
||||||
|
*
|
||||||
|
* Permission to use, copy, modify, and distribute this software for any
|
||||||
|
* purpose with or without fee is hereby granted, provided that the above
|
||||||
|
* copyright notice and this permission notice appear in all copies.
|
||||||
|
*
|
||||||
|
* THE SOFTWARE IS PROVIDED "AS IS" AND THE AUTHOR DISCLAIMS ALL WARRANTIES
|
||||||
|
* WITH REGARD TO THIS SOFTWARE INCLUDING ALL IMPLIED WARRANTIES OF
|
||||||
|
* MERCHANTABILITY AND FITNESS. IN NO EVENT SHALL THE AUTHOR BE LIABLE FOR
|
||||||
|
* ANY SPECIAL, DIRECT, INDIRECT, OR CONSEQUENTIAL DAMAGES OR ANY DAMAGES
|
||||||
|
* WHATSOEVER RESULTING FROM LOSS OF USE, DATA OR PROFITS, WHETHER IN AN
|
||||||
|
* ACTION OF CONTRACT, NEGLIGENCE OR OTHER TORTIOUS ACTION, ARISING OUT OF
|
||||||
|
* OR IN CONNECTION WITH THE USE OR PERFORMANCE OF THIS SOFTWARE.
|
||||||
|
*/
|
||||||
|
|
||||||
|
package spew
|
||||||
|
|
||||||
|
import (
|
||||||
|
"bytes"
|
||||||
|
"encoding/hex"
|
||||||
|
"fmt"
|
||||||
|
"io"
|
||||||
|
"os"
|
||||||
|
"reflect"
|
||||||
|
"regexp"
|
||||||
|
"strconv"
|
||||||
|
"strings"
|
||||||
|
)
|
||||||
|
|
||||||
|
var (
|
||||||
|
// uint8Type is a reflect.Type representing a uint8. It is used to
|
||||||
|
// convert cgo types to uint8 slices for hexdumping.
|
||||||
|
uint8Type = reflect.TypeOf(uint8(0))
|
||||||
|
|
||||||
|
// cCharRE is a regular expression that matches a cgo char.
|
||||||
|
// It is used to detect character arrays to hexdump them.
|
||||||
|
cCharRE = regexp.MustCompile("^.*\\._Ctype_char$")
|
||||||
|
|
||||||
|
// cUnsignedCharRE is a regular expression that matches a cgo unsigned
|
||||||
|
// char. It is used to detect unsigned character arrays to hexdump
|
||||||
|
// them.
|
||||||
|
cUnsignedCharRE = regexp.MustCompile("^.*\\._Ctype_unsignedchar$")
|
||||||
|
|
||||||
|
// cUint8tCharRE is a regular expression that matches a cgo uint8_t.
|
||||||
|
// It is used to detect uint8_t arrays to hexdump them.
|
||||||
|
cUint8tCharRE = regexp.MustCompile("^.*\\._Ctype_uint8_t$")
|
||||||
|
)
|
||||||
|
|
||||||
|
// dumpState contains information about the state of a dump operation.
|
||||||
|
type dumpState struct {
|
||||||
|
w io.Writer
|
||||||
|
depth int
|
||||||
|
pointers map[uintptr]int
|
||||||
|
ignoreNextType bool
|
||||||
|
ignoreNextIndent bool
|
||||||
|
cs *ConfigState
|
||||||
|
}
|
||||||
|
|
||||||
|
// indent performs indentation according to the depth level and cs.Indent
|
||||||
|
// option.
|
||||||
|
func (d *dumpState) indent() {
|
||||||
|
if d.ignoreNextIndent {
|
||||||
|
d.ignoreNextIndent = false
|
||||||
|
return
|
||||||
|
}
|
||||||
|
d.w.Write(bytes.Repeat([]byte(d.cs.Indent), d.depth))
|
||||||
|
}
|
||||||
|
|
||||||
|
// unpackValue returns values inside of non-nil interfaces when possible.
|
||||||
|
// This is useful for data types like structs, arrays, slices, and maps which
|
||||||
|
// can contain varying types packed inside an interface.
|
||||||
|
func (d *dumpState) unpackValue(v reflect.Value) reflect.Value {
|
||||||
|
if v.Kind() == reflect.Interface && !v.IsNil() {
|
||||||
|
v = v.Elem()
|
||||||
|
}
|
||||||
|
return v
|
||||||
|
}
|
||||||
|
|
||||||
|
// dumpPtr handles formatting of pointers by indirecting them as necessary.
|
||||||
|
func (d *dumpState) dumpPtr(v reflect.Value) {
|
||||||
|
// Remove pointers at or below the current depth from map used to detect
|
||||||
|
// circular refs.
|
||||||
|
for k, depth := range d.pointers {
|
||||||
|
if depth >= d.depth {
|
||||||
|
delete(d.pointers, k)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// Keep list of all dereferenced pointers to show later.
|
||||||
|
pointerChain := make([]uintptr, 0)
|
||||||
|
|
||||||
|
// Figure out how many levels of indirection there are by dereferencing
|
||||||
|
// pointers and unpacking interfaces down the chain while detecting circular
|
||||||
|
// references.
|
||||||
|
nilFound := false
|
||||||
|
cycleFound := false
|
||||||
|
indirects := 0
|
||||||
|
ve := v
|
||||||
|
for ve.Kind() == reflect.Ptr {
|
||||||
|
if ve.IsNil() {
|
||||||
|
nilFound = true
|
||||||
|
break
|
||||||
|
}
|
||||||
|
indirects++
|
||||||
|
addr := ve.Pointer()
|
||||||
|
pointerChain = append(pointerChain, addr)
|
||||||
|
if pd, ok := d.pointers[addr]; ok && pd < d.depth {
|
||||||
|
cycleFound = true
|
||||||
|
indirects--
|
||||||
|
break
|
||||||
|
}
|
||||||
|
d.pointers[addr] = d.depth
|
||||||
|
|
||||||
|
ve = ve.Elem()
|
||||||
|
if ve.Kind() == reflect.Interface {
|
||||||
|
if ve.IsNil() {
|
||||||
|
nilFound = true
|
||||||
|
break
|
||||||
|
}
|
||||||
|
ve = ve.Elem()
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// Display type information.
|
||||||
|
d.w.Write(openParenBytes)
|
||||||
|
d.w.Write(bytes.Repeat(asteriskBytes, indirects))
|
||||||
|
d.w.Write([]byte(ve.Type().String()))
|
||||||
|
d.w.Write(closeParenBytes)
|
||||||
|
|
||||||
|
// Display pointer information.
|
||||||
|
if len(pointerChain) > 0 {
|
||||||
|
d.w.Write(openParenBytes)
|
||||||
|
for i, addr := range pointerChain {
|
||||||
|
if i > 0 {
|
||||||
|
d.w.Write(pointerChainBytes)
|
||||||
|
}
|
||||||
|
printHexPtr(d.w, addr)
|
||||||
|
}
|
||||||
|
d.w.Write(closeParenBytes)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Display dereferenced value.
|
||||||
|
d.w.Write(openParenBytes)
|
||||||
|
switch {
|
||||||
|
case nilFound == true:
|
||||||
|
d.w.Write(nilAngleBytes)
|
||||||
|
|
||||||
|
case cycleFound == true:
|
||||||
|
d.w.Write(circularBytes)
|
||||||
|
|
||||||
|
default:
|
||||||
|
d.ignoreNextType = true
|
||||||
|
d.dump(ve)
|
||||||
|
}
|
||||||
|
d.w.Write(closeParenBytes)
|
||||||
|
}
|
||||||
|
|
||||||
|
// dumpSlice handles formatting of arrays and slices. Byte (uint8 under
|
||||||
|
// reflection) arrays and slices are dumped in hexdump -C fashion.
|
||||||
|
func (d *dumpState) dumpSlice(v reflect.Value) {
|
||||||
|
// Determine whether this type should be hex dumped or not. Also,
|
||||||
|
// for types which should be hexdumped, try to use the underlying data
|
||||||
|
// first, then fall back to trying to convert them to a uint8 slice.
|
||||||
|
var buf []uint8
|
||||||
|
doConvert := false
|
||||||
|
doHexDump := false
|
||||||
|
numEntries := v.Len()
|
||||||
|
if numEntries > 0 {
|
||||||
|
vt := v.Index(0).Type()
|
||||||
|
vts := vt.String()
|
||||||
|
switch {
|
||||||
|
// C types that need to be converted.
|
||||||
|
case cCharRE.MatchString(vts):
|
||||||
|
fallthrough
|
||||||
|
case cUnsignedCharRE.MatchString(vts):
|
||||||
|
fallthrough
|
||||||
|
case cUint8tCharRE.MatchString(vts):
|
||||||
|
doConvert = true
|
||||||
|
|
||||||
|
// Try to use existing uint8 slices and fall back to converting
|
||||||
|
// and copying if that fails.
|
||||||
|
case vt.Kind() == reflect.Uint8:
|
||||||
|
// We need an addressable interface to convert the type back
|
||||||
|
// into a byte slice. However, the reflect package won't give
|
||||||
|
// us an interface on certain things like unexported struct
|
||||||
|
// fields in order to enforce visibility rules. We use unsafe
|
||||||
|
// to bypass these restrictions since this package does not
|
||||||
|
// mutate the values.
|
||||||
|
vs := v
|
||||||
|
if !vs.CanInterface() || !vs.CanAddr() {
|
||||||
|
vs = unsafeReflectValue(vs)
|
||||||
|
}
|
||||||
|
vs = vs.Slice(0, numEntries)
|
||||||
|
|
||||||
|
// Use the existing uint8 slice if it can be type
|
||||||
|
// asserted.
|
||||||
|
iface := vs.Interface()
|
||||||
|
if slice, ok := iface.([]uint8); ok {
|
||||||
|
buf = slice
|
||||||
|
doHexDump = true
|
||||||
|
break
|
||||||
|
}
|
||||||
|
|
||||||
|
// The underlying data needs to be converted if it can't
|
||||||
|
// be type asserted to a uint8 slice.
|
||||||
|
doConvert = true
|
||||||
|
}
|
||||||
|
|
||||||
|
// Copy and convert the underlying type if needed.
|
||||||
|
if doConvert && vt.ConvertibleTo(uint8Type) {
|
||||||
|
// Convert and copy each element into a uint8 byte
|
||||||
|
// slice.
|
||||||
|
buf = make([]uint8, numEntries)
|
||||||
|
for i := 0; i < numEntries; i++ {
|
||||||
|
vv := v.Index(i)
|
||||||
|
buf[i] = uint8(vv.Convert(uint8Type).Uint())
|
||||||
|
}
|
||||||
|
doHexDump = true
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// Hexdump the entire slice as needed.
|
||||||
|
if doHexDump {
|
||||||
|
indent := strings.Repeat(d.cs.Indent, d.depth)
|
||||||
|
str := indent + hex.Dump(buf)
|
||||||
|
str = strings.Replace(str, "\n", "\n"+indent, -1)
|
||||||
|
str = strings.TrimRight(str, d.cs.Indent)
|
||||||
|
d.w.Write([]byte(str))
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
// Recursively call dump for each item.
|
||||||
|
for i := 0; i < numEntries; i++ {
|
||||||
|
d.dump(d.unpackValue(v.Index(i)))
|
||||||
|
if i < (numEntries - 1) {
|
||||||
|
d.w.Write(commaNewlineBytes)
|
||||||
|
} else {
|
||||||
|
d.w.Write(newlineBytes)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// dump is the main workhorse for dumping a value. It uses the passed reflect
|
||||||
|
// value to figure out what kind of object we are dealing with and formats it
|
||||||
|
// appropriately. It is a recursive function, however circular data structures
|
||||||
|
// are detected and handled properly.
|
||||||
|
func (d *dumpState) dump(v reflect.Value) {
|
||||||
|
// Handle invalid reflect values immediately.
|
||||||
|
kind := v.Kind()
|
||||||
|
if kind == reflect.Invalid {
|
||||||
|
d.w.Write(invalidAngleBytes)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
// Handle pointers specially.
|
||||||
|
if kind == reflect.Ptr {
|
||||||
|
d.indent()
|
||||||
|
d.dumpPtr(v)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
// Print type information unless already handled elsewhere.
|
||||||
|
if !d.ignoreNextType {
|
||||||
|
d.indent()
|
||||||
|
d.w.Write(openParenBytes)
|
||||||
|
d.w.Write([]byte(v.Type().String()))
|
||||||
|
d.w.Write(closeParenBytes)
|
||||||
|
d.w.Write(spaceBytes)
|
||||||
|
}
|
||||||
|
d.ignoreNextType = false
|
||||||
|
|
||||||
|
// Display length and capacity if the built-in len and cap functions
|
||||||
|
// work with the value's kind and the len/cap itself is non-zero.
|
||||||
|
valueLen, valueCap := 0, 0
|
||||||
|
switch v.Kind() {
|
||||||
|
case reflect.Array, reflect.Slice, reflect.Chan:
|
||||||
|
valueLen, valueCap = v.Len(), v.Cap()
|
||||||
|
case reflect.Map, reflect.String:
|
||||||
|
valueLen = v.Len()
|
||||||
|
}
|
||||||
|
if valueLen != 0 || valueCap != 0 {
|
||||||
|
d.w.Write(openParenBytes)
|
||||||
|
if valueLen != 0 {
|
||||||
|
d.w.Write(lenEqualsBytes)
|
||||||
|
printInt(d.w, int64(valueLen), 10)
|
||||||
|
}
|
||||||
|
if valueCap != 0 {
|
||||||
|
if valueLen != 0 {
|
||||||
|
d.w.Write(spaceBytes)
|
||||||
|
}
|
||||||
|
d.w.Write(capEqualsBytes)
|
||||||
|
printInt(d.w, int64(valueCap), 10)
|
||||||
|
}
|
||||||
|
d.w.Write(closeParenBytes)
|
||||||
|
d.w.Write(spaceBytes)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Call Stringer/error interfaces if they exist and the handle methods flag
|
||||||
|
// is enabled
|
||||||
|
if !d.cs.DisableMethods {
|
||||||
|
if (kind != reflect.Invalid) && (kind != reflect.Interface) {
|
||||||
|
if handled := handleMethods(d.cs, d.w, v); handled {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
switch kind {
|
||||||
|
case reflect.Invalid:
|
||||||
|
// Do nothing. We should never get here since invalid has already
|
||||||
|
// been handled above.
|
||||||
|
|
||||||
|
case reflect.Bool:
|
||||||
|
printBool(d.w, v.Bool())
|
||||||
|
|
||||||
|
case reflect.Int8, reflect.Int16, reflect.Int32, reflect.Int64, reflect.Int:
|
||||||
|
printInt(d.w, v.Int(), 10)
|
||||||
|
|
||||||
|
case reflect.Uint8, reflect.Uint16, reflect.Uint32, reflect.Uint64, reflect.Uint:
|
||||||
|
printUint(d.w, v.Uint(), 10)
|
||||||
|
|
||||||
|
case reflect.Float32:
|
||||||
|
printFloat(d.w, v.Float(), 32)
|
||||||
|
|
||||||
|
case reflect.Float64:
|
||||||
|
printFloat(d.w, v.Float(), 64)
|
||||||
|
|
||||||
|
case reflect.Complex64:
|
||||||
|
printComplex(d.w, v.Complex(), 32)
|
||||||
|
|
||||||
|
case reflect.Complex128:
|
||||||
|
printComplex(d.w, v.Complex(), 64)
|
||||||
|
|
||||||
|
case reflect.Slice:
|
||||||
|
if v.IsNil() {
|
||||||
|
d.w.Write(nilAngleBytes)
|
||||||
|
break
|
||||||
|
}
|
||||||
|
fallthrough
|
||||||
|
|
||||||
|
case reflect.Array:
|
||||||
|
d.w.Write(openBraceNewlineBytes)
|
||||||
|
d.depth++
|
||||||
|
if (d.cs.MaxDepth != 0) && (d.depth > d.cs.MaxDepth) {
|
||||||
|
d.indent()
|
||||||
|
d.w.Write(maxNewlineBytes)
|
||||||
|
} else {
|
||||||
|
d.dumpSlice(v)
|
||||||
|
}
|
||||||
|
d.depth--
|
||||||
|
d.indent()
|
||||||
|
d.w.Write(closeBraceBytes)
|
||||||
|
|
||||||
|
case reflect.String:
|
||||||
|
d.w.Write([]byte(strconv.Quote(v.String())))
|
||||||
|
|
||||||
|
case reflect.Interface:
|
||||||
|
// The only time we should get here is for nil interfaces due to
|
||||||
|
// unpackValue calls.
|
||||||
|
if v.IsNil() {
|
||||||
|
d.w.Write(nilAngleBytes)
|
||||||
|
}
|
||||||
|
|
||||||
|
case reflect.Ptr:
|
||||||
|
// Do nothing. We should never get here since pointers have already
|
||||||
|
// been handled above.
|
||||||
|
|
||||||
|
case reflect.Map:
|
||||||
|
// nil maps should be indicated as different than empty maps
|
||||||
|
if v.IsNil() {
|
||||||
|
d.w.Write(nilAngleBytes)
|
||||||
|
break
|
||||||
|
}
|
||||||
|
|
||||||
|
d.w.Write(openBraceNewlineBytes)
|
||||||
|
d.depth++
|
||||||
|
if (d.cs.MaxDepth != 0) && (d.depth > d.cs.MaxDepth) {
|
||||||
|
d.indent()
|
||||||
|
d.w.Write(maxNewlineBytes)
|
||||||
|
} else {
|
||||||
|
numEntries := v.Len()
|
||||||
|
keys := v.MapKeys()
|
||||||
|
if d.cs.SortKeys {
|
||||||
|
sortValues(keys, d.cs)
|
||||||
|
}
|
||||||
|
for i, key := range keys {
|
||||||
|
d.dump(d.unpackValue(key))
|
||||||
|
d.w.Write(colonSpaceBytes)
|
||||||
|
d.ignoreNextIndent = true
|
||||||
|
d.dump(d.unpackValue(v.MapIndex(key)))
|
||||||
|
if i < (numEntries - 1) {
|
||||||
|
d.w.Write(commaNewlineBytes)
|
||||||
|
} else {
|
||||||
|
d.w.Write(newlineBytes)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
d.depth--
|
||||||
|
d.indent()
|
||||||
|
d.w.Write(closeBraceBytes)
|
||||||
|
|
||||||
|
case reflect.Struct:
|
||||||
|
d.w.Write(openBraceNewlineBytes)
|
||||||
|
d.depth++
|
||||||
|
if (d.cs.MaxDepth != 0) && (d.depth > d.cs.MaxDepth) {
|
||||||
|
d.indent()
|
||||||
|
d.w.Write(maxNewlineBytes)
|
||||||
|
} else {
|
||||||
|
vt := v.Type()
|
||||||
|
numFields := v.NumField()
|
||||||
|
for i := 0; i < numFields; i++ {
|
||||||
|
d.indent()
|
||||||
|
vtf := vt.Field(i)
|
||||||
|
d.w.Write([]byte(vtf.Name))
|
||||||
|
d.w.Write(colonSpaceBytes)
|
||||||
|
d.ignoreNextIndent = true
|
||||||
|
d.dump(d.unpackValue(v.Field(i)))
|
||||||
|
if i < (numFields - 1) {
|
||||||
|
d.w.Write(commaNewlineBytes)
|
||||||
|
} else {
|
||||||
|
d.w.Write(newlineBytes)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
d.depth--
|
||||||
|
d.indent()
|
||||||
|
d.w.Write(closeBraceBytes)
|
||||||
|
|
||||||
|
case reflect.Uintptr:
|
||||||
|
printHexPtr(d.w, uintptr(v.Uint()))
|
||||||
|
|
||||||
|
case reflect.UnsafePointer, reflect.Chan, reflect.Func:
|
||||||
|
printHexPtr(d.w, v.Pointer())
|
||||||
|
|
||||||
|
// There were not any other types at the time this code was written, but
|
||||||
|
// fall back to letting the default fmt package handle it in case any new
|
||||||
|
// types are added.
|
||||||
|
default:
|
||||||
|
if v.CanInterface() {
|
||||||
|
fmt.Fprintf(d.w, "%v", v.Interface())
|
||||||
|
} else {
|
||||||
|
fmt.Fprintf(d.w, "%v", v.String())
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// fdump is a helper function to consolidate the logic from the various public
|
||||||
|
// methods which take varying writers and config states.
|
||||||
|
func fdump(cs *ConfigState, w io.Writer, a ...interface{}) {
|
||||||
|
for _, arg := range a {
|
||||||
|
if arg == nil {
|
||||||
|
w.Write(interfaceBytes)
|
||||||
|
w.Write(spaceBytes)
|
||||||
|
w.Write(nilAngleBytes)
|
||||||
|
w.Write(newlineBytes)
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
|
||||||
|
d := dumpState{w: w, cs: cs}
|
||||||
|
d.pointers = make(map[uintptr]int)
|
||||||
|
d.dump(reflect.ValueOf(arg))
|
||||||
|
d.w.Write(newlineBytes)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// Fdump formats and displays the passed arguments to io.Writer w. It formats
|
||||||
|
// exactly the same as Dump.
|
||||||
|
func Fdump(w io.Writer, a ...interface{}) {
|
||||||
|
fdump(&Config, w, a...)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Sdump returns a string with the passed arguments formatted exactly the same
|
||||||
|
// as Dump.
|
||||||
|
func Sdump(a ...interface{}) string {
|
||||||
|
var buf bytes.Buffer
|
||||||
|
fdump(&Config, &buf, a...)
|
||||||
|
return buf.String()
|
||||||
|
}
|
||||||
|
|
||||||
|
/*
|
||||||
|
Dump displays the passed parameters to standard out with newlines, customizable
|
||||||
|
indentation, and additional debug information such as complete types and all
|
||||||
|
pointer addresses used to indirect to the final value. It provides the
|
||||||
|
following features over the built-in printing facilities provided by the fmt
|
||||||
|
package:
|
||||||
|
|
||||||
|
* Pointers are dereferenced and followed
|
||||||
|
* Circular data structures are detected and handled properly
|
||||||
|
* Custom Stringer/error interfaces are optionally invoked, including
|
||||||
|
on unexported types
|
||||||
|
* Custom types which only implement the Stringer/error interfaces via
|
||||||
|
a pointer receiver are optionally invoked when passing non-pointer
|
||||||
|
variables
|
||||||
|
* Byte arrays and slices are dumped like the hexdump -C command which
|
||||||
|
includes offsets, byte values in hex, and ASCII output
|
||||||
|
|
||||||
|
The configuration options are controlled by an exported package global,
|
||||||
|
spew.Config. See ConfigState for options documentation.
|
||||||
|
|
||||||
|
See Fdump if you would prefer dumping to an arbitrary io.Writer or Sdump to
|
||||||
|
get the formatted result as a string.
|
||||||
|
*/
|
||||||
|
func Dump(a ...interface{}) {
|
||||||
|
fdump(&Config, os.Stdout, a...)
|
||||||
|
}
|
1021
Godeps/_workspace/src/github.com/davecgh/go-spew/spew/dump_test.go
generated
vendored
Normal file
1021
Godeps/_workspace/src/github.com/davecgh/go-spew/spew/dump_test.go
generated
vendored
Normal file
File diff suppressed because it is too large
Load Diff
97
Godeps/_workspace/src/github.com/davecgh/go-spew/spew/dumpcgo_test.go
generated
vendored
Normal file
97
Godeps/_workspace/src/github.com/davecgh/go-spew/spew/dumpcgo_test.go
generated
vendored
Normal file
|
@ -0,0 +1,97 @@
|
||||||
|
// Copyright (c) 2013 Dave Collins <dave@davec.name>
|
||||||
|
//
|
||||||
|
// Permission to use, copy, modify, and distribute this software for any
|
||||||
|
// purpose with or without fee is hereby granted, provided that the above
|
||||||
|
// copyright notice and this permission notice appear in all copies.
|
||||||
|
//
|
||||||
|
// THE SOFTWARE IS PROVIDED "AS IS" AND THE AUTHOR DISCLAIMS ALL WARRANTIES
|
||||||
|
// WITH REGARD TO THIS SOFTWARE INCLUDING ALL IMPLIED WARRANTIES OF
|
||||||
|
// MERCHANTABILITY AND FITNESS. IN NO EVENT SHALL THE AUTHOR BE LIABLE FOR
|
||||||
|
// ANY SPECIAL, DIRECT, INDIRECT, OR CONSEQUENTIAL DAMAGES OR ANY DAMAGES
|
||||||
|
// WHATSOEVER RESULTING FROM LOSS OF USE, DATA OR PROFITS, WHETHER IN AN
|
||||||
|
// ACTION OF CONTRACT, NEGLIGENCE OR OTHER TORTIOUS ACTION, ARISING OUT OF
|
||||||
|
// OR IN CONNECTION WITH THE USE OR PERFORMANCE OF THIS SOFTWARE.
|
||||||
|
|
||||||
|
// NOTE: Due to the following build constraints, this file will only be compiled
|
||||||
|
// when both cgo is supported and "-tags testcgo" is added to the go test
|
||||||
|
// command line. This means the cgo tests are only added (and hence run) when
|
||||||
|
// specifially requested. This configuration is used because spew itself
|
||||||
|
// does not require cgo to run even though it does handle certain cgo types
|
||||||
|
// specially. Rather than forcing all clients to require cgo and an external
|
||||||
|
// C compiler just to run the tests, this scheme makes them optional.
|
||||||
|
// +build cgo,testcgo
|
||||||
|
|
||||||
|
package spew_test
|
||||||
|
|
||||||
|
import (
|
||||||
|
"fmt"
|
||||||
|
"github.com/davecgh/go-spew/spew/testdata"
|
||||||
|
)
|
||||||
|
|
||||||
|
func addCgoDumpTests() {
|
||||||
|
// C char pointer.
|
||||||
|
v := testdata.GetCgoCharPointer()
|
||||||
|
nv := testdata.GetCgoNullCharPointer()
|
||||||
|
pv := &v
|
||||||
|
vcAddr := fmt.Sprintf("%p", v)
|
||||||
|
vAddr := fmt.Sprintf("%p", pv)
|
||||||
|
pvAddr := fmt.Sprintf("%p", &pv)
|
||||||
|
vt := "*testdata._Ctype_char"
|
||||||
|
vs := "116"
|
||||||
|
addDumpTest(v, "("+vt+")("+vcAddr+")("+vs+")\n")
|
||||||
|
addDumpTest(pv, "(*"+vt+")("+vAddr+"->"+vcAddr+")("+vs+")\n")
|
||||||
|
addDumpTest(&pv, "(**"+vt+")("+pvAddr+"->"+vAddr+"->"+vcAddr+")("+vs+")\n")
|
||||||
|
addDumpTest(nv, "("+vt+")(<nil>)\n")
|
||||||
|
|
||||||
|
// C char array.
|
||||||
|
v2, v2l, v2c := testdata.GetCgoCharArray()
|
||||||
|
v2Len := fmt.Sprintf("%d", v2l)
|
||||||
|
v2Cap := fmt.Sprintf("%d", v2c)
|
||||||
|
v2t := "[6]testdata._Ctype_char"
|
||||||
|
v2s := "(len=" + v2Len + " cap=" + v2Cap + ") " +
|
||||||
|
"{\n 00000000 74 65 73 74 32 00 " +
|
||||||
|
" |test2.|\n}"
|
||||||
|
addDumpTest(v2, "("+v2t+") "+v2s+"\n")
|
||||||
|
|
||||||
|
// C unsigned char array.
|
||||||
|
v3, v3l, v3c := testdata.GetCgoUnsignedCharArray()
|
||||||
|
v3Len := fmt.Sprintf("%d", v3l)
|
||||||
|
v3Cap := fmt.Sprintf("%d", v3c)
|
||||||
|
v3t := "[6]testdata._Ctype_unsignedchar"
|
||||||
|
v3s := "(len=" + v3Len + " cap=" + v3Cap + ") " +
|
||||||
|
"{\n 00000000 74 65 73 74 33 00 " +
|
||||||
|
" |test3.|\n}"
|
||||||
|
addDumpTest(v3, "("+v3t+") "+v3s+"\n")
|
||||||
|
|
||||||
|
// C signed char array.
|
||||||
|
v4, v4l, v4c := testdata.GetCgoSignedCharArray()
|
||||||
|
v4Len := fmt.Sprintf("%d", v4l)
|
||||||
|
v4Cap := fmt.Sprintf("%d", v4c)
|
||||||
|
v4t := "[6]testdata._Ctype_schar"
|
||||||
|
v4t2 := "testdata._Ctype_schar"
|
||||||
|
v4s := "(len=" + v4Len + " cap=" + v4Cap + ") " +
|
||||||
|
"{\n (" + v4t2 + ") 116,\n (" + v4t2 + ") 101,\n (" + v4t2 +
|
||||||
|
") 115,\n (" + v4t2 + ") 116,\n (" + v4t2 + ") 52,\n (" + v4t2 +
|
||||||
|
") 0\n}"
|
||||||
|
addDumpTest(v4, "("+v4t+") "+v4s+"\n")
|
||||||
|
|
||||||
|
// C uint8_t array.
|
||||||
|
v5, v5l, v5c := testdata.GetCgoUint8tArray()
|
||||||
|
v5Len := fmt.Sprintf("%d", v5l)
|
||||||
|
v5Cap := fmt.Sprintf("%d", v5c)
|
||||||
|
v5t := "[6]testdata._Ctype_uint8_t"
|
||||||
|
v5s := "(len=" + v5Len + " cap=" + v5Cap + ") " +
|
||||||
|
"{\n 00000000 74 65 73 74 35 00 " +
|
||||||
|
" |test5.|\n}"
|
||||||
|
addDumpTest(v5, "("+v5t+") "+v5s+"\n")
|
||||||
|
|
||||||
|
// C typedefed unsigned char array.
|
||||||
|
v6, v6l, v6c := testdata.GetCgoTypdefedUnsignedCharArray()
|
||||||
|
v6Len := fmt.Sprintf("%d", v6l)
|
||||||
|
v6Cap := fmt.Sprintf("%d", v6c)
|
||||||
|
v6t := "[6]testdata._Ctype_custom_uchar_t"
|
||||||
|
v6s := "(len=" + v6Len + " cap=" + v6Cap + ") " +
|
||||||
|
"{\n 00000000 74 65 73 74 36 00 " +
|
||||||
|
" |test6.|\n}"
|
||||||
|
addDumpTest(v6, "("+v6t+") "+v6s+"\n")
|
||||||
|
}
|
26
Godeps/_workspace/src/github.com/davecgh/go-spew/spew/dumpnocgo_test.go
generated
vendored
Normal file
26
Godeps/_workspace/src/github.com/davecgh/go-spew/spew/dumpnocgo_test.go
generated
vendored
Normal file
|
@ -0,0 +1,26 @@
|
||||||
|
// Copyright (c) 2013 Dave Collins <dave@davec.name>
|
||||||
|
//
|
||||||
|
// Permission to use, copy, modify, and distribute this software for any
|
||||||
|
// purpose with or without fee is hereby granted, provided that the above
|
||||||
|
// copyright notice and this permission notice appear in all copies.
|
||||||
|
//
|
||||||
|
// THE SOFTWARE IS PROVIDED "AS IS" AND THE AUTHOR DISCLAIMS ALL WARRANTIES
|
||||||
|
// WITH REGARD TO THIS SOFTWARE INCLUDING ALL IMPLIED WARRANTIES OF
|
||||||
|
// MERCHANTABILITY AND FITNESS. IN NO EVENT SHALL THE AUTHOR BE LIABLE FOR
|
||||||
|
// ANY SPECIAL, DIRECT, INDIRECT, OR CONSEQUENTIAL DAMAGES OR ANY DAMAGES
|
||||||
|
// WHATSOEVER RESULTING FROM LOSS OF USE, DATA OR PROFITS, WHETHER IN AN
|
||||||
|
// ACTION OF CONTRACT, NEGLIGENCE OR OTHER TORTIOUS ACTION, ARISING OUT OF
|
||||||
|
// OR IN CONNECTION WITH THE USE OR PERFORMANCE OF THIS SOFTWARE.
|
||||||
|
|
||||||
|
// NOTE: Due to the following build constraints, this file will only be compiled
|
||||||
|
// when either cgo is not supported or "-tags testcgo" is not added to the go
|
||||||
|
// test command line. This file intentionally does not setup any cgo tests in
|
||||||
|
// this scenario.
|
||||||
|
// +build !cgo !testcgo
|
||||||
|
|
||||||
|
package spew_test
|
||||||
|
|
||||||
|
func addCgoDumpTests() {
|
||||||
|
// Don't add any tests for cgo since this file is only compiled when
|
||||||
|
// there should not be any cgo tests.
|
||||||
|
}
|
230
Godeps/_workspace/src/github.com/davecgh/go-spew/spew/example_test.go
generated
vendored
Normal file
230
Godeps/_workspace/src/github.com/davecgh/go-spew/spew/example_test.go
generated
vendored
Normal file
|
@ -0,0 +1,230 @@
|
||||||
|
/*
|
||||||
|
* Copyright (c) 2013 Dave Collins <dave@davec.name>
|
||||||
|
*
|
||||||
|
* Permission to use, copy, modify, and distribute this software for any
|
||||||
|
* purpose with or without fee is hereby granted, provided that the above
|
||||||
|
* copyright notice and this permission notice appear in all copies.
|
||||||
|
*
|
||||||
|
* THE SOFTWARE IS PROVIDED "AS IS" AND THE AUTHOR DISCLAIMS ALL WARRANTIES
|
||||||
|
* WITH REGARD TO THIS SOFTWARE INCLUDING ALL IMPLIED WARRANTIES OF
|
||||||
|
* MERCHANTABILITY AND FITNESS. IN NO EVENT SHALL THE AUTHOR BE LIABLE FOR
|
||||||
|
* ANY SPECIAL, DIRECT, INDIRECT, OR CONSEQUENTIAL DAMAGES OR ANY DAMAGES
|
||||||
|
* WHATSOEVER RESULTING FROM LOSS OF USE, DATA OR PROFITS, WHETHER IN AN
|
||||||
|
* ACTION OF CONTRACT, NEGLIGENCE OR OTHER TORTIOUS ACTION, ARISING OUT OF
|
||||||
|
* OR IN CONNECTION WITH THE USE OR PERFORMANCE OF THIS SOFTWARE.
|
||||||
|
*/
|
||||||
|
|
||||||
|
package spew_test
|
||||||
|
|
||||||
|
import (
|
||||||
|
"fmt"
|
||||||
|
"github.com/davecgh/go-spew/spew"
|
||||||
|
)
|
||||||
|
|
||||||
|
type Flag int
|
||||||
|
|
||||||
|
const (
|
||||||
|
flagOne Flag = iota
|
||||||
|
flagTwo
|
||||||
|
)
|
||||||
|
|
||||||
|
var flagStrings = map[Flag]string{
|
||||||
|
flagOne: "flagOne",
|
||||||
|
flagTwo: "flagTwo",
|
||||||
|
}
|
||||||
|
|
||||||
|
func (f Flag) String() string {
|
||||||
|
if s, ok := flagStrings[f]; ok {
|
||||||
|
return s
|
||||||
|
}
|
||||||
|
return fmt.Sprintf("Unknown flag (%d)", int(f))
|
||||||
|
}
|
||||||
|
|
||||||
|
type Bar struct {
|
||||||
|
flag Flag
|
||||||
|
data uintptr
|
||||||
|
}
|
||||||
|
|
||||||
|
type Foo struct {
|
||||||
|
unexportedField Bar
|
||||||
|
ExportedField map[interface{}]interface{}
|
||||||
|
}
|
||||||
|
|
||||||
|
// This example demonstrates how to use Dump to dump variables to stdout.
|
||||||
|
func ExampleDump() {
|
||||||
|
// The following package level declarations are assumed for this example:
|
||||||
|
/*
|
||||||
|
type Flag int
|
||||||
|
|
||||||
|
const (
|
||||||
|
flagOne Flag = iota
|
||||||
|
flagTwo
|
||||||
|
)
|
||||||
|
|
||||||
|
var flagStrings = map[Flag]string{
|
||||||
|
flagOne: "flagOne",
|
||||||
|
flagTwo: "flagTwo",
|
||||||
|
}
|
||||||
|
|
||||||
|
func (f Flag) String() string {
|
||||||
|
if s, ok := flagStrings[f]; ok {
|
||||||
|
return s
|
||||||
|
}
|
||||||
|
return fmt.Sprintf("Unknown flag (%d)", int(f))
|
||||||
|
}
|
||||||
|
|
||||||
|
type Bar struct {
|
||||||
|
flag Flag
|
||||||
|
data uintptr
|
||||||
|
}
|
||||||
|
|
||||||
|
type Foo struct {
|
||||||
|
unexportedField Bar
|
||||||
|
ExportedField map[interface{}]interface{}
|
||||||
|
}
|
||||||
|
*/
|
||||||
|
|
||||||
|
// Setup some sample data structures for the example.
|
||||||
|
bar := Bar{Flag(flagTwo), uintptr(0)}
|
||||||
|
s1 := Foo{bar, map[interface{}]interface{}{"one": true}}
|
||||||
|
f := Flag(5)
|
||||||
|
b := []byte{
|
||||||
|
0x11, 0x12, 0x13, 0x14, 0x15, 0x16, 0x17, 0x18,
|
||||||
|
0x19, 0x1a, 0x1b, 0x1c, 0x1d, 0x1e, 0x1f, 0x20,
|
||||||
|
0x21, 0x22, 0x23, 0x24, 0x25, 0x26, 0x27, 0x28,
|
||||||
|
0x29, 0x2a, 0x2b, 0x2c, 0x2d, 0x2e, 0x2f, 0x30,
|
||||||
|
0x31, 0x32,
|
||||||
|
}
|
||||||
|
|
||||||
|
// Dump!
|
||||||
|
spew.Dump(s1, f, b)
|
||||||
|
|
||||||
|
// Output:
|
||||||
|
// (spew_test.Foo) {
|
||||||
|
// unexportedField: (spew_test.Bar) {
|
||||||
|
// flag: (spew_test.Flag) flagTwo,
|
||||||
|
// data: (uintptr) <nil>
|
||||||
|
// },
|
||||||
|
// ExportedField: (map[interface {}]interface {}) (len=1) {
|
||||||
|
// (string) (len=3) "one": (bool) true
|
||||||
|
// }
|
||||||
|
// }
|
||||||
|
// (spew_test.Flag) Unknown flag (5)
|
||||||
|
// ([]uint8) (len=34 cap=34) {
|
||||||
|
// 00000000 11 12 13 14 15 16 17 18 19 1a 1b 1c 1d 1e 1f 20 |............... |
|
||||||
|
// 00000010 21 22 23 24 25 26 27 28 29 2a 2b 2c 2d 2e 2f 30 |!"#$%&'()*+,-./0|
|
||||||
|
// 00000020 31 32 |12|
|
||||||
|
// }
|
||||||
|
//
|
||||||
|
}
|
||||||
|
|
||||||
|
// This example demonstrates how to use Printf to display a variable with a
|
||||||
|
// format string and inline formatting.
|
||||||
|
func ExamplePrintf() {
|
||||||
|
// Create a double pointer to a uint 8.
|
||||||
|
ui8 := uint8(5)
|
||||||
|
pui8 := &ui8
|
||||||
|
ppui8 := &pui8
|
||||||
|
|
||||||
|
// Create a circular data type.
|
||||||
|
type circular struct {
|
||||||
|
ui8 uint8
|
||||||
|
c *circular
|
||||||
|
}
|
||||||
|
c := circular{ui8: 1}
|
||||||
|
c.c = &c
|
||||||
|
|
||||||
|
// Print!
|
||||||
|
spew.Printf("ppui8: %v\n", ppui8)
|
||||||
|
spew.Printf("circular: %v\n", c)
|
||||||
|
|
||||||
|
// Output:
|
||||||
|
// ppui8: <**>5
|
||||||
|
// circular: {1 <*>{1 <*><shown>}}
|
||||||
|
}
|
||||||
|
|
||||||
|
// This example demonstrates how to use a ConfigState.
|
||||||
|
func ExampleConfigState() {
|
||||||
|
// Modify the indent level of the ConfigState only. The global
|
||||||
|
// configuration is not modified.
|
||||||
|
scs := spew.ConfigState{Indent: "\t"}
|
||||||
|
|
||||||
|
// Output using the ConfigState instance.
|
||||||
|
v := map[string]int{"one": 1}
|
||||||
|
scs.Printf("v: %v\n", v)
|
||||||
|
scs.Dump(v)
|
||||||
|
|
||||||
|
// Output:
|
||||||
|
// v: map[one:1]
|
||||||
|
// (map[string]int) (len=1) {
|
||||||
|
// (string) (len=3) "one": (int) 1
|
||||||
|
// }
|
||||||
|
}
|
||||||
|
|
||||||
|
// This example demonstrates how to use ConfigState.Dump to dump variables to
|
||||||
|
// stdout
|
||||||
|
func ExampleConfigState_Dump() {
|
||||||
|
// See the top-level Dump example for details on the types used in this
|
||||||
|
// example.
|
||||||
|
|
||||||
|
// Create two ConfigState instances with different indentation.
|
||||||
|
scs := spew.ConfigState{Indent: "\t"}
|
||||||
|
scs2 := spew.ConfigState{Indent: " "}
|
||||||
|
|
||||||
|
// Setup some sample data structures for the example.
|
||||||
|
bar := Bar{Flag(flagTwo), uintptr(0)}
|
||||||
|
s1 := Foo{bar, map[interface{}]interface{}{"one": true}}
|
||||||
|
|
||||||
|
// Dump using the ConfigState instances.
|
||||||
|
scs.Dump(s1)
|
||||||
|
scs2.Dump(s1)
|
||||||
|
|
||||||
|
// Output:
|
||||||
|
// (spew_test.Foo) {
|
||||||
|
// unexportedField: (spew_test.Bar) {
|
||||||
|
// flag: (spew_test.Flag) flagTwo,
|
||||||
|
// data: (uintptr) <nil>
|
||||||
|
// },
|
||||||
|
// ExportedField: (map[interface {}]interface {}) (len=1) {
|
||||||
|
// (string) (len=3) "one": (bool) true
|
||||||
|
// }
|
||||||
|
// }
|
||||||
|
// (spew_test.Foo) {
|
||||||
|
// unexportedField: (spew_test.Bar) {
|
||||||
|
// flag: (spew_test.Flag) flagTwo,
|
||||||
|
// data: (uintptr) <nil>
|
||||||
|
// },
|
||||||
|
// ExportedField: (map[interface {}]interface {}) (len=1) {
|
||||||
|
// (string) (len=3) "one": (bool) true
|
||||||
|
// }
|
||||||
|
// }
|
||||||
|
//
|
||||||
|
}
|
||||||
|
|
||||||
|
// This example demonstrates how to use ConfigState.Printf to display a variable
|
||||||
|
// with a format string and inline formatting.
|
||||||
|
func ExampleConfigState_Printf() {
|
||||||
|
// See the top-level Dump example for details on the types used in this
|
||||||
|
// example.
|
||||||
|
|
||||||
|
// Create two ConfigState instances and modify the method handling of the
|
||||||
|
// first ConfigState only.
|
||||||
|
scs := spew.NewDefaultConfig()
|
||||||
|
scs2 := spew.NewDefaultConfig()
|
||||||
|
scs.DisableMethods = true
|
||||||
|
|
||||||
|
// Alternatively
|
||||||
|
// scs := spew.ConfigState{Indent: " ", DisableMethods: true}
|
||||||
|
// scs2 := spew.ConfigState{Indent: " "}
|
||||||
|
|
||||||
|
// This is of type Flag which implements a Stringer and has raw value 1.
|
||||||
|
f := flagTwo
|
||||||
|
|
||||||
|
// Dump using the ConfigState instances.
|
||||||
|
scs.Printf("f: %v\n", f)
|
||||||
|
scs2.Printf("f: %v\n", f)
|
||||||
|
|
||||||
|
// Output:
|
||||||
|
// f: 1
|
||||||
|
// f: flagTwo
|
||||||
|
}
|
|
@ -0,0 +1,419 @@
|
||||||
|
/*
|
||||||
|
* Copyright (c) 2013 Dave Collins <dave@davec.name>
|
||||||
|
*
|
||||||
|
* Permission to use, copy, modify, and distribute this software for any
|
||||||
|
* purpose with or without fee is hereby granted, provided that the above
|
||||||
|
* copyright notice and this permission notice appear in all copies.
|
||||||
|
*
|
||||||
|
* THE SOFTWARE IS PROVIDED "AS IS" AND THE AUTHOR DISCLAIMS ALL WARRANTIES
|
||||||
|
* WITH REGARD TO THIS SOFTWARE INCLUDING ALL IMPLIED WARRANTIES OF
|
||||||
|
* MERCHANTABILITY AND FITNESS. IN NO EVENT SHALL THE AUTHOR BE LIABLE FOR
|
||||||
|
* ANY SPECIAL, DIRECT, INDIRECT, OR CONSEQUENTIAL DAMAGES OR ANY DAMAGES
|
||||||
|
* WHATSOEVER RESULTING FROM LOSS OF USE, DATA OR PROFITS, WHETHER IN AN
|
||||||
|
* ACTION OF CONTRACT, NEGLIGENCE OR OTHER TORTIOUS ACTION, ARISING OUT OF
|
||||||
|
* OR IN CONNECTION WITH THE USE OR PERFORMANCE OF THIS SOFTWARE.
|
||||||
|
*/
|
||||||
|
|
||||||
|
package spew
|
||||||
|
|
||||||
|
import (
|
||||||
|
"bytes"
|
||||||
|
"fmt"
|
||||||
|
"reflect"
|
||||||
|
"strconv"
|
||||||
|
"strings"
|
||||||
|
)
|
||||||
|
|
||||||
|
// supportedFlags is a list of all the character flags supported by fmt package.
|
||||||
|
const supportedFlags = "0-+# "
|
||||||
|
|
||||||
|
// formatState implements the fmt.Formatter interface and contains information
|
||||||
|
// about the state of a formatting operation. The NewFormatter function can
|
||||||
|
// be used to get a new Formatter which can be used directly as arguments
|
||||||
|
// in standard fmt package printing calls.
|
||||||
|
type formatState struct {
|
||||||
|
value interface{}
|
||||||
|
fs fmt.State
|
||||||
|
depth int
|
||||||
|
pointers map[uintptr]int
|
||||||
|
ignoreNextType bool
|
||||||
|
cs *ConfigState
|
||||||
|
}
|
||||||
|
|
||||||
|
// buildDefaultFormat recreates the original format string without precision
|
||||||
|
// and width information to pass in to fmt.Sprintf in the case of an
|
||||||
|
// unrecognized type. Unless new types are added to the language, this
|
||||||
|
// function won't ever be called.
|
||||||
|
func (f *formatState) buildDefaultFormat() (format string) {
|
||||||
|
buf := bytes.NewBuffer(percentBytes)
|
||||||
|
|
||||||
|
for _, flag := range supportedFlags {
|
||||||
|
if f.fs.Flag(int(flag)) {
|
||||||
|
buf.WriteRune(flag)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
buf.WriteRune('v')
|
||||||
|
|
||||||
|
format = buf.String()
|
||||||
|
return format
|
||||||
|
}
|
||||||
|
|
||||||
|
// constructOrigFormat recreates the original format string including precision
|
||||||
|
// and width information to pass along to the standard fmt package. This allows
|
||||||
|
// automatic deferral of all format strings this package doesn't support.
|
||||||
|
func (f *formatState) constructOrigFormat(verb rune) (format string) {
|
||||||
|
buf := bytes.NewBuffer(percentBytes)
|
||||||
|
|
||||||
|
for _, flag := range supportedFlags {
|
||||||
|
if f.fs.Flag(int(flag)) {
|
||||||
|
buf.WriteRune(flag)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
if width, ok := f.fs.Width(); ok {
|
||||||
|
buf.WriteString(strconv.Itoa(width))
|
||||||
|
}
|
||||||
|
|
||||||
|
if precision, ok := f.fs.Precision(); ok {
|
||||||
|
buf.Write(precisionBytes)
|
||||||
|
buf.WriteString(strconv.Itoa(precision))
|
||||||
|
}
|
||||||
|
|
||||||
|
buf.WriteRune(verb)
|
||||||
|
|
||||||
|
format = buf.String()
|
||||||
|
return format
|
||||||
|
}
|
||||||
|
|
||||||
|
// unpackValue returns values inside of non-nil interfaces when possible and
|
||||||
|
// ensures that types for values which have been unpacked from an interface
|
||||||
|
// are displayed when the show types flag is also set.
|
||||||
|
// This is useful for data types like structs, arrays, slices, and maps which
|
||||||
|
// can contain varying types packed inside an interface.
|
||||||
|
func (f *formatState) unpackValue(v reflect.Value) reflect.Value {
|
||||||
|
if v.Kind() == reflect.Interface {
|
||||||
|
f.ignoreNextType = false
|
||||||
|
if !v.IsNil() {
|
||||||
|
v = v.Elem()
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return v
|
||||||
|
}
|
||||||
|
|
||||||
|
// formatPtr handles formatting of pointers by indirecting them as necessary.
|
||||||
|
func (f *formatState) formatPtr(v reflect.Value) {
|
||||||
|
// Display nil if top level pointer is nil.
|
||||||
|
showTypes := f.fs.Flag('#')
|
||||||
|
if v.IsNil() && (!showTypes || f.ignoreNextType) {
|
||||||
|
f.fs.Write(nilAngleBytes)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
// Remove pointers at or below the current depth from map used to detect
|
||||||
|
// circular refs.
|
||||||
|
for k, depth := range f.pointers {
|
||||||
|
if depth >= f.depth {
|
||||||
|
delete(f.pointers, k)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// Keep list of all dereferenced pointers to possibly show later.
|
||||||
|
pointerChain := make([]uintptr, 0)
|
||||||
|
|
||||||
|
// Figure out how many levels of indirection there are by derferencing
|
||||||
|
// pointers and unpacking interfaces down the chain while detecting circular
|
||||||
|
// references.
|
||||||
|
nilFound := false
|
||||||
|
cycleFound := false
|
||||||
|
indirects := 0
|
||||||
|
ve := v
|
||||||
|
for ve.Kind() == reflect.Ptr {
|
||||||
|
if ve.IsNil() {
|
||||||
|
nilFound = true
|
||||||
|
break
|
||||||
|
}
|
||||||
|
indirects++
|
||||||
|
addr := ve.Pointer()
|
||||||
|
pointerChain = append(pointerChain, addr)
|
||||||
|
if pd, ok := f.pointers[addr]; ok && pd < f.depth {
|
||||||
|
cycleFound = true
|
||||||
|
indirects--
|
||||||
|
break
|
||||||
|
}
|
||||||
|
f.pointers[addr] = f.depth
|
||||||
|
|
||||||
|
ve = ve.Elem()
|
||||||
|
if ve.Kind() == reflect.Interface {
|
||||||
|
if ve.IsNil() {
|
||||||
|
nilFound = true
|
||||||
|
break
|
||||||
|
}
|
||||||
|
ve = ve.Elem()
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// Display type or indirection level depending on flags.
|
||||||
|
if showTypes && !f.ignoreNextType {
|
||||||
|
f.fs.Write(openParenBytes)
|
||||||
|
f.fs.Write(bytes.Repeat(asteriskBytes, indirects))
|
||||||
|
f.fs.Write([]byte(ve.Type().String()))
|
||||||
|
f.fs.Write(closeParenBytes)
|
||||||
|
} else {
|
||||||
|
if nilFound || cycleFound {
|
||||||
|
indirects += strings.Count(ve.Type().String(), "*")
|
||||||
|
}
|
||||||
|
f.fs.Write(openAngleBytes)
|
||||||
|
f.fs.Write([]byte(strings.Repeat("*", indirects)))
|
||||||
|
f.fs.Write(closeAngleBytes)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Display pointer information depending on flags.
|
||||||
|
if f.fs.Flag('+') && (len(pointerChain) > 0) {
|
||||||
|
f.fs.Write(openParenBytes)
|
||||||
|
for i, addr := range pointerChain {
|
||||||
|
if i > 0 {
|
||||||
|
f.fs.Write(pointerChainBytes)
|
||||||
|
}
|
||||||
|
printHexPtr(f.fs, addr)
|
||||||
|
}
|
||||||
|
f.fs.Write(closeParenBytes)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Display dereferenced value.
|
||||||
|
switch {
|
||||||
|
case nilFound == true:
|
||||||
|
f.fs.Write(nilAngleBytes)
|
||||||
|
|
||||||
|
case cycleFound == true:
|
||||||
|
f.fs.Write(circularShortBytes)
|
||||||
|
|
||||||
|
default:
|
||||||
|
f.ignoreNextType = true
|
||||||
|
f.format(ve)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// format is the main workhorse for providing the Formatter interface. It
|
||||||
|
// uses the passed reflect value to figure out what kind of object we are
|
||||||
|
// dealing with and formats it appropriately. It is a recursive function,
|
||||||
|
// however circular data structures are detected and handled properly.
|
||||||
|
func (f *formatState) format(v reflect.Value) {
|
||||||
|
// Handle invalid reflect values immediately.
|
||||||
|
kind := v.Kind()
|
||||||
|
if kind == reflect.Invalid {
|
||||||
|
f.fs.Write(invalidAngleBytes)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
// Handle pointers specially.
|
||||||
|
if kind == reflect.Ptr {
|
||||||
|
f.formatPtr(v)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
// Print type information unless already handled elsewhere.
|
||||||
|
if !f.ignoreNextType && f.fs.Flag('#') {
|
||||||
|
f.fs.Write(openParenBytes)
|
||||||
|
f.fs.Write([]byte(v.Type().String()))
|
||||||
|
f.fs.Write(closeParenBytes)
|
||||||
|
}
|
||||||
|
f.ignoreNextType = false
|
||||||
|
|
||||||
|
// Call Stringer/error interfaces if they exist and the handle methods
|
||||||
|
// flag is enabled.
|
||||||
|
if !f.cs.DisableMethods {
|
||||||
|
if (kind != reflect.Invalid) && (kind != reflect.Interface) {
|
||||||
|
if handled := handleMethods(f.cs, f.fs, v); handled {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
switch kind {
|
||||||
|
case reflect.Invalid:
|
||||||
|
// Do nothing. We should never get here since invalid has already
|
||||||
|
// been handled above.
|
||||||
|
|
||||||
|
case reflect.Bool:
|
||||||
|
printBool(f.fs, v.Bool())
|
||||||
|
|
||||||
|
case reflect.Int8, reflect.Int16, reflect.Int32, reflect.Int64, reflect.Int:
|
||||||
|
printInt(f.fs, v.Int(), 10)
|
||||||
|
|
||||||
|
case reflect.Uint8, reflect.Uint16, reflect.Uint32, reflect.Uint64, reflect.Uint:
|
||||||
|
printUint(f.fs, v.Uint(), 10)
|
||||||
|
|
||||||
|
case reflect.Float32:
|
||||||
|
printFloat(f.fs, v.Float(), 32)
|
||||||
|
|
||||||
|
case reflect.Float64:
|
||||||
|
printFloat(f.fs, v.Float(), 64)
|
||||||
|
|
||||||
|
case reflect.Complex64:
|
||||||
|
printComplex(f.fs, v.Complex(), 32)
|
||||||
|
|
||||||
|
case reflect.Complex128:
|
||||||
|
printComplex(f.fs, v.Complex(), 64)
|
||||||
|
|
||||||
|
case reflect.Slice:
|
||||||
|
if v.IsNil() {
|
||||||
|
f.fs.Write(nilAngleBytes)
|
||||||
|
break
|
||||||
|
}
|
||||||
|
fallthrough
|
||||||
|
|
||||||
|
case reflect.Array:
|
||||||
|
f.fs.Write(openBracketBytes)
|
||||||
|
f.depth++
|
||||||
|
if (f.cs.MaxDepth != 0) && (f.depth > f.cs.MaxDepth) {
|
||||||
|
f.fs.Write(maxShortBytes)
|
||||||
|
} else {
|
||||||
|
numEntries := v.Len()
|
||||||
|
for i := 0; i < numEntries; i++ {
|
||||||
|
if i > 0 {
|
||||||
|
f.fs.Write(spaceBytes)
|
||||||
|
}
|
||||||
|
f.ignoreNextType = true
|
||||||
|
f.format(f.unpackValue(v.Index(i)))
|
||||||
|
}
|
||||||
|
}
|
||||||
|
f.depth--
|
||||||
|
f.fs.Write(closeBracketBytes)
|
||||||
|
|
||||||
|
case reflect.String:
|
||||||
|
f.fs.Write([]byte(v.String()))
|
||||||
|
|
||||||
|
case reflect.Interface:
|
||||||
|
// The only time we should get here is for nil interfaces due to
|
||||||
|
// unpackValue calls.
|
||||||
|
if v.IsNil() {
|
||||||
|
f.fs.Write(nilAngleBytes)
|
||||||
|
}
|
||||||
|
|
||||||
|
case reflect.Ptr:
|
||||||
|
// Do nothing. We should never get here since pointers have already
|
||||||
|
// been handled above.
|
||||||
|
|
||||||
|
case reflect.Map:
|
||||||
|
// nil maps should be indicated as different than empty maps
|
||||||
|
if v.IsNil() {
|
||||||
|
f.fs.Write(nilAngleBytes)
|
||||||
|
break
|
||||||
|
}
|
||||||
|
|
||||||
|
f.fs.Write(openMapBytes)
|
||||||
|
f.depth++
|
||||||
|
if (f.cs.MaxDepth != 0) && (f.depth > f.cs.MaxDepth) {
|
||||||
|
f.fs.Write(maxShortBytes)
|
||||||
|
} else {
|
||||||
|
keys := v.MapKeys()
|
||||||
|
if f.cs.SortKeys {
|
||||||
|
sortValues(keys, f.cs)
|
||||||
|
}
|
||||||
|
for i, key := range keys {
|
||||||
|
if i > 0 {
|
||||||
|
f.fs.Write(spaceBytes)
|
||||||
|
}
|
||||||
|
f.ignoreNextType = true
|
||||||
|
f.format(f.unpackValue(key))
|
||||||
|
f.fs.Write(colonBytes)
|
||||||
|
f.ignoreNextType = true
|
||||||
|
f.format(f.unpackValue(v.MapIndex(key)))
|
||||||
|
}
|
||||||
|
}
|
||||||
|
f.depth--
|
||||||
|
f.fs.Write(closeMapBytes)
|
||||||
|
|
||||||
|
case reflect.Struct:
|
||||||
|
numFields := v.NumField()
|
||||||
|
f.fs.Write(openBraceBytes)
|
||||||
|
f.depth++
|
||||||
|
if (f.cs.MaxDepth != 0) && (f.depth > f.cs.MaxDepth) {
|
||||||
|
f.fs.Write(maxShortBytes)
|
||||||
|
} else {
|
||||||
|
vt := v.Type()
|
||||||
|
for i := 0; i < numFields; i++ {
|
||||||
|
if i > 0 {
|
||||||
|
f.fs.Write(spaceBytes)
|
||||||
|
}
|
||||||
|
vtf := vt.Field(i)
|
||||||
|
if f.fs.Flag('+') || f.fs.Flag('#') {
|
||||||
|
f.fs.Write([]byte(vtf.Name))
|
||||||
|
f.fs.Write(colonBytes)
|
||||||
|
}
|
||||||
|
f.format(f.unpackValue(v.Field(i)))
|
||||||
|
}
|
||||||
|
}
|
||||||
|
f.depth--
|
||||||
|
f.fs.Write(closeBraceBytes)
|
||||||
|
|
||||||
|
case reflect.Uintptr:
|
||||||
|
printHexPtr(f.fs, uintptr(v.Uint()))
|
||||||
|
|
||||||
|
case reflect.UnsafePointer, reflect.Chan, reflect.Func:
|
||||||
|
printHexPtr(f.fs, v.Pointer())
|
||||||
|
|
||||||
|
// There were not any other types at the time this code was written, but
|
||||||
|
// fall back to letting the default fmt package handle it if any get added.
|
||||||
|
default:
|
||||||
|
format := f.buildDefaultFormat()
|
||||||
|
if v.CanInterface() {
|
||||||
|
fmt.Fprintf(f.fs, format, v.Interface())
|
||||||
|
} else {
|
||||||
|
fmt.Fprintf(f.fs, format, v.String())
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// Format satisfies the fmt.Formatter interface. See NewFormatter for usage
|
||||||
|
// details.
|
||||||
|
func (f *formatState) Format(fs fmt.State, verb rune) {
|
||||||
|
f.fs = fs
|
||||||
|
|
||||||
|
// Use standard formatting for verbs that are not v.
|
||||||
|
if verb != 'v' {
|
||||||
|
format := f.constructOrigFormat(verb)
|
||||||
|
fmt.Fprintf(fs, format, f.value)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
if f.value == nil {
|
||||||
|
if fs.Flag('#') {
|
||||||
|
fs.Write(interfaceBytes)
|
||||||
|
}
|
||||||
|
fs.Write(nilAngleBytes)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
f.format(reflect.ValueOf(f.value))
|
||||||
|
}
|
||||||
|
|
||||||
|
// newFormatter is a helper function to consolidate the logic from the various
|
||||||
|
// public methods which take varying config states.
|
||||||
|
func newFormatter(cs *ConfigState, v interface{}) fmt.Formatter {
|
||||||
|
fs := &formatState{value: v, cs: cs}
|
||||||
|
fs.pointers = make(map[uintptr]int)
|
||||||
|
return fs
|
||||||
|
}
|
||||||
|
|
||||||
|
/*
|
||||||
|
NewFormatter returns a custom formatter that satisfies the fmt.Formatter
|
||||||
|
interface. As a result, it integrates cleanly with standard fmt package
|
||||||
|
printing functions. The formatter is useful for inline printing of smaller data
|
||||||
|
types similar to the standard %v format specifier.
|
||||||
|
|
||||||
|
The custom formatter only responds to the %v (most compact), %+v (adds pointer
|
||||||
|
addresses), %#v (adds types), or %#+v (adds types and pointer addresses) verb
|
||||||
|
combinations. Any other verbs such as %x and %q will be sent to the the
|
||||||
|
standard fmt package for formatting. In addition, the custom formatter ignores
|
||||||
|
the width and precision arguments (however they will still work on the format
|
||||||
|
specifiers not handled by the custom formatter).
|
||||||
|
|
||||||
|
Typically this function shouldn't be called directly. It is much easier to make
|
||||||
|
use of the custom formatter by calling one of the convenience functions such as
|
||||||
|
Printf, Println, or Fprintf.
|
||||||
|
*/
|
||||||
|
func NewFormatter(v interface{}) fmt.Formatter {
|
||||||
|
return newFormatter(&Config, v)
|
||||||
|
}
|
1535
Godeps/_workspace/src/github.com/davecgh/go-spew/spew/format_test.go
generated
vendored
Normal file
1535
Godeps/_workspace/src/github.com/davecgh/go-spew/spew/format_test.go
generated
vendored
Normal file
File diff suppressed because it is too large
Load Diff
156
Godeps/_workspace/src/github.com/davecgh/go-spew/spew/internal_test.go
generated
vendored
Normal file
156
Godeps/_workspace/src/github.com/davecgh/go-spew/spew/internal_test.go
generated
vendored
Normal file
|
@ -0,0 +1,156 @@
|
||||||
|
/*
|
||||||
|
* Copyright (c) 2013 Dave Collins <dave@davec.name>
|
||||||
|
*
|
||||||
|
* Permission to use, copy, modify, and distribute this software for any
|
||||||
|
* purpose with or without fee is hereby granted, provided that the above
|
||||||
|
* copyright notice and this permission notice appear in all copies.
|
||||||
|
*
|
||||||
|
* THE SOFTWARE IS PROVIDED "AS IS" AND THE AUTHOR DISCLAIMS ALL WARRANTIES
|
||||||
|
* WITH REGARD TO THIS SOFTWARE INCLUDING ALL IMPLIED WARRANTIES OF
|
||||||
|
* MERCHANTABILITY AND FITNESS. IN NO EVENT SHALL THE AUTHOR BE LIABLE FOR
|
||||||
|
* ANY SPECIAL, DIRECT, INDIRECT, OR CONSEQUENTIAL DAMAGES OR ANY DAMAGES
|
||||||
|
* WHATSOEVER RESULTING FROM LOSS OF USE, DATA OR PROFITS, WHETHER IN AN
|
||||||
|
* ACTION OF CONTRACT, NEGLIGENCE OR OTHER TORTIOUS ACTION, ARISING OUT OF
|
||||||
|
* OR IN CONNECTION WITH THE USE OR PERFORMANCE OF THIS SOFTWARE.
|
||||||
|
*/
|
||||||
|
|
||||||
|
/*
|
||||||
|
This test file is part of the spew package rather than than the spew_test
|
||||||
|
package because it needs access to internals to properly test certain cases
|
||||||
|
which are not possible via the public interface since they should never happen.
|
||||||
|
*/
|
||||||
|
|
||||||
|
package spew
|
||||||
|
|
||||||
|
import (
|
||||||
|
"bytes"
|
||||||
|
"reflect"
|
||||||
|
"testing"
|
||||||
|
"unsafe"
|
||||||
|
)
|
||||||
|
|
||||||
|
// dummyFmtState implements a fake fmt.State to use for testing invalid
|
||||||
|
// reflect.Value handling. This is necessary because the fmt package catches
|
||||||
|
// invalid values before invoking the formatter on them.
|
||||||
|
type dummyFmtState struct {
|
||||||
|
bytes.Buffer
|
||||||
|
}
|
||||||
|
|
||||||
|
func (dfs *dummyFmtState) Flag(f int) bool {
|
||||||
|
if f == int('+') {
|
||||||
|
return true
|
||||||
|
}
|
||||||
|
return false
|
||||||
|
}
|
||||||
|
|
||||||
|
func (dfs *dummyFmtState) Precision() (int, bool) {
|
||||||
|
return 0, false
|
||||||
|
}
|
||||||
|
|
||||||
|
func (dfs *dummyFmtState) Width() (int, bool) {
|
||||||
|
return 0, false
|
||||||
|
}
|
||||||
|
|
||||||
|
// TestInvalidReflectValue ensures the dump and formatter code handles an
|
||||||
|
// invalid reflect value properly. This needs access to internal state since it
|
||||||
|
// should never happen in real code and therefore can't be tested via the public
|
||||||
|
// API.
|
||||||
|
func TestInvalidReflectValue(t *testing.T) {
|
||||||
|
i := 1
|
||||||
|
|
||||||
|
// Dump invalid reflect value.
|
||||||
|
v := new(reflect.Value)
|
||||||
|
buf := new(bytes.Buffer)
|
||||||
|
d := dumpState{w: buf, cs: &Config}
|
||||||
|
d.dump(*v)
|
||||||
|
s := buf.String()
|
||||||
|
want := "<invalid>"
|
||||||
|
if s != want {
|
||||||
|
t.Errorf("InvalidReflectValue #%d\n got: %s want: %s", i, s, want)
|
||||||
|
}
|
||||||
|
i++
|
||||||
|
|
||||||
|
// Formatter invalid reflect value.
|
||||||
|
buf2 := new(dummyFmtState)
|
||||||
|
f := formatState{value: *v, cs: &Config, fs: buf2}
|
||||||
|
f.format(*v)
|
||||||
|
s = buf2.String()
|
||||||
|
want = "<invalid>"
|
||||||
|
if s != want {
|
||||||
|
t.Errorf("InvalidReflectValue #%d got: %s want: %s", i, s, want)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// changeKind uses unsafe to intentionally change the kind of a reflect.Value to
|
||||||
|
// the maximum kind value which does not exist. This is needed to test the
|
||||||
|
// fallback code which punts to the standard fmt library for new types that
|
||||||
|
// might get added to the language.
|
||||||
|
func changeKind(v *reflect.Value, readOnly bool) {
|
||||||
|
rvf := (*uintptr)(unsafe.Pointer(uintptr(unsafe.Pointer(v)) + offsetFlag))
|
||||||
|
*rvf = *rvf | ((1<<flagKindWidth - 1) << flagKindShift)
|
||||||
|
if readOnly {
|
||||||
|
*rvf |= flagRO
|
||||||
|
} else {
|
||||||
|
*rvf &= ^uintptr(flagRO)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// TestAddedReflectValue tests functionaly of the dump and formatter code which
|
||||||
|
// falls back to the standard fmt library for new types that might get added to
|
||||||
|
// the language.
|
||||||
|
func TestAddedReflectValue(t *testing.T) {
|
||||||
|
i := 1
|
||||||
|
|
||||||
|
// Dump using a reflect.Value that is exported.
|
||||||
|
v := reflect.ValueOf(int8(5))
|
||||||
|
changeKind(&v, false)
|
||||||
|
buf := new(bytes.Buffer)
|
||||||
|
d := dumpState{w: buf, cs: &Config}
|
||||||
|
d.dump(v)
|
||||||
|
s := buf.String()
|
||||||
|
want := "(int8) 5"
|
||||||
|
if s != want {
|
||||||
|
t.Errorf("TestAddedReflectValue #%d\n got: %s want: %s", i, s, want)
|
||||||
|
}
|
||||||
|
i++
|
||||||
|
|
||||||
|
// Dump using a reflect.Value that is not exported.
|
||||||
|
changeKind(&v, true)
|
||||||
|
buf.Reset()
|
||||||
|
d.dump(v)
|
||||||
|
s = buf.String()
|
||||||
|
want = "(int8) <int8 Value>"
|
||||||
|
if s != want {
|
||||||
|
t.Errorf("TestAddedReflectValue #%d\n got: %s want: %s", i, s, want)
|
||||||
|
}
|
||||||
|
i++
|
||||||
|
|
||||||
|
// Formatter using a reflect.Value that is exported.
|
||||||
|
changeKind(&v, false)
|
||||||
|
buf2 := new(dummyFmtState)
|
||||||
|
f := formatState{value: v, cs: &Config, fs: buf2}
|
||||||
|
f.format(v)
|
||||||
|
s = buf2.String()
|
||||||
|
want = "5"
|
||||||
|
if s != want {
|
||||||
|
t.Errorf("TestAddedReflectValue #%d got: %s want: %s", i, s, want)
|
||||||
|
}
|
||||||
|
i++
|
||||||
|
|
||||||
|
// Formatter using a reflect.Value that is not exported.
|
||||||
|
changeKind(&v, true)
|
||||||
|
buf2.Reset()
|
||||||
|
f = formatState{value: v, cs: &Config, fs: buf2}
|
||||||
|
f.format(v)
|
||||||
|
s = buf2.String()
|
||||||
|
want = "<int8 Value>"
|
||||||
|
if s != want {
|
||||||
|
t.Errorf("TestAddedReflectValue #%d got: %s want: %s", i, s, want)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// SortValues makes the internal sortValues function available to the test
|
||||||
|
// package.
|
||||||
|
func SortValues(values []reflect.Value, cs *ConfigState) {
|
||||||
|
sortValues(values, cs)
|
||||||
|
}
|
|
@ -0,0 +1,148 @@
|
||||||
|
/*
|
||||||
|
* Copyright (c) 2013 Dave Collins <dave@davec.name>
|
||||||
|
*
|
||||||
|
* Permission to use, copy, modify, and distribute this software for any
|
||||||
|
* purpose with or without fee is hereby granted, provided that the above
|
||||||
|
* copyright notice and this permission notice appear in all copies.
|
||||||
|
*
|
||||||
|
* THE SOFTWARE IS PROVIDED "AS IS" AND THE AUTHOR DISCLAIMS ALL WARRANTIES
|
||||||
|
* WITH REGARD TO THIS SOFTWARE INCLUDING ALL IMPLIED WARRANTIES OF
|
||||||
|
* MERCHANTABILITY AND FITNESS. IN NO EVENT SHALL THE AUTHOR BE LIABLE FOR
|
||||||
|
* ANY SPECIAL, DIRECT, INDIRECT, OR CONSEQUENTIAL DAMAGES OR ANY DAMAGES
|
||||||
|
* WHATSOEVER RESULTING FROM LOSS OF USE, DATA OR PROFITS, WHETHER IN AN
|
||||||
|
* ACTION OF CONTRACT, NEGLIGENCE OR OTHER TORTIOUS ACTION, ARISING OUT OF
|
||||||
|
* OR IN CONNECTION WITH THE USE OR PERFORMANCE OF THIS SOFTWARE.
|
||||||
|
*/
|
||||||
|
|
||||||
|
package spew
|
||||||
|
|
||||||
|
import (
|
||||||
|
"fmt"
|
||||||
|
"io"
|
||||||
|
)
|
||||||
|
|
||||||
|
// Errorf is a wrapper for fmt.Errorf that treats each argument as if it were
|
||||||
|
// passed with a default Formatter interface returned by NewFormatter. It
|
||||||
|
// returns the formatted string as a value that satisfies error. See
|
||||||
|
// NewFormatter for formatting details.
|
||||||
|
//
|
||||||
|
// This function is shorthand for the following syntax:
|
||||||
|
//
|
||||||
|
// fmt.Errorf(format, spew.NewFormatter(a), spew.NewFormatter(b))
|
||||||
|
func Errorf(format string, a ...interface{}) (err error) {
|
||||||
|
return fmt.Errorf(format, convertArgs(a)...)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Fprint is a wrapper for fmt.Fprint that treats each argument as if it were
|
||||||
|
// passed with a default Formatter interface returned by NewFormatter. It
|
||||||
|
// returns the number of bytes written and any write error encountered. See
|
||||||
|
// NewFormatter for formatting details.
|
||||||
|
//
|
||||||
|
// This function is shorthand for the following syntax:
|
||||||
|
//
|
||||||
|
// fmt.Fprint(w, spew.NewFormatter(a), spew.NewFormatter(b))
|
||||||
|
func Fprint(w io.Writer, a ...interface{}) (n int, err error) {
|
||||||
|
return fmt.Fprint(w, convertArgs(a)...)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Fprintf is a wrapper for fmt.Fprintf that treats each argument as if it were
|
||||||
|
// passed with a default Formatter interface returned by NewFormatter. It
|
||||||
|
// returns the number of bytes written and any write error encountered. See
|
||||||
|
// NewFormatter for formatting details.
|
||||||
|
//
|
||||||
|
// This function is shorthand for the following syntax:
|
||||||
|
//
|
||||||
|
// fmt.Fprintf(w, format, spew.NewFormatter(a), spew.NewFormatter(b))
|
||||||
|
func Fprintf(w io.Writer, format string, a ...interface{}) (n int, err error) {
|
||||||
|
return fmt.Fprintf(w, format, convertArgs(a)...)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Fprintln is a wrapper for fmt.Fprintln that treats each argument as if it
|
||||||
|
// passed with a default Formatter interface returned by NewFormatter. See
|
||||||
|
// NewFormatter for formatting details.
|
||||||
|
//
|
||||||
|
// This function is shorthand for the following syntax:
|
||||||
|
//
|
||||||
|
// fmt.Fprintln(w, spew.NewFormatter(a), spew.NewFormatter(b))
|
||||||
|
func Fprintln(w io.Writer, a ...interface{}) (n int, err error) {
|
||||||
|
return fmt.Fprintln(w, convertArgs(a)...)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Print is a wrapper for fmt.Print that treats each argument as if it were
|
||||||
|
// passed with a default Formatter interface returned by NewFormatter. It
|
||||||
|
// returns the number of bytes written and any write error encountered. See
|
||||||
|
// NewFormatter for formatting details.
|
||||||
|
//
|
||||||
|
// This function is shorthand for the following syntax:
|
||||||
|
//
|
||||||
|
// fmt.Print(spew.NewFormatter(a), spew.NewFormatter(b))
|
||||||
|
func Print(a ...interface{}) (n int, err error) {
|
||||||
|
return fmt.Print(convertArgs(a)...)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Printf is a wrapper for fmt.Printf that treats each argument as if it were
|
||||||
|
// passed with a default Formatter interface returned by NewFormatter. It
|
||||||
|
// returns the number of bytes written and any write error encountered. See
|
||||||
|
// NewFormatter for formatting details.
|
||||||
|
//
|
||||||
|
// This function is shorthand for the following syntax:
|
||||||
|
//
|
||||||
|
// fmt.Printf(format, spew.NewFormatter(a), spew.NewFormatter(b))
|
||||||
|
func Printf(format string, a ...interface{}) (n int, err error) {
|
||||||
|
return fmt.Printf(format, convertArgs(a)...)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Println is a wrapper for fmt.Println that treats each argument as if it were
|
||||||
|
// passed with a default Formatter interface returned by NewFormatter. It
|
||||||
|
// returns the number of bytes written and any write error encountered. See
|
||||||
|
// NewFormatter for formatting details.
|
||||||
|
//
|
||||||
|
// This function is shorthand for the following syntax:
|
||||||
|
//
|
||||||
|
// fmt.Println(spew.NewFormatter(a), spew.NewFormatter(b))
|
||||||
|
func Println(a ...interface{}) (n int, err error) {
|
||||||
|
return fmt.Println(convertArgs(a)...)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Sprint is a wrapper for fmt.Sprint that treats each argument as if it were
|
||||||
|
// passed with a default Formatter interface returned by NewFormatter. It
|
||||||
|
// returns the resulting string. See NewFormatter for formatting details.
|
||||||
|
//
|
||||||
|
// This function is shorthand for the following syntax:
|
||||||
|
//
|
||||||
|
// fmt.Sprint(spew.NewFormatter(a), spew.NewFormatter(b))
|
||||||
|
func Sprint(a ...interface{}) string {
|
||||||
|
return fmt.Sprint(convertArgs(a)...)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Sprintf is a wrapper for fmt.Sprintf that treats each argument as if it were
|
||||||
|
// passed with a default Formatter interface returned by NewFormatter. It
|
||||||
|
// returns the resulting string. See NewFormatter for formatting details.
|
||||||
|
//
|
||||||
|
// This function is shorthand for the following syntax:
|
||||||
|
//
|
||||||
|
// fmt.Sprintf(format, spew.NewFormatter(a), spew.NewFormatter(b))
|
||||||
|
func Sprintf(format string, a ...interface{}) string {
|
||||||
|
return fmt.Sprintf(format, convertArgs(a)...)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Sprintln is a wrapper for fmt.Sprintln that treats each argument as if it
|
||||||
|
// were passed with a default Formatter interface returned by NewFormatter. It
|
||||||
|
// returns the resulting string. See NewFormatter for formatting details.
|
||||||
|
//
|
||||||
|
// This function is shorthand for the following syntax:
|
||||||
|
//
|
||||||
|
// fmt.Sprintln(spew.NewFormatter(a), spew.NewFormatter(b))
|
||||||
|
func Sprintln(a ...interface{}) string {
|
||||||
|
return fmt.Sprintln(convertArgs(a)...)
|
||||||
|
}
|
||||||
|
|
||||||
|
// convertArgs accepts a slice of arguments and returns a slice of the same
|
||||||
|
// length with each argument converted to a default spew Formatter interface.
|
||||||
|
func convertArgs(args []interface{}) (formatters []interface{}) {
|
||||||
|
formatters = make([]interface{}, len(args))
|
||||||
|
for index, arg := range args {
|
||||||
|
formatters[index] = NewFormatter(arg)
|
||||||
|
}
|
||||||
|
return formatters
|
||||||
|
}
|
308
Godeps/_workspace/src/github.com/davecgh/go-spew/spew/spew_test.go
generated
vendored
Normal file
308
Godeps/_workspace/src/github.com/davecgh/go-spew/spew/spew_test.go
generated
vendored
Normal file
|
@ -0,0 +1,308 @@
|
||||||
|
/*
|
||||||
|
* Copyright (c) 2013 Dave Collins <dave@davec.name>
|
||||||
|
*
|
||||||
|
* Permission to use, copy, modify, and distribute this software for any
|
||||||
|
* purpose with or without fee is hereby granted, provided that the above
|
||||||
|
* copyright notice and this permission notice appear in all copies.
|
||||||
|
*
|
||||||
|
* THE SOFTWARE IS PROVIDED "AS IS" AND THE AUTHOR DISCLAIMS ALL WARRANTIES
|
||||||
|
* WITH REGARD TO THIS SOFTWARE INCLUDING ALL IMPLIED WARRANTIES OF
|
||||||
|
* MERCHANTABILITY AND FITNESS. IN NO EVENT SHALL THE AUTHOR BE LIABLE FOR
|
||||||
|
* ANY SPECIAL, DIRECT, INDIRECT, OR CONSEQUENTIAL DAMAGES OR ANY DAMAGES
|
||||||
|
* WHATSOEVER RESULTING FROM LOSS OF USE, DATA OR PROFITS, WHETHER IN AN
|
||||||
|
* ACTION OF CONTRACT, NEGLIGENCE OR OTHER TORTIOUS ACTION, ARISING OUT OF
|
||||||
|
* OR IN CONNECTION WITH THE USE OR PERFORMANCE OF THIS SOFTWARE.
|
||||||
|
*/
|
||||||
|
|
||||||
|
package spew_test
|
||||||
|
|
||||||
|
import (
|
||||||
|
"bytes"
|
||||||
|
"fmt"
|
||||||
|
"github.com/davecgh/go-spew/spew"
|
||||||
|
"io/ioutil"
|
||||||
|
"os"
|
||||||
|
"testing"
|
||||||
|
)
|
||||||
|
|
||||||
|
// spewFunc is used to identify which public function of the spew package or
|
||||||
|
// ConfigState a test applies to.
|
||||||
|
type spewFunc int
|
||||||
|
|
||||||
|
const (
|
||||||
|
fCSFdump spewFunc = iota
|
||||||
|
fCSFprint
|
||||||
|
fCSFprintf
|
||||||
|
fCSFprintln
|
||||||
|
fCSPrint
|
||||||
|
fCSPrintln
|
||||||
|
fCSSdump
|
||||||
|
fCSSprint
|
||||||
|
fCSSprintf
|
||||||
|
fCSSprintln
|
||||||
|
fCSErrorf
|
||||||
|
fCSNewFormatter
|
||||||
|
fErrorf
|
||||||
|
fFprint
|
||||||
|
fFprintln
|
||||||
|
fPrint
|
||||||
|
fPrintln
|
||||||
|
fSdump
|
||||||
|
fSprint
|
||||||
|
fSprintf
|
||||||
|
fSprintln
|
||||||
|
)
|
||||||
|
|
||||||
|
// Map of spewFunc values to names for pretty printing.
|
||||||
|
var spewFuncStrings = map[spewFunc]string{
|
||||||
|
fCSFdump: "ConfigState.Fdump",
|
||||||
|
fCSFprint: "ConfigState.Fprint",
|
||||||
|
fCSFprintf: "ConfigState.Fprintf",
|
||||||
|
fCSFprintln: "ConfigState.Fprintln",
|
||||||
|
fCSSdump: "ConfigState.Sdump",
|
||||||
|
fCSPrint: "ConfigState.Print",
|
||||||
|
fCSPrintln: "ConfigState.Println",
|
||||||
|
fCSSprint: "ConfigState.Sprint",
|
||||||
|
fCSSprintf: "ConfigState.Sprintf",
|
||||||
|
fCSSprintln: "ConfigState.Sprintln",
|
||||||
|
fCSErrorf: "ConfigState.Errorf",
|
||||||
|
fCSNewFormatter: "ConfigState.NewFormatter",
|
||||||
|
fErrorf: "spew.Errorf",
|
||||||
|
fFprint: "spew.Fprint",
|
||||||
|
fFprintln: "spew.Fprintln",
|
||||||
|
fPrint: "spew.Print",
|
||||||
|
fPrintln: "spew.Println",
|
||||||
|
fSdump: "spew.Sdump",
|
||||||
|
fSprint: "spew.Sprint",
|
||||||
|
fSprintf: "spew.Sprintf",
|
||||||
|
fSprintln: "spew.Sprintln",
|
||||||
|
}
|
||||||
|
|
||||||
|
func (f spewFunc) String() string {
|
||||||
|
if s, ok := spewFuncStrings[f]; ok {
|
||||||
|
return s
|
||||||
|
}
|
||||||
|
return fmt.Sprintf("Unknown spewFunc (%d)", int(f))
|
||||||
|
}
|
||||||
|
|
||||||
|
// spewTest is used to describe a test to be performed against the public
|
||||||
|
// functions of the spew package or ConfigState.
|
||||||
|
type spewTest struct {
|
||||||
|
cs *spew.ConfigState
|
||||||
|
f spewFunc
|
||||||
|
format string
|
||||||
|
in interface{}
|
||||||
|
want string
|
||||||
|
}
|
||||||
|
|
||||||
|
// spewTests houses the tests to be performed against the public functions of
|
||||||
|
// the spew package and ConfigState.
|
||||||
|
//
|
||||||
|
// These tests are only intended to ensure the public functions are exercised
|
||||||
|
// and are intentionally not exhaustive of types. The exhaustive type
|
||||||
|
// tests are handled in the dump and format tests.
|
||||||
|
var spewTests []spewTest
|
||||||
|
|
||||||
|
// redirStdout is a helper function to return the standard output from f as a
|
||||||
|
// byte slice.
|
||||||
|
func redirStdout(f func()) ([]byte, error) {
|
||||||
|
tempFile, err := ioutil.TempFile("", "ss-test")
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
fileName := tempFile.Name()
|
||||||
|
defer os.Remove(fileName) // Ignore error
|
||||||
|
|
||||||
|
origStdout := os.Stdout
|
||||||
|
os.Stdout = tempFile
|
||||||
|
f()
|
||||||
|
os.Stdout = origStdout
|
||||||
|
tempFile.Close()
|
||||||
|
|
||||||
|
return ioutil.ReadFile(fileName)
|
||||||
|
}
|
||||||
|
|
||||||
|
func initSpewTests() {
|
||||||
|
// Config states with various settings.
|
||||||
|
scsDefault := spew.NewDefaultConfig()
|
||||||
|
scsNoMethods := &spew.ConfigState{Indent: " ", DisableMethods: true}
|
||||||
|
scsNoPmethods := &spew.ConfigState{Indent: " ", DisablePointerMethods: true}
|
||||||
|
scsMaxDepth := &spew.ConfigState{Indent: " ", MaxDepth: 1}
|
||||||
|
scsContinue := &spew.ConfigState{Indent: " ", ContinueOnMethod: true}
|
||||||
|
|
||||||
|
// Variables for tests on types which implement Stringer interface with and
|
||||||
|
// without a pointer receiver.
|
||||||
|
ts := stringer("test")
|
||||||
|
tps := pstringer("test")
|
||||||
|
|
||||||
|
// depthTester is used to test max depth handling for structs, array, slices
|
||||||
|
// and maps.
|
||||||
|
type depthTester struct {
|
||||||
|
ic indirCir1
|
||||||
|
arr [1]string
|
||||||
|
slice []string
|
||||||
|
m map[string]int
|
||||||
|
}
|
||||||
|
dt := depthTester{indirCir1{nil}, [1]string{"arr"}, []string{"slice"},
|
||||||
|
map[string]int{"one": 1}}
|
||||||
|
|
||||||
|
// Variable for tests on types which implement error interface.
|
||||||
|
te := customError(10)
|
||||||
|
|
||||||
|
spewTests = []spewTest{
|
||||||
|
{scsDefault, fCSFdump, "", int8(127), "(int8) 127\n"},
|
||||||
|
{scsDefault, fCSFprint, "", int16(32767), "32767"},
|
||||||
|
{scsDefault, fCSFprintf, "%v", int32(2147483647), "2147483647"},
|
||||||
|
{scsDefault, fCSFprintln, "", int(2147483647), "2147483647\n"},
|
||||||
|
{scsDefault, fCSPrint, "", int64(9223372036854775807), "9223372036854775807"},
|
||||||
|
{scsDefault, fCSPrintln, "", uint8(255), "255\n"},
|
||||||
|
{scsDefault, fCSSdump, "", uint8(64), "(uint8) 64\n"},
|
||||||
|
{scsDefault, fCSSprint, "", complex(1, 2), "(1+2i)"},
|
||||||
|
{scsDefault, fCSSprintf, "%v", complex(float32(3), 4), "(3+4i)"},
|
||||||
|
{scsDefault, fCSSprintln, "", complex(float64(5), 6), "(5+6i)\n"},
|
||||||
|
{scsDefault, fCSErrorf, "%#v", uint16(65535), "(uint16)65535"},
|
||||||
|
{scsDefault, fCSNewFormatter, "%v", uint32(4294967295), "4294967295"},
|
||||||
|
{scsDefault, fErrorf, "%v", uint64(18446744073709551615), "18446744073709551615"},
|
||||||
|
{scsDefault, fFprint, "", float32(3.14), "3.14"},
|
||||||
|
{scsDefault, fFprintln, "", float64(6.28), "6.28\n"},
|
||||||
|
{scsDefault, fPrint, "", true, "true"},
|
||||||
|
{scsDefault, fPrintln, "", false, "false\n"},
|
||||||
|
{scsDefault, fSdump, "", complex(-10, -20), "(complex128) (-10-20i)\n"},
|
||||||
|
{scsDefault, fSprint, "", complex(-1, -2), "(-1-2i)"},
|
||||||
|
{scsDefault, fSprintf, "%v", complex(float32(-3), -4), "(-3-4i)"},
|
||||||
|
{scsDefault, fSprintln, "", complex(float64(-5), -6), "(-5-6i)\n"},
|
||||||
|
{scsNoMethods, fCSFprint, "", ts, "test"},
|
||||||
|
{scsNoMethods, fCSFprint, "", &ts, "<*>test"},
|
||||||
|
{scsNoMethods, fCSFprint, "", tps, "test"},
|
||||||
|
{scsNoMethods, fCSFprint, "", &tps, "<*>test"},
|
||||||
|
{scsNoPmethods, fCSFprint, "", ts, "stringer test"},
|
||||||
|
{scsNoPmethods, fCSFprint, "", &ts, "<*>stringer test"},
|
||||||
|
{scsNoPmethods, fCSFprint, "", tps, "test"},
|
||||||
|
{scsNoPmethods, fCSFprint, "", &tps, "<*>stringer test"},
|
||||||
|
{scsMaxDepth, fCSFprint, "", dt, "{{<max>} [<max>] [<max>] map[<max>]}"},
|
||||||
|
{scsMaxDepth, fCSFdump, "", dt, "(spew_test.depthTester) {\n" +
|
||||||
|
" ic: (spew_test.indirCir1) {\n <max depth reached>\n },\n" +
|
||||||
|
" arr: ([1]string) (len=1 cap=1) {\n <max depth reached>\n },\n" +
|
||||||
|
" slice: ([]string) (len=1 cap=1) {\n <max depth reached>\n },\n" +
|
||||||
|
" m: (map[string]int) (len=1) {\n <max depth reached>\n }\n}\n"},
|
||||||
|
{scsContinue, fCSFprint, "", ts, "(stringer test) test"},
|
||||||
|
{scsContinue, fCSFdump, "", ts, "(spew_test.stringer) " +
|
||||||
|
"(len=4) (stringer test) \"test\"\n"},
|
||||||
|
{scsContinue, fCSFprint, "", te, "(error: 10) 10"},
|
||||||
|
{scsContinue, fCSFdump, "", te, "(spew_test.customError) " +
|
||||||
|
"(error: 10) 10\n"},
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// TestSpew executes all of the tests described by spewTests.
|
||||||
|
func TestSpew(t *testing.T) {
|
||||||
|
initSpewTests()
|
||||||
|
|
||||||
|
t.Logf("Running %d tests", len(spewTests))
|
||||||
|
for i, test := range spewTests {
|
||||||
|
buf := new(bytes.Buffer)
|
||||||
|
switch test.f {
|
||||||
|
case fCSFdump:
|
||||||
|
test.cs.Fdump(buf, test.in)
|
||||||
|
|
||||||
|
case fCSFprint:
|
||||||
|
test.cs.Fprint(buf, test.in)
|
||||||
|
|
||||||
|
case fCSFprintf:
|
||||||
|
test.cs.Fprintf(buf, test.format, test.in)
|
||||||
|
|
||||||
|
case fCSFprintln:
|
||||||
|
test.cs.Fprintln(buf, test.in)
|
||||||
|
|
||||||
|
case fCSPrint:
|
||||||
|
b, err := redirStdout(func() { test.cs.Print(test.in) })
|
||||||
|
if err != nil {
|
||||||
|
t.Errorf("%v #%d %v", test.f, i, err)
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
buf.Write(b)
|
||||||
|
|
||||||
|
case fCSPrintln:
|
||||||
|
b, err := redirStdout(func() { test.cs.Println(test.in) })
|
||||||
|
if err != nil {
|
||||||
|
t.Errorf("%v #%d %v", test.f, i, err)
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
buf.Write(b)
|
||||||
|
|
||||||
|
case fCSSdump:
|
||||||
|
str := test.cs.Sdump(test.in)
|
||||||
|
buf.WriteString(str)
|
||||||
|
|
||||||
|
case fCSSprint:
|
||||||
|
str := test.cs.Sprint(test.in)
|
||||||
|
buf.WriteString(str)
|
||||||
|
|
||||||
|
case fCSSprintf:
|
||||||
|
str := test.cs.Sprintf(test.format, test.in)
|
||||||
|
buf.WriteString(str)
|
||||||
|
|
||||||
|
case fCSSprintln:
|
||||||
|
str := test.cs.Sprintln(test.in)
|
||||||
|
buf.WriteString(str)
|
||||||
|
|
||||||
|
case fCSErrorf:
|
||||||
|
err := test.cs.Errorf(test.format, test.in)
|
||||||
|
buf.WriteString(err.Error())
|
||||||
|
|
||||||
|
case fCSNewFormatter:
|
||||||
|
fmt.Fprintf(buf, test.format, test.cs.NewFormatter(test.in))
|
||||||
|
|
||||||
|
case fErrorf:
|
||||||
|
err := spew.Errorf(test.format, test.in)
|
||||||
|
buf.WriteString(err.Error())
|
||||||
|
|
||||||
|
case fFprint:
|
||||||
|
spew.Fprint(buf, test.in)
|
||||||
|
|
||||||
|
case fFprintln:
|
||||||
|
spew.Fprintln(buf, test.in)
|
||||||
|
|
||||||
|
case fPrint:
|
||||||
|
b, err := redirStdout(func() { spew.Print(test.in) })
|
||||||
|
if err != nil {
|
||||||
|
t.Errorf("%v #%d %v", test.f, i, err)
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
buf.Write(b)
|
||||||
|
|
||||||
|
case fPrintln:
|
||||||
|
b, err := redirStdout(func() { spew.Println(test.in) })
|
||||||
|
if err != nil {
|
||||||
|
t.Errorf("%v #%d %v", test.f, i, err)
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
buf.Write(b)
|
||||||
|
|
||||||
|
case fSdump:
|
||||||
|
str := spew.Sdump(test.in)
|
||||||
|
buf.WriteString(str)
|
||||||
|
|
||||||
|
case fSprint:
|
||||||
|
str := spew.Sprint(test.in)
|
||||||
|
buf.WriteString(str)
|
||||||
|
|
||||||
|
case fSprintf:
|
||||||
|
str := spew.Sprintf(test.format, test.in)
|
||||||
|
buf.WriteString(str)
|
||||||
|
|
||||||
|
case fSprintln:
|
||||||
|
str := spew.Sprintln(test.in)
|
||||||
|
buf.WriteString(str)
|
||||||
|
|
||||||
|
default:
|
||||||
|
t.Errorf("%v #%d unrecognized function", test.f, i)
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
s := buf.String()
|
||||||
|
if test.want != s {
|
||||||
|
t.Errorf("ConfigState #%d\n got: %s want: %s", i, s, test.want)
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
82
Godeps/_workspace/src/github.com/davecgh/go-spew/spew/testdata/dumpcgo.go
generated
vendored
Normal file
82
Godeps/_workspace/src/github.com/davecgh/go-spew/spew/testdata/dumpcgo.go
generated
vendored
Normal file
|
@ -0,0 +1,82 @@
|
||||||
|
// Copyright (c) 2013 Dave Collins <dave@davec.name>
|
||||||
|
//
|
||||||
|
// Permission to use, copy, modify, and distribute this software for any
|
||||||
|
// purpose with or without fee is hereby granted, provided that the above
|
||||||
|
// copyright notice and this permission notice appear in all copies.
|
||||||
|
//
|
||||||
|
// THE SOFTWARE IS PROVIDED "AS IS" AND THE AUTHOR DISCLAIMS ALL WARRANTIES
|
||||||
|
// WITH REGARD TO THIS SOFTWARE INCLUDING ALL IMPLIED WARRANTIES OF
|
||||||
|
// MERCHANTABILITY AND FITNESS. IN NO EVENT SHALL THE AUTHOR BE LIABLE FOR
|
||||||
|
// ANY SPECIAL, DIRECT, INDIRECT, OR CONSEQUENTIAL DAMAGES OR ANY DAMAGES
|
||||||
|
// WHATSOEVER RESULTING FROM LOSS OF USE, DATA OR PROFITS, WHETHER IN AN
|
||||||
|
// ACTION OF CONTRACT, NEGLIGENCE OR OTHER TORTIOUS ACTION, ARISING OUT OF
|
||||||
|
// OR IN CONNECTION WITH THE USE OR PERFORMANCE OF THIS SOFTWARE.
|
||||||
|
|
||||||
|
// NOTE: Due to the following build constraints, this file will only be compiled
|
||||||
|
// when both cgo is supported and "-tags testcgo" is added to the go test
|
||||||
|
// command line. This code should really only be in the dumpcgo_test.go file,
|
||||||
|
// but unfortunately Go will not allow cgo in test files, so this is a
|
||||||
|
// workaround to allow cgo types to be tested. This configuration is used
|
||||||
|
// because spew itself does not require cgo to run even though it does handle
|
||||||
|
// certain cgo types specially. Rather than forcing all clients to require cgo
|
||||||
|
// and an external C compiler just to run the tests, this scheme makes them
|
||||||
|
// optional.
|
||||||
|
// +build cgo,testcgo
|
||||||
|
|
||||||
|
package testdata
|
||||||
|
|
||||||
|
/*
|
||||||
|
#include <stdint.h>
|
||||||
|
typedef unsigned char custom_uchar_t;
|
||||||
|
|
||||||
|
char *ncp = 0;
|
||||||
|
char *cp = "test";
|
||||||
|
char ca[6] = {'t', 'e', 's', 't', '2', '\0'};
|
||||||
|
unsigned char uca[6] = {'t', 'e', 's', 't', '3', '\0'};
|
||||||
|
signed char sca[6] = {'t', 'e', 's', 't', '4', '\0'};
|
||||||
|
uint8_t ui8ta[6] = {'t', 'e', 's', 't', '5', '\0'};
|
||||||
|
custom_uchar_t tuca[6] = {'t', 'e', 's', 't', '6', '\0'};
|
||||||
|
*/
|
||||||
|
import "C"
|
||||||
|
|
||||||
|
// GetCgoNullCharPointer returns a null char pointer via cgo. This is only
|
||||||
|
// used for tests.
|
||||||
|
func GetCgoNullCharPointer() interface{} {
|
||||||
|
return C.ncp
|
||||||
|
}
|
||||||
|
|
||||||
|
// GetCgoCharPointer returns a char pointer via cgo. This is only used for
|
||||||
|
// tests.
|
||||||
|
func GetCgoCharPointer() interface{} {
|
||||||
|
return C.cp
|
||||||
|
}
|
||||||
|
|
||||||
|
// GetCgoCharArray returns a char array via cgo and the array's len and cap.
|
||||||
|
// This is only used for tests.
|
||||||
|
func GetCgoCharArray() (interface{}, int, int) {
|
||||||
|
return C.ca, len(C.ca), cap(C.ca)
|
||||||
|
}
|
||||||
|
|
||||||
|
// GetCgoUnsignedCharArray returns an unsigned char array via cgo and the
|
||||||
|
// array's len and cap. This is only used for tests.
|
||||||
|
func GetCgoUnsignedCharArray() (interface{}, int, int) {
|
||||||
|
return C.uca, len(C.uca), cap(C.uca)
|
||||||
|
}
|
||||||
|
|
||||||
|
// GetCgoSignedCharArray returns a signed char array via cgo and the array's len
|
||||||
|
// and cap. This is only used for tests.
|
||||||
|
func GetCgoSignedCharArray() (interface{}, int, int) {
|
||||||
|
return C.sca, len(C.sca), cap(C.sca)
|
||||||
|
}
|
||||||
|
|
||||||
|
// GetCgoUint8tArray returns a uint8_t array via cgo and the array's len and
|
||||||
|
// cap. This is only used for tests.
|
||||||
|
func GetCgoUint8tArray() (interface{}, int, int) {
|
||||||
|
return C.ui8ta, len(C.ui8ta), cap(C.ui8ta)
|
||||||
|
}
|
||||||
|
|
||||||
|
// GetCgoTypdefedUnsignedCharArray returns a typedefed unsigned char array via
|
||||||
|
// cgo and the array's len and cap. This is only used for tests.
|
||||||
|
func GetCgoTypdefedUnsignedCharArray() (interface{}, int, int) {
|
||||||
|
return C.tuca, len(C.tuca), cap(C.tuca)
|
||||||
|
}
|
|
@ -535,6 +535,7 @@ func (self *Ethereum) AddPeer(nodeURL string) error {
|
||||||
func (s *Ethereum) Stop() {
|
func (s *Ethereum) Stop() {
|
||||||
s.txSub.Unsubscribe() // quits txBroadcastLoop
|
s.txSub.Unsubscribe() // quits txBroadcastLoop
|
||||||
|
|
||||||
|
s.net.Stop()
|
||||||
s.protocolManager.Stop()
|
s.protocolManager.Stop()
|
||||||
s.chainManager.Stop()
|
s.chainManager.Stop()
|
||||||
s.txPool.Stop()
|
s.txPool.Stop()
|
||||||
|
@ -544,7 +545,6 @@ func (s *Ethereum) Stop() {
|
||||||
}
|
}
|
||||||
s.StopAutoDAG()
|
s.StopAutoDAG()
|
||||||
|
|
||||||
glog.V(logger.Info).Infoln("Server stopped")
|
|
||||||
close(s.shutdownChan)
|
close(s.shutdownChan)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
|
@ -0,0 +1,276 @@
|
||||||
|
package p2p
|
||||||
|
|
||||||
|
import (
|
||||||
|
"container/heap"
|
||||||
|
"crypto/rand"
|
||||||
|
"fmt"
|
||||||
|
"net"
|
||||||
|
"time"
|
||||||
|
|
||||||
|
"github.com/ethereum/go-ethereum/logger"
|
||||||
|
"github.com/ethereum/go-ethereum/logger/glog"
|
||||||
|
"github.com/ethereum/go-ethereum/p2p/discover"
|
||||||
|
)
|
||||||
|
|
||||||
|
const (
|
||||||
|
// This is the amount of time spent waiting in between
|
||||||
|
// redialing a certain node.
|
||||||
|
dialHistoryExpiration = 30 * time.Second
|
||||||
|
|
||||||
|
// Discovery lookup tasks will wait for this long when
|
||||||
|
// no results are returned. This can happen if the table
|
||||||
|
// becomes empty (i.e. not often).
|
||||||
|
emptyLookupDelay = 10 * time.Second
|
||||||
|
)
|
||||||
|
|
||||||
|
// dialstate schedules dials and discovery lookups.
|
||||||
|
// it get's a chance to compute new tasks on every iteration
|
||||||
|
// of the main loop in Server.run.
|
||||||
|
type dialstate struct {
|
||||||
|
maxDynDials int
|
||||||
|
ntab discoverTable
|
||||||
|
|
||||||
|
lookupRunning bool
|
||||||
|
bootstrapped bool
|
||||||
|
|
||||||
|
dialing map[discover.NodeID]connFlag
|
||||||
|
lookupBuf []*discover.Node // current discovery lookup results
|
||||||
|
randomNodes []*discover.Node // filled from Table
|
||||||
|
static map[discover.NodeID]*discover.Node
|
||||||
|
hist *dialHistory
|
||||||
|
}
|
||||||
|
|
||||||
|
type discoverTable interface {
|
||||||
|
Self() *discover.Node
|
||||||
|
Close()
|
||||||
|
Bootstrap([]*discover.Node)
|
||||||
|
Lookup(target discover.NodeID) []*discover.Node
|
||||||
|
ReadRandomNodes([]*discover.Node) int
|
||||||
|
}
|
||||||
|
|
||||||
|
// the dial history remembers recent dials.
|
||||||
|
type dialHistory []pastDial
|
||||||
|
|
||||||
|
// pastDial is an entry in the dial history.
|
||||||
|
type pastDial struct {
|
||||||
|
id discover.NodeID
|
||||||
|
exp time.Time
|
||||||
|
}
|
||||||
|
|
||||||
|
type task interface {
|
||||||
|
Do(*Server)
|
||||||
|
}
|
||||||
|
|
||||||
|
// A dialTask is generated for each node that is dialed.
|
||||||
|
type dialTask struct {
|
||||||
|
flags connFlag
|
||||||
|
dest *discover.Node
|
||||||
|
}
|
||||||
|
|
||||||
|
// discoverTask runs discovery table operations.
|
||||||
|
// Only one discoverTask is active at any time.
|
||||||
|
//
|
||||||
|
// If bootstrap is true, the task runs Table.Bootstrap,
|
||||||
|
// otherwise it performs a random lookup and leaves the
|
||||||
|
// results in the task.
|
||||||
|
type discoverTask struct {
|
||||||
|
bootstrap bool
|
||||||
|
results []*discover.Node
|
||||||
|
}
|
||||||
|
|
||||||
|
// A waitExpireTask is generated if there are no other tasks
|
||||||
|
// to keep the loop in Server.run ticking.
|
||||||
|
type waitExpireTask struct {
|
||||||
|
time.Duration
|
||||||
|
}
|
||||||
|
|
||||||
|
func newDialState(static []*discover.Node, ntab discoverTable, maxdyn int) *dialstate {
|
||||||
|
s := &dialstate{
|
||||||
|
maxDynDials: maxdyn,
|
||||||
|
ntab: ntab,
|
||||||
|
static: make(map[discover.NodeID]*discover.Node),
|
||||||
|
dialing: make(map[discover.NodeID]connFlag),
|
||||||
|
randomNodes: make([]*discover.Node, maxdyn/2),
|
||||||
|
hist: new(dialHistory),
|
||||||
|
}
|
||||||
|
for _, n := range static {
|
||||||
|
s.static[n.ID] = n
|
||||||
|
}
|
||||||
|
return s
|
||||||
|
}
|
||||||
|
|
||||||
|
func (s *dialstate) addStatic(n *discover.Node) {
|
||||||
|
s.static[n.ID] = n
|
||||||
|
}
|
||||||
|
|
||||||
|
func (s *dialstate) newTasks(nRunning int, peers map[discover.NodeID]*Peer, now time.Time) []task {
|
||||||
|
var newtasks []task
|
||||||
|
addDial := func(flag connFlag, n *discover.Node) bool {
|
||||||
|
_, dialing := s.dialing[n.ID]
|
||||||
|
if dialing || peers[n.ID] != nil || s.hist.contains(n.ID) {
|
||||||
|
return false
|
||||||
|
}
|
||||||
|
s.dialing[n.ID] = flag
|
||||||
|
newtasks = append(newtasks, &dialTask{flags: flag, dest: n})
|
||||||
|
return true
|
||||||
|
}
|
||||||
|
|
||||||
|
// Compute number of dynamic dials necessary at this point.
|
||||||
|
needDynDials := s.maxDynDials
|
||||||
|
for _, p := range peers {
|
||||||
|
if p.rw.is(dynDialedConn) {
|
||||||
|
needDynDials--
|
||||||
|
}
|
||||||
|
}
|
||||||
|
for _, flag := range s.dialing {
|
||||||
|
if flag&dynDialedConn != 0 {
|
||||||
|
needDynDials--
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// Expire the dial history on every invocation.
|
||||||
|
s.hist.expire(now)
|
||||||
|
|
||||||
|
// Create dials for static nodes if they are not connected.
|
||||||
|
for _, n := range s.static {
|
||||||
|
addDial(staticDialedConn, n)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Use random nodes from the table for half of the necessary
|
||||||
|
// dynamic dials.
|
||||||
|
randomCandidates := needDynDials / 2
|
||||||
|
if randomCandidates > 0 && s.bootstrapped {
|
||||||
|
n := s.ntab.ReadRandomNodes(s.randomNodes)
|
||||||
|
for i := 0; i < randomCandidates && i < n; i++ {
|
||||||
|
if addDial(dynDialedConn, s.randomNodes[i]) {
|
||||||
|
needDynDials--
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
// Create dynamic dials from random lookup results, removing tried
|
||||||
|
// items from the result buffer.
|
||||||
|
i := 0
|
||||||
|
for ; i < len(s.lookupBuf) && needDynDials > 0; i++ {
|
||||||
|
if addDial(dynDialedConn, s.lookupBuf[i]) {
|
||||||
|
needDynDials--
|
||||||
|
}
|
||||||
|
}
|
||||||
|
s.lookupBuf = s.lookupBuf[:copy(s.lookupBuf, s.lookupBuf[i:])]
|
||||||
|
// Launch a discovery lookup if more candidates are needed. The
|
||||||
|
// first discoverTask bootstraps the table and won't return any
|
||||||
|
// results.
|
||||||
|
if len(s.lookupBuf) < needDynDials && !s.lookupRunning {
|
||||||
|
s.lookupRunning = true
|
||||||
|
newtasks = append(newtasks, &discoverTask{bootstrap: !s.bootstrapped})
|
||||||
|
}
|
||||||
|
|
||||||
|
// Launch a timer to wait for the next node to expire if all
|
||||||
|
// candidates have been tried and no task is currently active.
|
||||||
|
// This should prevent cases where the dialer logic is not ticked
|
||||||
|
// because there are no pending events.
|
||||||
|
if nRunning == 0 && len(newtasks) == 0 && s.hist.Len() > 0 {
|
||||||
|
t := &waitExpireTask{s.hist.min().exp.Sub(now)}
|
||||||
|
newtasks = append(newtasks, t)
|
||||||
|
}
|
||||||
|
return newtasks
|
||||||
|
}
|
||||||
|
|
||||||
|
func (s *dialstate) taskDone(t task, now time.Time) {
|
||||||
|
switch t := t.(type) {
|
||||||
|
case *dialTask:
|
||||||
|
s.hist.add(t.dest.ID, now.Add(dialHistoryExpiration))
|
||||||
|
delete(s.dialing, t.dest.ID)
|
||||||
|
case *discoverTask:
|
||||||
|
if t.bootstrap {
|
||||||
|
s.bootstrapped = true
|
||||||
|
}
|
||||||
|
s.lookupRunning = false
|
||||||
|
s.lookupBuf = append(s.lookupBuf, t.results...)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func (t *dialTask) Do(srv *Server) {
|
||||||
|
addr := &net.TCPAddr{IP: t.dest.IP, Port: int(t.dest.TCP)}
|
||||||
|
glog.V(logger.Debug).Infof("dialing %v\n", t.dest)
|
||||||
|
fd, err := srv.Dialer.Dial("tcp", addr.String())
|
||||||
|
if err != nil {
|
||||||
|
glog.V(logger.Detail).Infof("dial error: %v", err)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
srv.setupConn(fd, t.flags, t.dest)
|
||||||
|
}
|
||||||
|
func (t *dialTask) String() string {
|
||||||
|
return fmt.Sprintf("%v %x %v:%d", t.flags, t.dest.ID[:8], t.dest.IP, t.dest.TCP)
|
||||||
|
}
|
||||||
|
|
||||||
|
func (t *discoverTask) Do(srv *Server) {
|
||||||
|
if t.bootstrap {
|
||||||
|
srv.ntab.Bootstrap(srv.BootstrapNodes)
|
||||||
|
} else {
|
||||||
|
var target discover.NodeID
|
||||||
|
rand.Read(target[:])
|
||||||
|
t.results = srv.ntab.Lookup(target)
|
||||||
|
// newTasks generates a lookup task whenever dynamic dials are
|
||||||
|
// necessary. Lookups need to take some time, otherwise the
|
||||||
|
// event loop spins too fast. An empty result can only be
|
||||||
|
// returned if the table is empty.
|
||||||
|
if len(t.results) == 0 {
|
||||||
|
time.Sleep(emptyLookupDelay)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func (t *discoverTask) String() (s string) {
|
||||||
|
if t.bootstrap {
|
||||||
|
s = "discovery bootstrap"
|
||||||
|
} else {
|
||||||
|
s = "discovery lookup"
|
||||||
|
}
|
||||||
|
if len(t.results) > 0 {
|
||||||
|
s += fmt.Sprintf(" (%d results)", len(t.results))
|
||||||
|
}
|
||||||
|
return s
|
||||||
|
}
|
||||||
|
|
||||||
|
func (t waitExpireTask) Do(*Server) {
|
||||||
|
time.Sleep(t.Duration)
|
||||||
|
}
|
||||||
|
func (t waitExpireTask) String() string {
|
||||||
|
return fmt.Sprintf("wait for dial hist expire (%v)", t.Duration)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Use only these methods to access or modify dialHistory.
|
||||||
|
func (h dialHistory) min() pastDial {
|
||||||
|
return h[0]
|
||||||
|
}
|
||||||
|
func (h *dialHistory) add(id discover.NodeID, exp time.Time) {
|
||||||
|
heap.Push(h, pastDial{id, exp})
|
||||||
|
}
|
||||||
|
func (h dialHistory) contains(id discover.NodeID) bool {
|
||||||
|
for _, v := range h {
|
||||||
|
if v.id == id {
|
||||||
|
return true
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return false
|
||||||
|
}
|
||||||
|
func (h *dialHistory) expire(now time.Time) {
|
||||||
|
for h.Len() > 0 && h.min().exp.Before(now) {
|
||||||
|
heap.Pop(h)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// heap.Interface boilerplate
|
||||||
|
func (h dialHistory) Len() int { return len(h) }
|
||||||
|
func (h dialHistory) Less(i, j int) bool { return h[i].exp.Before(h[j].exp) }
|
||||||
|
func (h dialHistory) Swap(i, j int) { h[i], h[j] = h[j], h[i] }
|
||||||
|
func (h *dialHistory) Push(x interface{}) {
|
||||||
|
*h = append(*h, x.(pastDial))
|
||||||
|
}
|
||||||
|
func (h *dialHistory) Pop() interface{} {
|
||||||
|
old := *h
|
||||||
|
n := len(old)
|
||||||
|
x := old[n-1]
|
||||||
|
*h = old[0 : n-1]
|
||||||
|
return x
|
||||||
|
}
|
|
@ -0,0 +1,482 @@
|
||||||
|
package p2p
|
||||||
|
|
||||||
|
import (
|
||||||
|
"encoding/binary"
|
||||||
|
"reflect"
|
||||||
|
"testing"
|
||||||
|
"time"
|
||||||
|
|
||||||
|
"github.com/davecgh/go-spew/spew"
|
||||||
|
"github.com/ethereum/go-ethereum/p2p/discover"
|
||||||
|
)
|
||||||
|
|
||||||
|
func init() {
|
||||||
|
spew.Config.Indent = "\t"
|
||||||
|
}
|
||||||
|
|
||||||
|
type dialtest struct {
|
||||||
|
init *dialstate // state before and after the test.
|
||||||
|
rounds []round
|
||||||
|
}
|
||||||
|
|
||||||
|
type round struct {
|
||||||
|
peers []*Peer // current peer set
|
||||||
|
done []task // tasks that got done this round
|
||||||
|
new []task // the result must match this one
|
||||||
|
}
|
||||||
|
|
||||||
|
func runDialTest(t *testing.T, test dialtest) {
|
||||||
|
var (
|
||||||
|
vtime time.Time
|
||||||
|
running int
|
||||||
|
)
|
||||||
|
pm := func(ps []*Peer) map[discover.NodeID]*Peer {
|
||||||
|
m := make(map[discover.NodeID]*Peer)
|
||||||
|
for _, p := range ps {
|
||||||
|
m[p.rw.id] = p
|
||||||
|
}
|
||||||
|
return m
|
||||||
|
}
|
||||||
|
for i, round := range test.rounds {
|
||||||
|
for _, task := range round.done {
|
||||||
|
running--
|
||||||
|
if running < 0 {
|
||||||
|
panic("running task counter underflow")
|
||||||
|
}
|
||||||
|
test.init.taskDone(task, vtime)
|
||||||
|
}
|
||||||
|
|
||||||
|
new := test.init.newTasks(running, pm(round.peers), vtime)
|
||||||
|
if !sametasks(new, round.new) {
|
||||||
|
t.Errorf("round %d: new tasks mismatch:\ngot %v\nwant %v\nstate: %v\nrunning: %v\n",
|
||||||
|
i, spew.Sdump(new), spew.Sdump(round.new), spew.Sdump(test.init), spew.Sdump(running))
|
||||||
|
}
|
||||||
|
|
||||||
|
// Time advances by 16 seconds on every round.
|
||||||
|
vtime = vtime.Add(16 * time.Second)
|
||||||
|
running += len(new)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
type fakeTable []*discover.Node
|
||||||
|
|
||||||
|
func (t fakeTable) Self() *discover.Node { return new(discover.Node) }
|
||||||
|
func (t fakeTable) Close() {}
|
||||||
|
func (t fakeTable) Bootstrap([]*discover.Node) {}
|
||||||
|
func (t fakeTable) Lookup(target discover.NodeID) []*discover.Node {
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
func (t fakeTable) ReadRandomNodes(buf []*discover.Node) int {
|
||||||
|
return copy(buf, t)
|
||||||
|
}
|
||||||
|
|
||||||
|
// This test checks that dynamic dials are launched from discovery results.
|
||||||
|
func TestDialStateDynDial(t *testing.T) {
|
||||||
|
runDialTest(t, dialtest{
|
||||||
|
init: newDialState(nil, fakeTable{}, 5),
|
||||||
|
rounds: []round{
|
||||||
|
// A discovery query is launched.
|
||||||
|
{
|
||||||
|
peers: []*Peer{
|
||||||
|
{rw: &conn{flags: staticDialedConn, id: uintID(0)}},
|
||||||
|
{rw: &conn{flags: dynDialedConn, id: uintID(1)}},
|
||||||
|
{rw: &conn{flags: dynDialedConn, id: uintID(2)}},
|
||||||
|
},
|
||||||
|
new: []task{&discoverTask{bootstrap: true}},
|
||||||
|
},
|
||||||
|
// Dynamic dials are launched when it completes.
|
||||||
|
{
|
||||||
|
peers: []*Peer{
|
||||||
|
{rw: &conn{flags: staticDialedConn, id: uintID(0)}},
|
||||||
|
{rw: &conn{flags: dynDialedConn, id: uintID(1)}},
|
||||||
|
{rw: &conn{flags: dynDialedConn, id: uintID(2)}},
|
||||||
|
},
|
||||||
|
done: []task{
|
||||||
|
&discoverTask{bootstrap: true, results: []*discover.Node{
|
||||||
|
{ID: uintID(2)}, // this one is already connected and not dialed.
|
||||||
|
{ID: uintID(3)},
|
||||||
|
{ID: uintID(4)},
|
||||||
|
{ID: uintID(5)},
|
||||||
|
{ID: uintID(6)}, // these are not tried because max dyn dials is 5
|
||||||
|
{ID: uintID(7)}, // ...
|
||||||
|
}},
|
||||||
|
},
|
||||||
|
new: []task{
|
||||||
|
&dialTask{dynDialedConn, &discover.Node{ID: uintID(3)}},
|
||||||
|
&dialTask{dynDialedConn, &discover.Node{ID: uintID(4)}},
|
||||||
|
&dialTask{dynDialedConn, &discover.Node{ID: uintID(5)}},
|
||||||
|
},
|
||||||
|
},
|
||||||
|
// Some of the dials complete but no new ones are launched yet because
|
||||||
|
// the sum of active dial count and dynamic peer count is == maxDynDials.
|
||||||
|
{
|
||||||
|
peers: []*Peer{
|
||||||
|
{rw: &conn{flags: staticDialedConn, id: uintID(0)}},
|
||||||
|
{rw: &conn{flags: dynDialedConn, id: uintID(1)}},
|
||||||
|
{rw: &conn{flags: dynDialedConn, id: uintID(2)}},
|
||||||
|
{rw: &conn{flags: dynDialedConn, id: uintID(3)}},
|
||||||
|
{rw: &conn{flags: dynDialedConn, id: uintID(4)}},
|
||||||
|
},
|
||||||
|
done: []task{
|
||||||
|
&dialTask{dynDialedConn, &discover.Node{ID: uintID(3)}},
|
||||||
|
&dialTask{dynDialedConn, &discover.Node{ID: uintID(4)}},
|
||||||
|
},
|
||||||
|
},
|
||||||
|
// No new dial tasks are launched in the this round because
|
||||||
|
// maxDynDials has been reached.
|
||||||
|
{
|
||||||
|
peers: []*Peer{
|
||||||
|
{rw: &conn{flags: staticDialedConn, id: uintID(0)}},
|
||||||
|
{rw: &conn{flags: dynDialedConn, id: uintID(1)}},
|
||||||
|
{rw: &conn{flags: dynDialedConn, id: uintID(2)}},
|
||||||
|
{rw: &conn{flags: dynDialedConn, id: uintID(3)}},
|
||||||
|
{rw: &conn{flags: dynDialedConn, id: uintID(4)}},
|
||||||
|
{rw: &conn{flags: dynDialedConn, id: uintID(5)}},
|
||||||
|
},
|
||||||
|
done: []task{
|
||||||
|
&dialTask{dynDialedConn, &discover.Node{ID: uintID(5)}},
|
||||||
|
},
|
||||||
|
new: []task{
|
||||||
|
&waitExpireTask{Duration: 14 * time.Second},
|
||||||
|
},
|
||||||
|
},
|
||||||
|
// In this round, the peer with id 2 drops off. The query
|
||||||
|
// results from last discovery lookup are reused.
|
||||||
|
{
|
||||||
|
peers: []*Peer{
|
||||||
|
{rw: &conn{flags: staticDialedConn, id: uintID(0)}},
|
||||||
|
{rw: &conn{flags: dynDialedConn, id: uintID(1)}},
|
||||||
|
{rw: &conn{flags: dynDialedConn, id: uintID(3)}},
|
||||||
|
{rw: &conn{flags: dynDialedConn, id: uintID(4)}},
|
||||||
|
{rw: &conn{flags: dynDialedConn, id: uintID(5)}},
|
||||||
|
},
|
||||||
|
new: []task{
|
||||||
|
&dialTask{dynDialedConn, &discover.Node{ID: uintID(6)}},
|
||||||
|
},
|
||||||
|
},
|
||||||
|
// More peers (3,4) drop off and dial for ID 6 completes.
|
||||||
|
// The last query result from the discovery lookup is reused
|
||||||
|
// and a new one is spawned because more candidates are needed.
|
||||||
|
{
|
||||||
|
peers: []*Peer{
|
||||||
|
{rw: &conn{flags: staticDialedConn, id: uintID(0)}},
|
||||||
|
{rw: &conn{flags: dynDialedConn, id: uintID(1)}},
|
||||||
|
{rw: &conn{flags: dynDialedConn, id: uintID(5)}},
|
||||||
|
},
|
||||||
|
done: []task{
|
||||||
|
&dialTask{dynDialedConn, &discover.Node{ID: uintID(6)}},
|
||||||
|
},
|
||||||
|
new: []task{
|
||||||
|
&dialTask{dynDialedConn, &discover.Node{ID: uintID(7)}},
|
||||||
|
&discoverTask{},
|
||||||
|
},
|
||||||
|
},
|
||||||
|
// Peer 7 is connected, but there still aren't enough dynamic peers
|
||||||
|
// (4 out of 5). However, a discovery is already running, so ensure
|
||||||
|
// no new is started.
|
||||||
|
{
|
||||||
|
peers: []*Peer{
|
||||||
|
{rw: &conn{flags: staticDialedConn, id: uintID(0)}},
|
||||||
|
{rw: &conn{flags: dynDialedConn, id: uintID(1)}},
|
||||||
|
{rw: &conn{flags: dynDialedConn, id: uintID(5)}},
|
||||||
|
{rw: &conn{flags: dynDialedConn, id: uintID(7)}},
|
||||||
|
},
|
||||||
|
done: []task{
|
||||||
|
&dialTask{dynDialedConn, &discover.Node{ID: uintID(7)}},
|
||||||
|
},
|
||||||
|
},
|
||||||
|
// Finish the running node discovery with an empty set. A new lookup
|
||||||
|
// should be immediately requested.
|
||||||
|
{
|
||||||
|
peers: []*Peer{
|
||||||
|
{rw: &conn{flags: staticDialedConn, id: uintID(0)}},
|
||||||
|
{rw: &conn{flags: dynDialedConn, id: uintID(1)}},
|
||||||
|
{rw: &conn{flags: dynDialedConn, id: uintID(5)}},
|
||||||
|
{rw: &conn{flags: dynDialedConn, id: uintID(7)}},
|
||||||
|
},
|
||||||
|
done: []task{
|
||||||
|
&discoverTask{},
|
||||||
|
},
|
||||||
|
new: []task{
|
||||||
|
&discoverTask{},
|
||||||
|
},
|
||||||
|
},
|
||||||
|
},
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestDialStateDynDialFromTable(t *testing.T) {
|
||||||
|
// This table always returns the same random nodes
|
||||||
|
// in the order given below.
|
||||||
|
table := fakeTable{
|
||||||
|
{ID: uintID(1)},
|
||||||
|
{ID: uintID(2)},
|
||||||
|
{ID: uintID(3)},
|
||||||
|
{ID: uintID(4)},
|
||||||
|
{ID: uintID(5)},
|
||||||
|
{ID: uintID(6)},
|
||||||
|
{ID: uintID(7)},
|
||||||
|
{ID: uintID(8)},
|
||||||
|
}
|
||||||
|
|
||||||
|
runDialTest(t, dialtest{
|
||||||
|
init: newDialState(nil, table, 10),
|
||||||
|
rounds: []round{
|
||||||
|
// Discovery bootstrap is launched.
|
||||||
|
{
|
||||||
|
new: []task{&discoverTask{bootstrap: true}},
|
||||||
|
},
|
||||||
|
// 5 out of 8 of the nodes returned by ReadRandomNodes are dialed.
|
||||||
|
{
|
||||||
|
done: []task{
|
||||||
|
&discoverTask{bootstrap: true},
|
||||||
|
},
|
||||||
|
new: []task{
|
||||||
|
&dialTask{dynDialedConn, &discover.Node{ID: uintID(1)}},
|
||||||
|
&dialTask{dynDialedConn, &discover.Node{ID: uintID(2)}},
|
||||||
|
&dialTask{dynDialedConn, &discover.Node{ID: uintID(3)}},
|
||||||
|
&dialTask{dynDialedConn, &discover.Node{ID: uintID(4)}},
|
||||||
|
&dialTask{dynDialedConn, &discover.Node{ID: uintID(5)}},
|
||||||
|
&discoverTask{bootstrap: false},
|
||||||
|
},
|
||||||
|
},
|
||||||
|
// Dialing nodes 1,2 succeeds. Dials from the lookup are launched.
|
||||||
|
{
|
||||||
|
peers: []*Peer{
|
||||||
|
{rw: &conn{flags: dynDialedConn, id: uintID(1)}},
|
||||||
|
{rw: &conn{flags: dynDialedConn, id: uintID(2)}},
|
||||||
|
},
|
||||||
|
done: []task{
|
||||||
|
&dialTask{dynDialedConn, &discover.Node{ID: uintID(1)}},
|
||||||
|
&dialTask{dynDialedConn, &discover.Node{ID: uintID(2)}},
|
||||||
|
&discoverTask{results: []*discover.Node{
|
||||||
|
{ID: uintID(10)},
|
||||||
|
{ID: uintID(11)},
|
||||||
|
{ID: uintID(12)},
|
||||||
|
}},
|
||||||
|
},
|
||||||
|
new: []task{
|
||||||
|
&dialTask{dynDialedConn, &discover.Node{ID: uintID(10)}},
|
||||||
|
&dialTask{dynDialedConn, &discover.Node{ID: uintID(11)}},
|
||||||
|
&dialTask{dynDialedConn, &discover.Node{ID: uintID(12)}},
|
||||||
|
&discoverTask{bootstrap: false},
|
||||||
|
},
|
||||||
|
},
|
||||||
|
// Dialing nodes 3,4,5 fails. The dials from the lookup succeed.
|
||||||
|
{
|
||||||
|
peers: []*Peer{
|
||||||
|
{rw: &conn{flags: dynDialedConn, id: uintID(1)}},
|
||||||
|
{rw: &conn{flags: dynDialedConn, id: uintID(2)}},
|
||||||
|
{rw: &conn{flags: dynDialedConn, id: uintID(10)}},
|
||||||
|
{rw: &conn{flags: dynDialedConn, id: uintID(11)}},
|
||||||
|
{rw: &conn{flags: dynDialedConn, id: uintID(12)}},
|
||||||
|
},
|
||||||
|
done: []task{
|
||||||
|
&dialTask{dynDialedConn, &discover.Node{ID: uintID(3)}},
|
||||||
|
&dialTask{dynDialedConn, &discover.Node{ID: uintID(4)}},
|
||||||
|
&dialTask{dynDialedConn, &discover.Node{ID: uintID(5)}},
|
||||||
|
&dialTask{dynDialedConn, &discover.Node{ID: uintID(10)}},
|
||||||
|
&dialTask{dynDialedConn, &discover.Node{ID: uintID(11)}},
|
||||||
|
&dialTask{dynDialedConn, &discover.Node{ID: uintID(12)}},
|
||||||
|
},
|
||||||
|
},
|
||||||
|
// Waiting for expiry. No waitExpireTask is launched because the
|
||||||
|
// discovery query is still running.
|
||||||
|
{
|
||||||
|
peers: []*Peer{
|
||||||
|
{rw: &conn{flags: dynDialedConn, id: uintID(1)}},
|
||||||
|
{rw: &conn{flags: dynDialedConn, id: uintID(2)}},
|
||||||
|
{rw: &conn{flags: dynDialedConn, id: uintID(10)}},
|
||||||
|
{rw: &conn{flags: dynDialedConn, id: uintID(11)}},
|
||||||
|
{rw: &conn{flags: dynDialedConn, id: uintID(12)}},
|
||||||
|
},
|
||||||
|
},
|
||||||
|
// Nodes 3,4 are not tried again because only the first two
|
||||||
|
// returned random nodes (nodes 1,2) are tried and they're
|
||||||
|
// already connected.
|
||||||
|
{
|
||||||
|
peers: []*Peer{
|
||||||
|
{rw: &conn{flags: dynDialedConn, id: uintID(1)}},
|
||||||
|
{rw: &conn{flags: dynDialedConn, id: uintID(2)}},
|
||||||
|
{rw: &conn{flags: dynDialedConn, id: uintID(10)}},
|
||||||
|
{rw: &conn{flags: dynDialedConn, id: uintID(11)}},
|
||||||
|
{rw: &conn{flags: dynDialedConn, id: uintID(12)}},
|
||||||
|
},
|
||||||
|
},
|
||||||
|
},
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
|
// This test checks that static dials are launched.
|
||||||
|
func TestDialStateStaticDial(t *testing.T) {
|
||||||
|
wantStatic := []*discover.Node{
|
||||||
|
{ID: uintID(1)},
|
||||||
|
{ID: uintID(2)},
|
||||||
|
{ID: uintID(3)},
|
||||||
|
{ID: uintID(4)},
|
||||||
|
{ID: uintID(5)},
|
||||||
|
}
|
||||||
|
|
||||||
|
runDialTest(t, dialtest{
|
||||||
|
init: newDialState(wantStatic, fakeTable{}, 0),
|
||||||
|
rounds: []round{
|
||||||
|
// Static dials are launched for the nodes that
|
||||||
|
// aren't yet connected.
|
||||||
|
{
|
||||||
|
peers: []*Peer{
|
||||||
|
{rw: &conn{flags: dynDialedConn, id: uintID(1)}},
|
||||||
|
{rw: &conn{flags: dynDialedConn, id: uintID(2)}},
|
||||||
|
},
|
||||||
|
new: []task{
|
||||||
|
&dialTask{staticDialedConn, &discover.Node{ID: uintID(3)}},
|
||||||
|
&dialTask{staticDialedConn, &discover.Node{ID: uintID(4)}},
|
||||||
|
&dialTask{staticDialedConn, &discover.Node{ID: uintID(5)}},
|
||||||
|
},
|
||||||
|
},
|
||||||
|
// No new tasks are launched in this round because all static
|
||||||
|
// nodes are either connected or still being dialed.
|
||||||
|
{
|
||||||
|
peers: []*Peer{
|
||||||
|
{rw: &conn{flags: dynDialedConn, id: uintID(1)}},
|
||||||
|
{rw: &conn{flags: dynDialedConn, id: uintID(2)}},
|
||||||
|
{rw: &conn{flags: staticDialedConn, id: uintID(3)}},
|
||||||
|
},
|
||||||
|
done: []task{
|
||||||
|
&dialTask{staticDialedConn, &discover.Node{ID: uintID(3)}},
|
||||||
|
},
|
||||||
|
},
|
||||||
|
// No new dial tasks are launched because all static
|
||||||
|
// nodes are now connected.
|
||||||
|
{
|
||||||
|
peers: []*Peer{
|
||||||
|
{rw: &conn{flags: dynDialedConn, id: uintID(1)}},
|
||||||
|
{rw: &conn{flags: dynDialedConn, id: uintID(2)}},
|
||||||
|
{rw: &conn{flags: staticDialedConn, id: uintID(3)}},
|
||||||
|
{rw: &conn{flags: staticDialedConn, id: uintID(4)}},
|
||||||
|
{rw: &conn{flags: staticDialedConn, id: uintID(5)}},
|
||||||
|
},
|
||||||
|
done: []task{
|
||||||
|
&dialTask{staticDialedConn, &discover.Node{ID: uintID(4)}},
|
||||||
|
&dialTask{staticDialedConn, &discover.Node{ID: uintID(5)}},
|
||||||
|
},
|
||||||
|
new: []task{
|
||||||
|
&waitExpireTask{Duration: 14 * time.Second},
|
||||||
|
},
|
||||||
|
},
|
||||||
|
// Wait a round for dial history to expire, no new tasks should spawn.
|
||||||
|
{
|
||||||
|
peers: []*Peer{
|
||||||
|
{rw: &conn{flags: dynDialedConn, id: uintID(1)}},
|
||||||
|
{rw: &conn{flags: dynDialedConn, id: uintID(2)}},
|
||||||
|
{rw: &conn{flags: staticDialedConn, id: uintID(3)}},
|
||||||
|
{rw: &conn{flags: staticDialedConn, id: uintID(4)}},
|
||||||
|
{rw: &conn{flags: staticDialedConn, id: uintID(5)}},
|
||||||
|
},
|
||||||
|
},
|
||||||
|
// If a static node is dropped, it should be immediately redialed,
|
||||||
|
// irrespective whether it was originally static or dynamic.
|
||||||
|
{
|
||||||
|
peers: []*Peer{
|
||||||
|
{rw: &conn{flags: dynDialedConn, id: uintID(1)}},
|
||||||
|
{rw: &conn{flags: staticDialedConn, id: uintID(3)}},
|
||||||
|
{rw: &conn{flags: staticDialedConn, id: uintID(5)}},
|
||||||
|
},
|
||||||
|
new: []task{
|
||||||
|
&dialTask{staticDialedConn, &discover.Node{ID: uintID(2)}},
|
||||||
|
&dialTask{staticDialedConn, &discover.Node{ID: uintID(4)}},
|
||||||
|
},
|
||||||
|
},
|
||||||
|
},
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
|
// This test checks that past dials are not retried for some time.
|
||||||
|
func TestDialStateCache(t *testing.T) {
|
||||||
|
wantStatic := []*discover.Node{
|
||||||
|
{ID: uintID(1)},
|
||||||
|
{ID: uintID(2)},
|
||||||
|
{ID: uintID(3)},
|
||||||
|
}
|
||||||
|
|
||||||
|
runDialTest(t, dialtest{
|
||||||
|
init: newDialState(wantStatic, fakeTable{}, 0),
|
||||||
|
rounds: []round{
|
||||||
|
// Static dials are launched for the nodes that
|
||||||
|
// aren't yet connected.
|
||||||
|
{
|
||||||
|
peers: nil,
|
||||||
|
new: []task{
|
||||||
|
&dialTask{staticDialedConn, &discover.Node{ID: uintID(1)}},
|
||||||
|
&dialTask{staticDialedConn, &discover.Node{ID: uintID(2)}},
|
||||||
|
&dialTask{staticDialedConn, &discover.Node{ID: uintID(3)}},
|
||||||
|
},
|
||||||
|
},
|
||||||
|
// No new tasks are launched in this round because all static
|
||||||
|
// nodes are either connected or still being dialed.
|
||||||
|
{
|
||||||
|
peers: []*Peer{
|
||||||
|
{rw: &conn{flags: staticDialedConn, id: uintID(1)}},
|
||||||
|
{rw: &conn{flags: staticDialedConn, id: uintID(2)}},
|
||||||
|
},
|
||||||
|
done: []task{
|
||||||
|
&dialTask{staticDialedConn, &discover.Node{ID: uintID(1)}},
|
||||||
|
&dialTask{staticDialedConn, &discover.Node{ID: uintID(2)}},
|
||||||
|
},
|
||||||
|
},
|
||||||
|
// A salvage task is launched to wait for node 3's history
|
||||||
|
// entry to expire.
|
||||||
|
{
|
||||||
|
peers: []*Peer{
|
||||||
|
{rw: &conn{flags: dynDialedConn, id: uintID(1)}},
|
||||||
|
{rw: &conn{flags: dynDialedConn, id: uintID(2)}},
|
||||||
|
},
|
||||||
|
done: []task{
|
||||||
|
&dialTask{staticDialedConn, &discover.Node{ID: uintID(3)}},
|
||||||
|
},
|
||||||
|
new: []task{
|
||||||
|
&waitExpireTask{Duration: 14 * time.Second},
|
||||||
|
},
|
||||||
|
},
|
||||||
|
// Still waiting for node 3's entry to expire in the cache.
|
||||||
|
{
|
||||||
|
peers: []*Peer{
|
||||||
|
{rw: &conn{flags: dynDialedConn, id: uintID(1)}},
|
||||||
|
{rw: &conn{flags: dynDialedConn, id: uintID(2)}},
|
||||||
|
},
|
||||||
|
},
|
||||||
|
// The cache entry for node 3 has expired and is retried.
|
||||||
|
{
|
||||||
|
peers: []*Peer{
|
||||||
|
{rw: &conn{flags: dynDialedConn, id: uintID(1)}},
|
||||||
|
{rw: &conn{flags: dynDialedConn, id: uintID(2)}},
|
||||||
|
},
|
||||||
|
new: []task{
|
||||||
|
&dialTask{staticDialedConn, &discover.Node{ID: uintID(3)}},
|
||||||
|
},
|
||||||
|
},
|
||||||
|
},
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
|
// compares task lists but doesn't care about the order.
|
||||||
|
func sametasks(a, b []task) bool {
|
||||||
|
if len(a) != len(b) {
|
||||||
|
return false
|
||||||
|
}
|
||||||
|
next:
|
||||||
|
for _, ta := range a {
|
||||||
|
for _, tb := range b {
|
||||||
|
if reflect.DeepEqual(ta, tb) {
|
||||||
|
continue next
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return false
|
||||||
|
}
|
||||||
|
return true
|
||||||
|
}
|
||||||
|
|
||||||
|
func uintID(i uint32) discover.NodeID {
|
||||||
|
var id discover.NodeID
|
||||||
|
binary.BigEndian.PutUint32(id[:], i)
|
||||||
|
return id
|
||||||
|
}
|
|
@ -8,6 +8,7 @@ package discover
|
||||||
|
|
||||||
import (
|
import (
|
||||||
"crypto/rand"
|
"crypto/rand"
|
||||||
|
"encoding/binary"
|
||||||
"net"
|
"net"
|
||||||
"sort"
|
"sort"
|
||||||
"sync"
|
"sync"
|
||||||
|
@ -90,10 +91,58 @@ func newTable(t transport, ourID NodeID, ourAddr *net.UDPAddr, nodeDBPath string
|
||||||
}
|
}
|
||||||
|
|
||||||
// Self returns the local node.
|
// Self returns the local node.
|
||||||
|
// The returned node should not be modified by the caller.
|
||||||
func (tab *Table) Self() *Node {
|
func (tab *Table) Self() *Node {
|
||||||
return tab.self
|
return tab.self
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// ReadRandomNodes fills the given slice with random nodes from the
|
||||||
|
// table. It will not write the same node more than once. The nodes in
|
||||||
|
// the slice are copies and can be modified by the caller.
|
||||||
|
func (tab *Table) ReadRandomNodes(buf []*Node) (n int) {
|
||||||
|
tab.mutex.Lock()
|
||||||
|
defer tab.mutex.Unlock()
|
||||||
|
// TODO: tree-based buckets would help here
|
||||||
|
// Find all non-empty buckets and get a fresh slice of their entries.
|
||||||
|
var buckets [][]*Node
|
||||||
|
for _, b := range tab.buckets {
|
||||||
|
if len(b.entries) > 0 {
|
||||||
|
buckets = append(buckets, b.entries[:])
|
||||||
|
}
|
||||||
|
}
|
||||||
|
if len(buckets) == 0 {
|
||||||
|
return 0
|
||||||
|
}
|
||||||
|
// Shuffle the buckets.
|
||||||
|
for i := uint32(len(buckets)) - 1; i > 0; i-- {
|
||||||
|
j := randUint(i)
|
||||||
|
buckets[i], buckets[j] = buckets[j], buckets[i]
|
||||||
|
}
|
||||||
|
// Move head of each bucket into buf, removing buckets that become empty.
|
||||||
|
var i, j int
|
||||||
|
for ; i < len(buf); i, j = i+1, (j+1)%len(buckets) {
|
||||||
|
b := buckets[j]
|
||||||
|
buf[i] = &(*b[0])
|
||||||
|
buckets[j] = b[1:]
|
||||||
|
if len(b) == 1 {
|
||||||
|
buckets = append(buckets[:j], buckets[j+1:]...)
|
||||||
|
}
|
||||||
|
if len(buckets) == 0 {
|
||||||
|
break
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return i + 1
|
||||||
|
}
|
||||||
|
|
||||||
|
func randUint(max uint32) uint32 {
|
||||||
|
if max == 0 {
|
||||||
|
return 0
|
||||||
|
}
|
||||||
|
var b [4]byte
|
||||||
|
rand.Read(b[:])
|
||||||
|
return binary.BigEndian.Uint32(b[:]) % max
|
||||||
|
}
|
||||||
|
|
||||||
// Close terminates the network listener and flushes the node database.
|
// Close terminates the network listener and flushes the node database.
|
||||||
func (tab *Table) Close() {
|
func (tab *Table) Close() {
|
||||||
tab.net.close()
|
tab.net.close()
|
||||||
|
|
|
@ -210,6 +210,36 @@ func TestTable_closest(t *testing.T) {
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func TestTable_ReadRandomNodesGetAll(t *testing.T) {
|
||||||
|
cfg := &quick.Config{
|
||||||
|
MaxCount: 200,
|
||||||
|
Rand: quickrand,
|
||||||
|
Values: func(args []reflect.Value, rand *rand.Rand) {
|
||||||
|
args[0] = reflect.ValueOf(make([]*Node, rand.Intn(1000)))
|
||||||
|
},
|
||||||
|
}
|
||||||
|
test := func(buf []*Node) bool {
|
||||||
|
tab := newTable(nil, NodeID{}, &net.UDPAddr{}, "")
|
||||||
|
for i := 0; i < len(buf); i++ {
|
||||||
|
ld := quickrand.Intn(len(tab.buckets))
|
||||||
|
tab.add([]*Node{nodeAtDistance(tab.self.sha, ld)})
|
||||||
|
}
|
||||||
|
gotN := tab.ReadRandomNodes(buf)
|
||||||
|
if gotN != tab.len() {
|
||||||
|
t.Errorf("wrong number of nodes, got %d, want %d", gotN, tab.len())
|
||||||
|
return false
|
||||||
|
}
|
||||||
|
if hasDuplicates(buf[:gotN]) {
|
||||||
|
t.Errorf("result contains duplicates")
|
||||||
|
return false
|
||||||
|
}
|
||||||
|
return true
|
||||||
|
}
|
||||||
|
if err := quick.Check(test, cfg); err != nil {
|
||||||
|
t.Error(err)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
type closeTest struct {
|
type closeTest struct {
|
||||||
Self NodeID
|
Self NodeID
|
||||||
Target common.Hash
|
Target common.Hash
|
||||||
|
@ -517,7 +547,10 @@ func (n *preminedTestnet) mine(target NodeID) {
|
||||||
|
|
||||||
func hasDuplicates(slice []*Node) bool {
|
func hasDuplicates(slice []*Node) bool {
|
||||||
seen := make(map[NodeID]bool)
|
seen := make(map[NodeID]bool)
|
||||||
for _, e := range slice {
|
for i, e := range slice {
|
||||||
|
if e == nil {
|
||||||
|
panic(fmt.Sprintf("nil *Node at %d", i))
|
||||||
|
}
|
||||||
if seen[e.ID] {
|
if seen[e.ID] {
|
||||||
return true
|
return true
|
||||||
}
|
}
|
||||||
|
|
448
p2p/handshake.go
448
p2p/handshake.go
|
@ -1,448 +0,0 @@
|
||||||
package p2p
|
|
||||||
|
|
||||||
import (
|
|
||||||
"crypto/ecdsa"
|
|
||||||
"crypto/elliptic"
|
|
||||||
"crypto/rand"
|
|
||||||
"errors"
|
|
||||||
"fmt"
|
|
||||||
"hash"
|
|
||||||
"io"
|
|
||||||
"net"
|
|
||||||
|
|
||||||
"github.com/ethereum/go-ethereum/crypto"
|
|
||||||
"github.com/ethereum/go-ethereum/crypto/ecies"
|
|
||||||
"github.com/ethereum/go-ethereum/crypto/secp256k1"
|
|
||||||
"github.com/ethereum/go-ethereum/crypto/sha3"
|
|
||||||
"github.com/ethereum/go-ethereum/p2p/discover"
|
|
||||||
"github.com/ethereum/go-ethereum/rlp"
|
|
||||||
)
|
|
||||||
|
|
||||||
const (
|
|
||||||
sskLen = 16 // ecies.MaxSharedKeyLength(pubKey) / 2
|
|
||||||
sigLen = 65 // elliptic S256
|
|
||||||
pubLen = 64 // 512 bit pubkey in uncompressed representation without format byte
|
|
||||||
shaLen = 32 // hash length (for nonce etc)
|
|
||||||
|
|
||||||
authMsgLen = sigLen + shaLen + pubLen + shaLen + 1
|
|
||||||
authRespLen = pubLen + shaLen + 1
|
|
||||||
|
|
||||||
eciesBytes = 65 + 16 + 32
|
|
||||||
encAuthMsgLen = authMsgLen + eciesBytes // size of the final ECIES payload sent as initiator's handshake
|
|
||||||
encAuthRespLen = authRespLen + eciesBytes // size of the final ECIES payload sent as receiver's handshake
|
|
||||||
)
|
|
||||||
|
|
||||||
// conn represents a remote connection after encryption handshake
|
|
||||||
// and protocol handshake have completed.
|
|
||||||
//
|
|
||||||
// The MsgReadWriter is usually layered as follows:
|
|
||||||
//
|
|
||||||
// netWrapper (I/O timeouts, thread-safe ReadMsg, WriteMsg)
|
|
||||||
// rlpxFrameRW (message encoding, encryption, authentication)
|
|
||||||
// bufio.ReadWriter (buffering)
|
|
||||||
// net.Conn (network I/O)
|
|
||||||
//
|
|
||||||
type conn struct {
|
|
||||||
MsgReadWriter
|
|
||||||
*protoHandshake
|
|
||||||
}
|
|
||||||
|
|
||||||
// secrets represents the connection secrets
|
|
||||||
// which are negotiated during the encryption handshake.
|
|
||||||
type secrets struct {
|
|
||||||
RemoteID discover.NodeID
|
|
||||||
AES, MAC []byte
|
|
||||||
EgressMAC, IngressMAC hash.Hash
|
|
||||||
Token []byte
|
|
||||||
}
|
|
||||||
|
|
||||||
// protoHandshake is the RLP structure of the protocol handshake.
|
|
||||||
type protoHandshake struct {
|
|
||||||
Version uint64
|
|
||||||
Name string
|
|
||||||
Caps []Cap
|
|
||||||
ListenPort uint64
|
|
||||||
ID discover.NodeID
|
|
||||||
}
|
|
||||||
|
|
||||||
// setupConn starts a protocol session on the given connection. It
|
|
||||||
// runs the encryption handshake and the protocol handshake. If dial
|
|
||||||
// is non-nil, the connection the local node is the initiator. If
|
|
||||||
// keepconn returns false, the connection will be disconnected with
|
|
||||||
// DiscTooManyPeers after the key exchange.
|
|
||||||
func setupConn(fd net.Conn, prv *ecdsa.PrivateKey, our *protoHandshake, dial *discover.Node, keepconn func(discover.NodeID) bool) (*conn, error) {
|
|
||||||
if dial == nil {
|
|
||||||
return setupInboundConn(fd, prv, our, keepconn)
|
|
||||||
} else {
|
|
||||||
return setupOutboundConn(fd, prv, our, dial, keepconn)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
func setupInboundConn(fd net.Conn, prv *ecdsa.PrivateKey, our *protoHandshake, keepconn func(discover.NodeID) bool) (*conn, error) {
|
|
||||||
secrets, err := receiverEncHandshake(fd, prv, nil)
|
|
||||||
if err != nil {
|
|
||||||
return nil, fmt.Errorf("encryption handshake failed: %v", err)
|
|
||||||
}
|
|
||||||
rw := newRlpxFrameRW(fd, secrets)
|
|
||||||
if !keepconn(secrets.RemoteID) {
|
|
||||||
SendItems(rw, discMsg, DiscTooManyPeers)
|
|
||||||
return nil, errors.New("we have too many peers")
|
|
||||||
}
|
|
||||||
// Run the protocol handshake using authenticated messages.
|
|
||||||
rhs, err := readProtocolHandshake(rw, secrets.RemoteID, our)
|
|
||||||
if err != nil {
|
|
||||||
return nil, err
|
|
||||||
}
|
|
||||||
if err := Send(rw, handshakeMsg, our); err != nil {
|
|
||||||
return nil, fmt.Errorf("protocol handshake write error: %v", err)
|
|
||||||
}
|
|
||||||
return &conn{rw, rhs}, nil
|
|
||||||
}
|
|
||||||
|
|
||||||
func setupOutboundConn(fd net.Conn, prv *ecdsa.PrivateKey, our *protoHandshake, dial *discover.Node, keepconn func(discover.NodeID) bool) (*conn, error) {
|
|
||||||
secrets, err := initiatorEncHandshake(fd, prv, dial.ID, nil)
|
|
||||||
if err != nil {
|
|
||||||
return nil, fmt.Errorf("encryption handshake failed: %v", err)
|
|
||||||
}
|
|
||||||
rw := newRlpxFrameRW(fd, secrets)
|
|
||||||
if !keepconn(secrets.RemoteID) {
|
|
||||||
SendItems(rw, discMsg, DiscTooManyPeers)
|
|
||||||
return nil, errors.New("we have too many peers")
|
|
||||||
}
|
|
||||||
// Run the protocol handshake using authenticated messages.
|
|
||||||
//
|
|
||||||
// Note that even though writing the handshake is first, we prefer
|
|
||||||
// returning the handshake read error. If the remote side
|
|
||||||
// disconnects us early with a valid reason, we should return it
|
|
||||||
// as the error so it can be tracked elsewhere.
|
|
||||||
werr := make(chan error, 1)
|
|
||||||
go func() { werr <- Send(rw, handshakeMsg, our) }()
|
|
||||||
rhs, err := readProtocolHandshake(rw, secrets.RemoteID, our)
|
|
||||||
if err != nil {
|
|
||||||
return nil, err
|
|
||||||
}
|
|
||||||
if err := <-werr; err != nil {
|
|
||||||
return nil, fmt.Errorf("protocol handshake write error: %v", err)
|
|
||||||
}
|
|
||||||
if rhs.ID != dial.ID {
|
|
||||||
return nil, errors.New("dialed node id mismatch")
|
|
||||||
}
|
|
||||||
return &conn{rw, rhs}, nil
|
|
||||||
}
|
|
||||||
|
|
||||||
// encHandshake contains the state of the encryption handshake.
|
|
||||||
type encHandshake struct {
|
|
||||||
initiator bool
|
|
||||||
remoteID discover.NodeID
|
|
||||||
|
|
||||||
remotePub *ecies.PublicKey // remote-pubk
|
|
||||||
initNonce, respNonce []byte // nonce
|
|
||||||
randomPrivKey *ecies.PrivateKey // ecdhe-random
|
|
||||||
remoteRandomPub *ecies.PublicKey // ecdhe-random-pubk
|
|
||||||
}
|
|
||||||
|
|
||||||
// secrets is called after the handshake is completed.
|
|
||||||
// It extracts the connection secrets from the handshake values.
|
|
||||||
func (h *encHandshake) secrets(auth, authResp []byte) (secrets, error) {
|
|
||||||
ecdheSecret, err := h.randomPrivKey.GenerateShared(h.remoteRandomPub, sskLen, sskLen)
|
|
||||||
if err != nil {
|
|
||||||
return secrets{}, err
|
|
||||||
}
|
|
||||||
|
|
||||||
// derive base secrets from ephemeral key agreement
|
|
||||||
sharedSecret := crypto.Sha3(ecdheSecret, crypto.Sha3(h.respNonce, h.initNonce))
|
|
||||||
aesSecret := crypto.Sha3(ecdheSecret, sharedSecret)
|
|
||||||
s := secrets{
|
|
||||||
RemoteID: h.remoteID,
|
|
||||||
AES: aesSecret,
|
|
||||||
MAC: crypto.Sha3(ecdheSecret, aesSecret),
|
|
||||||
Token: crypto.Sha3(sharedSecret),
|
|
||||||
}
|
|
||||||
|
|
||||||
// setup sha3 instances for the MACs
|
|
||||||
mac1 := sha3.NewKeccak256()
|
|
||||||
mac1.Write(xor(s.MAC, h.respNonce))
|
|
||||||
mac1.Write(auth)
|
|
||||||
mac2 := sha3.NewKeccak256()
|
|
||||||
mac2.Write(xor(s.MAC, h.initNonce))
|
|
||||||
mac2.Write(authResp)
|
|
||||||
if h.initiator {
|
|
||||||
s.EgressMAC, s.IngressMAC = mac1, mac2
|
|
||||||
} else {
|
|
||||||
s.EgressMAC, s.IngressMAC = mac2, mac1
|
|
||||||
}
|
|
||||||
|
|
||||||
return s, nil
|
|
||||||
}
|
|
||||||
|
|
||||||
func (h *encHandshake) ecdhShared(prv *ecdsa.PrivateKey) ([]byte, error) {
|
|
||||||
return ecies.ImportECDSA(prv).GenerateShared(h.remotePub, sskLen, sskLen)
|
|
||||||
}
|
|
||||||
|
|
||||||
// initiatorEncHandshake negotiates a session token on conn.
|
|
||||||
// it should be called on the dialing side of the connection.
|
|
||||||
//
|
|
||||||
// prv is the local client's private key.
|
|
||||||
// token is the token from a previous session with this node.
|
|
||||||
func initiatorEncHandshake(conn io.ReadWriter, prv *ecdsa.PrivateKey, remoteID discover.NodeID, token []byte) (s secrets, err error) {
|
|
||||||
h, err := newInitiatorHandshake(remoteID)
|
|
||||||
if err != nil {
|
|
||||||
return s, err
|
|
||||||
}
|
|
||||||
auth, err := h.authMsg(prv, token)
|
|
||||||
if err != nil {
|
|
||||||
return s, err
|
|
||||||
}
|
|
||||||
if _, err = conn.Write(auth); err != nil {
|
|
||||||
return s, err
|
|
||||||
}
|
|
||||||
|
|
||||||
response := make([]byte, encAuthRespLen)
|
|
||||||
if _, err = io.ReadFull(conn, response); err != nil {
|
|
||||||
return s, err
|
|
||||||
}
|
|
||||||
if err := h.decodeAuthResp(response, prv); err != nil {
|
|
||||||
return s, err
|
|
||||||
}
|
|
||||||
return h.secrets(auth, response)
|
|
||||||
}
|
|
||||||
|
|
||||||
func newInitiatorHandshake(remoteID discover.NodeID) (*encHandshake, error) {
|
|
||||||
// generate random initiator nonce
|
|
||||||
n := make([]byte, shaLen)
|
|
||||||
if _, err := rand.Read(n); err != nil {
|
|
||||||
return nil, err
|
|
||||||
}
|
|
||||||
// generate random keypair to use for signing
|
|
||||||
randpriv, err := ecies.GenerateKey(rand.Reader, crypto.S256(), nil)
|
|
||||||
if err != nil {
|
|
||||||
return nil, err
|
|
||||||
}
|
|
||||||
rpub, err := remoteID.Pubkey()
|
|
||||||
if err != nil {
|
|
||||||
return nil, fmt.Errorf("bad remoteID: %v", err)
|
|
||||||
}
|
|
||||||
h := &encHandshake{
|
|
||||||
initiator: true,
|
|
||||||
remoteID: remoteID,
|
|
||||||
remotePub: ecies.ImportECDSAPublic(rpub),
|
|
||||||
initNonce: n,
|
|
||||||
randomPrivKey: randpriv,
|
|
||||||
}
|
|
||||||
return h, nil
|
|
||||||
}
|
|
||||||
|
|
||||||
// authMsg creates an encrypted initiator handshake message.
|
|
||||||
func (h *encHandshake) authMsg(prv *ecdsa.PrivateKey, token []byte) ([]byte, error) {
|
|
||||||
var tokenFlag byte
|
|
||||||
if token == nil {
|
|
||||||
// no session token found means we need to generate shared secret.
|
|
||||||
// ecies shared secret is used as initial session token for new peers
|
|
||||||
// generate shared key from prv and remote pubkey
|
|
||||||
var err error
|
|
||||||
if token, err = h.ecdhShared(prv); err != nil {
|
|
||||||
return nil, err
|
|
||||||
}
|
|
||||||
} else {
|
|
||||||
// for known peers, we use stored token from the previous session
|
|
||||||
tokenFlag = 0x01
|
|
||||||
}
|
|
||||||
|
|
||||||
// sign known message:
|
|
||||||
// ecdh-shared-secret^nonce for new peers
|
|
||||||
// token^nonce for old peers
|
|
||||||
signed := xor(token, h.initNonce)
|
|
||||||
signature, err := crypto.Sign(signed, h.randomPrivKey.ExportECDSA())
|
|
||||||
if err != nil {
|
|
||||||
return nil, err
|
|
||||||
}
|
|
||||||
|
|
||||||
// encode auth message
|
|
||||||
// signature || sha3(ecdhe-random-pubk) || pubk || nonce || token-flag
|
|
||||||
msg := make([]byte, authMsgLen)
|
|
||||||
n := copy(msg, signature)
|
|
||||||
n += copy(msg[n:], crypto.Sha3(exportPubkey(&h.randomPrivKey.PublicKey)))
|
|
||||||
n += copy(msg[n:], crypto.FromECDSAPub(&prv.PublicKey)[1:])
|
|
||||||
n += copy(msg[n:], h.initNonce)
|
|
||||||
msg[n] = tokenFlag
|
|
||||||
|
|
||||||
// encrypt auth message using remote-pubk
|
|
||||||
return ecies.Encrypt(rand.Reader, h.remotePub, msg, nil, nil)
|
|
||||||
}
|
|
||||||
|
|
||||||
// decodeAuthResp decode an encrypted authentication response message.
|
|
||||||
func (h *encHandshake) decodeAuthResp(auth []byte, prv *ecdsa.PrivateKey) error {
|
|
||||||
msg, err := crypto.Decrypt(prv, auth)
|
|
||||||
if err != nil {
|
|
||||||
return fmt.Errorf("could not decrypt auth response (%v)", err)
|
|
||||||
}
|
|
||||||
h.respNonce = msg[pubLen : pubLen+shaLen]
|
|
||||||
h.remoteRandomPub, err = importPublicKey(msg[:pubLen])
|
|
||||||
if err != nil {
|
|
||||||
return err
|
|
||||||
}
|
|
||||||
// ignore token flag for now
|
|
||||||
return nil
|
|
||||||
}
|
|
||||||
|
|
||||||
// receiverEncHandshake negotiates a session token on conn.
|
|
||||||
// it should be called on the listening side of the connection.
|
|
||||||
//
|
|
||||||
// prv is the local client's private key.
|
|
||||||
// token is the token from a previous session with this node.
|
|
||||||
func receiverEncHandshake(conn io.ReadWriter, prv *ecdsa.PrivateKey, token []byte) (s secrets, err error) {
|
|
||||||
// read remote auth sent by initiator.
|
|
||||||
auth := make([]byte, encAuthMsgLen)
|
|
||||||
if _, err := io.ReadFull(conn, auth); err != nil {
|
|
||||||
return s, err
|
|
||||||
}
|
|
||||||
h, err := decodeAuthMsg(prv, token, auth)
|
|
||||||
if err != nil {
|
|
||||||
return s, err
|
|
||||||
}
|
|
||||||
|
|
||||||
// send auth response
|
|
||||||
resp, err := h.authResp(prv, token)
|
|
||||||
if err != nil {
|
|
||||||
return s, err
|
|
||||||
}
|
|
||||||
if _, err = conn.Write(resp); err != nil {
|
|
||||||
return s, err
|
|
||||||
}
|
|
||||||
|
|
||||||
return h.secrets(auth, resp)
|
|
||||||
}
|
|
||||||
|
|
||||||
func decodeAuthMsg(prv *ecdsa.PrivateKey, token []byte, auth []byte) (*encHandshake, error) {
|
|
||||||
var err error
|
|
||||||
h := new(encHandshake)
|
|
||||||
// generate random keypair for session
|
|
||||||
h.randomPrivKey, err = ecies.GenerateKey(rand.Reader, crypto.S256(), nil)
|
|
||||||
if err != nil {
|
|
||||||
return nil, err
|
|
||||||
}
|
|
||||||
// generate random nonce
|
|
||||||
h.respNonce = make([]byte, shaLen)
|
|
||||||
if _, err = rand.Read(h.respNonce); err != nil {
|
|
||||||
return nil, err
|
|
||||||
}
|
|
||||||
|
|
||||||
msg, err := crypto.Decrypt(prv, auth)
|
|
||||||
if err != nil {
|
|
||||||
return nil, fmt.Errorf("could not decrypt auth message (%v)", err)
|
|
||||||
}
|
|
||||||
|
|
||||||
// decode message parameters
|
|
||||||
// signature || sha3(ecdhe-random-pubk) || pubk || nonce || token-flag
|
|
||||||
h.initNonce = msg[authMsgLen-shaLen-1 : authMsgLen-1]
|
|
||||||
copy(h.remoteID[:], msg[sigLen+shaLen:sigLen+shaLen+pubLen])
|
|
||||||
rpub, err := h.remoteID.Pubkey()
|
|
||||||
if err != nil {
|
|
||||||
return nil, fmt.Errorf("bad remoteID: %#v", err)
|
|
||||||
}
|
|
||||||
h.remotePub = ecies.ImportECDSAPublic(rpub)
|
|
||||||
|
|
||||||
// recover remote random pubkey from signed message.
|
|
||||||
if token == nil {
|
|
||||||
// TODO: it is an error if the initiator has a token and we don't. check that.
|
|
||||||
|
|
||||||
// no session token means we need to generate shared secret.
|
|
||||||
// ecies shared secret is used as initial session token for new peers.
|
|
||||||
// generate shared key from prv and remote pubkey.
|
|
||||||
if token, err = h.ecdhShared(prv); err != nil {
|
|
||||||
return nil, err
|
|
||||||
}
|
|
||||||
}
|
|
||||||
signedMsg := xor(token, h.initNonce)
|
|
||||||
remoteRandomPub, err := secp256k1.RecoverPubkey(signedMsg, msg[:sigLen])
|
|
||||||
if err != nil {
|
|
||||||
return nil, err
|
|
||||||
}
|
|
||||||
h.remoteRandomPub, _ = importPublicKey(remoteRandomPub)
|
|
||||||
return h, nil
|
|
||||||
}
|
|
||||||
|
|
||||||
// authResp generates the encrypted authentication response message.
|
|
||||||
func (h *encHandshake) authResp(prv *ecdsa.PrivateKey, token []byte) ([]byte, error) {
|
|
||||||
// responder auth message
|
|
||||||
// E(remote-pubk, ecdhe-random-pubk || nonce || 0x0)
|
|
||||||
resp := make([]byte, authRespLen)
|
|
||||||
n := copy(resp, exportPubkey(&h.randomPrivKey.PublicKey))
|
|
||||||
n += copy(resp[n:], h.respNonce)
|
|
||||||
if token == nil {
|
|
||||||
resp[n] = 0
|
|
||||||
} else {
|
|
||||||
resp[n] = 1
|
|
||||||
}
|
|
||||||
// encrypt using remote-pubk
|
|
||||||
return ecies.Encrypt(rand.Reader, h.remotePub, resp, nil, nil)
|
|
||||||
}
|
|
||||||
|
|
||||||
// importPublicKey unmarshals 512 bit public keys.
|
|
||||||
func importPublicKey(pubKey []byte) (*ecies.PublicKey, error) {
|
|
||||||
var pubKey65 []byte
|
|
||||||
switch len(pubKey) {
|
|
||||||
case 64:
|
|
||||||
// add 'uncompressed key' flag
|
|
||||||
pubKey65 = append([]byte{0x04}, pubKey...)
|
|
||||||
case 65:
|
|
||||||
pubKey65 = pubKey
|
|
||||||
default:
|
|
||||||
return nil, fmt.Errorf("invalid public key length %v (expect 64/65)", len(pubKey))
|
|
||||||
}
|
|
||||||
// TODO: fewer pointless conversions
|
|
||||||
return ecies.ImportECDSAPublic(crypto.ToECDSAPub(pubKey65)), nil
|
|
||||||
}
|
|
||||||
|
|
||||||
func exportPubkey(pub *ecies.PublicKey) []byte {
|
|
||||||
if pub == nil {
|
|
||||||
panic("nil pubkey")
|
|
||||||
}
|
|
||||||
return elliptic.Marshal(pub.Curve, pub.X, pub.Y)[1:]
|
|
||||||
}
|
|
||||||
|
|
||||||
func xor(one, other []byte) (xor []byte) {
|
|
||||||
xor = make([]byte, len(one))
|
|
||||||
for i := 0; i < len(one); i++ {
|
|
||||||
xor[i] = one[i] ^ other[i]
|
|
||||||
}
|
|
||||||
return xor
|
|
||||||
}
|
|
||||||
|
|
||||||
func readProtocolHandshake(rw MsgReadWriter, wantID discover.NodeID, our *protoHandshake) (*protoHandshake, error) {
|
|
||||||
msg, err := rw.ReadMsg()
|
|
||||||
if err != nil {
|
|
||||||
return nil, err
|
|
||||||
}
|
|
||||||
if msg.Code == discMsg {
|
|
||||||
// disconnect before protocol handshake is valid according to the
|
|
||||||
// spec and we send it ourself if Server.addPeer fails.
|
|
||||||
var reason [1]DiscReason
|
|
||||||
rlp.Decode(msg.Payload, &reason)
|
|
||||||
return nil, reason[0]
|
|
||||||
}
|
|
||||||
if msg.Code != handshakeMsg {
|
|
||||||
return nil, fmt.Errorf("expected handshake, got %x", msg.Code)
|
|
||||||
}
|
|
||||||
if msg.Size > baseProtocolMaxMsgSize {
|
|
||||||
return nil, fmt.Errorf("message too big (%d > %d)", msg.Size, baseProtocolMaxMsgSize)
|
|
||||||
}
|
|
||||||
var hs protoHandshake
|
|
||||||
if err := msg.Decode(&hs); err != nil {
|
|
||||||
return nil, err
|
|
||||||
}
|
|
||||||
// validate handshake info
|
|
||||||
if hs.Version != our.Version {
|
|
||||||
SendItems(rw, discMsg, DiscIncompatibleVersion)
|
|
||||||
return nil, fmt.Errorf("required version %d, received %d\n", baseProtocolVersion, hs.Version)
|
|
||||||
}
|
|
||||||
if (hs.ID == discover.NodeID{}) {
|
|
||||||
SendItems(rw, discMsg, DiscInvalidIdentity)
|
|
||||||
return nil, errors.New("invalid public key in handshake")
|
|
||||||
}
|
|
||||||
if hs.ID != wantID {
|
|
||||||
SendItems(rw, discMsg, DiscUnexpectedIdentity)
|
|
||||||
return nil, errors.New("handshake node ID does not match encryption handshake")
|
|
||||||
}
|
|
||||||
return &hs, nil
|
|
||||||
}
|
|
|
@ -1,172 +0,0 @@
|
||||||
package p2p
|
|
||||||
|
|
||||||
import (
|
|
||||||
"bytes"
|
|
||||||
"crypto/rand"
|
|
||||||
"fmt"
|
|
||||||
"net"
|
|
||||||
"reflect"
|
|
||||||
"testing"
|
|
||||||
"time"
|
|
||||||
|
|
||||||
"github.com/ethereum/go-ethereum/crypto"
|
|
||||||
"github.com/ethereum/go-ethereum/crypto/ecies"
|
|
||||||
"github.com/ethereum/go-ethereum/p2p/discover"
|
|
||||||
)
|
|
||||||
|
|
||||||
func TestSharedSecret(t *testing.T) {
|
|
||||||
prv0, _ := crypto.GenerateKey() // = ecdsa.GenerateKey(crypto.S256(), rand.Reader)
|
|
||||||
pub0 := &prv0.PublicKey
|
|
||||||
prv1, _ := crypto.GenerateKey()
|
|
||||||
pub1 := &prv1.PublicKey
|
|
||||||
|
|
||||||
ss0, err := ecies.ImportECDSA(prv0).GenerateShared(ecies.ImportECDSAPublic(pub1), sskLen, sskLen)
|
|
||||||
if err != nil {
|
|
||||||
return
|
|
||||||
}
|
|
||||||
ss1, err := ecies.ImportECDSA(prv1).GenerateShared(ecies.ImportECDSAPublic(pub0), sskLen, sskLen)
|
|
||||||
if err != nil {
|
|
||||||
return
|
|
||||||
}
|
|
||||||
t.Logf("Secret:\n%v %x\n%v %x", len(ss0), ss0, len(ss0), ss1)
|
|
||||||
if !bytes.Equal(ss0, ss1) {
|
|
||||||
t.Errorf("dont match :(")
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
func TestEncHandshake(t *testing.T) {
|
|
||||||
for i := 0; i < 20; i++ {
|
|
||||||
start := time.Now()
|
|
||||||
if err := testEncHandshake(nil); err != nil {
|
|
||||||
t.Fatalf("i=%d %v", i, err)
|
|
||||||
}
|
|
||||||
t.Logf("(without token) %d %v\n", i+1, time.Since(start))
|
|
||||||
}
|
|
||||||
|
|
||||||
for i := 0; i < 20; i++ {
|
|
||||||
tok := make([]byte, shaLen)
|
|
||||||
rand.Reader.Read(tok)
|
|
||||||
start := time.Now()
|
|
||||||
if err := testEncHandshake(tok); err != nil {
|
|
||||||
t.Fatalf("i=%d %v", i, err)
|
|
||||||
}
|
|
||||||
t.Logf("(with token) %d %v\n", i+1, time.Since(start))
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
func testEncHandshake(token []byte) error {
|
|
||||||
type result struct {
|
|
||||||
side string
|
|
||||||
s secrets
|
|
||||||
err error
|
|
||||||
}
|
|
||||||
var (
|
|
||||||
prv0, _ = crypto.GenerateKey()
|
|
||||||
prv1, _ = crypto.GenerateKey()
|
|
||||||
rw0, rw1 = net.Pipe()
|
|
||||||
output = make(chan result)
|
|
||||||
)
|
|
||||||
|
|
||||||
go func() {
|
|
||||||
r := result{side: "initiator"}
|
|
||||||
defer func() { output <- r }()
|
|
||||||
|
|
||||||
pub1s := discover.PubkeyID(&prv1.PublicKey)
|
|
||||||
r.s, r.err = initiatorEncHandshake(rw0, prv0, pub1s, token)
|
|
||||||
if r.err != nil {
|
|
||||||
return
|
|
||||||
}
|
|
||||||
id1 := discover.PubkeyID(&prv1.PublicKey)
|
|
||||||
if r.s.RemoteID != id1 {
|
|
||||||
r.err = fmt.Errorf("remote ID mismatch: got %v, want: %v", r.s.RemoteID, id1)
|
|
||||||
}
|
|
||||||
}()
|
|
||||||
go func() {
|
|
||||||
r := result{side: "receiver"}
|
|
||||||
defer func() { output <- r }()
|
|
||||||
|
|
||||||
r.s, r.err = receiverEncHandshake(rw1, prv1, token)
|
|
||||||
if r.err != nil {
|
|
||||||
return
|
|
||||||
}
|
|
||||||
id0 := discover.PubkeyID(&prv0.PublicKey)
|
|
||||||
if r.s.RemoteID != id0 {
|
|
||||||
r.err = fmt.Errorf("remote ID mismatch: got %v, want: %v", r.s.RemoteID, id0)
|
|
||||||
}
|
|
||||||
}()
|
|
||||||
|
|
||||||
// wait for results from both sides
|
|
||||||
r1, r2 := <-output, <-output
|
|
||||||
|
|
||||||
if r1.err != nil {
|
|
||||||
return fmt.Errorf("%s side error: %v", r1.side, r1.err)
|
|
||||||
}
|
|
||||||
if r2.err != nil {
|
|
||||||
return fmt.Errorf("%s side error: %v", r2.side, r2.err)
|
|
||||||
}
|
|
||||||
|
|
||||||
// don't compare remote node IDs
|
|
||||||
r1.s.RemoteID, r2.s.RemoteID = discover.NodeID{}, discover.NodeID{}
|
|
||||||
// flip MACs on one of them so they compare equal
|
|
||||||
r1.s.EgressMAC, r1.s.IngressMAC = r1.s.IngressMAC, r1.s.EgressMAC
|
|
||||||
if !reflect.DeepEqual(r1.s, r2.s) {
|
|
||||||
return fmt.Errorf("secrets mismatch:\n t1: %#v\n t2: %#v", r1.s, r2.s)
|
|
||||||
}
|
|
||||||
return nil
|
|
||||||
}
|
|
||||||
|
|
||||||
func TestSetupConn(t *testing.T) {
|
|
||||||
prv0, _ := crypto.GenerateKey()
|
|
||||||
prv1, _ := crypto.GenerateKey()
|
|
||||||
node0 := &discover.Node{
|
|
||||||
ID: discover.PubkeyID(&prv0.PublicKey),
|
|
||||||
IP: net.IP{1, 2, 3, 4},
|
|
||||||
TCP: 33,
|
|
||||||
}
|
|
||||||
node1 := &discover.Node{
|
|
||||||
ID: discover.PubkeyID(&prv1.PublicKey),
|
|
||||||
IP: net.IP{5, 6, 7, 8},
|
|
||||||
TCP: 44,
|
|
||||||
}
|
|
||||||
hs0 := &protoHandshake{
|
|
||||||
Version: baseProtocolVersion,
|
|
||||||
ID: node0.ID,
|
|
||||||
Caps: []Cap{{"a", 0}, {"b", 2}},
|
|
||||||
}
|
|
||||||
hs1 := &protoHandshake{
|
|
||||||
Version: baseProtocolVersion,
|
|
||||||
ID: node1.ID,
|
|
||||||
Caps: []Cap{{"c", 1}, {"d", 3}},
|
|
||||||
}
|
|
||||||
fd0, fd1 := net.Pipe()
|
|
||||||
|
|
||||||
done := make(chan struct{})
|
|
||||||
keepalways := func(discover.NodeID) bool { return true }
|
|
||||||
go func() {
|
|
||||||
defer close(done)
|
|
||||||
conn0, err := setupConn(fd0, prv0, hs0, node1, keepalways)
|
|
||||||
if err != nil {
|
|
||||||
t.Errorf("outbound side error: %v", err)
|
|
||||||
return
|
|
||||||
}
|
|
||||||
if conn0.ID != node1.ID {
|
|
||||||
t.Errorf("outbound conn id mismatch: got %v, want %v", conn0.ID, node1.ID)
|
|
||||||
}
|
|
||||||
if !reflect.DeepEqual(conn0.Caps, hs1.Caps) {
|
|
||||||
t.Errorf("outbound caps mismatch: got %v, want %v", conn0.Caps, hs1.Caps)
|
|
||||||
}
|
|
||||||
}()
|
|
||||||
|
|
||||||
conn1, err := setupConn(fd1, prv1, hs1, nil, keepalways)
|
|
||||||
if err != nil {
|
|
||||||
t.Fatalf("inbound side error: %v", err)
|
|
||||||
}
|
|
||||||
if conn1.ID != node0.ID {
|
|
||||||
t.Errorf("inbound conn id mismatch: got %v, want %v", conn1.ID, node0.ID)
|
|
||||||
}
|
|
||||||
if !reflect.DeepEqual(conn1.Caps, hs0.Caps) {
|
|
||||||
t.Errorf("inbound caps mismatch: got %v, want %v", conn1.Caps, hs0.Caps)
|
|
||||||
}
|
|
||||||
|
|
||||||
<-done
|
|
||||||
}
|
|
73
p2p/peer.go
73
p2p/peer.go
|
@ -18,7 +18,7 @@ import (
|
||||||
const (
|
const (
|
||||||
baseProtocolVersion = 4
|
baseProtocolVersion = 4
|
||||||
baseProtocolLength = uint64(16)
|
baseProtocolLength = uint64(16)
|
||||||
baseProtocolMaxMsgSize = 10 * 1024 * 1024
|
baseProtocolMaxMsgSize = 2 * 1024
|
||||||
|
|
||||||
pingInterval = 15 * time.Second
|
pingInterval = 15 * time.Second
|
||||||
)
|
)
|
||||||
|
@ -33,9 +33,17 @@ const (
|
||||||
peersMsg = 0x05
|
peersMsg = 0x05
|
||||||
)
|
)
|
||||||
|
|
||||||
|
// protoHandshake is the RLP structure of the protocol handshake.
|
||||||
|
type protoHandshake struct {
|
||||||
|
Version uint64
|
||||||
|
Name string
|
||||||
|
Caps []Cap
|
||||||
|
ListenPort uint64
|
||||||
|
ID discover.NodeID
|
||||||
|
}
|
||||||
|
|
||||||
// Peer represents a connected remote node.
|
// Peer represents a connected remote node.
|
||||||
type Peer struct {
|
type Peer struct {
|
||||||
conn net.Conn
|
|
||||||
rw *conn
|
rw *conn
|
||||||
running map[string]*protoRW
|
running map[string]*protoRW
|
||||||
|
|
||||||
|
@ -48,37 +56,36 @@ type Peer struct {
|
||||||
// NewPeer returns a peer for testing purposes.
|
// NewPeer returns a peer for testing purposes.
|
||||||
func NewPeer(id discover.NodeID, name string, caps []Cap) *Peer {
|
func NewPeer(id discover.NodeID, name string, caps []Cap) *Peer {
|
||||||
pipe, _ := net.Pipe()
|
pipe, _ := net.Pipe()
|
||||||
msgpipe, _ := MsgPipe()
|
conn := &conn{fd: pipe, transport: nil, id: id, caps: caps, name: name}
|
||||||
conn := &conn{msgpipe, &protoHandshake{ID: id, Name: name, Caps: caps}}
|
peer := newPeer(conn, nil)
|
||||||
peer := newPeer(pipe, conn, nil)
|
|
||||||
close(peer.closed) // ensures Disconnect doesn't block
|
close(peer.closed) // ensures Disconnect doesn't block
|
||||||
return peer
|
return peer
|
||||||
}
|
}
|
||||||
|
|
||||||
// ID returns the node's public key.
|
// ID returns the node's public key.
|
||||||
func (p *Peer) ID() discover.NodeID {
|
func (p *Peer) ID() discover.NodeID {
|
||||||
return p.rw.ID
|
return p.rw.id
|
||||||
}
|
}
|
||||||
|
|
||||||
// Name returns the node name that the remote node advertised.
|
// Name returns the node name that the remote node advertised.
|
||||||
func (p *Peer) Name() string {
|
func (p *Peer) Name() string {
|
||||||
return p.rw.Name
|
return p.rw.name
|
||||||
}
|
}
|
||||||
|
|
||||||
// Caps returns the capabilities (supported subprotocols) of the remote peer.
|
// Caps returns the capabilities (supported subprotocols) of the remote peer.
|
||||||
func (p *Peer) Caps() []Cap {
|
func (p *Peer) Caps() []Cap {
|
||||||
// TODO: maybe return copy
|
// TODO: maybe return copy
|
||||||
return p.rw.Caps
|
return p.rw.caps
|
||||||
}
|
}
|
||||||
|
|
||||||
// RemoteAddr returns the remote address of the network connection.
|
// RemoteAddr returns the remote address of the network connection.
|
||||||
func (p *Peer) RemoteAddr() net.Addr {
|
func (p *Peer) RemoteAddr() net.Addr {
|
||||||
return p.conn.RemoteAddr()
|
return p.rw.fd.RemoteAddr()
|
||||||
}
|
}
|
||||||
|
|
||||||
// LocalAddr returns the local address of the network connection.
|
// LocalAddr returns the local address of the network connection.
|
||||||
func (p *Peer) LocalAddr() net.Addr {
|
func (p *Peer) LocalAddr() net.Addr {
|
||||||
return p.conn.LocalAddr()
|
return p.rw.fd.LocalAddr()
|
||||||
}
|
}
|
||||||
|
|
||||||
// Disconnect terminates the peer connection with the given reason.
|
// Disconnect terminates the peer connection with the given reason.
|
||||||
|
@ -92,13 +99,12 @@ func (p *Peer) Disconnect(reason DiscReason) {
|
||||||
|
|
||||||
// String implements fmt.Stringer.
|
// String implements fmt.Stringer.
|
||||||
func (p *Peer) String() string {
|
func (p *Peer) String() string {
|
||||||
return fmt.Sprintf("Peer %.8x %v", p.rw.ID[:], p.RemoteAddr())
|
return fmt.Sprintf("Peer %x %v", p.rw.id[:8], p.RemoteAddr())
|
||||||
}
|
}
|
||||||
|
|
||||||
func newPeer(fd net.Conn, conn *conn, protocols []Protocol) *Peer {
|
func newPeer(conn *conn, protocols []Protocol) *Peer {
|
||||||
protomap := matchProtocols(protocols, conn.Caps, conn)
|
protomap := matchProtocols(protocols, conn.caps, conn)
|
||||||
p := &Peer{
|
p := &Peer{
|
||||||
conn: fd,
|
|
||||||
rw: conn,
|
rw: conn,
|
||||||
running: protomap,
|
running: protomap,
|
||||||
disc: make(chan DiscReason),
|
disc: make(chan DiscReason),
|
||||||
|
@ -117,7 +123,10 @@ func (p *Peer) run() DiscReason {
|
||||||
p.startProtocols()
|
p.startProtocols()
|
||||||
|
|
||||||
// Wait for an error or disconnect.
|
// Wait for an error or disconnect.
|
||||||
var reason DiscReason
|
var (
|
||||||
|
reason DiscReason
|
||||||
|
requested bool
|
||||||
|
)
|
||||||
select {
|
select {
|
||||||
case err := <-readErr:
|
case err := <-readErr:
|
||||||
if r, ok := err.(DiscReason); ok {
|
if r, ok := err.(DiscReason); ok {
|
||||||
|
@ -131,23 +140,19 @@ func (p *Peer) run() DiscReason {
|
||||||
case err := <-p.protoErr:
|
case err := <-p.protoErr:
|
||||||
reason = discReasonForError(err)
|
reason = discReasonForError(err)
|
||||||
case reason = <-p.disc:
|
case reason = <-p.disc:
|
||||||
p.politeDisconnect(reason)
|
requested = true
|
||||||
|
}
|
||||||
|
close(p.closed)
|
||||||
|
p.rw.close(reason)
|
||||||
|
p.wg.Wait()
|
||||||
|
|
||||||
|
if requested {
|
||||||
reason = DiscRequested
|
reason = DiscRequested
|
||||||
}
|
}
|
||||||
|
|
||||||
close(p.closed)
|
|
||||||
p.wg.Wait()
|
|
||||||
glog.V(logger.Debug).Infof("%v: Disconnected: %v\n", p, reason)
|
glog.V(logger.Debug).Infof("%v: Disconnected: %v\n", p, reason)
|
||||||
return reason
|
return reason
|
||||||
}
|
}
|
||||||
|
|
||||||
func (p *Peer) politeDisconnect(reason DiscReason) {
|
|
||||||
if reason != DiscNetworkError {
|
|
||||||
SendItems(p.rw, discMsg, uint(reason))
|
|
||||||
}
|
|
||||||
p.conn.Close()
|
|
||||||
}
|
|
||||||
|
|
||||||
func (p *Peer) pingLoop() {
|
func (p *Peer) pingLoop() {
|
||||||
ping := time.NewTicker(pingInterval)
|
ping := time.NewTicker(pingInterval)
|
||||||
defer p.wg.Done()
|
defer p.wg.Done()
|
||||||
|
@ -254,7 +259,7 @@ func (p *Peer) startProtocols() {
|
||||||
glog.V(logger.Detail).Infof("%v: Protocol %s/%d returned\n", p, proto.Name, proto.Version)
|
glog.V(logger.Detail).Infof("%v: Protocol %s/%d returned\n", p, proto.Name, proto.Version)
|
||||||
err = errors.New("protocol returned")
|
err = errors.New("protocol returned")
|
||||||
} else if err != io.EOF {
|
} else if err != io.EOF {
|
||||||
glog.V(logger.Detail).Infof("%v: Protocol %s/%d error: \n", p, proto.Name, proto.Version, err)
|
glog.V(logger.Detail).Infof("%v: Protocol %s/%d error: %v\n", p, proto.Name, proto.Version, err)
|
||||||
}
|
}
|
||||||
p.protoErr <- err
|
p.protoErr <- err
|
||||||
p.wg.Done()
|
p.wg.Done()
|
||||||
|
@ -273,20 +278,6 @@ func (p *Peer) getProto(code uint64) (*protoRW, error) {
|
||||||
return nil, newPeerError(errInvalidMsgCode, "%d", code)
|
return nil, newPeerError(errInvalidMsgCode, "%d", code)
|
||||||
}
|
}
|
||||||
|
|
||||||
// writeProtoMsg sends the given message on behalf of the given named protocol.
|
|
||||||
// this exists because of Server.Broadcast.
|
|
||||||
func (p *Peer) writeProtoMsg(protoName string, msg Msg) error {
|
|
||||||
proto, ok := p.running[protoName]
|
|
||||||
if !ok {
|
|
||||||
return fmt.Errorf("protocol %s not handled by peer", protoName)
|
|
||||||
}
|
|
||||||
if msg.Code >= proto.Length {
|
|
||||||
return newPeerError(errInvalidMsgCode, "code %x is out of range for protocol %q", msg.Code, protoName)
|
|
||||||
}
|
|
||||||
msg.Code += proto.offset
|
|
||||||
return p.rw.WriteMsg(msg)
|
|
||||||
}
|
|
||||||
|
|
||||||
type protoRW struct {
|
type protoRW struct {
|
||||||
Protocol
|
Protocol
|
||||||
in chan Msg
|
in chan Msg
|
||||||
|
|
|
@ -5,39 +5,17 @@ import (
|
||||||
)
|
)
|
||||||
|
|
||||||
const (
|
const (
|
||||||
errMagicTokenMismatch = iota
|
errInvalidMsgCode = iota
|
||||||
errRead
|
|
||||||
errWrite
|
|
||||||
errMisc
|
|
||||||
errInvalidMsgCode
|
|
||||||
errInvalidMsg
|
errInvalidMsg
|
||||||
errP2PVersionMismatch
|
|
||||||
errPubkeyInvalid
|
|
||||||
errPubkeyForbidden
|
|
||||||
errProtocolBreach
|
|
||||||
errPingTimeout
|
|
||||||
errInvalidNetworkId
|
|
||||||
errInvalidProtocolVersion
|
|
||||||
)
|
)
|
||||||
|
|
||||||
var errorToString = map[int]string{
|
var errorToString = map[int]string{
|
||||||
errMagicTokenMismatch: "magic token mismatch",
|
errInvalidMsgCode: "invalid message code",
|
||||||
errRead: "read error",
|
errInvalidMsg: "invalid message",
|
||||||
errWrite: "write error",
|
|
||||||
errMisc: "misc error",
|
|
||||||
errInvalidMsgCode: "invalid message code",
|
|
||||||
errInvalidMsg: "invalid message",
|
|
||||||
errP2PVersionMismatch: "P2P Version Mismatch",
|
|
||||||
errPubkeyInvalid: "public key invalid",
|
|
||||||
errPubkeyForbidden: "public key forbidden",
|
|
||||||
errProtocolBreach: "protocol Breach",
|
|
||||||
errPingTimeout: "ping timeout",
|
|
||||||
errInvalidNetworkId: "invalid network id",
|
|
||||||
errInvalidProtocolVersion: "invalid protocol version",
|
|
||||||
}
|
}
|
||||||
|
|
||||||
type peerError struct {
|
type peerError struct {
|
||||||
Code int
|
code int
|
||||||
message string
|
message string
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -107,23 +85,13 @@ func discReasonForError(err error) DiscReason {
|
||||||
return reason
|
return reason
|
||||||
}
|
}
|
||||||
peerError, ok := err.(*peerError)
|
peerError, ok := err.(*peerError)
|
||||||
if !ok {
|
if ok {
|
||||||
return DiscSubprotocolError
|
switch peerError.code {
|
||||||
}
|
case errInvalidMsgCode, errInvalidMsg:
|
||||||
switch peerError.Code {
|
return DiscProtocolError
|
||||||
case errP2PVersionMismatch:
|
default:
|
||||||
return DiscIncompatibleVersion
|
return DiscSubprotocolError
|
||||||
case errPubkeyInvalid:
|
}
|
||||||
return DiscInvalidIdentity
|
|
||||||
case errPubkeyForbidden:
|
|
||||||
return DiscUselessPeer
|
|
||||||
case errInvalidMsgCode, errMagicTokenMismatch, errProtocolBreach:
|
|
||||||
return DiscProtocolError
|
|
||||||
case errPingTimeout:
|
|
||||||
return DiscReadTimeout
|
|
||||||
case errRead, errWrite:
|
|
||||||
return DiscNetworkError
|
|
||||||
default:
|
|
||||||
return DiscSubprotocolError
|
|
||||||
}
|
}
|
||||||
|
return DiscSubprotocolError
|
||||||
}
|
}
|
||||||
|
|
|
@ -1,7 +1,6 @@
|
||||||
package p2p
|
package p2p
|
||||||
|
|
||||||
import (
|
import (
|
||||||
"bytes"
|
|
||||||
"errors"
|
"errors"
|
||||||
"fmt"
|
"fmt"
|
||||||
"math/rand"
|
"math/rand"
|
||||||
|
@ -29,24 +28,20 @@ var discard = Protocol{
|
||||||
}
|
}
|
||||||
|
|
||||||
func testPeer(protos []Protocol) (func(), *conn, *Peer, <-chan DiscReason) {
|
func testPeer(protos []Protocol) (func(), *conn, *Peer, <-chan DiscReason) {
|
||||||
fd1, _ := net.Pipe()
|
fd1, fd2 := net.Pipe()
|
||||||
hs1 := &protoHandshake{ID: randomID(), Version: baseProtocolVersion}
|
c1 := &conn{fd: fd1, transport: newTestTransport(randomID(), fd1)}
|
||||||
hs2 := &protoHandshake{ID: randomID(), Version: baseProtocolVersion}
|
c2 := &conn{fd: fd2, transport: newTestTransport(randomID(), fd2)}
|
||||||
for _, p := range protos {
|
for _, p := range protos {
|
||||||
hs1.Caps = append(hs1.Caps, p.cap())
|
c1.caps = append(c1.caps, p.cap())
|
||||||
hs2.Caps = append(hs2.Caps, p.cap())
|
c2.caps = append(c2.caps, p.cap())
|
||||||
}
|
}
|
||||||
|
|
||||||
p1, p2 := MsgPipe()
|
peer := newPeer(c1, protos)
|
||||||
peer := newPeer(fd1, &conn{p1, hs1}, protos)
|
|
||||||
errc := make(chan DiscReason, 1)
|
errc := make(chan DiscReason, 1)
|
||||||
go func() { errc <- peer.run() }()
|
go func() { errc <- peer.run() }()
|
||||||
|
|
||||||
closer := func() {
|
closer := func() { c2.close(errors.New("close func called")) }
|
||||||
p1.Close()
|
return closer, c2, peer, errc
|
||||||
fd1.Close()
|
|
||||||
}
|
|
||||||
return closer, &conn{p2, hs2}, peer, errc
|
|
||||||
}
|
}
|
||||||
|
|
||||||
func TestPeerProtoReadMsg(t *testing.T) {
|
func TestPeerProtoReadMsg(t *testing.T) {
|
||||||
|
@ -107,44 +102,6 @@ func TestPeerProtoEncodeMsg(t *testing.T) {
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
func TestPeerWriteForBroadcast(t *testing.T) {
|
|
||||||
closer, rw, peer, peerErr := testPeer([]Protocol{discard})
|
|
||||||
defer closer()
|
|
||||||
|
|
||||||
emptymsg := func(code uint64) Msg {
|
|
||||||
return Msg{Code: code, Size: 0, Payload: bytes.NewReader(nil)}
|
|
||||||
}
|
|
||||||
|
|
||||||
// test write errors
|
|
||||||
if err := peer.writeProtoMsg("b", emptymsg(3)); err == nil {
|
|
||||||
t.Errorf("expected error for unknown protocol, got nil")
|
|
||||||
}
|
|
||||||
if err := peer.writeProtoMsg("discard", emptymsg(8)); err == nil {
|
|
||||||
t.Errorf("expected error for out-of-range msg code, got nil")
|
|
||||||
} else if perr, ok := err.(*peerError); !ok || perr.Code != errInvalidMsgCode {
|
|
||||||
t.Errorf("wrong error for out-of-range msg code, got %#v", err)
|
|
||||||
}
|
|
||||||
|
|
||||||
// setup for reading the message on the other end
|
|
||||||
read := make(chan struct{})
|
|
||||||
go func() {
|
|
||||||
if err := ExpectMsg(rw, 16, nil); err != nil {
|
|
||||||
t.Error(err)
|
|
||||||
}
|
|
||||||
close(read)
|
|
||||||
}()
|
|
||||||
|
|
||||||
// test successful write
|
|
||||||
if err := peer.writeProtoMsg("discard", emptymsg(0)); err != nil {
|
|
||||||
t.Errorf("expect no error for known protocol: %v", err)
|
|
||||||
}
|
|
||||||
select {
|
|
||||||
case <-read:
|
|
||||||
case err := <-peerErr:
|
|
||||||
t.Fatalf("peer stopped: %v", err)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
func TestPeerPing(t *testing.T) {
|
func TestPeerPing(t *testing.T) {
|
||||||
closer, rw, _, _ := testPeer(nil)
|
closer, rw, _, _ := testPeer(nil)
|
||||||
defer closer()
|
defer closer()
|
||||||
|
|
444
p2p/rlpx.go
444
p2p/rlpx.go
|
@ -4,23 +4,459 @@ import (
|
||||||
"bytes"
|
"bytes"
|
||||||
"crypto/aes"
|
"crypto/aes"
|
||||||
"crypto/cipher"
|
"crypto/cipher"
|
||||||
|
"crypto/ecdsa"
|
||||||
|
"crypto/elliptic"
|
||||||
"crypto/hmac"
|
"crypto/hmac"
|
||||||
|
"crypto/rand"
|
||||||
"errors"
|
"errors"
|
||||||
|
"fmt"
|
||||||
"hash"
|
"hash"
|
||||||
"io"
|
"io"
|
||||||
|
"net"
|
||||||
|
"sync"
|
||||||
|
"time"
|
||||||
|
|
||||||
|
"github.com/ethereum/go-ethereum/crypto"
|
||||||
|
"github.com/ethereum/go-ethereum/crypto/ecies"
|
||||||
|
"github.com/ethereum/go-ethereum/crypto/secp256k1"
|
||||||
|
"github.com/ethereum/go-ethereum/crypto/sha3"
|
||||||
|
"github.com/ethereum/go-ethereum/p2p/discover"
|
||||||
"github.com/ethereum/go-ethereum/rlp"
|
"github.com/ethereum/go-ethereum/rlp"
|
||||||
)
|
)
|
||||||
|
|
||||||
|
const (
|
||||||
|
maxUint24 = ^uint32(0) >> 8
|
||||||
|
|
||||||
|
sskLen = 16 // ecies.MaxSharedKeyLength(pubKey) / 2
|
||||||
|
sigLen = 65 // elliptic S256
|
||||||
|
pubLen = 64 // 512 bit pubkey in uncompressed representation without format byte
|
||||||
|
shaLen = 32 // hash length (for nonce etc)
|
||||||
|
|
||||||
|
authMsgLen = sigLen + shaLen + pubLen + shaLen + 1
|
||||||
|
authRespLen = pubLen + shaLen + 1
|
||||||
|
|
||||||
|
eciesBytes = 65 + 16 + 32
|
||||||
|
encAuthMsgLen = authMsgLen + eciesBytes // size of the final ECIES payload sent as initiator's handshake
|
||||||
|
encAuthRespLen = authRespLen + eciesBytes // size of the final ECIES payload sent as receiver's handshake
|
||||||
|
|
||||||
|
// total timeout for encryption handshake and protocol
|
||||||
|
// handshake in both directions.
|
||||||
|
handshakeTimeout = 5 * time.Second
|
||||||
|
|
||||||
|
// This is the timeout for sending the disconnect reason.
|
||||||
|
// This is shorter than the usual timeout because we don't want
|
||||||
|
// to wait if the connection is known to be bad anyway.
|
||||||
|
discWriteTimeout = 1 * time.Second
|
||||||
|
)
|
||||||
|
|
||||||
|
// rlpx is the transport protocol used by actual (non-test) connections.
|
||||||
|
// It wraps the frame encoder with locks and read/write deadlines.
|
||||||
|
type rlpx struct {
|
||||||
|
fd net.Conn
|
||||||
|
|
||||||
|
rmu, wmu sync.Mutex
|
||||||
|
rw *rlpxFrameRW
|
||||||
|
}
|
||||||
|
|
||||||
|
func newRLPX(fd net.Conn) transport {
|
||||||
|
fd.SetDeadline(time.Now().Add(handshakeTimeout))
|
||||||
|
return &rlpx{fd: fd}
|
||||||
|
}
|
||||||
|
|
||||||
|
func (t *rlpx) ReadMsg() (Msg, error) {
|
||||||
|
t.rmu.Lock()
|
||||||
|
defer t.rmu.Unlock()
|
||||||
|
t.fd.SetReadDeadline(time.Now().Add(frameReadTimeout))
|
||||||
|
return t.rw.ReadMsg()
|
||||||
|
}
|
||||||
|
|
||||||
|
func (t *rlpx) WriteMsg(msg Msg) error {
|
||||||
|
t.wmu.Lock()
|
||||||
|
defer t.wmu.Unlock()
|
||||||
|
t.fd.SetWriteDeadline(time.Now().Add(frameWriteTimeout))
|
||||||
|
return t.rw.WriteMsg(msg)
|
||||||
|
}
|
||||||
|
|
||||||
|
func (t *rlpx) close(err error) {
|
||||||
|
t.wmu.Lock()
|
||||||
|
defer t.wmu.Unlock()
|
||||||
|
// Tell the remote end why we're disconnecting if possible.
|
||||||
|
if t.rw != nil {
|
||||||
|
if r, ok := err.(DiscReason); ok && r != DiscNetworkError {
|
||||||
|
t.fd.SetWriteDeadline(time.Now().Add(discWriteTimeout))
|
||||||
|
SendItems(t.rw, discMsg, r)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
t.fd.Close()
|
||||||
|
}
|
||||||
|
|
||||||
|
// doEncHandshake runs the protocol handshake using authenticated
|
||||||
|
// messages. the protocol handshake is the first authenticated message
|
||||||
|
// and also verifies whether the encryption handshake 'worked' and the
|
||||||
|
// remote side actually provided the right public key.
|
||||||
|
func (t *rlpx) doProtoHandshake(our *protoHandshake) (their *protoHandshake, err error) {
|
||||||
|
// Writing our handshake happens concurrently, we prefer
|
||||||
|
// returning the handshake read error. If the remote side
|
||||||
|
// disconnects us early with a valid reason, we should return it
|
||||||
|
// as the error so it can be tracked elsewhere.
|
||||||
|
werr := make(chan error, 1)
|
||||||
|
go func() { werr <- Send(t.rw, handshakeMsg, our) }()
|
||||||
|
if their, err = readProtocolHandshake(t.rw, our); err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
if err := <-werr; err != nil {
|
||||||
|
return nil, fmt.Errorf("write error: %v", err)
|
||||||
|
}
|
||||||
|
return their, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func readProtocolHandshake(rw MsgReader, our *protoHandshake) (*protoHandshake, error) {
|
||||||
|
msg, err := rw.ReadMsg()
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
if msg.Size > baseProtocolMaxMsgSize {
|
||||||
|
return nil, fmt.Errorf("message too big")
|
||||||
|
}
|
||||||
|
if msg.Code == discMsg {
|
||||||
|
// Disconnect before protocol handshake is valid according to the
|
||||||
|
// spec and we send it ourself if the posthanshake checks fail.
|
||||||
|
// We can't return the reason directly, though, because it is echoed
|
||||||
|
// back otherwise. Wrap it in a string instead.
|
||||||
|
var reason [1]DiscReason
|
||||||
|
rlp.Decode(msg.Payload, &reason)
|
||||||
|
return nil, reason[0]
|
||||||
|
}
|
||||||
|
if msg.Code != handshakeMsg {
|
||||||
|
return nil, fmt.Errorf("expected handshake, got %x", msg.Code)
|
||||||
|
}
|
||||||
|
var hs protoHandshake
|
||||||
|
if err := msg.Decode(&hs); err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
// validate handshake info
|
||||||
|
if hs.Version != our.Version {
|
||||||
|
return nil, DiscIncompatibleVersion
|
||||||
|
}
|
||||||
|
if (hs.ID == discover.NodeID{}) {
|
||||||
|
return nil, DiscInvalidIdentity
|
||||||
|
}
|
||||||
|
return &hs, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (t *rlpx) doEncHandshake(prv *ecdsa.PrivateKey, dial *discover.Node) (discover.NodeID, error) {
|
||||||
|
var (
|
||||||
|
sec secrets
|
||||||
|
err error
|
||||||
|
)
|
||||||
|
if dial == nil {
|
||||||
|
sec, err = receiverEncHandshake(t.fd, prv, nil)
|
||||||
|
} else {
|
||||||
|
sec, err = initiatorEncHandshake(t.fd, prv, dial.ID, nil)
|
||||||
|
}
|
||||||
|
if err != nil {
|
||||||
|
return discover.NodeID{}, err
|
||||||
|
}
|
||||||
|
t.wmu.Lock()
|
||||||
|
t.rw = newRLPXFrameRW(t.fd, sec)
|
||||||
|
t.wmu.Unlock()
|
||||||
|
return sec.RemoteID, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// encHandshake contains the state of the encryption handshake.
|
||||||
|
type encHandshake struct {
|
||||||
|
initiator bool
|
||||||
|
remoteID discover.NodeID
|
||||||
|
|
||||||
|
remotePub *ecies.PublicKey // remote-pubk
|
||||||
|
initNonce, respNonce []byte // nonce
|
||||||
|
randomPrivKey *ecies.PrivateKey // ecdhe-random
|
||||||
|
remoteRandomPub *ecies.PublicKey // ecdhe-random-pubk
|
||||||
|
}
|
||||||
|
|
||||||
|
// secrets represents the connection secrets
|
||||||
|
// which are negotiated during the encryption handshake.
|
||||||
|
type secrets struct {
|
||||||
|
RemoteID discover.NodeID
|
||||||
|
AES, MAC []byte
|
||||||
|
EgressMAC, IngressMAC hash.Hash
|
||||||
|
Token []byte
|
||||||
|
}
|
||||||
|
|
||||||
|
// secrets is called after the handshake is completed.
|
||||||
|
// It extracts the connection secrets from the handshake values.
|
||||||
|
func (h *encHandshake) secrets(auth, authResp []byte) (secrets, error) {
|
||||||
|
ecdheSecret, err := h.randomPrivKey.GenerateShared(h.remoteRandomPub, sskLen, sskLen)
|
||||||
|
if err != nil {
|
||||||
|
return secrets{}, err
|
||||||
|
}
|
||||||
|
|
||||||
|
// derive base secrets from ephemeral key agreement
|
||||||
|
sharedSecret := crypto.Sha3(ecdheSecret, crypto.Sha3(h.respNonce, h.initNonce))
|
||||||
|
aesSecret := crypto.Sha3(ecdheSecret, sharedSecret)
|
||||||
|
s := secrets{
|
||||||
|
RemoteID: h.remoteID,
|
||||||
|
AES: aesSecret,
|
||||||
|
MAC: crypto.Sha3(ecdheSecret, aesSecret),
|
||||||
|
Token: crypto.Sha3(sharedSecret),
|
||||||
|
}
|
||||||
|
|
||||||
|
// setup sha3 instances for the MACs
|
||||||
|
mac1 := sha3.NewKeccak256()
|
||||||
|
mac1.Write(xor(s.MAC, h.respNonce))
|
||||||
|
mac1.Write(auth)
|
||||||
|
mac2 := sha3.NewKeccak256()
|
||||||
|
mac2.Write(xor(s.MAC, h.initNonce))
|
||||||
|
mac2.Write(authResp)
|
||||||
|
if h.initiator {
|
||||||
|
s.EgressMAC, s.IngressMAC = mac1, mac2
|
||||||
|
} else {
|
||||||
|
s.EgressMAC, s.IngressMAC = mac2, mac1
|
||||||
|
}
|
||||||
|
|
||||||
|
return s, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (h *encHandshake) ecdhShared(prv *ecdsa.PrivateKey) ([]byte, error) {
|
||||||
|
return ecies.ImportECDSA(prv).GenerateShared(h.remotePub, sskLen, sskLen)
|
||||||
|
}
|
||||||
|
|
||||||
|
// initiatorEncHandshake negotiates a session token on conn.
|
||||||
|
// it should be called on the dialing side of the connection.
|
||||||
|
//
|
||||||
|
// prv is the local client's private key.
|
||||||
|
// token is the token from a previous session with this node.
|
||||||
|
func initiatorEncHandshake(conn io.ReadWriter, prv *ecdsa.PrivateKey, remoteID discover.NodeID, token []byte) (s secrets, err error) {
|
||||||
|
h, err := newInitiatorHandshake(remoteID)
|
||||||
|
if err != nil {
|
||||||
|
return s, err
|
||||||
|
}
|
||||||
|
auth, err := h.authMsg(prv, token)
|
||||||
|
if err != nil {
|
||||||
|
return s, err
|
||||||
|
}
|
||||||
|
if _, err = conn.Write(auth); err != nil {
|
||||||
|
return s, err
|
||||||
|
}
|
||||||
|
|
||||||
|
response := make([]byte, encAuthRespLen)
|
||||||
|
if _, err = io.ReadFull(conn, response); err != nil {
|
||||||
|
return s, err
|
||||||
|
}
|
||||||
|
if err := h.decodeAuthResp(response, prv); err != nil {
|
||||||
|
return s, err
|
||||||
|
}
|
||||||
|
return h.secrets(auth, response)
|
||||||
|
}
|
||||||
|
|
||||||
|
func newInitiatorHandshake(remoteID discover.NodeID) (*encHandshake, error) {
|
||||||
|
// generate random initiator nonce
|
||||||
|
n := make([]byte, shaLen)
|
||||||
|
if _, err := rand.Read(n); err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
// generate random keypair to use for signing
|
||||||
|
randpriv, err := ecies.GenerateKey(rand.Reader, crypto.S256(), nil)
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
rpub, err := remoteID.Pubkey()
|
||||||
|
if err != nil {
|
||||||
|
return nil, fmt.Errorf("bad remoteID: %v", err)
|
||||||
|
}
|
||||||
|
h := &encHandshake{
|
||||||
|
initiator: true,
|
||||||
|
remoteID: remoteID,
|
||||||
|
remotePub: ecies.ImportECDSAPublic(rpub),
|
||||||
|
initNonce: n,
|
||||||
|
randomPrivKey: randpriv,
|
||||||
|
}
|
||||||
|
return h, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// authMsg creates an encrypted initiator handshake message.
|
||||||
|
func (h *encHandshake) authMsg(prv *ecdsa.PrivateKey, token []byte) ([]byte, error) {
|
||||||
|
var tokenFlag byte
|
||||||
|
if token == nil {
|
||||||
|
// no session token found means we need to generate shared secret.
|
||||||
|
// ecies shared secret is used as initial session token for new peers
|
||||||
|
// generate shared key from prv and remote pubkey
|
||||||
|
var err error
|
||||||
|
if token, err = h.ecdhShared(prv); err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
} else {
|
||||||
|
// for known peers, we use stored token from the previous session
|
||||||
|
tokenFlag = 0x01
|
||||||
|
}
|
||||||
|
|
||||||
|
// sign known message:
|
||||||
|
// ecdh-shared-secret^nonce for new peers
|
||||||
|
// token^nonce for old peers
|
||||||
|
signed := xor(token, h.initNonce)
|
||||||
|
signature, err := crypto.Sign(signed, h.randomPrivKey.ExportECDSA())
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
|
||||||
|
// encode auth message
|
||||||
|
// signature || sha3(ecdhe-random-pubk) || pubk || nonce || token-flag
|
||||||
|
msg := make([]byte, authMsgLen)
|
||||||
|
n := copy(msg, signature)
|
||||||
|
n += copy(msg[n:], crypto.Sha3(exportPubkey(&h.randomPrivKey.PublicKey)))
|
||||||
|
n += copy(msg[n:], crypto.FromECDSAPub(&prv.PublicKey)[1:])
|
||||||
|
n += copy(msg[n:], h.initNonce)
|
||||||
|
msg[n] = tokenFlag
|
||||||
|
|
||||||
|
// encrypt auth message using remote-pubk
|
||||||
|
return ecies.Encrypt(rand.Reader, h.remotePub, msg, nil, nil)
|
||||||
|
}
|
||||||
|
|
||||||
|
// decodeAuthResp decode an encrypted authentication response message.
|
||||||
|
func (h *encHandshake) decodeAuthResp(auth []byte, prv *ecdsa.PrivateKey) error {
|
||||||
|
msg, err := crypto.Decrypt(prv, auth)
|
||||||
|
if err != nil {
|
||||||
|
return fmt.Errorf("could not decrypt auth response (%v)", err)
|
||||||
|
}
|
||||||
|
h.respNonce = msg[pubLen : pubLen+shaLen]
|
||||||
|
h.remoteRandomPub, err = importPublicKey(msg[:pubLen])
|
||||||
|
if err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
// ignore token flag for now
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// receiverEncHandshake negotiates a session token on conn.
|
||||||
|
// it should be called on the listening side of the connection.
|
||||||
|
//
|
||||||
|
// prv is the local client's private key.
|
||||||
|
// token is the token from a previous session with this node.
|
||||||
|
func receiverEncHandshake(conn io.ReadWriter, prv *ecdsa.PrivateKey, token []byte) (s secrets, err error) {
|
||||||
|
// read remote auth sent by initiator.
|
||||||
|
auth := make([]byte, encAuthMsgLen)
|
||||||
|
if _, err := io.ReadFull(conn, auth); err != nil {
|
||||||
|
return s, err
|
||||||
|
}
|
||||||
|
h, err := decodeAuthMsg(prv, token, auth)
|
||||||
|
if err != nil {
|
||||||
|
return s, err
|
||||||
|
}
|
||||||
|
|
||||||
|
// send auth response
|
||||||
|
resp, err := h.authResp(prv, token)
|
||||||
|
if err != nil {
|
||||||
|
return s, err
|
||||||
|
}
|
||||||
|
if _, err = conn.Write(resp); err != nil {
|
||||||
|
return s, err
|
||||||
|
}
|
||||||
|
|
||||||
|
return h.secrets(auth, resp)
|
||||||
|
}
|
||||||
|
|
||||||
|
func decodeAuthMsg(prv *ecdsa.PrivateKey, token []byte, auth []byte) (*encHandshake, error) {
|
||||||
|
var err error
|
||||||
|
h := new(encHandshake)
|
||||||
|
// generate random keypair for session
|
||||||
|
h.randomPrivKey, err = ecies.GenerateKey(rand.Reader, crypto.S256(), nil)
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
// generate random nonce
|
||||||
|
h.respNonce = make([]byte, shaLen)
|
||||||
|
if _, err = rand.Read(h.respNonce); err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
|
||||||
|
msg, err := crypto.Decrypt(prv, auth)
|
||||||
|
if err != nil {
|
||||||
|
return nil, fmt.Errorf("could not decrypt auth message (%v)", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
// decode message parameters
|
||||||
|
// signature || sha3(ecdhe-random-pubk) || pubk || nonce || token-flag
|
||||||
|
h.initNonce = msg[authMsgLen-shaLen-1 : authMsgLen-1]
|
||||||
|
copy(h.remoteID[:], msg[sigLen+shaLen:sigLen+shaLen+pubLen])
|
||||||
|
rpub, err := h.remoteID.Pubkey()
|
||||||
|
if err != nil {
|
||||||
|
return nil, fmt.Errorf("bad remoteID: %#v", err)
|
||||||
|
}
|
||||||
|
h.remotePub = ecies.ImportECDSAPublic(rpub)
|
||||||
|
|
||||||
|
// recover remote random pubkey from signed message.
|
||||||
|
if token == nil {
|
||||||
|
// TODO: it is an error if the initiator has a token and we don't. check that.
|
||||||
|
|
||||||
|
// no session token means we need to generate shared secret.
|
||||||
|
// ecies shared secret is used as initial session token for new peers.
|
||||||
|
// generate shared key from prv and remote pubkey.
|
||||||
|
if token, err = h.ecdhShared(prv); err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
}
|
||||||
|
signedMsg := xor(token, h.initNonce)
|
||||||
|
remoteRandomPub, err := secp256k1.RecoverPubkey(signedMsg, msg[:sigLen])
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
h.remoteRandomPub, _ = importPublicKey(remoteRandomPub)
|
||||||
|
return h, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// authResp generates the encrypted authentication response message.
|
||||||
|
func (h *encHandshake) authResp(prv *ecdsa.PrivateKey, token []byte) ([]byte, error) {
|
||||||
|
// responder auth message
|
||||||
|
// E(remote-pubk, ecdhe-random-pubk || nonce || 0x0)
|
||||||
|
resp := make([]byte, authRespLen)
|
||||||
|
n := copy(resp, exportPubkey(&h.randomPrivKey.PublicKey))
|
||||||
|
n += copy(resp[n:], h.respNonce)
|
||||||
|
if token == nil {
|
||||||
|
resp[n] = 0
|
||||||
|
} else {
|
||||||
|
resp[n] = 1
|
||||||
|
}
|
||||||
|
// encrypt using remote-pubk
|
||||||
|
return ecies.Encrypt(rand.Reader, h.remotePub, resp, nil, nil)
|
||||||
|
}
|
||||||
|
|
||||||
|
// importPublicKey unmarshals 512 bit public keys.
|
||||||
|
func importPublicKey(pubKey []byte) (*ecies.PublicKey, error) {
|
||||||
|
var pubKey65 []byte
|
||||||
|
switch len(pubKey) {
|
||||||
|
case 64:
|
||||||
|
// add 'uncompressed key' flag
|
||||||
|
pubKey65 = append([]byte{0x04}, pubKey...)
|
||||||
|
case 65:
|
||||||
|
pubKey65 = pubKey
|
||||||
|
default:
|
||||||
|
return nil, fmt.Errorf("invalid public key length %v (expect 64/65)", len(pubKey))
|
||||||
|
}
|
||||||
|
// TODO: fewer pointless conversions
|
||||||
|
return ecies.ImportECDSAPublic(crypto.ToECDSAPub(pubKey65)), nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func exportPubkey(pub *ecies.PublicKey) []byte {
|
||||||
|
if pub == nil {
|
||||||
|
panic("nil pubkey")
|
||||||
|
}
|
||||||
|
return elliptic.Marshal(pub.Curve, pub.X, pub.Y)[1:]
|
||||||
|
}
|
||||||
|
|
||||||
|
func xor(one, other []byte) (xor []byte) {
|
||||||
|
xor = make([]byte, len(one))
|
||||||
|
for i := 0; i < len(one); i++ {
|
||||||
|
xor[i] = one[i] ^ other[i]
|
||||||
|
}
|
||||||
|
return xor
|
||||||
|
}
|
||||||
|
|
||||||
var (
|
var (
|
||||||
// this is used in place of actual frame header data.
|
// this is used in place of actual frame header data.
|
||||||
// TODO: replace this when Msg contains the protocol type code.
|
// TODO: replace this when Msg contains the protocol type code.
|
||||||
zeroHeader = []byte{0xC2, 0x80, 0x80}
|
zeroHeader = []byte{0xC2, 0x80, 0x80}
|
||||||
|
|
||||||
// sixteen zero bytes
|
// sixteen zero bytes
|
||||||
zero16 = make([]byte, 16)
|
zero16 = make([]byte, 16)
|
||||||
|
|
||||||
maxUint24 = ^uint32(0) >> 8
|
|
||||||
)
|
)
|
||||||
|
|
||||||
// rlpxFrameRW implements a simplified version of RLPx framing.
|
// rlpxFrameRW implements a simplified version of RLPx framing.
|
||||||
|
@ -38,7 +474,7 @@ type rlpxFrameRW struct {
|
||||||
ingressMAC hash.Hash
|
ingressMAC hash.Hash
|
||||||
}
|
}
|
||||||
|
|
||||||
func newRlpxFrameRW(conn io.ReadWriter, s secrets) *rlpxFrameRW {
|
func newRLPXFrameRW(conn io.ReadWriter, s secrets) *rlpxFrameRW {
|
||||||
macc, err := aes.NewCipher(s.MAC)
|
macc, err := aes.NewCipher(s.MAC)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
panic("invalid MAC secret: " + err.Error())
|
panic("invalid MAC secret: " + err.Error())
|
||||||
|
|
244
p2p/rlpx_test.go
244
p2p/rlpx_test.go
|
@ -3,19 +3,253 @@ package p2p
|
||||||
import (
|
import (
|
||||||
"bytes"
|
"bytes"
|
||||||
"crypto/rand"
|
"crypto/rand"
|
||||||
|
"errors"
|
||||||
|
"fmt"
|
||||||
"io/ioutil"
|
"io/ioutil"
|
||||||
|
"net"
|
||||||
|
"reflect"
|
||||||
"strings"
|
"strings"
|
||||||
|
"sync"
|
||||||
"testing"
|
"testing"
|
||||||
|
"time"
|
||||||
|
|
||||||
|
"github.com/davecgh/go-spew/spew"
|
||||||
"github.com/ethereum/go-ethereum/crypto"
|
"github.com/ethereum/go-ethereum/crypto"
|
||||||
|
"github.com/ethereum/go-ethereum/crypto/ecies"
|
||||||
"github.com/ethereum/go-ethereum/crypto/sha3"
|
"github.com/ethereum/go-ethereum/crypto/sha3"
|
||||||
|
"github.com/ethereum/go-ethereum/p2p/discover"
|
||||||
"github.com/ethereum/go-ethereum/rlp"
|
"github.com/ethereum/go-ethereum/rlp"
|
||||||
)
|
)
|
||||||
|
|
||||||
func TestRlpxFrameFake(t *testing.T) {
|
func TestSharedSecret(t *testing.T) {
|
||||||
|
prv0, _ := crypto.GenerateKey() // = ecdsa.GenerateKey(crypto.S256(), rand.Reader)
|
||||||
|
pub0 := &prv0.PublicKey
|
||||||
|
prv1, _ := crypto.GenerateKey()
|
||||||
|
pub1 := &prv1.PublicKey
|
||||||
|
|
||||||
|
ss0, err := ecies.ImportECDSA(prv0).GenerateShared(ecies.ImportECDSAPublic(pub1), sskLen, sskLen)
|
||||||
|
if err != nil {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
ss1, err := ecies.ImportECDSA(prv1).GenerateShared(ecies.ImportECDSAPublic(pub0), sskLen, sskLen)
|
||||||
|
if err != nil {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
t.Logf("Secret:\n%v %x\n%v %x", len(ss0), ss0, len(ss0), ss1)
|
||||||
|
if !bytes.Equal(ss0, ss1) {
|
||||||
|
t.Errorf("dont match :(")
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestEncHandshake(t *testing.T) {
|
||||||
|
for i := 0; i < 10; i++ {
|
||||||
|
start := time.Now()
|
||||||
|
if err := testEncHandshake(nil); err != nil {
|
||||||
|
t.Fatalf("i=%d %v", i, err)
|
||||||
|
}
|
||||||
|
t.Logf("(without token) %d %v\n", i+1, time.Since(start))
|
||||||
|
}
|
||||||
|
for i := 0; i < 10; i++ {
|
||||||
|
tok := make([]byte, shaLen)
|
||||||
|
rand.Reader.Read(tok)
|
||||||
|
start := time.Now()
|
||||||
|
if err := testEncHandshake(tok); err != nil {
|
||||||
|
t.Fatalf("i=%d %v", i, err)
|
||||||
|
}
|
||||||
|
t.Logf("(with token) %d %v\n", i+1, time.Since(start))
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func testEncHandshake(token []byte) error {
|
||||||
|
type result struct {
|
||||||
|
side string
|
||||||
|
id discover.NodeID
|
||||||
|
err error
|
||||||
|
}
|
||||||
|
var (
|
||||||
|
prv0, _ = crypto.GenerateKey()
|
||||||
|
prv1, _ = crypto.GenerateKey()
|
||||||
|
fd0, fd1 = net.Pipe()
|
||||||
|
c0, c1 = newRLPX(fd0).(*rlpx), newRLPX(fd1).(*rlpx)
|
||||||
|
output = make(chan result)
|
||||||
|
)
|
||||||
|
|
||||||
|
go func() {
|
||||||
|
r := result{side: "initiator"}
|
||||||
|
defer func() { output <- r }()
|
||||||
|
|
||||||
|
dest := &discover.Node{ID: discover.PubkeyID(&prv1.PublicKey)}
|
||||||
|
r.id, r.err = c0.doEncHandshake(prv0, dest)
|
||||||
|
if r.err != nil {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
id1 := discover.PubkeyID(&prv1.PublicKey)
|
||||||
|
if r.id != id1 {
|
||||||
|
r.err = fmt.Errorf("remote ID mismatch: got %v, want: %v", r.id, id1)
|
||||||
|
}
|
||||||
|
}()
|
||||||
|
go func() {
|
||||||
|
r := result{side: "receiver"}
|
||||||
|
defer func() { output <- r }()
|
||||||
|
|
||||||
|
r.id, r.err = c1.doEncHandshake(prv1, nil)
|
||||||
|
if r.err != nil {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
id0 := discover.PubkeyID(&prv0.PublicKey)
|
||||||
|
if r.id != id0 {
|
||||||
|
r.err = fmt.Errorf("remote ID mismatch: got %v, want: %v", r.id, id0)
|
||||||
|
}
|
||||||
|
}()
|
||||||
|
|
||||||
|
// wait for results from both sides
|
||||||
|
r1, r2 := <-output, <-output
|
||||||
|
if r1.err != nil {
|
||||||
|
return fmt.Errorf("%s side error: %v", r1.side, r1.err)
|
||||||
|
}
|
||||||
|
if r2.err != nil {
|
||||||
|
return fmt.Errorf("%s side error: %v", r2.side, r2.err)
|
||||||
|
}
|
||||||
|
|
||||||
|
// compare derived secrets
|
||||||
|
if !reflect.DeepEqual(c0.rw.egressMAC, c1.rw.ingressMAC) {
|
||||||
|
return fmt.Errorf("egress mac mismatch:\n c0.rw: %#v\n c1.rw: %#v", c0.rw.egressMAC, c1.rw.ingressMAC)
|
||||||
|
}
|
||||||
|
if !reflect.DeepEqual(c0.rw.ingressMAC, c1.rw.egressMAC) {
|
||||||
|
return fmt.Errorf("ingress mac mismatch:\n c0.rw: %#v\n c1.rw: %#v", c0.rw.ingressMAC, c1.rw.egressMAC)
|
||||||
|
}
|
||||||
|
if !reflect.DeepEqual(c0.rw.enc, c1.rw.enc) {
|
||||||
|
return fmt.Errorf("enc cipher mismatch:\n c0.rw: %#v\n c1.rw: %#v", c0.rw.enc, c1.rw.enc)
|
||||||
|
}
|
||||||
|
if !reflect.DeepEqual(c0.rw.dec, c1.rw.dec) {
|
||||||
|
return fmt.Errorf("dec cipher mismatch:\n c0.rw: %#v\n c1.rw: %#v", c0.rw.dec, c1.rw.dec)
|
||||||
|
}
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestProtocolHandshake(t *testing.T) {
|
||||||
|
var (
|
||||||
|
prv0, _ = crypto.GenerateKey()
|
||||||
|
node0 = &discover.Node{ID: discover.PubkeyID(&prv0.PublicKey), IP: net.IP{1, 2, 3, 4}, TCP: 33}
|
||||||
|
hs0 = &protoHandshake{Version: 3, ID: node0.ID, Caps: []Cap{{"a", 0}, {"b", 2}}}
|
||||||
|
|
||||||
|
prv1, _ = crypto.GenerateKey()
|
||||||
|
node1 = &discover.Node{ID: discover.PubkeyID(&prv1.PublicKey), IP: net.IP{5, 6, 7, 8}, TCP: 44}
|
||||||
|
hs1 = &protoHandshake{Version: 3, ID: node1.ID, Caps: []Cap{{"c", 1}, {"d", 3}}}
|
||||||
|
|
||||||
|
fd0, fd1 = net.Pipe()
|
||||||
|
wg sync.WaitGroup
|
||||||
|
)
|
||||||
|
|
||||||
|
wg.Add(2)
|
||||||
|
go func() {
|
||||||
|
defer wg.Done()
|
||||||
|
rlpx := newRLPX(fd0)
|
||||||
|
remid, err := rlpx.doEncHandshake(prv0, node1)
|
||||||
|
if err != nil {
|
||||||
|
t.Errorf("dial side enc handshake failed: %v", err)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
if remid != node1.ID {
|
||||||
|
t.Errorf("dial side remote id mismatch: got %v, want %v", remid, node1.ID)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
phs, err := rlpx.doProtoHandshake(hs0)
|
||||||
|
if err != nil {
|
||||||
|
t.Errorf("dial side proto handshake error: %v", err)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
if !reflect.DeepEqual(phs, hs1) {
|
||||||
|
t.Errorf("dial side proto handshake mismatch:\ngot: %s\nwant: %s\n", spew.Sdump(phs), spew.Sdump(hs1))
|
||||||
|
return
|
||||||
|
}
|
||||||
|
rlpx.close(DiscQuitting)
|
||||||
|
}()
|
||||||
|
go func() {
|
||||||
|
defer wg.Done()
|
||||||
|
rlpx := newRLPX(fd1)
|
||||||
|
remid, err := rlpx.doEncHandshake(prv1, nil)
|
||||||
|
if err != nil {
|
||||||
|
t.Errorf("listen side enc handshake failed: %v", err)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
if remid != node0.ID {
|
||||||
|
t.Errorf("listen side remote id mismatch: got %v, want %v", remid, node0.ID)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
phs, err := rlpx.doProtoHandshake(hs1)
|
||||||
|
if err != nil {
|
||||||
|
t.Errorf("listen side proto handshake error: %v", err)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
if !reflect.DeepEqual(phs, hs0) {
|
||||||
|
t.Errorf("listen side proto handshake mismatch:\ngot: %s\nwant: %s\n", spew.Sdump(phs), spew.Sdump(hs0))
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
if err := ExpectMsg(rlpx, discMsg, []DiscReason{DiscQuitting}); err != nil {
|
||||||
|
t.Errorf("error receiving disconnect: %v", err)
|
||||||
|
}
|
||||||
|
}()
|
||||||
|
wg.Wait()
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestProtocolHandshakeErrors(t *testing.T) {
|
||||||
|
our := &protoHandshake{Version: 3, Caps: []Cap{{"foo", 2}, {"bar", 3}}, Name: "quux"}
|
||||||
|
id := randomID()
|
||||||
|
tests := []struct {
|
||||||
|
code uint64
|
||||||
|
msg interface{}
|
||||||
|
err error
|
||||||
|
}{
|
||||||
|
{
|
||||||
|
code: discMsg,
|
||||||
|
msg: []DiscReason{DiscQuitting},
|
||||||
|
err: DiscQuitting,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
code: 0x989898,
|
||||||
|
msg: []byte{1},
|
||||||
|
err: errors.New("expected handshake, got 989898"),
|
||||||
|
},
|
||||||
|
{
|
||||||
|
code: handshakeMsg,
|
||||||
|
msg: make([]byte, baseProtocolMaxMsgSize+2),
|
||||||
|
err: errors.New("message too big"),
|
||||||
|
},
|
||||||
|
{
|
||||||
|
code: handshakeMsg,
|
||||||
|
msg: []byte{1, 2, 3},
|
||||||
|
err: newPeerError(errInvalidMsg, "(code 0) (size 4) rlp: expected input list for p2p.protoHandshake"),
|
||||||
|
},
|
||||||
|
{
|
||||||
|
code: handshakeMsg,
|
||||||
|
msg: &protoHandshake{Version: 9944, ID: id},
|
||||||
|
err: DiscIncompatibleVersion,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
code: handshakeMsg,
|
||||||
|
msg: &protoHandshake{Version: 3},
|
||||||
|
err: DiscInvalidIdentity,
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
for i, test := range tests {
|
||||||
|
p1, p2 := MsgPipe()
|
||||||
|
go Send(p1, test.code, test.msg)
|
||||||
|
_, err := readProtocolHandshake(p2, our)
|
||||||
|
if !reflect.DeepEqual(err, test.err) {
|
||||||
|
t.Errorf("test %d: error mismatch: got %q, want %q", i, err, test.err)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestRLPXFrameFake(t *testing.T) {
|
||||||
buf := new(bytes.Buffer)
|
buf := new(bytes.Buffer)
|
||||||
hash := fakeHash([]byte{1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1})
|
hash := fakeHash([]byte{1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1})
|
||||||
rw := newRlpxFrameRW(buf, secrets{
|
rw := newRLPXFrameRW(buf, secrets{
|
||||||
AES: crypto.Sha3(),
|
AES: crypto.Sha3(),
|
||||||
MAC: crypto.Sha3(),
|
MAC: crypto.Sha3(),
|
||||||
IngressMAC: hash,
|
IngressMAC: hash,
|
||||||
|
@ -66,7 +300,7 @@ func (fakeHash) BlockSize() int { return 0 }
|
||||||
func (h fakeHash) Size() int { return len(h) }
|
func (h fakeHash) Size() int { return len(h) }
|
||||||
func (h fakeHash) Sum(b []byte) []byte { return append(b, h...) }
|
func (h fakeHash) Sum(b []byte) []byte { return append(b, h...) }
|
||||||
|
|
||||||
func TestRlpxFrameRW(t *testing.T) {
|
func TestRLPXFrameRW(t *testing.T) {
|
||||||
var (
|
var (
|
||||||
aesSecret = make([]byte, 16)
|
aesSecret = make([]byte, 16)
|
||||||
macSecret = make([]byte, 16)
|
macSecret = make([]byte, 16)
|
||||||
|
@ -86,7 +320,7 @@ func TestRlpxFrameRW(t *testing.T) {
|
||||||
}
|
}
|
||||||
s1.EgressMAC.Write(egressMACinit)
|
s1.EgressMAC.Write(egressMACinit)
|
||||||
s1.IngressMAC.Write(ingressMACinit)
|
s1.IngressMAC.Write(ingressMACinit)
|
||||||
rw1 := newRlpxFrameRW(conn, s1)
|
rw1 := newRLPXFrameRW(conn, s1)
|
||||||
|
|
||||||
s2 := secrets{
|
s2 := secrets{
|
||||||
AES: aesSecret,
|
AES: aesSecret,
|
||||||
|
@ -96,7 +330,7 @@ func TestRlpxFrameRW(t *testing.T) {
|
||||||
}
|
}
|
||||||
s2.EgressMAC.Write(ingressMACinit)
|
s2.EgressMAC.Write(ingressMACinit)
|
||||||
s2.IngressMAC.Write(egressMACinit)
|
s2.IngressMAC.Write(egressMACinit)
|
||||||
rw2 := newRlpxFrameRW(conn, s2)
|
rw2 := newRLPXFrameRW(conn, s2)
|
||||||
|
|
||||||
// send some messages
|
// send some messages
|
||||||
for i := 0; i < 10; i++ {
|
for i := 0; i < 10; i++ {
|
||||||
|
|
718
p2p/server.go
718
p2p/server.go
|
@ -1,9 +1,7 @@
|
||||||
package p2p
|
package p2p
|
||||||
|
|
||||||
import (
|
import (
|
||||||
"bytes"
|
|
||||||
"crypto/ecdsa"
|
"crypto/ecdsa"
|
||||||
"crypto/rand"
|
|
||||||
"errors"
|
"errors"
|
||||||
"fmt"
|
"fmt"
|
||||||
"net"
|
"net"
|
||||||
|
@ -14,7 +12,6 @@ import (
|
||||||
"github.com/ethereum/go-ethereum/logger/glog"
|
"github.com/ethereum/go-ethereum/logger/glog"
|
||||||
"github.com/ethereum/go-ethereum/p2p/discover"
|
"github.com/ethereum/go-ethereum/p2p/discover"
|
||||||
"github.com/ethereum/go-ethereum/p2p/nat"
|
"github.com/ethereum/go-ethereum/p2p/nat"
|
||||||
"github.com/ethereum/go-ethereum/rlp"
|
|
||||||
)
|
)
|
||||||
|
|
||||||
const (
|
const (
|
||||||
|
@ -26,18 +23,18 @@ const (
|
||||||
maxAcceptConns = 50
|
maxAcceptConns = 50
|
||||||
|
|
||||||
// Maximum number of concurrently dialing outbound connections.
|
// Maximum number of concurrently dialing outbound connections.
|
||||||
maxDialingConns = 10
|
maxActiveDialTasks = 16
|
||||||
|
|
||||||
// total timeout for encryption handshake and protocol
|
// Maximum time allowed for reading a complete message.
|
||||||
// handshake in both directions.
|
// This is effectively the amount of time a connection can be idle.
|
||||||
handshakeTimeout = 5 * time.Second
|
frameReadTimeout = 30 * time.Second
|
||||||
// maximum time allowed for reading a complete message.
|
|
||||||
// this is effectively the amount of time a connection can be idle.
|
// Maximum amount of time allowed for writing a complete message.
|
||||||
frameReadTimeout = 1 * time.Minute
|
|
||||||
// maximum amount of time allowed for writing a complete message.
|
|
||||||
frameWriteTimeout = 5 * time.Second
|
frameWriteTimeout = 5 * time.Second
|
||||||
)
|
)
|
||||||
|
|
||||||
|
var errServerStopped = errors.New("server stopped")
|
||||||
|
|
||||||
var srvjslog = logger.NewJsonLogger()
|
var srvjslog = logger.NewJsonLogger()
|
||||||
|
|
||||||
// Server manages all peer connections.
|
// Server manages all peer connections.
|
||||||
|
@ -105,107 +102,173 @@ type Server struct {
|
||||||
|
|
||||||
// Hooks for testing. These are useful because we can inhibit
|
// Hooks for testing. These are useful because we can inhibit
|
||||||
// the whole protocol stack.
|
// the whole protocol stack.
|
||||||
setupFunc
|
newTransport func(net.Conn) transport
|
||||||
newPeerHook
|
newPeerHook func(*Peer)
|
||||||
|
|
||||||
|
lock sync.Mutex // protects running
|
||||||
|
running bool
|
||||||
|
|
||||||
|
ntab discoverTable
|
||||||
|
listener net.Listener
|
||||||
ourHandshake *protoHandshake
|
ourHandshake *protoHandshake
|
||||||
|
|
||||||
lock sync.RWMutex // protects running, peers and the trust fields
|
// These are for Peers, PeerCount (and nothing else).
|
||||||
running bool
|
peerOp chan peerOpFunc
|
||||||
peers map[discover.NodeID]*Peer
|
peerOpDone chan struct{}
|
||||||
staticNodes map[discover.NodeID]*discover.Node // Map of currently maintained static remote nodes
|
|
||||||
staticDial chan *discover.Node // Dial request channel reserved for the static nodes
|
|
||||||
staticCycle time.Duration // Overrides staticPeerCheckInterval, used for testing
|
|
||||||
trustedNodes map[discover.NodeID]bool // Set of currently trusted remote nodes
|
|
||||||
|
|
||||||
ntab *discover.Table
|
quit chan struct{}
|
||||||
listener net.Listener
|
addstatic chan *discover.Node
|
||||||
|
posthandshake chan *conn
|
||||||
quit chan struct{}
|
addpeer chan *conn
|
||||||
loopWG sync.WaitGroup // {dial,listen,nat}Loop
|
delpeer chan *Peer
|
||||||
peerWG sync.WaitGroup // active peer goroutines
|
loopWG sync.WaitGroup // loop, listenLoop
|
||||||
}
|
}
|
||||||
|
|
||||||
type setupFunc func(net.Conn, *ecdsa.PrivateKey, *protoHandshake, *discover.Node, func(discover.NodeID) bool) (*conn, error)
|
type peerOpFunc func(map[discover.NodeID]*Peer)
|
||||||
type newPeerHook func(*Peer)
|
|
||||||
|
type connFlag int
|
||||||
|
|
||||||
|
const (
|
||||||
|
dynDialedConn connFlag = 1 << iota
|
||||||
|
staticDialedConn
|
||||||
|
inboundConn
|
||||||
|
trustedConn
|
||||||
|
)
|
||||||
|
|
||||||
|
// conn wraps a network connection with information gathered
|
||||||
|
// during the two handshakes.
|
||||||
|
type conn struct {
|
||||||
|
fd net.Conn
|
||||||
|
transport
|
||||||
|
flags connFlag
|
||||||
|
cont chan error // The run loop uses cont to signal errors to setupConn.
|
||||||
|
id discover.NodeID // valid after the encryption handshake
|
||||||
|
caps []Cap // valid after the protocol handshake
|
||||||
|
name string // valid after the protocol handshake
|
||||||
|
}
|
||||||
|
|
||||||
|
type transport interface {
|
||||||
|
// The two handshakes.
|
||||||
|
doEncHandshake(prv *ecdsa.PrivateKey, dialDest *discover.Node) (discover.NodeID, error)
|
||||||
|
doProtoHandshake(our *protoHandshake) (*protoHandshake, error)
|
||||||
|
// The MsgReadWriter can only be used after the encryption
|
||||||
|
// handshake has completed. The code uses conn.id to track this
|
||||||
|
// by setting it to a non-nil value after the encryption handshake.
|
||||||
|
MsgReadWriter
|
||||||
|
// transports must provide Close because we use MsgPipe in some of
|
||||||
|
// the tests. Closing the actual network connection doesn't do
|
||||||
|
// anything in those tests because NsgPipe doesn't use it.
|
||||||
|
close(err error)
|
||||||
|
}
|
||||||
|
|
||||||
|
func (c *conn) String() string {
|
||||||
|
s := c.flags.String() + " conn"
|
||||||
|
if (c.id != discover.NodeID{}) {
|
||||||
|
s += fmt.Sprintf(" %x", c.id[:8])
|
||||||
|
}
|
||||||
|
s += " " + c.fd.RemoteAddr().String()
|
||||||
|
return s
|
||||||
|
}
|
||||||
|
|
||||||
|
func (f connFlag) String() string {
|
||||||
|
s := ""
|
||||||
|
if f&trustedConn != 0 {
|
||||||
|
s += " trusted"
|
||||||
|
}
|
||||||
|
if f&dynDialedConn != 0 {
|
||||||
|
s += " dyn dial"
|
||||||
|
}
|
||||||
|
if f&staticDialedConn != 0 {
|
||||||
|
s += " static dial"
|
||||||
|
}
|
||||||
|
if f&inboundConn != 0 {
|
||||||
|
s += " inbound"
|
||||||
|
}
|
||||||
|
if s != "" {
|
||||||
|
s = s[1:]
|
||||||
|
}
|
||||||
|
return s
|
||||||
|
}
|
||||||
|
|
||||||
|
func (c *conn) is(f connFlag) bool {
|
||||||
|
return c.flags&f != 0
|
||||||
|
}
|
||||||
|
|
||||||
// Peers returns all connected peers.
|
// Peers returns all connected peers.
|
||||||
func (srv *Server) Peers() (peers []*Peer) {
|
func (srv *Server) Peers() []*Peer {
|
||||||
srv.lock.RLock()
|
var ps []*Peer
|
||||||
defer srv.lock.RUnlock()
|
select {
|
||||||
for _, peer := range srv.peers {
|
// Note: We'd love to put this function into a variable but
|
||||||
if peer != nil {
|
// that seems to cause a weird compiler error in some
|
||||||
peers = append(peers, peer)
|
// environments.
|
||||||
|
case srv.peerOp <- func(peers map[discover.NodeID]*Peer) {
|
||||||
|
for _, p := range peers {
|
||||||
|
ps = append(ps, p)
|
||||||
}
|
}
|
||||||
|
}:
|
||||||
|
<-srv.peerOpDone
|
||||||
|
case <-srv.quit:
|
||||||
}
|
}
|
||||||
return
|
return ps
|
||||||
}
|
}
|
||||||
|
|
||||||
// PeerCount returns the number of connected peers.
|
// PeerCount returns the number of connected peers.
|
||||||
func (srv *Server) PeerCount() int {
|
func (srv *Server) PeerCount() int {
|
||||||
srv.lock.RLock()
|
var count int
|
||||||
n := len(srv.peers)
|
select {
|
||||||
srv.lock.RUnlock()
|
case srv.peerOp <- func(ps map[discover.NodeID]*Peer) { count = len(ps) }:
|
||||||
return n
|
<-srv.peerOpDone
|
||||||
|
case <-srv.quit:
|
||||||
|
}
|
||||||
|
return count
|
||||||
}
|
}
|
||||||
|
|
||||||
// AddPeer connects to the given node and maintains the connection until the
|
// AddPeer connects to the given node and maintains the connection until the
|
||||||
// server is shut down. If the connection fails for any reason, the server will
|
// server is shut down. If the connection fails for any reason, the server will
|
||||||
// attempt to reconnect the peer.
|
// attempt to reconnect the peer.
|
||||||
func (srv *Server) AddPeer(node *discover.Node) {
|
func (srv *Server) AddPeer(node *discover.Node) {
|
||||||
|
select {
|
||||||
|
case srv.addstatic <- node:
|
||||||
|
case <-srv.quit:
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// Self returns the local node's endpoint information.
|
||||||
|
func (srv *Server) Self() *discover.Node {
|
||||||
srv.lock.Lock()
|
srv.lock.Lock()
|
||||||
defer srv.lock.Unlock()
|
defer srv.lock.Unlock()
|
||||||
|
if !srv.running {
|
||||||
srv.staticNodes[node.ID] = node
|
return &discover.Node{IP: net.ParseIP("0.0.0.0")}
|
||||||
|
}
|
||||||
|
return srv.ntab.Self()
|
||||||
}
|
}
|
||||||
|
|
||||||
// Broadcast sends an RLP-encoded message to all connected peers.
|
// Stop terminates the server and all active peer connections.
|
||||||
// This method is deprecated and will be removed later.
|
// It blocks until all active connections have been closed.
|
||||||
func (srv *Server) Broadcast(protocol string, code uint64, data interface{}) error {
|
func (srv *Server) Stop() {
|
||||||
return srv.BroadcastLimited(protocol, code, func(i float64) float64 { return i }, data)
|
srv.lock.Lock()
|
||||||
}
|
defer srv.lock.Unlock()
|
||||||
|
if !srv.running {
|
||||||
// BroadcastsRange an RLP-encoded message to a random set of peers using the limit function to limit the amount
|
return
|
||||||
// of peers.
|
|
||||||
func (srv *Server) BroadcastLimited(protocol string, code uint64, limit func(float64) float64, data interface{}) error {
|
|
||||||
var payload []byte
|
|
||||||
if data != nil {
|
|
||||||
var err error
|
|
||||||
payload, err = rlp.EncodeToBytes(data)
|
|
||||||
if err != nil {
|
|
||||||
return err
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
srv.lock.RLock()
|
srv.running = false
|
||||||
defer srv.lock.RUnlock()
|
if srv.listener != nil {
|
||||||
|
// this unblocks listener Accept
|
||||||
i, max := 0, int(limit(float64(len(srv.peers))))
|
srv.listener.Close()
|
||||||
for _, peer := range srv.peers {
|
|
||||||
if i >= max {
|
|
||||||
break
|
|
||||||
}
|
|
||||||
|
|
||||||
if peer != nil {
|
|
||||||
var msg = Msg{Code: code}
|
|
||||||
if data != nil {
|
|
||||||
msg.Payload = bytes.NewReader(payload)
|
|
||||||
msg.Size = uint32(len(payload))
|
|
||||||
}
|
|
||||||
peer.writeProtoMsg(protocol, msg)
|
|
||||||
i++
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
return nil
|
close(srv.quit)
|
||||||
|
srv.loopWG.Wait()
|
||||||
}
|
}
|
||||||
|
|
||||||
// Start starts running the server.
|
// Start starts running the server.
|
||||||
// Servers can be re-used and started again after stopping.
|
// Servers can not be re-used after stopping.
|
||||||
func (srv *Server) Start() (err error) {
|
func (srv *Server) Start() (err error) {
|
||||||
srv.lock.Lock()
|
srv.lock.Lock()
|
||||||
defer srv.lock.Unlock()
|
defer srv.lock.Unlock()
|
||||||
if srv.running {
|
if srv.running {
|
||||||
return errors.New("server already running")
|
return errors.New("server already running")
|
||||||
}
|
}
|
||||||
|
srv.running = true
|
||||||
glog.V(logger.Info).Infoln("Starting Server")
|
glog.V(logger.Info).Infoln("Starting Server")
|
||||||
|
|
||||||
// static fields
|
// static fields
|
||||||
|
@ -215,23 +278,19 @@ func (srv *Server) Start() (err error) {
|
||||||
if srv.MaxPeers <= 0 {
|
if srv.MaxPeers <= 0 {
|
||||||
return fmt.Errorf("Server.MaxPeers must be > 0")
|
return fmt.Errorf("Server.MaxPeers must be > 0")
|
||||||
}
|
}
|
||||||
|
if srv.newTransport == nil {
|
||||||
|
srv.newTransport = newRLPX
|
||||||
|
}
|
||||||
|
if srv.Dialer == nil {
|
||||||
|
srv.Dialer = &net.Dialer{Timeout: defaultDialTimeout}
|
||||||
|
}
|
||||||
srv.quit = make(chan struct{})
|
srv.quit = make(chan struct{})
|
||||||
srv.peers = make(map[discover.NodeID]*Peer)
|
srv.addpeer = make(chan *conn)
|
||||||
|
srv.delpeer = make(chan *Peer)
|
||||||
// Create the current trust maps, and the associated dialing channel
|
srv.posthandshake = make(chan *conn)
|
||||||
srv.trustedNodes = make(map[discover.NodeID]bool)
|
srv.addstatic = make(chan *discover.Node)
|
||||||
for _, node := range srv.TrustedNodes {
|
srv.peerOp = make(chan peerOpFunc)
|
||||||
srv.trustedNodes[node.ID] = true
|
srv.peerOpDone = make(chan struct{})
|
||||||
}
|
|
||||||
srv.staticNodes = make(map[discover.NodeID]*discover.Node)
|
|
||||||
for _, node := range srv.StaticNodes {
|
|
||||||
srv.staticNodes[node.ID] = node
|
|
||||||
}
|
|
||||||
srv.staticDial = make(chan *discover.Node)
|
|
||||||
|
|
||||||
if srv.setupFunc == nil {
|
|
||||||
srv.setupFunc = setupConn
|
|
||||||
}
|
|
||||||
|
|
||||||
// node table
|
// node table
|
||||||
ntab, err := discover.ListenUDP(srv.PrivateKey, srv.ListenAddr, srv.NAT, srv.NodeDatabase)
|
ntab, err := discover.ListenUDP(srv.PrivateKey, srv.ListenAddr, srv.NAT, srv.NodeDatabase)
|
||||||
|
@ -239,37 +298,31 @@ func (srv *Server) Start() (err error) {
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
srv.ntab = ntab
|
srv.ntab = ntab
|
||||||
|
dialer := newDialState(srv.StaticNodes, srv.ntab, srv.MaxPeers/2)
|
||||||
|
|
||||||
// handshake
|
// handshake
|
||||||
srv.ourHandshake = &protoHandshake{Version: baseProtocolVersion, Name: srv.Name, ID: ntab.Self().ID}
|
srv.ourHandshake = &protoHandshake{Version: baseProtocolVersion, Name: srv.Name, ID: ntab.Self().ID}
|
||||||
for _, p := range srv.Protocols {
|
for _, p := range srv.Protocols {
|
||||||
srv.ourHandshake.Caps = append(srv.ourHandshake.Caps, p.cap())
|
srv.ourHandshake.Caps = append(srv.ourHandshake.Caps, p.cap())
|
||||||
}
|
}
|
||||||
|
|
||||||
// listen/dial
|
// listen/dial
|
||||||
if srv.ListenAddr != "" {
|
if srv.ListenAddr != "" {
|
||||||
if err := srv.startListening(); err != nil {
|
if err := srv.startListening(); err != nil {
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
if srv.Dialer == nil {
|
|
||||||
srv.Dialer = &net.Dialer{Timeout: defaultDialTimeout}
|
|
||||||
}
|
|
||||||
if !srv.NoDial {
|
|
||||||
srv.loopWG.Add(1)
|
|
||||||
go srv.dialLoop()
|
|
||||||
}
|
|
||||||
if srv.NoDial && srv.ListenAddr == "" {
|
if srv.NoDial && srv.ListenAddr == "" {
|
||||||
glog.V(logger.Warn).Infoln("I will be kind-of useless, neither dialing nor listening.")
|
glog.V(logger.Warn).Infoln("I will be kind-of useless, neither dialing nor listening.")
|
||||||
}
|
}
|
||||||
// maintain the static peers
|
|
||||||
go srv.staticNodesLoop()
|
|
||||||
|
|
||||||
|
srv.loopWG.Add(1)
|
||||||
|
go srv.run(dialer)
|
||||||
srv.running = true
|
srv.running = true
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func (srv *Server) startListening() error {
|
func (srv *Server) startListening() error {
|
||||||
|
// Launch the TCP listener.
|
||||||
listener, err := net.Listen("tcp", srv.ListenAddr)
|
listener, err := net.Listen("tcp", srv.ListenAddr)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return err
|
return err
|
||||||
|
@ -279,6 +332,7 @@ func (srv *Server) startListening() error {
|
||||||
srv.listener = listener
|
srv.listener = listener
|
||||||
srv.loopWG.Add(1)
|
srv.loopWG.Add(1)
|
||||||
go srv.listenLoop()
|
go srv.listenLoop()
|
||||||
|
// Map the TCP listening port if NAT is configured.
|
||||||
if !laddr.IP.IsLoopback() && srv.NAT != nil {
|
if !laddr.IP.IsLoopback() && srv.NAT != nil {
|
||||||
srv.loopWG.Add(1)
|
srv.loopWG.Add(1)
|
||||||
go func() {
|
go func() {
|
||||||
|
@ -289,50 +343,164 @@ func (srv *Server) startListening() error {
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
// Stop terminates the server and all active peer connections.
|
type dialer interface {
|
||||||
// It blocks until all active connections have been closed.
|
newTasks(running int, peers map[discover.NodeID]*Peer, now time.Time) []task
|
||||||
func (srv *Server) Stop() {
|
taskDone(task, time.Time)
|
||||||
srv.lock.Lock()
|
addStatic(*discover.Node)
|
||||||
if !srv.running {
|
}
|
||||||
srv.lock.Unlock()
|
|
||||||
return
|
|
||||||
}
|
|
||||||
srv.running = false
|
|
||||||
srv.lock.Unlock()
|
|
||||||
|
|
||||||
glog.V(logger.Info).Infoln("Stopping Server")
|
func (srv *Server) run(dialstate dialer) {
|
||||||
|
defer srv.loopWG.Done()
|
||||||
|
var (
|
||||||
|
peers = make(map[discover.NodeID]*Peer)
|
||||||
|
trusted = make(map[discover.NodeID]bool, len(srv.TrustedNodes))
|
||||||
|
|
||||||
|
tasks []task
|
||||||
|
pendingTasks []task
|
||||||
|
taskdone = make(chan task, maxActiveDialTasks)
|
||||||
|
)
|
||||||
|
// Put trusted nodes into a map to speed up checks.
|
||||||
|
// Trusted peers are loaded on startup and cannot be
|
||||||
|
// modified while the server is running.
|
||||||
|
for _, n := range srv.TrustedNodes {
|
||||||
|
trusted[n.ID] = true
|
||||||
|
}
|
||||||
|
|
||||||
|
// Some task list helpers.
|
||||||
|
delTask := func(t task) {
|
||||||
|
for i := range tasks {
|
||||||
|
if tasks[i] == t {
|
||||||
|
tasks = append(tasks[:i], tasks[i+1:]...)
|
||||||
|
break
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
scheduleTasks := func(new []task) {
|
||||||
|
pt := append(pendingTasks, new...)
|
||||||
|
start := maxActiveDialTasks - len(tasks)
|
||||||
|
if len(pt) < start {
|
||||||
|
start = len(pt)
|
||||||
|
}
|
||||||
|
if start > 0 {
|
||||||
|
tasks = append(tasks, pt[:start]...)
|
||||||
|
for _, t := range pt[:start] {
|
||||||
|
t := t
|
||||||
|
glog.V(logger.Detail).Infoln("new task:", t)
|
||||||
|
go func() { t.Do(srv); taskdone <- t }()
|
||||||
|
}
|
||||||
|
copy(pt, pt[start:])
|
||||||
|
pendingTasks = pt[:len(pt)-start]
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
running:
|
||||||
|
for {
|
||||||
|
// Query the dialer for new tasks and launch them.
|
||||||
|
now := time.Now()
|
||||||
|
nt := dialstate.newTasks(len(pendingTasks)+len(tasks), peers, now)
|
||||||
|
scheduleTasks(nt)
|
||||||
|
|
||||||
|
select {
|
||||||
|
case <-srv.quit:
|
||||||
|
// The server was stopped. Run the cleanup logic.
|
||||||
|
glog.V(logger.Detail).Infoln("<-quit: spinning down")
|
||||||
|
break running
|
||||||
|
case n := <-srv.addstatic:
|
||||||
|
// This channel is used by AddPeer to add to the
|
||||||
|
// ephemeral static peer list. Add it to the dialer,
|
||||||
|
// it will keep the node connected.
|
||||||
|
glog.V(logger.Detail).Infoln("<-addstatic:", n)
|
||||||
|
dialstate.addStatic(n)
|
||||||
|
case op := <-srv.peerOp:
|
||||||
|
// This channel is used by Peers and PeerCount.
|
||||||
|
op(peers)
|
||||||
|
srv.peerOpDone <- struct{}{}
|
||||||
|
case t := <-taskdone:
|
||||||
|
// A task got done. Tell dialstate about it so it
|
||||||
|
// can update its state and remove it from the active
|
||||||
|
// tasks list.
|
||||||
|
glog.V(logger.Detail).Infoln("<-taskdone:", t)
|
||||||
|
dialstate.taskDone(t, now)
|
||||||
|
delTask(t)
|
||||||
|
case c := <-srv.posthandshake:
|
||||||
|
// A connection has passed the encryption handshake so
|
||||||
|
// the remote identity is known (but hasn't been verified yet).
|
||||||
|
if trusted[c.id] {
|
||||||
|
// Ensure that the trusted flag is set before checking against MaxPeers.
|
||||||
|
c.flags |= trustedConn
|
||||||
|
}
|
||||||
|
glog.V(logger.Detail).Infoln("<-posthandshake:", c)
|
||||||
|
// TODO: track in-progress inbound node IDs (pre-Peer) to avoid dialing them.
|
||||||
|
c.cont <- srv.encHandshakeChecks(peers, c)
|
||||||
|
case c := <-srv.addpeer:
|
||||||
|
// At this point the connection is past the protocol handshake.
|
||||||
|
// Its capabilities are known and the remote identity is verified.
|
||||||
|
glog.V(logger.Detail).Infoln("<-addpeer:", c)
|
||||||
|
err := srv.protoHandshakeChecks(peers, c)
|
||||||
|
if err != nil {
|
||||||
|
glog.V(logger.Detail).Infof("Not adding %v as peer: %v", c, err)
|
||||||
|
} else {
|
||||||
|
// The handshakes are done and it passed all checks.
|
||||||
|
p := newPeer(c, srv.Protocols)
|
||||||
|
peers[c.id] = p
|
||||||
|
go srv.runPeer(p)
|
||||||
|
}
|
||||||
|
// The dialer logic relies on the assumption that
|
||||||
|
// dial tasks complete after the peer has been added or
|
||||||
|
// discarded. Unblock the task last.
|
||||||
|
c.cont <- err
|
||||||
|
case p := <-srv.delpeer:
|
||||||
|
// A peer disconnected.
|
||||||
|
glog.V(logger.Detail).Infoln("<-delpeer:", p)
|
||||||
|
delete(peers, p.ID())
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// Terminate discovery. If there is a running lookup it will terminate soon.
|
||||||
srv.ntab.Close()
|
srv.ntab.Close()
|
||||||
if srv.listener != nil {
|
// Disconnect all peers.
|
||||||
// this unblocks listener Accept
|
for _, p := range peers {
|
||||||
srv.listener.Close()
|
p.Disconnect(DiscQuitting)
|
||||||
}
|
}
|
||||||
close(srv.quit)
|
// Wait for peers to shut down. Pending connections and tasks are
|
||||||
srv.loopWG.Wait()
|
// not handled here and will terminate soon-ish because srv.quit
|
||||||
|
// is closed.
|
||||||
// No new peers can be added at this point because dialLoop and
|
glog.V(logger.Detail).Infof("ignoring %d pending tasks at spindown", len(tasks))
|
||||||
// listenLoop are down. It is safe to call peerWG.Wait because
|
for len(peers) > 0 {
|
||||||
// peerWG.Add is not called outside of those loops.
|
p := <-srv.delpeer
|
||||||
srv.lock.Lock()
|
glog.V(logger.Detail).Infoln("<-delpeer (spindown):", p)
|
||||||
for _, peer := range srv.peers {
|
delete(peers, p.ID())
|
||||||
peer.Disconnect(DiscQuitting)
|
|
||||||
}
|
}
|
||||||
srv.lock.Unlock()
|
|
||||||
srv.peerWG.Wait()
|
|
||||||
}
|
}
|
||||||
|
|
||||||
// Self returns the local node's endpoint information.
|
func (srv *Server) protoHandshakeChecks(peers map[discover.NodeID]*Peer, c *conn) error {
|
||||||
func (srv *Server) Self() *discover.Node {
|
// Drop connections with no matching protocols.
|
||||||
srv.lock.RLock()
|
if len(srv.Protocols) > 0 && countMatchingProtocols(srv.Protocols, c.caps) == 0 {
|
||||||
defer srv.lock.RUnlock()
|
return DiscUselessPeer
|
||||||
if !srv.running {
|
|
||||||
return &discover.Node{IP: net.ParseIP("0.0.0.0")}
|
|
||||||
}
|
}
|
||||||
return srv.ntab.Self()
|
// Repeat the encryption handshake checks because the
|
||||||
|
// peer set might have changed between the handshakes.
|
||||||
|
return srv.encHandshakeChecks(peers, c)
|
||||||
}
|
}
|
||||||
|
|
||||||
// main loop for adding connections via listening
|
func (srv *Server) encHandshakeChecks(peers map[discover.NodeID]*Peer, c *conn) error {
|
||||||
|
switch {
|
||||||
|
case !c.is(trustedConn|staticDialedConn) && len(peers) >= srv.MaxPeers:
|
||||||
|
return DiscTooManyPeers
|
||||||
|
case peers[c.id] != nil:
|
||||||
|
return DiscAlreadyConnected
|
||||||
|
case c.id == srv.ntab.Self().ID:
|
||||||
|
return DiscSelf
|
||||||
|
default:
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// listenLoop runs in its own goroutine and accepts
|
||||||
|
// inbound connections.
|
||||||
func (srv *Server) listenLoop() {
|
func (srv *Server) listenLoop() {
|
||||||
defer srv.loopWG.Done()
|
defer srv.loopWG.Done()
|
||||||
|
glog.V(logger.Info).Infoln("Listening on", srv.listener.Addr())
|
||||||
|
|
||||||
// This channel acts as a semaphore limiting
|
// This channel acts as a semaphore limiting
|
||||||
// active inbound connections that are lingering pre-handshake.
|
// active inbound connections that are lingering pre-handshake.
|
||||||
|
@ -346,204 +514,92 @@ func (srv *Server) listenLoop() {
|
||||||
slots <- struct{}{}
|
slots <- struct{}{}
|
||||||
}
|
}
|
||||||
|
|
||||||
glog.V(logger.Info).Infoln("Listening on", srv.listener.Addr())
|
|
||||||
for {
|
for {
|
||||||
<-slots
|
<-slots
|
||||||
conn, err := srv.listener.Accept()
|
fd, err := srv.listener.Accept()
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
glog.V(logger.Debug).Infof("Accepted conn %v\n", conn.RemoteAddr())
|
glog.V(logger.Debug).Infof("Accepted conn %v\n", fd.RemoteAddr())
|
||||||
srv.peerWG.Add(1)
|
|
||||||
go func() {
|
go func() {
|
||||||
srv.startPeer(conn, nil)
|
srv.setupConn(fd, inboundConn, nil)
|
||||||
slots <- struct{}{}
|
slots <- struct{}{}
|
||||||
}()
|
}()
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
// staticNodesLoop is responsible for periodically checking that static
|
// setupConn runs the handshakes and attempts to add the connection
|
||||||
// connections are actually live, and requests dialing if not.
|
// as a peer. It returns when the connection has been added as a peer
|
||||||
func (srv *Server) staticNodesLoop() {
|
// or the handshakes have failed.
|
||||||
// Create a default maintenance ticker, but override it requested
|
func (srv *Server) setupConn(fd net.Conn, flags connFlag, dialDest *discover.Node) {
|
||||||
cycle := staticPeerCheckInterval
|
// Prevent leftover pending conns from entering the handshake.
|
||||||
if srv.staticCycle != 0 {
|
srv.lock.Lock()
|
||||||
cycle = srv.staticCycle
|
running := srv.running
|
||||||
|
srv.lock.Unlock()
|
||||||
|
c := &conn{fd: fd, transport: srv.newTransport(fd), flags: flags, cont: make(chan error)}
|
||||||
|
if !running {
|
||||||
|
c.close(errServerStopped)
|
||||||
|
return
|
||||||
}
|
}
|
||||||
tick := time.NewTicker(cycle)
|
// Run the encryption handshake.
|
||||||
|
var err error
|
||||||
for {
|
if c.id, err = c.doEncHandshake(srv.PrivateKey, dialDest); err != nil {
|
||||||
select {
|
glog.V(logger.Debug).Infof("%v faild enc handshake: %v", c, err)
|
||||||
case <-srv.quit:
|
c.close(err)
|
||||||
return
|
return
|
||||||
|
|
||||||
case <-tick.C:
|
|
||||||
// Collect all the non-connected static nodes
|
|
||||||
needed := []*discover.Node{}
|
|
||||||
srv.lock.RLock()
|
|
||||||
for id, node := range srv.staticNodes {
|
|
||||||
if _, ok := srv.peers[id]; !ok {
|
|
||||||
needed = append(needed, node)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
srv.lock.RUnlock()
|
|
||||||
|
|
||||||
// Try to dial each of them (don't hang if server terminates)
|
|
||||||
for _, node := range needed {
|
|
||||||
glog.V(logger.Debug).Infof("Dialing static peer %v", node)
|
|
||||||
select {
|
|
||||||
case srv.staticDial <- node:
|
|
||||||
case <-srv.quit:
|
|
||||||
return
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
}
|
// For dialed connections, check that the remote public key matches.
|
||||||
|
if dialDest != nil && c.id != dialDest.ID {
|
||||||
func (srv *Server) dialLoop() {
|
c.close(DiscUnexpectedIdentity)
|
||||||
var (
|
glog.V(logger.Debug).Infof("%v dialed identity mismatch, want %x", c, dialDest.ID[:8])
|
||||||
dialed = make(chan *discover.Node)
|
return
|
||||||
dialing = make(map[discover.NodeID]bool)
|
|
||||||
findresults = make(chan []*discover.Node)
|
|
||||||
refresh = time.NewTimer(0)
|
|
||||||
)
|
|
||||||
defer srv.loopWG.Done()
|
|
||||||
defer refresh.Stop()
|
|
||||||
|
|
||||||
// Limit the number of concurrent dials
|
|
||||||
tokens := maxDialingConns
|
|
||||||
if srv.MaxPendingPeers > 0 {
|
|
||||||
tokens = srv.MaxPendingPeers
|
|
||||||
}
|
}
|
||||||
slots := make(chan struct{}, tokens)
|
if err := srv.checkpoint(c, srv.posthandshake); err != nil {
|
||||||
for i := 0; i < tokens; i++ {
|
glog.V(logger.Debug).Infof("%v failed checkpoint posthandshake: %v", c, err)
|
||||||
slots <- struct{}{}
|
c.close(err)
|
||||||
|
return
|
||||||
}
|
}
|
||||||
dial := func(dest *discover.Node) {
|
// Run the protocol handshake
|
||||||
// Don't dial nodes that would fail the checks in addPeer.
|
phs, err := c.doProtoHandshake(srv.ourHandshake)
|
||||||
// This is important because the connection handshake is a lot
|
|
||||||
// of work and we'd rather avoid doing that work for peers
|
|
||||||
// that can't be added.
|
|
||||||
srv.lock.RLock()
|
|
||||||
ok, _ := srv.checkPeer(dest.ID)
|
|
||||||
srv.lock.RUnlock()
|
|
||||||
if !ok || dialing[dest.ID] {
|
|
||||||
return
|
|
||||||
}
|
|
||||||
// Request a dial slot to prevent CPU exhaustion
|
|
||||||
<-slots
|
|
||||||
|
|
||||||
dialing[dest.ID] = true
|
|
||||||
srv.peerWG.Add(1)
|
|
||||||
go func() {
|
|
||||||
srv.dialNode(dest)
|
|
||||||
slots <- struct{}{}
|
|
||||||
dialed <- dest
|
|
||||||
}()
|
|
||||||
}
|
|
||||||
|
|
||||||
srv.ntab.Bootstrap(srv.BootstrapNodes)
|
|
||||||
for {
|
|
||||||
select {
|
|
||||||
case <-refresh.C:
|
|
||||||
// Grab some nodes to connect to if we're not at capacity.
|
|
||||||
srv.lock.RLock()
|
|
||||||
needpeers := len(srv.peers) < srv.MaxPeers/2
|
|
||||||
srv.lock.RUnlock()
|
|
||||||
if needpeers {
|
|
||||||
go func() {
|
|
||||||
var target discover.NodeID
|
|
||||||
rand.Read(target[:])
|
|
||||||
findresults <- srv.ntab.Lookup(target)
|
|
||||||
}()
|
|
||||||
} else {
|
|
||||||
// Make sure we check again if the peer count falls
|
|
||||||
// below MaxPeers.
|
|
||||||
refresh.Reset(refreshPeersInterval)
|
|
||||||
}
|
|
||||||
case dest := <-srv.staticDial:
|
|
||||||
dial(dest)
|
|
||||||
case dests := <-findresults:
|
|
||||||
for _, dest := range dests {
|
|
||||||
dial(dest)
|
|
||||||
}
|
|
||||||
refresh.Reset(refreshPeersInterval)
|
|
||||||
case dest := <-dialed:
|
|
||||||
delete(dialing, dest.ID)
|
|
||||||
if len(dialing) == 0 {
|
|
||||||
// Check again immediately after dialing all current candidates.
|
|
||||||
refresh.Reset(0)
|
|
||||||
}
|
|
||||||
case <-srv.quit:
|
|
||||||
// TODO: maybe wait for active dials
|
|
||||||
return
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
func (srv *Server) dialNode(dest *discover.Node) {
|
|
||||||
addr := &net.TCPAddr{IP: dest.IP, Port: int(dest.TCP)}
|
|
||||||
glog.V(logger.Debug).Infof("Dialing %v\n", dest)
|
|
||||||
conn, err := srv.Dialer.Dial("tcp", addr.String())
|
|
||||||
if err != nil {
|
if err != nil {
|
||||||
// dialLoop adds to the wait group counter when launching
|
glog.V(logger.Debug).Infof("%v failed proto handshake: %v", c, err)
|
||||||
// dialNode, so we need to count it down again. startPeer also
|
c.close(err)
|
||||||
// does that when an error occurs.
|
|
||||||
srv.peerWG.Done()
|
|
||||||
glog.V(logger.Detail).Infof("dial error: %v", err)
|
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
srv.startPeer(conn, dest)
|
if phs.ID != c.id {
|
||||||
}
|
glog.V(logger.Debug).Infof("%v wrong proto handshake identity: %x", c, phs.ID[:8])
|
||||||
|
c.close(DiscUnexpectedIdentity)
|
||||||
func (srv *Server) startPeer(fd net.Conn, dest *discover.Node) {
|
|
||||||
// TODO: handle/store session token
|
|
||||||
|
|
||||||
// Run setupFunc, which should create an authenticated connection
|
|
||||||
// and run the capability exchange. Note that any early error
|
|
||||||
// returns during that exchange need to call peerWG.Done because
|
|
||||||
// the callers of startPeer added the peer to the wait group already.
|
|
||||||
fd.SetDeadline(time.Now().Add(handshakeTimeout))
|
|
||||||
|
|
||||||
conn, err := srv.setupFunc(fd, srv.PrivateKey, srv.ourHandshake, dest, srv.keepconn)
|
|
||||||
if err != nil {
|
|
||||||
fd.Close()
|
|
||||||
glog.V(logger.Debug).Infof("Handshake with %v failed: %v", fd.RemoteAddr(), err)
|
|
||||||
srv.peerWG.Done()
|
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
conn.MsgReadWriter = &netWrapper{
|
c.caps, c.name = phs.Caps, phs.Name
|
||||||
wrapped: conn.MsgReadWriter,
|
if err := srv.checkpoint(c, srv.addpeer); err != nil {
|
||||||
conn: fd, rtimeout: frameReadTimeout, wtimeout: frameWriteTimeout,
|
glog.V(logger.Debug).Infof("%v failed checkpoint addpeer: %v", c, err)
|
||||||
}
|
c.close(err)
|
||||||
p := newPeer(fd, conn, srv.Protocols)
|
|
||||||
if ok, reason := srv.addPeer(conn, p); !ok {
|
|
||||||
glog.V(logger.Detail).Infof("Not adding %v (%v)\n", p, reason)
|
|
||||||
p.politeDisconnect(reason)
|
|
||||||
srv.peerWG.Done()
|
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
// The handshakes are done and it passed all checks.
|
// If the checks completed successfully, runPeer has now been
|
||||||
// Spawn the Peer loops.
|
// launched by run.
|
||||||
go srv.runPeer(p)
|
|
||||||
}
|
}
|
||||||
|
|
||||||
// preflight checks whether a connection should be kept. it runs
|
// checkpoint sends the conn to run, which performs the
|
||||||
// after the encryption handshake, as soon as the remote identity is
|
// post-handshake checks for the stage (posthandshake, addpeer).
|
||||||
// known.
|
func (srv *Server) checkpoint(c *conn, stage chan<- *conn) error {
|
||||||
func (srv *Server) keepconn(id discover.NodeID) bool {
|
select {
|
||||||
srv.lock.RLock()
|
case stage <- c:
|
||||||
defer srv.lock.RUnlock()
|
case <-srv.quit:
|
||||||
if _, ok := srv.staticNodes[id]; ok {
|
return errServerStopped
|
||||||
return true // static nodes are always allowed
|
|
||||||
}
|
}
|
||||||
if _, ok := srv.trustedNodes[id]; ok {
|
select {
|
||||||
return true // trusted nodes are always allowed
|
case err := <-c.cont:
|
||||||
|
return err
|
||||||
|
case <-srv.quit:
|
||||||
|
return errServerStopped
|
||||||
}
|
}
|
||||||
return len(srv.peers) < srv.MaxPeers
|
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// runPeer runs in its own goroutine for each peer.
|
||||||
|
// it waits until the Peer logic returns and removes
|
||||||
|
// the peer.
|
||||||
func (srv *Server) runPeer(p *Peer) {
|
func (srv *Server) runPeer(p *Peer) {
|
||||||
glog.V(logger.Debug).Infof("Added %v\n", p)
|
glog.V(logger.Debug).Infof("Added %v\n", p)
|
||||||
srvjslog.LogJson(&logger.P2PConnected{
|
srvjslog.LogJson(&logger.P2PConnected{
|
||||||
|
@ -552,58 +608,18 @@ func (srv *Server) runPeer(p *Peer) {
|
||||||
RemoteVersionString: p.Name(),
|
RemoteVersionString: p.Name(),
|
||||||
NumConnections: srv.PeerCount(),
|
NumConnections: srv.PeerCount(),
|
||||||
})
|
})
|
||||||
|
|
||||||
if srv.newPeerHook != nil {
|
if srv.newPeerHook != nil {
|
||||||
srv.newPeerHook(p)
|
srv.newPeerHook(p)
|
||||||
}
|
}
|
||||||
discreason := p.run()
|
discreason := p.run()
|
||||||
srv.removePeer(p)
|
// Note: run waits for existing peers to be sent on srv.delpeer
|
||||||
|
// before returning, so this send should not select on srv.quit.
|
||||||
|
srv.delpeer <- p
|
||||||
|
|
||||||
glog.V(logger.Debug).Infof("Removed %v (%v)\n", p, discreason)
|
glog.V(logger.Debug).Infof("Removed %v (%v)\n", p, discreason)
|
||||||
srvjslog.LogJson(&logger.P2PDisconnected{
|
srvjslog.LogJson(&logger.P2PDisconnected{
|
||||||
RemoteId: p.ID().String(),
|
RemoteId: p.ID().String(),
|
||||||
NumConnections: srv.PeerCount(),
|
NumConnections: srv.PeerCount(),
|
||||||
})
|
})
|
||||||
}
|
}
|
||||||
|
|
||||||
func (srv *Server) addPeer(conn *conn, p *Peer) (bool, DiscReason) {
|
|
||||||
// drop connections with no matching protocols.
|
|
||||||
if len(srv.Protocols) > 0 && countMatchingProtocols(srv.Protocols, conn.protoHandshake.Caps) == 0 {
|
|
||||||
return false, DiscUselessPeer
|
|
||||||
}
|
|
||||||
// add the peer if it passes the other checks.
|
|
||||||
srv.lock.Lock()
|
|
||||||
defer srv.lock.Unlock()
|
|
||||||
if ok, reason := srv.checkPeer(conn.ID); !ok {
|
|
||||||
return false, reason
|
|
||||||
}
|
|
||||||
srv.peers[conn.ID] = p
|
|
||||||
return true, 0
|
|
||||||
}
|
|
||||||
|
|
||||||
// checkPeer verifies whether a peer looks promising and should be allowed/kept
|
|
||||||
// in the pool, or if it's of no use.
|
|
||||||
func (srv *Server) checkPeer(id discover.NodeID) (bool, DiscReason) {
|
|
||||||
// First up, figure out if the peer is static or trusted
|
|
||||||
_, static := srv.staticNodes[id]
|
|
||||||
trusted := srv.trustedNodes[id]
|
|
||||||
|
|
||||||
// Make sure the peer passes all required checks
|
|
||||||
switch {
|
|
||||||
case !srv.running:
|
|
||||||
return false, DiscQuitting
|
|
||||||
case !static && !trusted && len(srv.peers) >= srv.MaxPeers:
|
|
||||||
return false, DiscTooManyPeers
|
|
||||||
case srv.peers[id] != nil:
|
|
||||||
return false, DiscAlreadyConnected
|
|
||||||
case id == srv.ntab.Self().ID:
|
|
||||||
return false, DiscSelf
|
|
||||||
default:
|
|
||||||
return true, 0
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
func (srv *Server) removePeer(p *Peer) {
|
|
||||||
srv.lock.Lock()
|
|
||||||
delete(srv.peers, p.ID())
|
|
||||||
srv.lock.Unlock()
|
|
||||||
srv.peerWG.Done()
|
|
||||||
}
|
|
||||||
|
|
|
@ -1,12 +1,11 @@
|
||||||
package p2p
|
package p2p
|
||||||
|
|
||||||
import (
|
import (
|
||||||
"bytes"
|
|
||||||
"crypto/ecdsa"
|
"crypto/ecdsa"
|
||||||
"io"
|
"errors"
|
||||||
"math/rand"
|
"math/rand"
|
||||||
"net"
|
"net"
|
||||||
"sync"
|
"reflect"
|
||||||
"testing"
|
"testing"
|
||||||
"time"
|
"time"
|
||||||
|
|
||||||
|
@ -15,29 +14,50 @@ import (
|
||||||
"github.com/ethereum/go-ethereum/p2p/discover"
|
"github.com/ethereum/go-ethereum/p2p/discover"
|
||||||
)
|
)
|
||||||
|
|
||||||
func startTestServer(t *testing.T, pf newPeerHook) *Server {
|
func init() {
|
||||||
|
// glog.SetV(6)
|
||||||
|
// glog.SetToStderr(true)
|
||||||
|
}
|
||||||
|
|
||||||
|
type testTransport struct {
|
||||||
|
id discover.NodeID
|
||||||
|
*rlpx
|
||||||
|
|
||||||
|
closeErr error
|
||||||
|
}
|
||||||
|
|
||||||
|
func newTestTransport(id discover.NodeID, fd net.Conn) transport {
|
||||||
|
wrapped := newRLPX(fd).(*rlpx)
|
||||||
|
wrapped.rw = newRLPXFrameRW(fd, secrets{
|
||||||
|
MAC: zero16,
|
||||||
|
AES: zero16,
|
||||||
|
IngressMAC: sha3.NewKeccak256(),
|
||||||
|
EgressMAC: sha3.NewKeccak256(),
|
||||||
|
})
|
||||||
|
return &testTransport{id: id, rlpx: wrapped}
|
||||||
|
}
|
||||||
|
|
||||||
|
func (c *testTransport) doEncHandshake(prv *ecdsa.PrivateKey, dialDest *discover.Node) (discover.NodeID, error) {
|
||||||
|
return c.id, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (c *testTransport) doProtoHandshake(our *protoHandshake) (*protoHandshake, error) {
|
||||||
|
return &protoHandshake{ID: c.id, Name: "test"}, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (c *testTransport) close(err error) {
|
||||||
|
c.rlpx.fd.Close()
|
||||||
|
c.closeErr = err
|
||||||
|
}
|
||||||
|
|
||||||
|
func startTestServer(t *testing.T, id discover.NodeID, pf func(*Peer)) *Server {
|
||||||
server := &Server{
|
server := &Server{
|
||||||
Name: "test",
|
Name: "test",
|
||||||
MaxPeers: 10,
|
MaxPeers: 10,
|
||||||
ListenAddr: "127.0.0.1:0",
|
ListenAddr: "127.0.0.1:0",
|
||||||
PrivateKey: newkey(),
|
PrivateKey: newkey(),
|
||||||
newPeerHook: pf,
|
newPeerHook: pf,
|
||||||
setupFunc: func(fd net.Conn, prv *ecdsa.PrivateKey, our *protoHandshake, dial *discover.Node, keepconn func(discover.NodeID) bool) (*conn, error) {
|
newTransport: func(fd net.Conn) transport { return newTestTransport(id, fd) },
|
||||||
id := randomID()
|
|
||||||
if !keepconn(id) {
|
|
||||||
return nil, DiscAlreadyConnected
|
|
||||||
}
|
|
||||||
rw := newRlpxFrameRW(fd, secrets{
|
|
||||||
MAC: zero16,
|
|
||||||
AES: zero16,
|
|
||||||
IngressMAC: sha3.NewKeccak256(),
|
|
||||||
EgressMAC: sha3.NewKeccak256(),
|
|
||||||
})
|
|
||||||
return &conn{
|
|
||||||
MsgReadWriter: rw,
|
|
||||||
protoHandshake: &protoHandshake{ID: id, Version: baseProtocolVersion},
|
|
||||||
}, nil
|
|
||||||
},
|
|
||||||
}
|
}
|
||||||
if err := server.Start(); err != nil {
|
if err := server.Start(); err != nil {
|
||||||
t.Fatalf("Could not start server: %v", err)
|
t.Fatalf("Could not start server: %v", err)
|
||||||
|
@ -48,7 +68,11 @@ func startTestServer(t *testing.T, pf newPeerHook) *Server {
|
||||||
func TestServerListen(t *testing.T) {
|
func TestServerListen(t *testing.T) {
|
||||||
// start the test server
|
// start the test server
|
||||||
connected := make(chan *Peer)
|
connected := make(chan *Peer)
|
||||||
srv := startTestServer(t, func(p *Peer) {
|
remid := randomID()
|
||||||
|
srv := startTestServer(t, remid, func(p *Peer) {
|
||||||
|
if p.ID() != remid {
|
||||||
|
t.Error("peer func called with wrong node id")
|
||||||
|
}
|
||||||
if p == nil {
|
if p == nil {
|
||||||
t.Error("peer func called with nil conn")
|
t.Error("peer func called with nil conn")
|
||||||
}
|
}
|
||||||
|
@ -70,6 +94,10 @@ func TestServerListen(t *testing.T) {
|
||||||
t.Errorf("peer started with wrong conn: got %v, want %v",
|
t.Errorf("peer started with wrong conn: got %v, want %v",
|
||||||
peer.LocalAddr(), conn.RemoteAddr())
|
peer.LocalAddr(), conn.RemoteAddr())
|
||||||
}
|
}
|
||||||
|
peers := srv.Peers()
|
||||||
|
if !reflect.DeepEqual(peers, []*Peer{peer}) {
|
||||||
|
t.Errorf("Peers mismatch: got %v, want %v", peers, []*Peer{peer})
|
||||||
|
}
|
||||||
case <-time.After(1 * time.Second):
|
case <-time.After(1 * time.Second):
|
||||||
t.Error("server did not accept within one second")
|
t.Error("server did not accept within one second")
|
||||||
}
|
}
|
||||||
|
@ -95,23 +123,33 @@ func TestServerDial(t *testing.T) {
|
||||||
|
|
||||||
// start the server
|
// start the server
|
||||||
connected := make(chan *Peer)
|
connected := make(chan *Peer)
|
||||||
srv := startTestServer(t, func(p *Peer) { connected <- p })
|
remid := randomID()
|
||||||
|
srv := startTestServer(t, remid, func(p *Peer) { connected <- p })
|
||||||
defer close(connected)
|
defer close(connected)
|
||||||
defer srv.Stop()
|
defer srv.Stop()
|
||||||
|
|
||||||
// tell the server to connect
|
// tell the server to connect
|
||||||
tcpAddr := listener.Addr().(*net.TCPAddr)
|
tcpAddr := listener.Addr().(*net.TCPAddr)
|
||||||
srv.staticDial <- &discover.Node{IP: tcpAddr.IP, TCP: uint16(tcpAddr.Port)}
|
srv.AddPeer(&discover.Node{ID: remid, IP: tcpAddr.IP, TCP: uint16(tcpAddr.Port)})
|
||||||
|
|
||||||
select {
|
select {
|
||||||
case conn := <-accepted:
|
case conn := <-accepted:
|
||||||
select {
|
select {
|
||||||
case peer := <-connected:
|
case peer := <-connected:
|
||||||
|
if peer.ID() != remid {
|
||||||
|
t.Errorf("peer has wrong id")
|
||||||
|
}
|
||||||
|
if peer.Name() != "test" {
|
||||||
|
t.Errorf("peer has wrong name")
|
||||||
|
}
|
||||||
if peer.RemoteAddr().String() != conn.LocalAddr().String() {
|
if peer.RemoteAddr().String() != conn.LocalAddr().String() {
|
||||||
t.Errorf("peer started with wrong conn: got %v, want %v",
|
t.Errorf("peer started with wrong conn: got %v, want %v",
|
||||||
peer.RemoteAddr(), conn.LocalAddr())
|
peer.RemoteAddr(), conn.LocalAddr())
|
||||||
}
|
}
|
||||||
// TODO: validate more fields
|
peers := srv.Peers()
|
||||||
|
if !reflect.DeepEqual(peers, []*Peer{peer}) {
|
||||||
|
t.Errorf("Peers mismatch: got %v, want %v", peers, []*Peer{peer})
|
||||||
|
}
|
||||||
case <-time.After(1 * time.Second):
|
case <-time.After(1 * time.Second):
|
||||||
t.Error("server did not launch peer within one second")
|
t.Error("server did not launch peer within one second")
|
||||||
}
|
}
|
||||||
|
@ -121,370 +159,250 @@ func TestServerDial(t *testing.T) {
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
func TestServerBroadcast(t *testing.T) {
|
// This test checks that tasks generated by dialstate are
|
||||||
var connected sync.WaitGroup
|
// actually executed and taskdone is called for them.
|
||||||
srv := startTestServer(t, func(p *Peer) {
|
func TestServerTaskScheduling(t *testing.T) {
|
||||||
p.running = matchProtocols([]Protocol{discard}, []Cap{discard.cap()}, p.rw)
|
var (
|
||||||
connected.Done()
|
done = make(chan *testTask)
|
||||||
})
|
quit, returned = make(chan struct{}), make(chan struct{})
|
||||||
defer srv.Stop()
|
tc = 0
|
||||||
|
tg = taskgen{
|
||||||
// create a few peers
|
newFunc: func(running int, peers map[discover.NodeID]*Peer) []task {
|
||||||
var conns = make([]net.Conn, 8)
|
tc++
|
||||||
connected.Add(len(conns))
|
return []task{&testTask{index: tc - 1}}
|
||||||
deadline := time.Now().Add(3 * time.Second)
|
},
|
||||||
dialer := &net.Dialer{Deadline: deadline}
|
doneFunc: func(t task) {
|
||||||
for i := range conns {
|
select {
|
||||||
conn, err := dialer.Dial("tcp", srv.ListenAddr)
|
case done <- t.(*testTask):
|
||||||
if err != nil {
|
case <-quit:
|
||||||
t.Fatalf("conn %d: dial error: %v", i, err)
|
}
|
||||||
|
},
|
||||||
}
|
}
|
||||||
defer conn.Close()
|
)
|
||||||
conn.SetDeadline(deadline)
|
|
||||||
conns[i] = conn
|
// The Server in this test isn't actually running
|
||||||
|
// because we're only interested in what run does.
|
||||||
|
srv := &Server{
|
||||||
|
MaxPeers: 10,
|
||||||
|
quit: make(chan struct{}),
|
||||||
|
ntab: fakeTable{},
|
||||||
|
running: true,
|
||||||
}
|
}
|
||||||
connected.Wait()
|
srv.loopWG.Add(1)
|
||||||
|
go func() {
|
||||||
|
srv.run(tg)
|
||||||
|
close(returned)
|
||||||
|
}()
|
||||||
|
|
||||||
// broadcast one message
|
var gotdone []*testTask
|
||||||
srv.Broadcast("discard", 0, []string{"foo"})
|
for i := 0; i < 100; i++ {
|
||||||
golden := unhex("66e94d166f0a2c3b884cfa59ca34")
|
gotdone = append(gotdone, <-done)
|
||||||
|
}
|
||||||
// check that the message has been written everywhere
|
for i, task := range gotdone {
|
||||||
for i, conn := range conns {
|
if task.index != i {
|
||||||
buf := make([]byte, len(golden))
|
t.Errorf("task %d has wrong index, got %d", i, task.index)
|
||||||
if _, err := io.ReadFull(conn, buf); err != nil {
|
break
|
||||||
t.Errorf("conn %d: read error: %v", i, err)
|
}
|
||||||
} else if !bytes.Equal(buf, golden) {
|
if !task.called {
|
||||||
t.Errorf("conn %d: msg mismatch\ngot: %x\nwant: %x", i, buf, golden)
|
t.Errorf("task %d was not called", i)
|
||||||
|
break
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
close(quit)
|
||||||
|
srv.Stop()
|
||||||
|
select {
|
||||||
|
case <-returned:
|
||||||
|
case <-time.After(500 * time.Millisecond):
|
||||||
|
t.Error("Server.run did not return within 500ms")
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
type taskgen struct {
|
||||||
|
newFunc func(running int, peers map[discover.NodeID]*Peer) []task
|
||||||
|
doneFunc func(task)
|
||||||
|
}
|
||||||
|
|
||||||
|
func (tg taskgen) newTasks(running int, peers map[discover.NodeID]*Peer, now time.Time) []task {
|
||||||
|
return tg.newFunc(running, peers)
|
||||||
|
}
|
||||||
|
func (tg taskgen) taskDone(t task, now time.Time) {
|
||||||
|
tg.doneFunc(t)
|
||||||
|
}
|
||||||
|
func (tg taskgen) addStatic(*discover.Node) {
|
||||||
|
}
|
||||||
|
|
||||||
|
type testTask struct {
|
||||||
|
index int
|
||||||
|
called bool
|
||||||
|
}
|
||||||
|
|
||||||
|
func (t *testTask) Do(srv *Server) {
|
||||||
|
t.called = true
|
||||||
}
|
}
|
||||||
|
|
||||||
// This test checks that connections are disconnected
|
// This test checks that connections are disconnected
|
||||||
// just after the encryption handshake when the server is
|
// just after the encryption handshake when the server is
|
||||||
// at capacity.
|
// at capacity. Trusted connections should still be accepted.
|
||||||
//
|
func TestServerAtCap(t *testing.T) {
|
||||||
// It also serves as a light-weight integration test.
|
trustedID := randomID()
|
||||||
func TestServerDisconnectAtCap(t *testing.T) {
|
|
||||||
started := make(chan *Peer)
|
|
||||||
srv := &Server{
|
srv := &Server{
|
||||||
ListenAddr: "127.0.0.1:0",
|
PrivateKey: newkey(),
|
||||||
PrivateKey: newkey(),
|
MaxPeers: 10,
|
||||||
MaxPeers: 10,
|
NoDial: true,
|
||||||
NoDial: true,
|
TrustedNodes: []*discover.Node{{ID: trustedID}},
|
||||||
// This hook signals that the peer was actually started. We
|
|
||||||
// need to wait for the peer to be started before dialing the
|
|
||||||
// next connection to get a deterministic peer count.
|
|
||||||
newPeerHook: func(p *Peer) { started <- p },
|
|
||||||
}
|
}
|
||||||
if err := srv.Start(); err != nil {
|
if err := srv.Start(); err != nil {
|
||||||
t.Fatal(err)
|
t.Fatalf("could not start: %v", err)
|
||||||
}
|
}
|
||||||
defer srv.Stop()
|
defer srv.Stop()
|
||||||
|
|
||||||
nconns := srv.MaxPeers + 1
|
newconn := func(id discover.NodeID) *conn {
|
||||||
dialer := &net.Dialer{Deadline: time.Now().Add(3 * time.Second)}
|
fd, _ := net.Pipe()
|
||||||
for i := 0; i < nconns; i++ {
|
tx := newTestTransport(id, fd)
|
||||||
conn, err := dialer.Dial("tcp", srv.ListenAddr)
|
return &conn{fd: fd, transport: tx, flags: inboundConn, id: id, cont: make(chan error)}
|
||||||
if err != nil {
|
}
|
||||||
t.Fatalf("conn %d: dial error: %v", i, err)
|
|
||||||
|
// Inject a few connections to fill up the peer set.
|
||||||
|
for i := 0; i < 10; i++ {
|
||||||
|
c := newconn(randomID())
|
||||||
|
if err := srv.checkpoint(c, srv.addpeer); err != nil {
|
||||||
|
t.Fatalf("could not add conn %d: %v", i, err)
|
||||||
}
|
}
|
||||||
// Close the connection when the test ends, before
|
}
|
||||||
// shutting down the server.
|
// Try inserting a non-trusted connection.
|
||||||
defer conn.Close()
|
c := newconn(randomID())
|
||||||
// Run the handshakes just like a real peer would.
|
if err := srv.checkpoint(c, srv.posthandshake); err != DiscTooManyPeers {
|
||||||
key := newkey()
|
t.Error("wrong error for insert:", err)
|
||||||
hs := &protoHandshake{Version: baseProtocolVersion, ID: discover.PubkeyID(&key.PublicKey)}
|
}
|
||||||
_, err = setupConn(conn, key, hs, srv.Self(), keepalways)
|
// Try inserting a trusted connection.
|
||||||
if i == nconns-1 {
|
c = newconn(trustedID)
|
||||||
// When handling the last connection, the server should
|
if err := srv.checkpoint(c, srv.posthandshake); err != nil {
|
||||||
// disconnect immediately instead of running the protocol
|
t.Error("unexpected error for trusted conn @posthandshake:", err)
|
||||||
// handshake.
|
}
|
||||||
if err != DiscTooManyPeers {
|
if !c.is(trustedConn) {
|
||||||
t.Errorf("conn %d: got error %q, expected %q", i, err, DiscTooManyPeers)
|
t.Error("Server did not set trusted flag")
|
||||||
|
}
|
||||||
|
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestServerSetupConn(t *testing.T) {
|
||||||
|
id := randomID()
|
||||||
|
srvkey := newkey()
|
||||||
|
srvid := discover.PubkeyID(&srvkey.PublicKey)
|
||||||
|
tests := []struct {
|
||||||
|
dontstart bool
|
||||||
|
tt *setupTransport
|
||||||
|
flags connFlag
|
||||||
|
dialDest *discover.Node
|
||||||
|
|
||||||
|
wantCloseErr error
|
||||||
|
wantCalls string
|
||||||
|
}{
|
||||||
|
{
|
||||||
|
dontstart: true,
|
||||||
|
tt: &setupTransport{id: id},
|
||||||
|
wantCalls: "close,",
|
||||||
|
wantCloseErr: errServerStopped,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
tt: &setupTransport{id: id, encHandshakeErr: errors.New("read error")},
|
||||||
|
flags: inboundConn,
|
||||||
|
wantCalls: "doEncHandshake,close,",
|
||||||
|
wantCloseErr: errors.New("read error"),
|
||||||
|
},
|
||||||
|
{
|
||||||
|
tt: &setupTransport{id: id},
|
||||||
|
dialDest: &discover.Node{ID: randomID()},
|
||||||
|
flags: dynDialedConn,
|
||||||
|
wantCalls: "doEncHandshake,close,",
|
||||||
|
wantCloseErr: DiscUnexpectedIdentity,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
tt: &setupTransport{id: id, phs: &protoHandshake{ID: randomID()}},
|
||||||
|
dialDest: &discover.Node{ID: id},
|
||||||
|
flags: dynDialedConn,
|
||||||
|
wantCalls: "doEncHandshake,doProtoHandshake,close,",
|
||||||
|
wantCloseErr: DiscUnexpectedIdentity,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
tt: &setupTransport{id: id, protoHandshakeErr: errors.New("foo")},
|
||||||
|
dialDest: &discover.Node{ID: id},
|
||||||
|
flags: dynDialedConn,
|
||||||
|
wantCalls: "doEncHandshake,doProtoHandshake,close,",
|
||||||
|
wantCloseErr: errors.New("foo"),
|
||||||
|
},
|
||||||
|
{
|
||||||
|
tt: &setupTransport{id: srvid, phs: &protoHandshake{ID: srvid}},
|
||||||
|
flags: inboundConn,
|
||||||
|
wantCalls: "doEncHandshake,close,",
|
||||||
|
wantCloseErr: DiscSelf,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
tt: &setupTransport{id: id, phs: &protoHandshake{ID: id}},
|
||||||
|
flags: inboundConn,
|
||||||
|
wantCalls: "doEncHandshake,doProtoHandshake,close,",
|
||||||
|
wantCloseErr: DiscUselessPeer,
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
for i, test := range tests {
|
||||||
|
srv := &Server{
|
||||||
|
PrivateKey: srvkey,
|
||||||
|
MaxPeers: 10,
|
||||||
|
NoDial: true,
|
||||||
|
Protocols: []Protocol{discard},
|
||||||
|
newTransport: func(fd net.Conn) transport { return test.tt },
|
||||||
|
}
|
||||||
|
if !test.dontstart {
|
||||||
|
if err := srv.Start(); err != nil {
|
||||||
|
t.Fatalf("couldn't start server: %v", err)
|
||||||
}
|
}
|
||||||
} else {
|
}
|
||||||
// For all earlier connections, the handshake should go through.
|
p1, _ := net.Pipe()
|
||||||
if err != nil {
|
srv.setupConn(p1, test.flags, test.dialDest)
|
||||||
t.Fatalf("conn %d: unexpected error: %v", i, err)
|
if !reflect.DeepEqual(test.tt.closeErr, test.wantCloseErr) {
|
||||||
}
|
t.Errorf("test %d: close error mismatch: got %q, want %q", i, test.tt.closeErr, test.wantCloseErr)
|
||||||
// Wait for runPeer to be started.
|
}
|
||||||
<-started
|
if test.tt.calls != test.wantCalls {
|
||||||
|
t.Errorf("test %d: calls mismatch: got %q, want %q", i, test.tt.calls, test.wantCalls)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
// Tests that static peers are (re)connected, and done so even above max peers.
|
type setupTransport struct {
|
||||||
func TestServerStaticPeers(t *testing.T) {
|
id discover.NodeID
|
||||||
// Create a test server with limited connection slots
|
encHandshakeErr error
|
||||||
started := make(chan *Peer)
|
|
||||||
server := &Server{
|
|
||||||
ListenAddr: "127.0.0.1:0",
|
|
||||||
PrivateKey: newkey(),
|
|
||||||
MaxPeers: 3,
|
|
||||||
newPeerHook: func(p *Peer) { started <- p },
|
|
||||||
staticCycle: time.Second,
|
|
||||||
}
|
|
||||||
if err := server.Start(); err != nil {
|
|
||||||
t.Fatal(err)
|
|
||||||
}
|
|
||||||
defer server.Stop()
|
|
||||||
|
|
||||||
// Fill up all the slots on the server
|
phs *protoHandshake
|
||||||
dialer := &net.Dialer{Deadline: time.Now().Add(3 * time.Second)}
|
protoHandshakeErr error
|
||||||
for i := 0; i < server.MaxPeers; i++ {
|
|
||||||
// Establish a new connection
|
|
||||||
conn, err := dialer.Dial("tcp", server.ListenAddr)
|
|
||||||
if err != nil {
|
|
||||||
t.Fatalf("conn %d: dial error: %v", i, err)
|
|
||||||
}
|
|
||||||
defer conn.Close()
|
|
||||||
|
|
||||||
// Run the handshakes just like a real peer would, and wait for completion
|
calls string
|
||||||
key := newkey()
|
closeErr error
|
||||||
shake := &protoHandshake{Version: baseProtocolVersion, ID: discover.PubkeyID(&key.PublicKey)}
|
|
||||||
if _, err = setupConn(conn, key, shake, server.Self(), keepalways); err != nil {
|
|
||||||
t.Fatalf("conn %d: unexpected error: %v", i, err)
|
|
||||||
}
|
|
||||||
<-started
|
|
||||||
}
|
|
||||||
// Open a TCP listener to accept static connections
|
|
||||||
listener, err := net.Listen("tcp", "127.0.0.1:0")
|
|
||||||
if err != nil {
|
|
||||||
t.Fatalf("failed to setup listener: %v", err)
|
|
||||||
}
|
|
||||||
defer listener.Close()
|
|
||||||
|
|
||||||
connected := make(chan net.Conn)
|
|
||||||
go func() {
|
|
||||||
for i := 0; i < 3; i++ {
|
|
||||||
conn, err := listener.Accept()
|
|
||||||
if err == nil {
|
|
||||||
connected <- conn
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}()
|
|
||||||
// Inject a static node and wait for a remote dial, then redial, then nothing
|
|
||||||
addr := listener.Addr().(*net.TCPAddr)
|
|
||||||
static := &discover.Node{
|
|
||||||
ID: discover.PubkeyID(&newkey().PublicKey),
|
|
||||||
IP: addr.IP,
|
|
||||||
TCP: uint16(addr.Port),
|
|
||||||
}
|
|
||||||
server.AddPeer(static)
|
|
||||||
|
|
||||||
select {
|
|
||||||
case conn := <-connected:
|
|
||||||
// Close the first connection, expect redial
|
|
||||||
conn.Close()
|
|
||||||
|
|
||||||
case <-time.After(2 * server.staticCycle):
|
|
||||||
t.Fatalf("remote dial timeout")
|
|
||||||
}
|
|
||||||
|
|
||||||
select {
|
|
||||||
case conn := <-connected:
|
|
||||||
// Keep the second connection, don't expect redial
|
|
||||||
defer conn.Close()
|
|
||||||
|
|
||||||
case <-time.After(2 * server.staticCycle):
|
|
||||||
t.Fatalf("remote re-dial timeout")
|
|
||||||
}
|
|
||||||
|
|
||||||
select {
|
|
||||||
case <-time.After(2 * server.staticCycle):
|
|
||||||
// Timeout as no dial occurred
|
|
||||||
|
|
||||||
case <-connected:
|
|
||||||
t.Fatalf("connected node dialed")
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
|
|
||||||
// Tests that trusted peers and can connect above max peer caps.
|
func (c *setupTransport) doEncHandshake(prv *ecdsa.PrivateKey, dialDest *discover.Node) (discover.NodeID, error) {
|
||||||
func TestServerTrustedPeers(t *testing.T) {
|
c.calls += "doEncHandshake,"
|
||||||
|
return c.id, c.encHandshakeErr
|
||||||
// Create a trusted peer to accept connections from
|
}
|
||||||
key := newkey()
|
func (c *setupTransport) doProtoHandshake(our *protoHandshake) (*protoHandshake, error) {
|
||||||
trusted := &discover.Node{
|
c.calls += "doProtoHandshake,"
|
||||||
ID: discover.PubkeyID(&key.PublicKey),
|
if c.protoHandshakeErr != nil {
|
||||||
}
|
return nil, c.protoHandshakeErr
|
||||||
// Create a test server with limited connection slots
|
|
||||||
started := make(chan *Peer)
|
|
||||||
server := &Server{
|
|
||||||
ListenAddr: "127.0.0.1:0",
|
|
||||||
PrivateKey: newkey(),
|
|
||||||
MaxPeers: 3,
|
|
||||||
NoDial: true,
|
|
||||||
TrustedNodes: []*discover.Node{trusted},
|
|
||||||
newPeerHook: func(p *Peer) { started <- p },
|
|
||||||
}
|
|
||||||
if err := server.Start(); err != nil {
|
|
||||||
t.Fatal(err)
|
|
||||||
}
|
|
||||||
defer server.Stop()
|
|
||||||
|
|
||||||
// Fill up all the slots on the server
|
|
||||||
dialer := &net.Dialer{Deadline: time.Now().Add(3 * time.Second)}
|
|
||||||
for i := 0; i < server.MaxPeers; i++ {
|
|
||||||
// Establish a new connection
|
|
||||||
conn, err := dialer.Dial("tcp", server.ListenAddr)
|
|
||||||
if err != nil {
|
|
||||||
t.Fatalf("conn %d: dial error: %v", i, err)
|
|
||||||
}
|
|
||||||
defer conn.Close()
|
|
||||||
|
|
||||||
// Run the handshakes just like a real peer would, and wait for completion
|
|
||||||
key := newkey()
|
|
||||||
shake := &protoHandshake{Version: baseProtocolVersion, ID: discover.PubkeyID(&key.PublicKey)}
|
|
||||||
if _, err = setupConn(conn, key, shake, server.Self(), keepalways); err != nil {
|
|
||||||
t.Fatalf("conn %d: unexpected error: %v", i, err)
|
|
||||||
}
|
|
||||||
<-started
|
|
||||||
}
|
|
||||||
// Dial from the trusted peer, ensure connection is accepted
|
|
||||||
conn, err := dialer.Dial("tcp", server.ListenAddr)
|
|
||||||
if err != nil {
|
|
||||||
t.Fatalf("trusted node: dial error: %v", err)
|
|
||||||
}
|
|
||||||
defer conn.Close()
|
|
||||||
|
|
||||||
shake := &protoHandshake{Version: baseProtocolVersion, ID: trusted.ID}
|
|
||||||
if _, err = setupConn(conn, key, shake, server.Self(), keepalways); err != nil {
|
|
||||||
t.Fatalf("trusted node: unexpected error: %v", err)
|
|
||||||
}
|
|
||||||
select {
|
|
||||||
case <-started:
|
|
||||||
// Ok, trusted peer accepted
|
|
||||||
|
|
||||||
case <-time.After(100 * time.Millisecond):
|
|
||||||
t.Fatalf("trusted node timeout")
|
|
||||||
}
|
}
|
||||||
|
return c.phs, nil
|
||||||
|
}
|
||||||
|
func (c *setupTransport) close(err error) {
|
||||||
|
c.calls += "close,"
|
||||||
|
c.closeErr = err
|
||||||
}
|
}
|
||||||
|
|
||||||
// Tests that a failed dial will temporarily throttle a peer.
|
// setupConn shouldn't write to/read from the connection.
|
||||||
func TestServerMaxPendingDials(t *testing.T) {
|
func (c *setupTransport) WriteMsg(Msg) error {
|
||||||
// Start a simple test server
|
panic("WriteMsg called on setupTransport")
|
||||||
server := &Server{
|
|
||||||
ListenAddr: "127.0.0.1:0",
|
|
||||||
PrivateKey: newkey(),
|
|
||||||
MaxPeers: 10,
|
|
||||||
MaxPendingPeers: 1,
|
|
||||||
}
|
|
||||||
if err := server.Start(); err != nil {
|
|
||||||
t.Fatal("failed to start test server: %v", err)
|
|
||||||
}
|
|
||||||
defer server.Stop()
|
|
||||||
|
|
||||||
// Simulate two separate remote peers
|
|
||||||
peers := make(chan *discover.Node, 2)
|
|
||||||
conns := make(chan net.Conn, 2)
|
|
||||||
for i := 0; i < 2; i++ {
|
|
||||||
listener, err := net.Listen("tcp", "127.0.0.1:0")
|
|
||||||
if err != nil {
|
|
||||||
t.Fatalf("listener %d: failed to setup: %v", i, err)
|
|
||||||
}
|
|
||||||
defer listener.Close()
|
|
||||||
|
|
||||||
addr := listener.Addr().(*net.TCPAddr)
|
|
||||||
peers <- &discover.Node{
|
|
||||||
ID: discover.PubkeyID(&newkey().PublicKey),
|
|
||||||
IP: addr.IP,
|
|
||||||
TCP: uint16(addr.Port),
|
|
||||||
}
|
|
||||||
go func() {
|
|
||||||
conn, err := listener.Accept()
|
|
||||||
if err == nil {
|
|
||||||
conns <- conn
|
|
||||||
}
|
|
||||||
}()
|
|
||||||
}
|
|
||||||
// Request a dial for both peers
|
|
||||||
go func() {
|
|
||||||
for i := 0; i < 2; i++ {
|
|
||||||
server.staticDial <- <-peers // hack piggybacking the static implementation
|
|
||||||
}
|
|
||||||
}()
|
|
||||||
|
|
||||||
// Make sure only one outbound connection goes through
|
|
||||||
var conn net.Conn
|
|
||||||
|
|
||||||
select {
|
|
||||||
case conn = <-conns:
|
|
||||||
case <-time.After(100 * time.Millisecond):
|
|
||||||
t.Fatalf("first dial timeout")
|
|
||||||
}
|
|
||||||
select {
|
|
||||||
case conn = <-conns:
|
|
||||||
t.Fatalf("second dial completed prematurely")
|
|
||||||
case <-time.After(100 * time.Millisecond):
|
|
||||||
}
|
|
||||||
// Finish the first dial, check the second
|
|
||||||
conn.Close()
|
|
||||||
select {
|
|
||||||
case conn = <-conns:
|
|
||||||
conn.Close()
|
|
||||||
|
|
||||||
case <-time.After(100 * time.Millisecond):
|
|
||||||
t.Fatalf("second dial timeout")
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
|
func (c *setupTransport) ReadMsg() (Msg, error) {
|
||||||
func TestServerMaxPendingAccepts(t *testing.T) {
|
panic("ReadMsg called on setupTransport")
|
||||||
// Start a test server and a peer sink for synchronization
|
|
||||||
started := make(chan *Peer)
|
|
||||||
server := &Server{
|
|
||||||
ListenAddr: "127.0.0.1:0",
|
|
||||||
PrivateKey: newkey(),
|
|
||||||
MaxPeers: 10,
|
|
||||||
MaxPendingPeers: 1,
|
|
||||||
NoDial: true,
|
|
||||||
newPeerHook: func(p *Peer) { started <- p },
|
|
||||||
}
|
|
||||||
if err := server.Start(); err != nil {
|
|
||||||
t.Fatal("failed to start test server: %v", err)
|
|
||||||
}
|
|
||||||
defer server.Stop()
|
|
||||||
|
|
||||||
// Try and connect to the server on multiple threads concurrently
|
|
||||||
conns := make([]net.Conn, 2)
|
|
||||||
for i := 0; i < 2; i++ {
|
|
||||||
dialer := &net.Dialer{Deadline: time.Now().Add(3 * time.Second)}
|
|
||||||
|
|
||||||
conn, err := dialer.Dial("tcp", server.ListenAddr)
|
|
||||||
if err != nil {
|
|
||||||
t.Fatalf("failed to dial server: %v", err)
|
|
||||||
}
|
|
||||||
conns[i] = conn
|
|
||||||
}
|
|
||||||
// Check that a handshake on the second doesn't pass
|
|
||||||
go func() {
|
|
||||||
key := newkey()
|
|
||||||
shake := &protoHandshake{Version: baseProtocolVersion, ID: discover.PubkeyID(&key.PublicKey)}
|
|
||||||
if _, err := setupConn(conns[1], key, shake, server.Self(), keepalways); err != nil {
|
|
||||||
t.Fatalf("failed to run handshake: %v", err)
|
|
||||||
}
|
|
||||||
}()
|
|
||||||
select {
|
|
||||||
case <-started:
|
|
||||||
t.Fatalf("handshake on second connection accepted")
|
|
||||||
|
|
||||||
case <-time.After(time.Second):
|
|
||||||
}
|
|
||||||
// Shake on first, check that both go through
|
|
||||||
go func() {
|
|
||||||
key := newkey()
|
|
||||||
shake := &protoHandshake{Version: baseProtocolVersion, ID: discover.PubkeyID(&key.PublicKey)}
|
|
||||||
if _, err := setupConn(conns[0], key, shake, server.Self(), keepalways); err != nil {
|
|
||||||
t.Fatalf("failed to run handshake: %v", err)
|
|
||||||
}
|
|
||||||
}()
|
|
||||||
for i := 0; i < 2; i++ {
|
|
||||||
select {
|
|
||||||
case <-started:
|
|
||||||
case <-time.After(time.Second):
|
|
||||||
t.Fatalf("peer %d: handshake timeout", i)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
|
|
||||||
func newkey() *ecdsa.PrivateKey {
|
func newkey() *ecdsa.PrivateKey {
|
||||||
|
@ -501,7 +419,3 @@ func randomID() (id discover.NodeID) {
|
||||||
}
|
}
|
||||||
return id
|
return id
|
||||||
}
|
}
|
||||||
|
|
||||||
func keepalways(id discover.NodeID) bool {
|
|
||||||
return true
|
|
||||||
}
|
|
||||||
|
|
Loading…
Reference in New Issue