chore: upgrade deps
This commit is contained in:
parent
6300751e36
commit
a9a87f39e5
|
@ -39,3 +39,18 @@ type CurveParams = secp.CurveParams
|
|||
func Params() *CurveParams {
|
||||
return secp.Params()
|
||||
}
|
||||
|
||||
// Generator returns the public key at the Generator Point.
|
||||
func Generator() *PublicKey {
|
||||
var (
|
||||
result JacobianPoint
|
||||
k secp.ModNScalar
|
||||
)
|
||||
|
||||
k.SetInt(1)
|
||||
ScalarBaseMultNonConst(&k, &result)
|
||||
|
||||
result.ToAffine()
|
||||
|
||||
return NewPublicKey(&result.X, &result.Y)
|
||||
}
|
||||
|
|
|
@ -4,6 +4,8 @@
|
|||
package btcec
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
|
||||
secp "github.com/decred/dcrd/dcrec/secp256k1/v4"
|
||||
)
|
||||
|
||||
|
@ -11,6 +13,9 @@ import (
|
|||
// Jacobian projective coordinates and thus represents a point on the curve.
|
||||
type JacobianPoint = secp.JacobianPoint
|
||||
|
||||
// infinityPoint is the jacobian representation of the point at infinity.
|
||||
var infinityPoint JacobianPoint
|
||||
|
||||
// MakeJacobianPoint returns a Jacobian point with the provided X, Y, and Z
|
||||
// coordinates.
|
||||
func MakeJacobianPoint(x, y, z *FieldVal) JacobianPoint {
|
||||
|
@ -61,3 +66,50 @@ func ScalarBaseMultNonConst(k *ModNScalar, result *JacobianPoint) {
|
|||
func ScalarMultNonConst(k *ModNScalar, point, result *JacobianPoint) {
|
||||
secp.ScalarMultNonConst(k, point, result)
|
||||
}
|
||||
|
||||
// ParseJacobian parses a byte slice point as a secp.Publickey and returns the
|
||||
// pubkey as a JacobianPoint. If the nonce is a zero slice, the infinityPoint
|
||||
// is returned.
|
||||
func ParseJacobian(point []byte) (JacobianPoint, error) {
|
||||
var result JacobianPoint
|
||||
|
||||
if len(point) != 33 {
|
||||
str := fmt.Sprintf("invalid nonce: invalid length: %v",
|
||||
len(point))
|
||||
return JacobianPoint{}, makeError(secp.ErrPubKeyInvalidLen, str)
|
||||
}
|
||||
|
||||
if point[0] == 0x00 {
|
||||
return infinityPoint, nil
|
||||
}
|
||||
|
||||
noncePk, err := secp.ParsePubKey(point)
|
||||
if err != nil {
|
||||
return JacobianPoint{}, err
|
||||
}
|
||||
noncePk.AsJacobian(&result)
|
||||
|
||||
return result, nil
|
||||
}
|
||||
|
||||
// JacobianToByteSlice converts the passed JacobianPoint to a Pubkey
|
||||
// and serializes that to a byte slice. If the JacobianPoint is the infinity
|
||||
// point, a zero slice is returned.
|
||||
func JacobianToByteSlice(point JacobianPoint) []byte {
|
||||
if point.X == infinityPoint.X && point.Y == infinityPoint.Y {
|
||||
return make([]byte, 33)
|
||||
}
|
||||
|
||||
point.ToAffine()
|
||||
|
||||
return NewPublicKey(
|
||||
&point.X, &point.Y,
|
||||
).SerializeCompressed()
|
||||
}
|
||||
|
||||
// GeneratorJacobian sets the passed JacobianPoint to the Generator Point.
|
||||
func GeneratorJacobian(jacobian *JacobianPoint) {
|
||||
var k ModNScalar
|
||||
k.SetInt(1)
|
||||
ScalarBaseMultNonConst(&k, jacobian)
|
||||
}
|
||||
|
|
|
@ -17,3 +17,8 @@ type Error = secp.Error
|
|||
// errors.As, so the caller can directly check against an error kind when
|
||||
// determining the reason for an error.
|
||||
type ErrorKind = secp.ErrorKind
|
||||
|
||||
// makeError creates an secp.Error given a set of arguments.
|
||||
func makeError(kind ErrorKind, desc string) Error {
|
||||
return Error{Err: kind, Description: desc}
|
||||
}
|
||||
|
|
|
@ -1,26 +0,0 @@
|
|||
# Compiled Object files, Static and Dynamic libs (Shared Objects)
|
||||
*.o
|
||||
*.a
|
||||
*.so
|
||||
|
||||
# Folders
|
||||
_obj
|
||||
_test
|
||||
|
||||
# Architecture specific extensions/prefixes
|
||||
*.[568vq]
|
||||
[568vq].out
|
||||
|
||||
*.cgo1.go
|
||||
*.cgo2.c
|
||||
_cgo_defun.c
|
||||
_cgo_gotypes.go
|
||||
_cgo_export.*
|
||||
|
||||
_testmain.go
|
||||
|
||||
*.exe
|
||||
*.test
|
||||
*.prof
|
||||
|
||||
genny
|
|
@ -1,6 +0,0 @@
|
|||
language: go
|
||||
|
||||
go:
|
||||
- 1.7
|
||||
- 1.8
|
||||
- 1.9
|
|
@ -1,245 +0,0 @@
|
|||
# genny - Generics for Go
|
||||
|
||||
[![Build Status](https://travis-ci.org/cheekybits/genny.svg?branch=master)](https://travis-ci.org/cheekybits/genny) [![GoDoc](https://godoc.org/github.com/cheekybits/genny/parse?status.png)](http://godoc.org/github.com/cheekybits/genny/parse)
|
||||
|
||||
Install:
|
||||
|
||||
```
|
||||
go get github.com/cheekybits/genny
|
||||
```
|
||||
|
||||
=====
|
||||
|
||||
(pron. Jenny) by Mat Ryer ([@matryer](https://twitter.com/matryer)) and Tyler Bunnell ([@TylerJBunnell](https://twitter.com/TylerJBunnell)).
|
||||
|
||||
Until the Go core team include support for [generics in Go](http://golang.org/doc/faq#generics), `genny` is a code-generation generics solution. It allows you write normal buildable and testable Go code which, when processed by the `genny gen` tool, will replace the generics with specific types.
|
||||
|
||||
* Generic code is valid Go code
|
||||
* Generic code compiles and can be tested
|
||||
* Use `stdin` and `stdout` or specify in and out files
|
||||
* Supports Go 1.4's [go generate](http://tip.golang.org/doc/go1.4#gogenerate)
|
||||
* Multiple specific types will generate every permutation
|
||||
* Use `BUILTINS` and `NUMBERS` wildtype to generate specific code for all built-in (and number) Go types
|
||||
* Function names and comments also get updated
|
||||
|
||||
## Library
|
||||
|
||||
We have started building a [library of common things](https://github.com/cheekybits/gennylib), and you can use `genny get` to generate the specific versions you need.
|
||||
|
||||
For example: `genny get maps/concurrentmap.go "KeyType=BUILTINS ValueType=BUILTINS"` will print out generated code for all types for a concurrent map. Any file in the library may be generated locally in this way using all the same options given to `genny gen`.
|
||||
|
||||
## Usage
|
||||
|
||||
```
|
||||
genny [{flags}] gen "{types}"
|
||||
|
||||
gen - generates type specific code from generic code.
|
||||
get <package/file> - fetch a generic template from the online library and gen it.
|
||||
|
||||
{flags} - (optional) Command line flags (see below)
|
||||
{types} - (required) Specific types for each generic type in the source
|
||||
{types} format: {generic}={specific}[,another][ {generic2}={specific2}]
|
||||
|
||||
Examples:
|
||||
Generic=Specific
|
||||
Generic1=Specific1 Generic2=Specific2
|
||||
Generic1=Specific1,Specific2 Generic2=Specific3,Specific4
|
||||
|
||||
Flags:
|
||||
-in="": file to parse instead of stdin
|
||||
-out="": file to save output to instead of stdout
|
||||
-pkg="": package name for generated files
|
||||
```
|
||||
|
||||
* Comma separated type lists will generate code for each type
|
||||
|
||||
### Flags
|
||||
|
||||
* `-in` - specify the input file (rather than using stdin)
|
||||
* `-out` - specify the output file (rather than using stdout)
|
||||
|
||||
### go generate
|
||||
|
||||
To use Go 1.4's `go generate` capability, insert the following comment in your source code file:
|
||||
|
||||
```
|
||||
//go:generate genny -in=$GOFILE -out=gen-$GOFILE gen "KeyType=string,int ValueType=string,int"
|
||||
```
|
||||
|
||||
* Start the line with `//go:generate `
|
||||
* Use the `-in` and `-out` flags to specify the files to work on
|
||||
* Use the `genny` command as usual after the flags
|
||||
|
||||
Now, running `go generate` (in a shell) for the package will cause the generic versions of the files to be generated.
|
||||
|
||||
* The output file will be overwritten, so it's safe to call `go generate` many times
|
||||
* Use `$GOFILE` to refer to the current file
|
||||
* The `//go:generate` line will be removed from the output
|
||||
|
||||
To see a real example of how to use `genny` with `go generate`, look in the [example/go-generate directory](https://github.com/cheekybits/genny/tree/master/examples/go-generate).
|
||||
|
||||
## How it works
|
||||
|
||||
Define your generic types using the special `generic.Type` placeholder type:
|
||||
|
||||
```go
|
||||
type KeyType generic.Type
|
||||
type ValueType generic.Type
|
||||
```
|
||||
|
||||
* You can use as many as you like
|
||||
* Give them meaningful names
|
||||
|
||||
Then write the generic code referencing the types as your normally would:
|
||||
|
||||
```go
|
||||
func SetValueTypeForKeyType(key KeyType, value ValueType) { /* ... */ }
|
||||
```
|
||||
|
||||
* Generic type names will also be replaced in comments and function names (see Real example below)
|
||||
|
||||
Since `generic.Type` is a real Go type, your code will compile, and you can even write unit tests against your generic code.
|
||||
|
||||
#### Generating specific versions
|
||||
|
||||
Pass the file through the `genny gen` tool with the specific types as the argument:
|
||||
|
||||
```
|
||||
cat generic.go | genny gen "KeyType=string ValueType=interface{}"
|
||||
```
|
||||
|
||||
The output will be the complete Go source file with the generic types replaced with the types specified in the arguments.
|
||||
|
||||
## Real example
|
||||
|
||||
Given [this generic Go code](https://github.com/cheekybits/genny/tree/master/examples/queue) which compiles and is tested:
|
||||
|
||||
```go
|
||||
package queue
|
||||
|
||||
import "github.com/cheekybits/genny/generic"
|
||||
|
||||
// NOTE: this is how easy it is to define a generic type
|
||||
type Something generic.Type
|
||||
|
||||
// SomethingQueue is a queue of Somethings.
|
||||
type SomethingQueue struct {
|
||||
items []Something
|
||||
}
|
||||
|
||||
func NewSomethingQueue() *SomethingQueue {
|
||||
return &SomethingQueue{items: make([]Something, 0)}
|
||||
}
|
||||
func (q *SomethingQueue) Push(item Something) {
|
||||
q.items = append(q.items, item)
|
||||
}
|
||||
func (q *SomethingQueue) Pop() Something {
|
||||
item := q.items[0]
|
||||
q.items = q.items[1:]
|
||||
return item
|
||||
}
|
||||
```
|
||||
|
||||
When `genny gen` is invoked like this:
|
||||
|
||||
```
|
||||
cat source.go | genny gen "Something=string"
|
||||
```
|
||||
|
||||
It outputs:
|
||||
|
||||
```go
|
||||
// This file was automatically generated by genny.
|
||||
// Any changes will be lost if this file is regenerated.
|
||||
// see https://github.com/cheekybits/genny
|
||||
|
||||
package queue
|
||||
|
||||
// StringQueue is a queue of Strings.
|
||||
type StringQueue struct {
|
||||
items []string
|
||||
}
|
||||
|
||||
func NewStringQueue() *StringQueue {
|
||||
return &StringQueue{items: make([]string, 0)}
|
||||
}
|
||||
func (q *StringQueue) Push(item string) {
|
||||
q.items = append(q.items, item)
|
||||
}
|
||||
func (q *StringQueue) Pop() string {
|
||||
item := q.items[0]
|
||||
q.items = q.items[1:]
|
||||
return item
|
||||
}
|
||||
```
|
||||
|
||||
To get a _something_ for every built-in Go type plus one of your own types, you could run:
|
||||
|
||||
```
|
||||
cat source.go | genny gen "Something=BUILTINS,*MyType"
|
||||
```
|
||||
|
||||
#### More examples
|
||||
|
||||
Check out the [test code files](https://github.com/cheekybits/genny/tree/master/parse/test) for more real examples.
|
||||
|
||||
## Writing test code
|
||||
|
||||
Once you have defined a generic type with some code worth testing:
|
||||
|
||||
```go
|
||||
package slice
|
||||
|
||||
import (
|
||||
"log"
|
||||
"reflect"
|
||||
|
||||
"github.com/stretchr/gogen/generic"
|
||||
)
|
||||
|
||||
type MyType generic.Type
|
||||
|
||||
func EnsureMyTypeSlice(objectOrSlice interface{}) []MyType {
|
||||
log.Printf("%v", reflect.TypeOf(objectOrSlice))
|
||||
switch obj := objectOrSlice.(type) {
|
||||
case []MyType:
|
||||
log.Println(" returning it untouched")
|
||||
return obj
|
||||
case MyType:
|
||||
log.Println(" wrapping in slice")
|
||||
return []MyType{obj}
|
||||
default:
|
||||
panic("ensure slice needs MyType or []MyType")
|
||||
}
|
||||
}
|
||||
```
|
||||
|
||||
You can treat it like any normal Go type in your test code:
|
||||
|
||||
```go
|
||||
func TestEnsureMyTypeSlice(t *testing.T) {
|
||||
|
||||
myType := new(MyType)
|
||||
slice := EnsureMyTypeSlice(myType)
|
||||
if assert.NotNil(t, slice) {
|
||||
assert.Equal(t, slice[0], myType)
|
||||
}
|
||||
|
||||
slice = EnsureMyTypeSlice(slice)
|
||||
log.Printf("%#v", slice[0])
|
||||
if assert.NotNil(t, slice) {
|
||||
assert.Equal(t, slice[0], myType)
|
||||
}
|
||||
|
||||
}
|
||||
```
|
||||
|
||||
### Understanding what `generic.Type` is
|
||||
|
||||
Because `generic.Type` is an empty interface type (literally `interface{}`) every other type will be considered to be a `generic.Type` if you are switching on the type of an object. Of course, once the specific versions are generated, this issue goes away but it's worth knowing when you are writing your tests against generic code.
|
||||
|
||||
### Contributions
|
||||
|
||||
* See the [API documentation for the parse package](http://godoc.org/github.com/cheekybits/genny/parse)
|
||||
* Please do TDD
|
||||
* All input welcome
|
|
@ -1,2 +0,0 @@
|
|||
// Package main is the command line tool for Genny.
|
||||
package main
|
|
@ -1,2 +0,0 @@
|
|||
// Package generic contains the generic marker types.
|
||||
package generic
|
|
@ -1,13 +0,0 @@
|
|||
package generic
|
||||
|
||||
// Type is the placeholder type that indicates a generic value.
|
||||
// When genny is executed, variables of this type will be replaced with
|
||||
// references to the specific types.
|
||||
// var GenericType generic.Type
|
||||
type Type interface{}
|
||||
|
||||
// Number is the placehoder type that indiccates a generic numerical value.
|
||||
// When genny is executed, variables of this type will be replaced with
|
||||
// references to the specific types.
|
||||
// var GenericType generic.Number
|
||||
type Number float64
|
|
@ -1,154 +0,0 @@
|
|||
package main
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"flag"
|
||||
"fmt"
|
||||
"io"
|
||||
"io/ioutil"
|
||||
"net/http"
|
||||
"os"
|
||||
"strings"
|
||||
|
||||
"github.com/cheekybits/genny/out"
|
||||
"github.com/cheekybits/genny/parse"
|
||||
)
|
||||
|
||||
/*
|
||||
|
||||
source | genny gen [-in=""] [-out=""] [-pkg=""] "KeyType=string,int ValueType=string,int"
|
||||
|
||||
*/
|
||||
|
||||
const (
|
||||
_ = iota
|
||||
exitcodeInvalidArgs
|
||||
exitcodeInvalidTypeSet
|
||||
exitcodeStdinFailed
|
||||
exitcodeGenFailed
|
||||
exitcodeGetFailed
|
||||
exitcodeSourceFileInvalid
|
||||
exitcodeDestFileFailed
|
||||
)
|
||||
|
||||
func main() {
|
||||
var (
|
||||
in = flag.String("in", "", "file to parse instead of stdin")
|
||||
out = flag.String("out", "", "file to save output to instead of stdout")
|
||||
pkgName = flag.String("pkg", "", "package name for generated files")
|
||||
prefix = "https://github.com/metabition/gennylib/raw/master/"
|
||||
)
|
||||
flag.Parse()
|
||||
args := flag.Args()
|
||||
|
||||
if len(args) < 2 {
|
||||
usage()
|
||||
os.Exit(exitcodeInvalidArgs)
|
||||
}
|
||||
|
||||
if strings.ToLower(args[0]) != "gen" && strings.ToLower(args[0]) != "get" {
|
||||
usage()
|
||||
os.Exit(exitcodeInvalidArgs)
|
||||
}
|
||||
|
||||
// parse the typesets
|
||||
var setsArg = args[1]
|
||||
if strings.ToLower(args[0]) == "get" {
|
||||
setsArg = args[2]
|
||||
}
|
||||
typeSets, err := parse.TypeSet(setsArg)
|
||||
if err != nil {
|
||||
fatal(exitcodeInvalidTypeSet, err)
|
||||
}
|
||||
|
||||
outWriter := newWriter(*out)
|
||||
|
||||
if strings.ToLower(args[0]) == "get" {
|
||||
if len(args) != 3 {
|
||||
fmt.Println("not enough arguments to get")
|
||||
usage()
|
||||
os.Exit(exitcodeInvalidArgs)
|
||||
}
|
||||
r, err := http.Get(prefix + args[1])
|
||||
if err != nil {
|
||||
fatal(exitcodeGetFailed, err)
|
||||
}
|
||||
b, err := ioutil.ReadAll(r.Body)
|
||||
if err != nil {
|
||||
fatal(exitcodeGetFailed, err)
|
||||
}
|
||||
r.Body.Close()
|
||||
br := bytes.NewReader(b)
|
||||
err = gen(*in, *pkgName, br, typeSets, outWriter)
|
||||
} else if len(*in) > 0 {
|
||||
var file *os.File
|
||||
file, err = os.Open(*in)
|
||||
if err != nil {
|
||||
fatal(exitcodeSourceFileInvalid, err)
|
||||
}
|
||||
defer file.Close()
|
||||
err = gen(*in, *pkgName, file, typeSets, outWriter)
|
||||
} else {
|
||||
var source []byte
|
||||
source, err = ioutil.ReadAll(os.Stdin)
|
||||
if err != nil {
|
||||
fatal(exitcodeStdinFailed, err)
|
||||
}
|
||||
reader := bytes.NewReader(source)
|
||||
err = gen("stdin", *pkgName, reader, typeSets, outWriter)
|
||||
}
|
||||
|
||||
// do the work
|
||||
if err != nil {
|
||||
fatal(exitcodeGenFailed, err)
|
||||
}
|
||||
|
||||
}
|
||||
|
||||
func usage() {
|
||||
fmt.Fprintln(os.Stderr, `usage: genny [{flags}] gen "{types}"
|
||||
|
||||
gen - generates type specific code from generic code.
|
||||
get <package/file> - fetch a generic template from the online library and gen it.
|
||||
|
||||
{flags} - (optional) Command line flags (see below)
|
||||
{types} - (required) Specific types for each generic type in the source
|
||||
{types} format: {generic}={specific}[,another][ {generic2}={specific2}]
|
||||
|
||||
Examples:
|
||||
Generic=Specific
|
||||
Generic1=Specific1 Generic2=Specific2
|
||||
Generic1=Specific1,Specific2 Generic2=Specific3,Specific4
|
||||
|
||||
Flags:`)
|
||||
flag.PrintDefaults()
|
||||
}
|
||||
|
||||
func newWriter(fileName string) io.Writer {
|
||||
if fileName == "" {
|
||||
return os.Stdout
|
||||
}
|
||||
lf := &out.LazyFile{FileName: fileName}
|
||||
defer lf.Close()
|
||||
return lf
|
||||
}
|
||||
|
||||
func fatal(code int, a ...interface{}) {
|
||||
fmt.Println(a...)
|
||||
os.Exit(code)
|
||||
}
|
||||
|
||||
// gen performs the generic generation.
|
||||
func gen(filename, pkgName string, in io.ReadSeeker, typesets []map[string]string, out io.Writer) error {
|
||||
|
||||
var output []byte
|
||||
var err error
|
||||
|
||||
output, err = parse.Generics(filename, pkgName, in, typesets)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
out.Write(output)
|
||||
return nil
|
||||
}
|
|
@ -1,38 +0,0 @@
|
|||
package out
|
||||
|
||||
import (
|
||||
"os"
|
||||
"path"
|
||||
)
|
||||
|
||||
// LazyFile is an io.WriteCloser which defers creation of the file it is supposed to write in
|
||||
// till the first call to its write function in order to prevent creation of file, if no write
|
||||
// is supposed to happen.
|
||||
type LazyFile struct {
|
||||
// FileName is path to the file to which genny will write.
|
||||
FileName string
|
||||
file *os.File
|
||||
}
|
||||
|
||||
// Close closes the file if it is created. Returns nil if no file is created.
|
||||
func (lw *LazyFile) Close() error {
|
||||
if lw.file != nil {
|
||||
return lw.file.Close()
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
// Write writes to the specified file and creates the file first time it is called.
|
||||
func (lw *LazyFile) Write(p []byte) (int, error) {
|
||||
if lw.file == nil {
|
||||
err := os.MkdirAll(path.Dir(lw.FileName), 0755)
|
||||
if err != nil {
|
||||
return 0, err
|
||||
}
|
||||
lw.file, err = os.Create(lw.FileName)
|
||||
if err != nil {
|
||||
return 0, err
|
||||
}
|
||||
}
|
||||
return lw.file.Write(p)
|
||||
}
|
|
@ -1,41 +0,0 @@
|
|||
package parse
|
||||
|
||||
// Builtins contains a slice of all built-in Go types.
|
||||
var Builtins = []string{
|
||||
"bool",
|
||||
"byte",
|
||||
"complex128",
|
||||
"complex64",
|
||||
"error",
|
||||
"float32",
|
||||
"float64",
|
||||
"int",
|
||||
"int16",
|
||||
"int32",
|
||||
"int64",
|
||||
"int8",
|
||||
"rune",
|
||||
"string",
|
||||
"uint",
|
||||
"uint16",
|
||||
"uint32",
|
||||
"uint64",
|
||||
"uint8",
|
||||
"uintptr",
|
||||
}
|
||||
|
||||
// Numbers contains a slice of all built-in number types.
|
||||
var Numbers = []string{
|
||||
"float32",
|
||||
"float64",
|
||||
"int",
|
||||
"int16",
|
||||
"int32",
|
||||
"int64",
|
||||
"int8",
|
||||
"uint",
|
||||
"uint16",
|
||||
"uint32",
|
||||
"uint64",
|
||||
"uint8",
|
||||
}
|
|
@ -1,14 +0,0 @@
|
|||
// Package parse contains the generic code generation capabilities
|
||||
// that power genny.
|
||||
//
|
||||
// genny gen "{types}"
|
||||
//
|
||||
// gen - generates type specific code (to stdout) from generic code (via stdin)
|
||||
//
|
||||
// {types} - (required) Specific types for each generic type in the source
|
||||
// {types} format: {generic}={specific}[,another][ {generic2}={specific2}]
|
||||
// Examples:
|
||||
// Generic=Specific
|
||||
// Generic1=Specific1 Generic2=Specific2
|
||||
// Generic1=Specific1,Specific2 Generic2=Specific3,Specific4
|
||||
package parse
|
|
@ -1,47 +0,0 @@
|
|||
package parse
|
||||
|
||||
import (
|
||||
"errors"
|
||||
)
|
||||
|
||||
// errMissingSpecificType represents an error when a generic type is not
|
||||
// satisfied by a specific type.
|
||||
type errMissingSpecificType struct {
|
||||
GenericType string
|
||||
}
|
||||
|
||||
// Error gets a human readable string describing this error.
|
||||
func (e errMissingSpecificType) Error() string {
|
||||
return "Missing specific type for '" + e.GenericType + "' generic type"
|
||||
}
|
||||
|
||||
// errImports represents an error from goimports.
|
||||
type errImports struct {
|
||||
Err error
|
||||
}
|
||||
|
||||
// Error gets a human readable string describing this error.
|
||||
func (e errImports) Error() string {
|
||||
return "Failed to goimports the generated code: " + e.Err.Error()
|
||||
}
|
||||
|
||||
// errSource represents an error with the source file.
|
||||
type errSource struct {
|
||||
Err error
|
||||
}
|
||||
|
||||
// Error gets a human readable string describing this error.
|
||||
func (e errSource) Error() string {
|
||||
return "Failed to parse source file: " + e.Err.Error()
|
||||
}
|
||||
|
||||
type errBadTypeArgs struct {
|
||||
Message string
|
||||
Arg string
|
||||
}
|
||||
|
||||
func (e errBadTypeArgs) Error() string {
|
||||
return "\"" + e.Arg + "\" is bad: " + e.Message
|
||||
}
|
||||
|
||||
var errMissingTypeInformation = errors.New("No type arguments were specified and no \"// +gogen\" tag was found in the source.")
|
|
@ -1,298 +0,0 @@
|
|||
package parse
|
||||
|
||||
import (
|
||||
"bufio"
|
||||
"bytes"
|
||||
"fmt"
|
||||
"go/ast"
|
||||
"go/parser"
|
||||
"go/scanner"
|
||||
"go/token"
|
||||
"io"
|
||||
"os"
|
||||
"strings"
|
||||
"unicode"
|
||||
|
||||
"golang.org/x/tools/imports"
|
||||
)
|
||||
|
||||
var header = []byte(`
|
||||
|
||||
// This file was automatically generated by genny.
|
||||
// Any changes will be lost if this file is regenerated.
|
||||
// see https://github.com/cheekybits/genny
|
||||
|
||||
`)
|
||||
|
||||
var (
|
||||
packageKeyword = []byte("package")
|
||||
importKeyword = []byte("import")
|
||||
openBrace = []byte("(")
|
||||
closeBrace = []byte(")")
|
||||
genericPackage = "generic"
|
||||
genericType = "generic.Type"
|
||||
genericNumber = "generic.Number"
|
||||
linefeed = "\r\n"
|
||||
)
|
||||
var unwantedLinePrefixes = [][]byte{
|
||||
[]byte("//go:generate genny "),
|
||||
}
|
||||
|
||||
func subIntoLiteral(lit, typeTemplate, specificType string) string {
|
||||
if lit == typeTemplate {
|
||||
return specificType
|
||||
}
|
||||
if !strings.Contains(lit, typeTemplate) {
|
||||
return lit
|
||||
}
|
||||
specificLg := wordify(specificType, true)
|
||||
specificSm := wordify(specificType, false)
|
||||
result := strings.Replace(lit, typeTemplate, specificLg, -1)
|
||||
if strings.HasPrefix(result, specificLg) && !isExported(lit) {
|
||||
return strings.Replace(result, specificLg, specificSm, 1)
|
||||
}
|
||||
return result
|
||||
}
|
||||
|
||||
func subTypeIntoComment(line, typeTemplate, specificType string) string {
|
||||
var subbed string
|
||||
for _, w := range strings.Fields(line) {
|
||||
subbed = subbed + subIntoLiteral(w, typeTemplate, specificType) + " "
|
||||
}
|
||||
return subbed
|
||||
}
|
||||
|
||||
// Does the heavy lifting of taking a line of our code and
|
||||
// sbustituting a type into there for our generic type
|
||||
func subTypeIntoLine(line, typeTemplate, specificType string) string {
|
||||
src := []byte(line)
|
||||
var s scanner.Scanner
|
||||
fset := token.NewFileSet()
|
||||
file := fset.AddFile("", fset.Base(), len(src))
|
||||
s.Init(file, src, nil, scanner.ScanComments)
|
||||
output := ""
|
||||
for {
|
||||
_, tok, lit := s.Scan()
|
||||
if tok == token.EOF {
|
||||
break
|
||||
} else if tok == token.COMMENT {
|
||||
subbed := subTypeIntoComment(lit, typeTemplate, specificType)
|
||||
output = output + subbed + " "
|
||||
} else if tok.IsLiteral() {
|
||||
subbed := subIntoLiteral(lit, typeTemplate, specificType)
|
||||
output = output + subbed + " "
|
||||
} else {
|
||||
output = output + tok.String() + " "
|
||||
}
|
||||
}
|
||||
return output
|
||||
}
|
||||
|
||||
// typeSet looks like "KeyType: int, ValueType: string"
|
||||
func generateSpecific(filename string, in io.ReadSeeker, typeSet map[string]string) ([]byte, error) {
|
||||
|
||||
// ensure we are at the beginning of the file
|
||||
in.Seek(0, os.SEEK_SET)
|
||||
|
||||
// parse the source file
|
||||
fs := token.NewFileSet()
|
||||
file, err := parser.ParseFile(fs, filename, in, 0)
|
||||
if err != nil {
|
||||
return nil, &errSource{Err: err}
|
||||
}
|
||||
|
||||
// make sure every generic.Type is represented in the types
|
||||
// argument.
|
||||
for _, decl := range file.Decls {
|
||||
switch it := decl.(type) {
|
||||
case *ast.GenDecl:
|
||||
for _, spec := range it.Specs {
|
||||
ts, ok := spec.(*ast.TypeSpec)
|
||||
if !ok {
|
||||
continue
|
||||
}
|
||||
switch tt := ts.Type.(type) {
|
||||
case *ast.SelectorExpr:
|
||||
if name, ok := tt.X.(*ast.Ident); ok {
|
||||
if name.Name == genericPackage {
|
||||
if _, ok := typeSet[ts.Name.Name]; !ok {
|
||||
return nil, &errMissingSpecificType{GenericType: ts.Name.Name}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
in.Seek(0, os.SEEK_SET)
|
||||
|
||||
var buf bytes.Buffer
|
||||
|
||||
comment := ""
|
||||
scanner := bufio.NewScanner(in)
|
||||
for scanner.Scan() {
|
||||
|
||||
line := scanner.Text()
|
||||
|
||||
// does this line contain generic.Type?
|
||||
if strings.Contains(line, genericType) || strings.Contains(line, genericNumber) {
|
||||
comment = ""
|
||||
continue
|
||||
}
|
||||
|
||||
for t, specificType := range typeSet {
|
||||
if strings.Contains(line, t) {
|
||||
newLine := subTypeIntoLine(line, t, specificType)
|
||||
line = newLine
|
||||
}
|
||||
}
|
||||
|
||||
if comment != "" {
|
||||
buf.WriteString(makeLine(comment))
|
||||
comment = ""
|
||||
}
|
||||
|
||||
// is this line a comment?
|
||||
// TODO: should we handle /* */ comments?
|
||||
if strings.HasPrefix(line, "//") {
|
||||
// record this line to print later
|
||||
comment = line
|
||||
continue
|
||||
}
|
||||
|
||||
// write the line
|
||||
buf.WriteString(makeLine(line))
|
||||
}
|
||||
|
||||
// write it out
|
||||
return buf.Bytes(), nil
|
||||
}
|
||||
|
||||
// Generics parses the source file and generates the bytes replacing the
|
||||
// generic types for the keys map with the specific types (its value).
|
||||
func Generics(filename, pkgName string, in io.ReadSeeker, typeSets []map[string]string) ([]byte, error) {
|
||||
|
||||
totalOutput := header
|
||||
|
||||
for _, typeSet := range typeSets {
|
||||
|
||||
// generate the specifics
|
||||
parsed, err := generateSpecific(filename, in, typeSet)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
totalOutput = append(totalOutput, parsed...)
|
||||
|
||||
}
|
||||
|
||||
// clean up the code line by line
|
||||
packageFound := false
|
||||
insideImportBlock := false
|
||||
var cleanOutputLines []string
|
||||
scanner := bufio.NewScanner(bytes.NewReader(totalOutput))
|
||||
for scanner.Scan() {
|
||||
|
||||
// end of imports block?
|
||||
if insideImportBlock {
|
||||
if bytes.HasSuffix(scanner.Bytes(), closeBrace) {
|
||||
insideImportBlock = false
|
||||
}
|
||||
continue
|
||||
}
|
||||
|
||||
if bytes.HasPrefix(scanner.Bytes(), packageKeyword) {
|
||||
if packageFound {
|
||||
continue
|
||||
} else {
|
||||
packageFound = true
|
||||
}
|
||||
} else if bytes.HasPrefix(scanner.Bytes(), importKeyword) {
|
||||
if bytes.HasSuffix(scanner.Bytes(), openBrace) {
|
||||
insideImportBlock = true
|
||||
}
|
||||
continue
|
||||
}
|
||||
|
||||
// check all unwantedLinePrefixes - and skip them
|
||||
skipline := false
|
||||
for _, prefix := range unwantedLinePrefixes {
|
||||
if bytes.HasPrefix(scanner.Bytes(), prefix) {
|
||||
skipline = true
|
||||
continue
|
||||
}
|
||||
}
|
||||
|
||||
if skipline {
|
||||
continue
|
||||
}
|
||||
|
||||
cleanOutputLines = append(cleanOutputLines, makeLine(scanner.Text()))
|
||||
}
|
||||
|
||||
cleanOutput := strings.Join(cleanOutputLines, "")
|
||||
|
||||
output := []byte(cleanOutput)
|
||||
var err error
|
||||
|
||||
// change package name
|
||||
if pkgName != "" {
|
||||
output = changePackage(bytes.NewReader([]byte(output)), pkgName)
|
||||
}
|
||||
// fix the imports
|
||||
output, err = imports.Process(filename, output, nil)
|
||||
if err != nil {
|
||||
return nil, &errImports{Err: err}
|
||||
}
|
||||
|
||||
return output, nil
|
||||
}
|
||||
|
||||
func makeLine(s string) string {
|
||||
return fmt.Sprintln(strings.TrimRight(s, linefeed))
|
||||
}
|
||||
|
||||
// isAlphaNumeric gets whether the rune is alphanumeric or _.
|
||||
func isAlphaNumeric(r rune) bool {
|
||||
return r == '_' || unicode.IsLetter(r) || unicode.IsDigit(r)
|
||||
}
|
||||
|
||||
// wordify turns a type into a nice word for function and type
|
||||
// names etc.
|
||||
func wordify(s string, exported bool) string {
|
||||
s = strings.TrimRight(s, "{}")
|
||||
s = strings.TrimLeft(s, "*&")
|
||||
s = strings.Replace(s, ".", "", -1)
|
||||
if !exported {
|
||||
return s
|
||||
}
|
||||
return strings.ToUpper(string(s[0])) + s[1:]
|
||||
}
|
||||
|
||||
func changePackage(r io.Reader, pkgName string) []byte {
|
||||
var out bytes.Buffer
|
||||
sc := bufio.NewScanner(r)
|
||||
done := false
|
||||
|
||||
for sc.Scan() {
|
||||
s := sc.Text()
|
||||
|
||||
if !done && strings.HasPrefix(s, "package") {
|
||||
parts := strings.Split(s, " ")
|
||||
parts[1] = pkgName
|
||||
s = strings.Join(parts, " ")
|
||||
done = true
|
||||
}
|
||||
|
||||
fmt.Fprintln(&out, s)
|
||||
}
|
||||
return out.Bytes()
|
||||
}
|
||||
|
||||
func isExported(lit string) bool {
|
||||
if len(lit) == 0 {
|
||||
return false
|
||||
}
|
||||
return unicode.IsUpper(rune(lit[0]))
|
||||
}
|
|
@ -1,89 +0,0 @@
|
|||
package parse
|
||||
|
||||
import "strings"
|
||||
|
||||
const (
|
||||
typeSep = " "
|
||||
keyValueSep = "="
|
||||
valuesSep = ","
|
||||
builtins = "BUILTINS"
|
||||
numbers = "NUMBERS"
|
||||
)
|
||||
|
||||
// TypeSet turns a type string into a []map[string]string
|
||||
// that can be given to parse.Generics for it to do its magic.
|
||||
//
|
||||
// Acceptable args are:
|
||||
//
|
||||
// Person=man
|
||||
// Person=man Animal=dog
|
||||
// Person=man Animal=dog Animal2=cat
|
||||
// Person=man,woman Animal=dog,cat
|
||||
// Person=man,woman,child Animal=dog,cat Place=london,paris
|
||||
func TypeSet(arg string) ([]map[string]string, error) {
|
||||
|
||||
types := make(map[string][]string)
|
||||
var keys []string
|
||||
for _, pair := range strings.Split(arg, typeSep) {
|
||||
segs := strings.Split(pair, keyValueSep)
|
||||
if len(segs) != 2 {
|
||||
return nil, &errBadTypeArgs{Arg: arg, Message: "Generic=Specific expected"}
|
||||
}
|
||||
key := segs[0]
|
||||
keys = append(keys, key)
|
||||
types[key] = make([]string, 0)
|
||||
for _, t := range strings.Split(segs[1], valuesSep) {
|
||||
if t == builtins {
|
||||
types[key] = append(types[key], Builtins...)
|
||||
} else if t == numbers {
|
||||
types[key] = append(types[key], Numbers...)
|
||||
} else {
|
||||
types[key] = append(types[key], t)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
cursors := make(map[string]int)
|
||||
for _, key := range keys {
|
||||
cursors[key] = 0
|
||||
}
|
||||
|
||||
outChan := make(chan map[string]string)
|
||||
go func() {
|
||||
buildTypeSet(keys, 0, cursors, types, outChan)
|
||||
close(outChan)
|
||||
}()
|
||||
|
||||
var typeSets []map[string]string
|
||||
for typeSet := range outChan {
|
||||
typeSets = append(typeSets, typeSet)
|
||||
}
|
||||
|
||||
return typeSets, nil
|
||||
|
||||
}
|
||||
|
||||
func buildTypeSet(keys []string, keyI int, cursors map[string]int, types map[string][]string, out chan<- map[string]string) {
|
||||
key := keys[keyI]
|
||||
for cursors[key] < len(types[key]) {
|
||||
if keyI < len(keys)-1 {
|
||||
buildTypeSet(keys, keyI+1, copycursors(cursors), types, out)
|
||||
} else {
|
||||
// build the typeset for this combination
|
||||
ts := make(map[string]string)
|
||||
for k, vals := range types {
|
||||
ts[k] = vals[cursors[k]]
|
||||
}
|
||||
out <- ts
|
||||
}
|
||||
cursors[key]++
|
||||
}
|
||||
}
|
||||
|
||||
func copycursors(source map[string]int) map[string]int {
|
||||
copy := make(map[string]int)
|
||||
for k, v := range source {
|
||||
copy[k] = v
|
||||
}
|
||||
return copy
|
||||
}
|
|
@ -176,6 +176,11 @@ func (c *Conn) Close() {
|
|||
c.sigconn.Close()
|
||||
}
|
||||
|
||||
// Connected returns whether conn is connected
|
||||
func (c *Conn) Connected() bool {
|
||||
return c.sysconn.Connected() && c.sigconn.Connected()
|
||||
}
|
||||
|
||||
// NewConnection establishes a connection to a bus using a caller-supplied function.
|
||||
// This allows connecting to remote buses through a user-supplied mechanism.
|
||||
// The supplied function may be called multiple times, and should return independent connections.
|
||||
|
|
|
@ -417,6 +417,29 @@ func (c *Conn) listUnitsInternal(f storeFunc) ([]UnitStatus, error) {
|
|||
return status, nil
|
||||
}
|
||||
|
||||
// GetUnitByPID returns the unit object path of the unit a process ID
|
||||
// belongs to. It takes a UNIX PID and returns the object path. The PID must
|
||||
// refer to an existing system process
|
||||
func (c *Conn) GetUnitByPID(ctx context.Context, pid uint32) (dbus.ObjectPath, error) {
|
||||
var result dbus.ObjectPath
|
||||
|
||||
err := c.sysobj.CallWithContext(ctx, "org.freedesktop.systemd1.Manager.GetUnitByPID", 0, pid).Store(&result)
|
||||
|
||||
return result, err
|
||||
}
|
||||
|
||||
// GetUnitNameByPID returns the name of the unit a process ID belongs to. It
|
||||
// takes a UNIX PID and returns the object path. The PID must refer to an
|
||||
// existing system process
|
||||
func (c *Conn) GetUnitNameByPID(ctx context.Context, pid uint32) (string, error) {
|
||||
path, err := c.GetUnitByPID(ctx, pid)
|
||||
if err != nil {
|
||||
return "", err
|
||||
}
|
||||
|
||||
return unitName(path), nil
|
||||
}
|
||||
|
||||
// Deprecated: use ListUnitsContext instead.
|
||||
func (c *Conn) ListUnits() ([]UnitStatus, error) {
|
||||
return c.ListUnitsContext(context.Background())
|
||||
|
@ -828,3 +851,14 @@ func (c *Conn) listJobsInternal(ctx context.Context) ([]JobStatus, error) {
|
|||
|
||||
return status, nil
|
||||
}
|
||||
|
||||
// Freeze the cgroup associated with the unit.
|
||||
// Note that FreezeUnit and ThawUnit are only supported on systems running with cgroup v2.
|
||||
func (c *Conn) FreezeUnit(ctx context.Context, unit string) error {
|
||||
return c.sysobj.CallWithContext(ctx, "org.freedesktop.systemd1.Manager.FreezeUnit", 0, unit).Store()
|
||||
}
|
||||
|
||||
// Unfreeze the cgroup associated with the unit.
|
||||
func (c *Conn) ThawUnit(ctx context.Context, unit string) error {
|
||||
return c.sysobj.CallWithContext(ctx, "org.freedesktop.systemd1.Manager.ThawUnit", 0, unit).Store()
|
||||
}
|
||||
|
|
File diff suppressed because one or more lines are too long
|
@ -1,4 +1,4 @@
|
|||
// Copyright (c) 2015-2021 The Decred developers
|
||||
// Copyright (c) 2015-2022 The Decred developers
|
||||
// Copyright 2013-2014 The btcsuite developers
|
||||
// Use of this source code is governed by an ISC
|
||||
// license that can be found in the LICENSE file.
|
||||
|
@ -7,7 +7,7 @@ package secp256k1
|
|||
|
||||
import (
|
||||
"encoding/hex"
|
||||
"math/big"
|
||||
"math/bits"
|
||||
)
|
||||
|
||||
// References:
|
||||
|
@ -18,6 +18,9 @@ import (
|
|||
//
|
||||
// [BRID]: On Binary Representations of Integers with Digits -1, 0, 1
|
||||
// (Prodinger, Helmut)
|
||||
//
|
||||
// [STWS]: Secure-TWS: Authenticating Node to Multi-user Communication in
|
||||
// Shared Sensor Networks (Oliveira, Leonardo B. et al)
|
||||
|
||||
// All group operations are performed using Jacobian coordinates. For a given
|
||||
// (x, y) position on the curve, the Jacobian coordinates are (x1, y1, z1)
|
||||
|
@ -39,29 +42,59 @@ func hexToFieldVal(s string) *FieldVal {
|
|||
return &f
|
||||
}
|
||||
|
||||
var (
|
||||
// Next 6 constants are from Hal Finney's bitcointalk.org post:
|
||||
// https://bitcointalk.org/index.php?topic=3238.msg45565#msg45565
|
||||
// May he rest in peace.
|
||||
//
|
||||
// They have also been independently derived from the code in the
|
||||
// EndomorphismVectors function in genstatics.go.
|
||||
endomorphismLambda = fromHex("5363ad4cc05c30e0a5261c028812645a122e22ea20816678df02967c1b23bd72")
|
||||
endomorphismBeta = hexToFieldVal("7ae96a2b657c07106e64479eac3434e99cf0497512f58995c1396c28719501ee")
|
||||
endomorphismA1 = fromHex("3086d221a7d46bcde86c90e49284eb15")
|
||||
endomorphismB1 = fromHex("-e4437ed6010e88286f547fa90abfe4c3")
|
||||
endomorphismA2 = fromHex("114ca50f7a8e2f3f657c1108d9d44cfd8")
|
||||
endomorphismB2 = fromHex("3086d221a7d46bcde86c90e49284eb15")
|
||||
// hexToModNScalar converts the passed hex string into a ModNScalar and will
|
||||
// panic if there is an error. This is only provided for the hard-coded
|
||||
// constants so errors in the source code can be detected. It will only (and
|
||||
// must only) be called with hard-coded values.
|
||||
func hexToModNScalar(s string) *ModNScalar {
|
||||
var isNegative bool
|
||||
if len(s) > 0 && s[0] == '-' {
|
||||
isNegative = true
|
||||
s = s[1:]
|
||||
}
|
||||
if len(s)%2 != 0 {
|
||||
s = "0" + s
|
||||
}
|
||||
b, err := hex.DecodeString(s)
|
||||
if err != nil {
|
||||
panic("invalid hex in source file: " + s)
|
||||
}
|
||||
var scalar ModNScalar
|
||||
if overflow := scalar.SetByteSlice(b); overflow {
|
||||
panic("hex in source file overflows mod N scalar: " + s)
|
||||
}
|
||||
if isNegative {
|
||||
scalar.Negate()
|
||||
}
|
||||
return &scalar
|
||||
}
|
||||
|
||||
// Alternatively, the following parameters are valid as well, however, they
|
||||
// seem to be about 8% slower in practice.
|
||||
var (
|
||||
// The following constants are used to accelerate scalar point
|
||||
// multiplication through the use of the endomorphism:
|
||||
//
|
||||
// endomorphismLambda = fromHex("AC9C52B33FA3CF1F5AD9E3FD77ED9BA4A880B9FC8EC739C2E0CFC810B51283CE")
|
||||
// endomorphismBeta = hexToFieldVal("851695D49A83F8EF919BB86153CBCB16630FB68AED0A766A3EC693D68E6AFA40")
|
||||
// endomorphismA1 = fromHex("E4437ED6010E88286F547FA90ABFE4C3")
|
||||
// endomorphismB1 = fromHex("-3086D221A7D46BCDE86C90E49284EB15")
|
||||
// endomorphismA2 = fromHex("3086D221A7D46BCDE86C90E49284EB15")
|
||||
// endomorphismB2 = fromHex("114CA50F7A8E2F3F657C1108D9D44CFD8")
|
||||
// φ(Q) ⟼ λ*Q = (β*Q.x mod p, Q.y)
|
||||
//
|
||||
// See the code in the deriveEndomorphismParams function in genprecomps.go
|
||||
// for details on their derivation.
|
||||
//
|
||||
// Additionally, see the scalar multiplication function in this file for
|
||||
// details on how they are used.
|
||||
endoNegLambda = hexToModNScalar("-5363ad4cc05c30e0a5261c028812645a122e22ea20816678df02967c1b23bd72")
|
||||
endoBeta = hexToFieldVal("7ae96a2b657c07106e64479eac3434e99cf0497512f58995c1396c28719501ee")
|
||||
endoNegB1 = hexToModNScalar("e4437ed6010e88286f547fa90abfe4c3")
|
||||
endoNegB2 = hexToModNScalar("-3086d221a7d46bcde86c90e49284eb15")
|
||||
endoZ1 = hexToModNScalar("3086d221a7d46bcde86c90e49284eb153daa8a1471e8ca7f")
|
||||
endoZ2 = hexToModNScalar("e4437ed6010e88286f547fa90abfe4c4221208ac9df506c6")
|
||||
|
||||
// Alternatively, the following parameters are valid as well, however,
|
||||
// benchmarks show them to be about 2% slower in practice.
|
||||
// endoNegLambda = hexToModNScalar("-ac9c52b33fa3cf1f5ad9e3fd77ed9ba4a880b9fc8ec739c2e0cfc810b51283ce")
|
||||
// endoBeta = hexToFieldVal("851695d49a83f8ef919bb86153cbcb16630fb68aed0a766a3ec693d68e6afa40")
|
||||
// endoNegB1 = hexToModNScalar("3086d221a7d46bcde86c90e49284eb15")
|
||||
// endoNegB2 = hexToModNScalar("-114ca50f7a8e2f3f657c1108d9d44cfd8")
|
||||
// endoZ1 = hexToModNScalar("114ca50f7a8e2f3f657c1108d9d44cfd95fbc92c10fddd145")
|
||||
// endoZ2 = hexToModNScalar("3086d221a7d46bcde86c90e49284eb153daa8a1471e8ca7f")
|
||||
)
|
||||
|
||||
// JacobianPoint is an element of the group formed by the secp256k1 curve in
|
||||
|
@ -178,7 +211,7 @@ func addZ1AndZ2EqualsOne(p1, p2, result *JacobianPoint) {
|
|||
y3.Set(&v).Add(&negX3).Mul(&r).Add(&j) // Y3 = r*(V-X3)-2*Y1*J (mag: 4)
|
||||
z3.Set(&h).MulInt(2) // Z3 = 2*H (mag: 6)
|
||||
|
||||
// Normalize the resulting field values to a magnitude of 1 as needed.
|
||||
// Normalize the resulting field values as needed.
|
||||
x3.Normalize()
|
||||
y3.Normalize()
|
||||
z3.Normalize()
|
||||
|
@ -196,7 +229,7 @@ func addZ1EqualsZ2(p1, p2, result *JacobianPoint) {
|
|||
// the equation into intermediate elements which are used to minimize
|
||||
// the number of field multiplications using a slightly modified version
|
||||
// of the method shown at:
|
||||
// https://hyperelliptic.org/EFD/g1p/auto-shortw-jacobian-0.html#addition-mmadd-2007-bl
|
||||
// https://hyperelliptic.org/EFD/g1p/auto-shortw-jacobian-0.html#addition-zadd-2007-m
|
||||
//
|
||||
// In particular it performs the calculations using the following:
|
||||
// A = X2-X1, B = A^2, C=Y2-Y1, D = C^2, E = X1*B, F = X2*B
|
||||
|
@ -242,13 +275,13 @@ func addZ1EqualsZ2(p1, p2, result *JacobianPoint) {
|
|||
e.Mul2(x1, &b) // E = X1*B (mag: 1)
|
||||
negE.Set(&e).Negate(1) // negE = -E (mag: 2)
|
||||
f.Mul2(x2, &b) // F = X2*B (mag: 1)
|
||||
x3.Add2(&e, &f).Negate(3).Add(&d) // X3 = D-E-F (mag: 5)
|
||||
negX3.Set(x3).Negate(5).Normalize() // negX3 = -X3 (mag: 1)
|
||||
y3.Set(y1).Mul(f.Add(&negE)).Negate(3) // Y3 = -(Y1*(F-E)) (mag: 4)
|
||||
y3.Add(e.Add(&negX3).Mul(&c)) // Y3 = C*(E-X3)+Y3 (mag: 5)
|
||||
x3.Add2(&e, &f).Negate(2).Add(&d) // X3 = D-E-F (mag: 4)
|
||||
negX3.Set(x3).Negate(4) // negX3 = -X3 (mag: 5)
|
||||
y3.Set(y1).Mul(f.Add(&negE)).Negate(1) // Y3 = -(Y1*(F-E)) (mag: 2)
|
||||
y3.Add(e.Add(&negX3).Mul(&c)) // Y3 = C*(E-X3)+Y3 (mag: 3)
|
||||
z3.Mul2(z1, &a) // Z3 = Z1*A (mag: 1)
|
||||
|
||||
// Normalize the resulting field values to a magnitude of 1 as needed.
|
||||
// Normalize the resulting field values as needed.
|
||||
x3.Normalize()
|
||||
y3.Normalize()
|
||||
z3.Normalize()
|
||||
|
@ -330,7 +363,7 @@ func addZ2EqualsOne(p1, p2, result *JacobianPoint) {
|
|||
z3.Add2(z1, &h).Square() // Z3 = (Z1+H)^2 (mag: 1)
|
||||
z3.Add(z1z1.Add(&hh).Negate(2)) // Z3 = Z3-(Z1Z1+HH) (mag: 4)
|
||||
|
||||
// Normalize the resulting field values to a magnitude of 1 as needed.
|
||||
// Normalize the resulting field values as needed.
|
||||
x3.Normalize()
|
||||
y3.Normalize()
|
||||
z3.Normalize()
|
||||
|
@ -397,7 +430,7 @@ func addGeneric(p1, p2, result *JacobianPoint) {
|
|||
var negU1, negS1, negX3 FieldVal
|
||||
negU1.Set(&u1).Negate(1) // negU1 = -U1 (mag: 2)
|
||||
h.Add2(&u2, &negU1) // H = U2-U1 (mag: 3)
|
||||
i.Set(&h).MulInt(2).Square() // I = (2*H)^2 (mag: 2)
|
||||
i.Set(&h).MulInt(2).Square() // I = (2*H)^2 (mag: 1)
|
||||
j.Mul2(&h, &i) // J = H*I (mag: 1)
|
||||
negS1.Set(&s1).Negate(1) // negS1 = -S1 (mag: 2)
|
||||
r.Set(&s2).Add(&negS1).MulInt(2) // r = 2*(S2-S1) (mag: 6)
|
||||
|
@ -412,7 +445,7 @@ func addGeneric(p1, p2, result *JacobianPoint) {
|
|||
z3.Add(z1z1.Add(&z2z2).Negate(2)) // Z3 = Z3-(Z1Z1+Z2Z2) (mag: 4)
|
||||
z3.Mul(&h) // Z3 = Z3*H (mag: 1)
|
||||
|
||||
// Normalize the resulting field values to a magnitude of 1 as needed.
|
||||
// Normalize the resulting field values as needed.
|
||||
x3.Normalize()
|
||||
y3.Normalize()
|
||||
z3.Normalize()
|
||||
|
@ -424,7 +457,7 @@ func addGeneric(p1, p2, result *JacobianPoint) {
|
|||
// NOTE: The points must be normalized for this function to return the correct
|
||||
// result. The resulting point will be normalized.
|
||||
func AddNonConst(p1, p2, result *JacobianPoint) {
|
||||
// A point at infinity is the identity according to the group law for
|
||||
// The point at infinity is the identity according to the group law for
|
||||
// elliptic curve cryptography. Thus, ∞ + P = P and P + ∞ = P.
|
||||
if (p1.X.IsZero() && p1.Y.IsZero()) || p1.Z.IsZero() {
|
||||
result.Set(p2)
|
||||
|
@ -508,7 +541,7 @@ func doubleZ1EqualsOne(p, result *JacobianPoint) {
|
|||
y3.Set(&c).MulInt(8).Negate(8) // Y3 = -(8*C) (mag: 9)
|
||||
y3.Add(f.Mul(&e)) // Y3 = E*F+Y3 (mag: 10)
|
||||
|
||||
// Normalize the field values back to a magnitude of 1.
|
||||
// Normalize the resulting field values as needed.
|
||||
x3.Normalize()
|
||||
y3.Normalize()
|
||||
z3.Normalize()
|
||||
|
@ -562,7 +595,7 @@ func doubleGeneric(p, result *JacobianPoint) {
|
|||
y3.Set(&c).MulInt(8).Negate(8) // Y3 = -(8*C) (mag: 9)
|
||||
y3.Add(f.Mul(&e)) // Y3 = E*F+Y3 (mag: 10)
|
||||
|
||||
// Normalize the field values back to a magnitude of 1.
|
||||
// Normalize the resulting field values as needed.
|
||||
x3.Normalize()
|
||||
y3.Normalize()
|
||||
z3.Normalize()
|
||||
|
@ -574,7 +607,7 @@ func doubleGeneric(p, result *JacobianPoint) {
|
|||
// NOTE: The point must be normalized for this function to return the correct
|
||||
// result. The resulting point will be normalized.
|
||||
func DoubleNonConst(p, result *JacobianPoint) {
|
||||
// Doubling a point at infinity is still infinity.
|
||||
// Doubling the point at infinity is still infinity.
|
||||
if p.Y.IsZero() || p.Z.IsZero() {
|
||||
result.X.SetInt(0)
|
||||
result.Y.SetInt(0)
|
||||
|
@ -596,47 +629,283 @@ func DoubleNonConst(p, result *JacobianPoint) {
|
|||
doubleGeneric(p, result)
|
||||
}
|
||||
|
||||
// splitK returns a balanced length-two representation of k and their signs.
|
||||
// This is algorithm 3.74 from [GECC].
|
||||
//
|
||||
// One thing of note about this algorithm is that no matter what c1 and c2 are,
|
||||
// the final equation of k = k1 + k2 * lambda (mod n) will hold. This is
|
||||
// provable mathematically due to how a1/b1/a2/b2 are computed.
|
||||
//
|
||||
// c1 and c2 are chosen to minimize the max(k1,k2).
|
||||
func splitK(k []byte) ([]byte, []byte, int, int) {
|
||||
// All math here is done with big.Int, which is slow.
|
||||
// At some point, it might be useful to write something similar to
|
||||
// FieldVal but for N instead of P as the prime field if this ends up
|
||||
// being a bottleneck.
|
||||
bigIntK := new(big.Int)
|
||||
c1, c2 := new(big.Int), new(big.Int)
|
||||
tmp1, tmp2 := new(big.Int), new(big.Int)
|
||||
k1, k2 := new(big.Int), new(big.Int)
|
||||
// mulAdd64 multiplies the two passed base 2^64 digits together, adds the given
|
||||
// value to the result, and returns the 128-bit result via a (hi, lo) tuple
|
||||
// where the upper half of the bits are returned in hi and the lower half in lo.
|
||||
func mulAdd64(digit1, digit2, m uint64) (hi, lo uint64) {
|
||||
// Note the carry on the final add is safe to discard because the maximum
|
||||
// possible value is:
|
||||
// (2^64 - 1)(2^64 - 1) + (2^64 - 1) = 2^128 - 2^64
|
||||
// and:
|
||||
// 2^128 - 2^64 < 2^128.
|
||||
var c uint64
|
||||
hi, lo = bits.Mul64(digit1, digit2)
|
||||
lo, c = bits.Add64(lo, m, 0)
|
||||
hi, _ = bits.Add64(hi, 0, c)
|
||||
return hi, lo
|
||||
}
|
||||
|
||||
bigIntK.SetBytes(k)
|
||||
// c1 = round(b2 * k / n) from step 4.
|
||||
// Rounding isn't really necessary and costs too much, hence skipped
|
||||
c1.Mul(endomorphismB2, bigIntK)
|
||||
c1.Div(c1, curveParams.N)
|
||||
// c2 = round(b1 * k / n) from step 4 (sign reversed to optimize one step)
|
||||
// Rounding isn't really necessary and costs too much, hence skipped
|
||||
c2.Mul(endomorphismB1, bigIntK)
|
||||
c2.Div(c2, curveParams.N)
|
||||
// k1 = k - c1 * a1 - c2 * a2 from step 5 (note c2's sign is reversed)
|
||||
tmp1.Mul(c1, endomorphismA1)
|
||||
tmp2.Mul(c2, endomorphismA2)
|
||||
k1.Sub(bigIntK, tmp1)
|
||||
k1.Add(k1, tmp2)
|
||||
// k2 = - c1 * b1 - c2 * b2 from step 5 (note c2's sign is reversed)
|
||||
tmp1.Mul(c1, endomorphismB1)
|
||||
tmp2.Mul(c2, endomorphismB2)
|
||||
k2.Sub(tmp2, tmp1)
|
||||
// mulAdd64Carry multiplies the two passed base 2^64 digits together, adds both
|
||||
// the given value and carry to the result, and returns the 128-bit result via a
|
||||
// (hi, lo) tuple where the upper half of the bits are returned in hi and the
|
||||
// lower half in lo.
|
||||
func mulAdd64Carry(digit1, digit2, m, c uint64) (hi, lo uint64) {
|
||||
// Note the carry on the high order add is safe to discard because the
|
||||
// maximum possible value is:
|
||||
// (2^64 - 1)(2^64 - 1) + 2*(2^64 - 1) = 2^128 - 1
|
||||
// and:
|
||||
// 2^128 - 1 < 2^128.
|
||||
var c2 uint64
|
||||
hi, lo = mulAdd64(digit1, digit2, m)
|
||||
lo, c2 = bits.Add64(lo, c, 0)
|
||||
hi, _ = bits.Add64(hi, 0, c2)
|
||||
return hi, lo
|
||||
}
|
||||
|
||||
// Note Bytes() throws out the sign of k1 and k2. This matters
|
||||
// since k1 and/or k2 can be negative. Hence, we pass that
|
||||
// back separately.
|
||||
return k1.Bytes(), k2.Bytes(), k1.Sign(), k2.Sign()
|
||||
// mul512Rsh320Round computes the full 512-bit product of the two given scalars,
|
||||
// right shifts the result by 320 bits, rounds to the nearest integer, and
|
||||
// returns the result in constant time.
|
||||
//
|
||||
// Note that despite the inputs and output being mod n scalars, the 512-bit
|
||||
// product is NOT reduced mod N prior to the right shift. This is intentional
|
||||
// because it is used for replacing division with multiplication and thus the
|
||||
// intermediate results must be done via a field extension to a larger field.
|
||||
func mul512Rsh320Round(n1, n2 *ModNScalar) ModNScalar {
|
||||
// Convert n1 and n2 to base 2^64 digits.
|
||||
n1Digit0 := uint64(n1.n[0]) | uint64(n1.n[1])<<32
|
||||
n1Digit1 := uint64(n1.n[2]) | uint64(n1.n[3])<<32
|
||||
n1Digit2 := uint64(n1.n[4]) | uint64(n1.n[5])<<32
|
||||
n1Digit3 := uint64(n1.n[6]) | uint64(n1.n[7])<<32
|
||||
n2Digit0 := uint64(n2.n[0]) | uint64(n2.n[1])<<32
|
||||
n2Digit1 := uint64(n2.n[2]) | uint64(n2.n[3])<<32
|
||||
n2Digit2 := uint64(n2.n[4]) | uint64(n2.n[5])<<32
|
||||
n2Digit3 := uint64(n2.n[6]) | uint64(n2.n[7])<<32
|
||||
|
||||
// Compute the full 512-bit product n1*n2.
|
||||
var r0, r1, r2, r3, r4, r5, r6, r7, c uint64
|
||||
|
||||
// Terms resulting from the product of the first digit of the second number
|
||||
// by all digits of the first number.
|
||||
//
|
||||
// Note that r0 is ignored because it is not needed to compute the higher
|
||||
// terms and it is shifted out below anyway.
|
||||
c, _ = bits.Mul64(n2Digit0, n1Digit0)
|
||||
c, r1 = mulAdd64(n2Digit0, n1Digit1, c)
|
||||
c, r2 = mulAdd64(n2Digit0, n1Digit2, c)
|
||||
r4, r3 = mulAdd64(n2Digit0, n1Digit3, c)
|
||||
|
||||
// Terms resulting from the product of the second digit of the second number
|
||||
// by all digits of the first number.
|
||||
//
|
||||
// Note that r1 is ignored because it is no longer needed to compute the
|
||||
// higher terms and it is shifted out below anyway.
|
||||
c, _ = mulAdd64(n2Digit1, n1Digit0, r1)
|
||||
c, r2 = mulAdd64Carry(n2Digit1, n1Digit1, r2, c)
|
||||
c, r3 = mulAdd64Carry(n2Digit1, n1Digit2, r3, c)
|
||||
r5, r4 = mulAdd64Carry(n2Digit1, n1Digit3, r4, c)
|
||||
|
||||
// Terms resulting from the product of the third digit of the second number
|
||||
// by all digits of the first number.
|
||||
//
|
||||
// Note that r2 is ignored because it is no longer needed to compute the
|
||||
// higher terms and it is shifted out below anyway.
|
||||
c, _ = mulAdd64(n2Digit2, n1Digit0, r2)
|
||||
c, r3 = mulAdd64Carry(n2Digit2, n1Digit1, r3, c)
|
||||
c, r4 = mulAdd64Carry(n2Digit2, n1Digit2, r4, c)
|
||||
r6, r5 = mulAdd64Carry(n2Digit2, n1Digit3, r5, c)
|
||||
|
||||
// Terms resulting from the product of the fourth digit of the second number
|
||||
// by all digits of the first number.
|
||||
//
|
||||
// Note that r3 is ignored because it is no longer needed to compute the
|
||||
// higher terms and it is shifted out below anyway.
|
||||
c, _ = mulAdd64(n2Digit3, n1Digit0, r3)
|
||||
c, r4 = mulAdd64Carry(n2Digit3, n1Digit1, r4, c)
|
||||
c, r5 = mulAdd64Carry(n2Digit3, n1Digit2, r5, c)
|
||||
r7, r6 = mulAdd64Carry(n2Digit3, n1Digit3, r6, c)
|
||||
|
||||
// At this point the upper 256 bits of the full 512-bit product n1*n2 are in
|
||||
// r4..r7 (recall the low order results were discarded as noted above).
|
||||
//
|
||||
// Right shift the result 320 bits. Note that the MSB of r4 determines
|
||||
// whether or not to round because it is the final bit that is shifted out.
|
||||
//
|
||||
// Also, notice that r3..r7 would also ordinarily be set to 0 as well for
|
||||
// the full shift, but that is skipped since they are no longer used as
|
||||
// their values are known to be zero.
|
||||
roundBit := r4 >> 63
|
||||
r2, r1, r0 = r7, r6, r5
|
||||
|
||||
// Conditionally add 1 depending on the round bit in constant time.
|
||||
r0, c = bits.Add64(r0, roundBit, 0)
|
||||
r1, c = bits.Add64(r1, 0, c)
|
||||
r2, r3 = bits.Add64(r2, 0, c)
|
||||
|
||||
// Finally, convert the result to a mod n scalar.
|
||||
//
|
||||
// No modular reduction is needed because the result is guaranteed to be
|
||||
// less than the group order given the group order is > 2^255 and the
|
||||
// maximum possible value of the result is 2^192.
|
||||
var result ModNScalar
|
||||
result.n[0] = uint32(r0)
|
||||
result.n[1] = uint32(r0 >> 32)
|
||||
result.n[2] = uint32(r1)
|
||||
result.n[3] = uint32(r1 >> 32)
|
||||
result.n[4] = uint32(r2)
|
||||
result.n[5] = uint32(r2 >> 32)
|
||||
result.n[6] = uint32(r3)
|
||||
result.n[7] = uint32(r3 >> 32)
|
||||
return result
|
||||
}
|
||||
|
||||
// splitK returns two scalars (k1 and k2) that are a balanced length-two
|
||||
// representation of the provided scalar such that k ≡ k1 + k2*λ (mod N), where
|
||||
// N is the secp256k1 group order.
|
||||
func splitK(k *ModNScalar) (ModNScalar, ModNScalar) {
|
||||
// The ultimate goal is to decompose k into two scalars that are around
|
||||
// half the bit length of k such that the following equation is satisfied:
|
||||
//
|
||||
// k1 + k2*λ ≡ k (mod n)
|
||||
//
|
||||
// The strategy used here is based on algorithm 3.74 from [GECC] with a few
|
||||
// modifications to make use of the more efficient mod n scalar type, avoid
|
||||
// some costly long divisions, and minimize the number of calculations.
|
||||
//
|
||||
// Start by defining a function that takes a vector v = <a,b> ∈ ℤ⨯ℤ:
|
||||
//
|
||||
// f(v) = a + bλ (mod n)
|
||||
//
|
||||
// Then, find two vectors, v1 = <a1,b1>, and v2 = <a2,b2> in ℤ⨯ℤ such that:
|
||||
// 1) v1 and v2 are linearly independent
|
||||
// 2) f(v1) = f(v2) = 0
|
||||
// 3) v1 and v2 have small Euclidean norm
|
||||
//
|
||||
// The vectors that satisfy these properties are found via the Euclidean
|
||||
// algorithm and are precomputed since both n and λ are fixed values for the
|
||||
// secp256k1 curve. See genprecomps.go for derivation details.
|
||||
//
|
||||
// Next, consider k as a vector <k, 0> in ℚ⨯ℚ and by linear algebra write:
|
||||
//
|
||||
// <k, 0> = g1*v1 + g2*v2, where g1, g2 ∈ ℚ
|
||||
//
|
||||
// Note that, per above, the components of vector v1 are a1 and b1 while the
|
||||
// components of vector v2 are a2 and b2. Given the vectors v1 and v2 were
|
||||
// generated such that a1*b2 - a2*b1 = n, solving the equation for g1 and g2
|
||||
// yields:
|
||||
//
|
||||
// g1 = b2*k / n
|
||||
// g2 = -b1*k / n
|
||||
//
|
||||
// Observe:
|
||||
// <k, 0> = g1*v1 + g2*v2
|
||||
// = (b2*k/n)*<a1,b1> + (-b1*k/n)*<a2,b2> | substitute
|
||||
// = <a1*b2*k/n, b1*b2*k/n> + <-a2*b1*k/n, -b2*b1*k/n> | scalar mul
|
||||
// = <a1*b2*k/n - a2*b1*k/n, b1*b2*k/n - b2*b1*k/n> | vector add
|
||||
// = <[a1*b2*k - a2*b1*k]/n, 0> | simplify
|
||||
// = <k*[a1*b2 - a2*b1]/n, 0> | factor out k
|
||||
// = <k*n/n, 0> | substitute
|
||||
// = <k, 0> | simplify
|
||||
//
|
||||
// Now, consider an integer-valued vector v:
|
||||
//
|
||||
// v = c1*v1 + c2*v2, where c1, c2 ∈ ℤ (mod n)
|
||||
//
|
||||
// Since vectors v1 and v2 are linearly independent and were generated such
|
||||
// that f(v1) = f(v2) = 0, all possible scalars c1 and c2 also produce a
|
||||
// vector v such that f(v) = 0.
|
||||
//
|
||||
// In other words, c1 and c2 can be any integers and the resulting
|
||||
// decomposition will still satisfy the required equation. However, since
|
||||
// the goal is to produce a balanced decomposition that provides a
|
||||
// performance advantage by minimizing max(k1, k2), c1 and c2 need to be
|
||||
// integers close to g1 and g2, respectively, so the resulting vector v is
|
||||
// an integer-valued vector that is close to <k, 0>.
|
||||
//
|
||||
// Finally, consider the vector u:
|
||||
//
|
||||
// u = <k, 0> - v
|
||||
//
|
||||
// It follows that f(u) = k and thus the two components of vector u satisfy
|
||||
// the required equation:
|
||||
//
|
||||
// k1 + k2*λ ≡ k (mod n)
|
||||
//
|
||||
// Choosing c1 and c2:
|
||||
// -------------------
|
||||
//
|
||||
// As mentioned above, c1 and c2 need to be integers close to g1 and g2,
|
||||
// respectively. The algorithm in [GECC] chooses the following values:
|
||||
//
|
||||
// c1 = round(g1) = round(b2*k / n)
|
||||
// c2 = round(g2) = round(-b1*k / n)
|
||||
//
|
||||
// However, as section 3.4.2 of [STWS] notes, the aforementioned approach
|
||||
// requires costly long divisions that can be avoided by precomputing
|
||||
// rounded estimates as follows:
|
||||
//
|
||||
// t = bitlen(n) + 1
|
||||
// z1 = round(2^t * b2 / n)
|
||||
// z2 = round(2^t * -b1 / n)
|
||||
//
|
||||
// Then, use those precomputed estimates to perform a multiplication by k
|
||||
// along with a floored division by 2^t, which is a simple right shift by t:
|
||||
//
|
||||
// c1 = floor(k * z1 / 2^t) = (k * z1) >> t
|
||||
// c2 = floor(k * z2 / 2^t) = (k * z2) >> t
|
||||
//
|
||||
// Finally, round up if last bit discarded in the right shift by t is set by
|
||||
// adding 1.
|
||||
//
|
||||
// As a further optimization, rather than setting t = bitlen(n) + 1 = 257 as
|
||||
// stated by [STWS], this implementation uses a higher precision estimate of
|
||||
// t = bitlen(n) + 64 = 320 because it allows simplification of the shifts
|
||||
// in the internal calculations that are done via uint64s and also allows
|
||||
// the use of floor in the precomputations.
|
||||
//
|
||||
// Thus, the calculations this implementation uses are:
|
||||
//
|
||||
// z1 = floor(b2<<320 / n) | precomputed
|
||||
// z2 = floor((-b1)<<320) / n) | precomputed
|
||||
// c1 = ((k * z1) >> 320) + (((k * z1) >> 319) & 1)
|
||||
// c2 = ((k * z2) >> 320) + (((k * z2) >> 319) & 1)
|
||||
//
|
||||
// Putting it all together:
|
||||
// ------------------------
|
||||
//
|
||||
// Calculate the following vectors using the values discussed above:
|
||||
//
|
||||
// v = c1*v1 + c2*v2
|
||||
// u = <k, 0> - v
|
||||
//
|
||||
// The two components of the resulting vector v are:
|
||||
// va = c1*a1 + c2*a2
|
||||
// vb = c1*b1 + c2*b2
|
||||
//
|
||||
// Thus, the two components of the resulting vector u are:
|
||||
// k1 = k - va
|
||||
// k2 = 0 - vb = -vb
|
||||
//
|
||||
// As some final optimizations:
|
||||
//
|
||||
// 1) Note that k1 + k2*λ ≡ k (mod n) means that k1 ≡ k - k2*λ (mod n).
|
||||
// Therefore, the computation of va can be avoided to save two
|
||||
// field multiplications and a field addition.
|
||||
//
|
||||
// 2) Since k1 = k - k2*λ = k + k2*(-λ), an additional field negation is
|
||||
// saved by storing and using the negative version of λ.
|
||||
//
|
||||
// 3) Since k2 = -vb = -(c1*b1 + c2*b2) = c1*(-b1) + c2*(-b2), one more
|
||||
// field negation is saved by storing and using the negative versions of
|
||||
// b1 and b2.
|
||||
//
|
||||
// k2 = c1*(-b1) + c2*(-b2)
|
||||
// k1 = k + k2*(-λ)
|
||||
var k1, k2 ModNScalar
|
||||
c1 := mul512Rsh320Round(k, endoZ1)
|
||||
c2 := mul512Rsh320Round(k, endoZ2)
|
||||
k2.Add2(c1.Mul(endoNegB1), c2.Mul(endoNegB2))
|
||||
k1.Mul2(&k2, endoNegLambda).Add(k)
|
||||
return k1, k2
|
||||
}
|
||||
|
||||
// nafScalar represents a positive integer up to a maximum value of 2^256 - 1
|
||||
|
@ -775,70 +1044,132 @@ func naf(k []byte) nafScalar {
|
|||
return result
|
||||
}
|
||||
|
||||
// ScalarMultNonConst multiplies k*P where k is a big endian integer modulo the
|
||||
// curve order and P is a point in Jacobian projective coordinates and stores
|
||||
// the result in the provided Jacobian point.
|
||||
// ScalarMultNonConst multiplies k*P where k is a scalar modulo the curve order
|
||||
// and P is a point in Jacobian projective coordinates and stores the result in
|
||||
// the provided Jacobian point.
|
||||
//
|
||||
// NOTE: The point must be normalized for this function to return the correct
|
||||
// result. The resulting point will be normalized.
|
||||
func ScalarMultNonConst(k *ModNScalar, point, result *JacobianPoint) {
|
||||
// Decompose K into k1 and k2 in order to halve the number of EC ops.
|
||||
// See Algorithm 3.74 in [GECC].
|
||||
kBytes := k.Bytes()
|
||||
k1, k2, signK1, signK2 := splitK(kBytes[:])
|
||||
zeroArray32(&kBytes)
|
||||
|
||||
// The main equation here to remember is:
|
||||
// k * P = k1 * P + k2 * ϕ(P)
|
||||
// -------------------------------------------------------------------------
|
||||
// This makes use of the following efficiently-computable endomorphism to
|
||||
// accelerate the computation:
|
||||
//
|
||||
// P1 below is P in the equation, P2 below is ϕ(P) in the equation
|
||||
// φ(P) ⟼ λ*P = (β*P.x mod p, P.y)
|
||||
//
|
||||
// In other words, there is a special scalar λ that every point on the
|
||||
// elliptic curve can be multiplied by that will result in the same point as
|
||||
// performing a single field multiplication of the point's X coordinate by
|
||||
// the special value β.
|
||||
//
|
||||
// This is useful because scalar point multiplication is significantly more
|
||||
// expensive than a single field multiplication given the former involves a
|
||||
// series of point doublings and additions which themselves consist of a
|
||||
// combination of several field multiplications, squarings, and additions.
|
||||
//
|
||||
// So, the idea behind making use of the endomorphism is thus to decompose
|
||||
// the scalar into two scalars that are each about half the bit length of
|
||||
// the original scalar such that:
|
||||
//
|
||||
// k ≡ k1 + k2*λ (mod n)
|
||||
//
|
||||
// This in turn allows the scalar point multiplication to be performed as a
|
||||
// sum of two smaller half-length multiplications as follows:
|
||||
//
|
||||
// k*P = (k1 + k2*λ)*P
|
||||
// = k1*P + k2*λ*P
|
||||
// = k1*P + k2*φ(P)
|
||||
//
|
||||
// Thus, a speedup is achieved so long as it's faster to decompose the
|
||||
// scalar, compute φ(P), and perform a simultaneous multiply of the
|
||||
// half-length point multiplications than it is to compute a full width
|
||||
// point multiplication.
|
||||
//
|
||||
// In practice, benchmarks show the current implementation provides a
|
||||
// speedup of around 30-35% versus not using the endomorphism.
|
||||
//
|
||||
// See section 3.5 in [GECC] for a more rigorous treatment.
|
||||
// -------------------------------------------------------------------------
|
||||
|
||||
// Per above, the main equation here to remember is:
|
||||
// k*P = k1*P + k2*φ(P)
|
||||
//
|
||||
// p1 below is P in the equation while p2 is φ(P) in the equation.
|
||||
//
|
||||
// NOTE: φ(x,y) = (β*x,y). The Jacobian z coordinates are the same, so this
|
||||
// math goes through.
|
||||
//
|
||||
// Also, calculate -p1 and -p2 for use in the NAF optimization.
|
||||
p1, p1Neg := new(JacobianPoint), new(JacobianPoint)
|
||||
p1.Set(point)
|
||||
p1Neg.Set(p1)
|
||||
p1Neg.Y.Negate(1).Normalize()
|
||||
|
||||
// NOTE: ϕ(x,y) = (βx,y). The Jacobian z coordinates are the same, so this
|
||||
// math goes through.
|
||||
p2, p2Neg := new(JacobianPoint), new(JacobianPoint)
|
||||
p2.Set(p1)
|
||||
p2.X.Mul(endomorphismBeta).Normalize()
|
||||
p2.X.Mul(endoBeta).Normalize()
|
||||
p2Neg.Set(p2)
|
||||
p2Neg.Y.Negate(1).Normalize()
|
||||
|
||||
// Flip the positive and negative values of the points as needed
|
||||
// depending on the signs of k1 and k2. As mentioned in the equation
|
||||
// above, each of k1 and k2 are multiplied by the respective point.
|
||||
// Since -k * P is the same thing as k * -P, and the group law for
|
||||
// elliptic curves states that P(x, y) = -P(x, -y), it's faster and
|
||||
// simplifies the code to just make the point negative.
|
||||
if signK1 == -1 {
|
||||
// Decompose k into k1 and k2 such that k = k1 + k2*λ (mod n) where k1 and
|
||||
// k2 are around half the bit length of k in order to halve the number of EC
|
||||
// operations.
|
||||
//
|
||||
// Notice that this also flips the sign of the scalars and points as needed
|
||||
// to minimize the bit lengths of the scalars k1 and k2.
|
||||
//
|
||||
// This is done because the scalars are operating modulo the group order
|
||||
// which means that when they would otherwise be a small negative magnitude
|
||||
// they will instead be a large positive magnitude. Since the goal is for
|
||||
// the scalars to have a small magnitude to achieve a performance boost, use
|
||||
// their negation when they are greater than the half order of the group and
|
||||
// flip the positive and negative values of the corresponding point that
|
||||
// will be multiplied by to compensate.
|
||||
//
|
||||
// In other words, transform the calc when k1 is over the half order to:
|
||||
// k1*P = -k1*-P
|
||||
//
|
||||
// Similarly, transform the calc when k2 is over the half order to:
|
||||
// k2*φ(P) = -k2*-φ(P)
|
||||
k1, k2 := splitK(k)
|
||||
if k1.IsOverHalfOrder() {
|
||||
k1.Negate()
|
||||
p1, p1Neg = p1Neg, p1
|
||||
}
|
||||
if signK2 == -1 {
|
||||
if k2.IsOverHalfOrder() {
|
||||
k2.Negate()
|
||||
p2, p2Neg = p2Neg, p2
|
||||
}
|
||||
|
||||
// NAF versions of k1 and k2 should have a lot more zeros.
|
||||
// Convert k1 and k2 into their NAF representations since NAF has a lot more
|
||||
// zeros overall on average which minimizes the number of required point
|
||||
// additions in exchange for a mix of fewer point additions and subtractions
|
||||
// at the cost of one additional point doubling.
|
||||
//
|
||||
// The Pos version of the bytes contain the +1s and the Neg versions
|
||||
// contain the -1s.
|
||||
k1NAF, k2NAF := naf(k1), naf(k2)
|
||||
// This is an excellent tradeoff because subtraction of points has the same
|
||||
// computational complexity as addition of points and point doubling is
|
||||
// faster than both.
|
||||
//
|
||||
// Concretely, on average, 1/2 of all bits will be non-zero with the normal
|
||||
// binary representation whereas only 1/3rd of the bits will be non-zero
|
||||
// with NAF.
|
||||
//
|
||||
// The Pos version of the bytes contain the +1s and the Neg versions contain
|
||||
// the -1s.
|
||||
k1Bytes, k2Bytes := k1.Bytes(), k2.Bytes()
|
||||
k1NAF, k2NAF := naf(k1Bytes[:]), naf(k2Bytes[:])
|
||||
k1PosNAF, k1NegNAF := k1NAF.Pos(), k1NAF.Neg()
|
||||
k2PosNAF, k2NegNAF := k2NAF.Pos(), k2NAF.Neg()
|
||||
k1Len, k2Len := len(k1PosNAF), len(k2PosNAF)
|
||||
|
||||
// Add left-to-right using the NAF optimization. See algorithm 3.77 from
|
||||
// [GECC].
|
||||
//
|
||||
// Point Q = ∞ (point at infinity).
|
||||
var q JacobianPoint
|
||||
m := k1Len
|
||||
if m < k2Len {
|
||||
m = k2Len
|
||||
}
|
||||
|
||||
// Point Q = ∞ (point at infinity).
|
||||
var q JacobianPoint
|
||||
|
||||
// Add left-to-right using the NAF optimization. See algorithm 3.77
|
||||
// from [GECC]. This should be faster overall since there will be a lot
|
||||
// more instances of 0, hence reducing the number of Jacobian additions
|
||||
// at the cost of 1 possible extra doubling.
|
||||
for i := 0; i < m; i++ {
|
||||
// Since k1 and k2 are potentially different lengths and the calculation
|
||||
// is being done left to right, pad the front of the shorter one with
|
||||
|
@ -850,7 +1181,8 @@ func ScalarMultNonConst(k *ModNScalar, point, result *JacobianPoint) {
|
|||
if i >= m-k2Len {
|
||||
k2BytePos, k2ByteNeg = k2PosNAF[i-(m-k2Len)], k2NegNAF[i-(m-k2Len)]
|
||||
}
|
||||
for bit, mask := 7, uint8(1<<7); bit >= 0; bit, mask = bit-1, mask>>1 {
|
||||
|
||||
for mask := uint8(1 << 7); mask > 0; mask >>= 1 {
|
||||
// Q = 2 * Q
|
||||
DoubleNonConst(&q, &q)
|
||||
|
||||
|
@ -883,31 +1215,28 @@ func ScalarMultNonConst(k *ModNScalar, point, result *JacobianPoint) {
|
|||
result.Set(&q)
|
||||
}
|
||||
|
||||
// ScalarBaseMultNonConst multiplies k*G where G is the base point of the group
|
||||
// and k is a big endian integer. The result is stored in Jacobian coordinates
|
||||
// (x1, y1, z1).
|
||||
// ScalarBaseMultNonConst multiplies k*G where k is a scalar modulo the curve
|
||||
// order and G is the base point of the group and stores the result in the
|
||||
// provided Jacobian point.
|
||||
//
|
||||
// NOTE: The resulting point will be normalized.
|
||||
func ScalarBaseMultNonConst(k *ModNScalar, result *JacobianPoint) {
|
||||
bytePoints := s256BytePoints()
|
||||
|
||||
// Point Q = ∞ (point at infinity).
|
||||
var q JacobianPoint
|
||||
// Start with the point at infinity.
|
||||
result.X.Zero()
|
||||
result.Y.Zero()
|
||||
result.Z.Zero()
|
||||
|
||||
// curve.bytePoints has all 256 byte points for each 8-bit window. The
|
||||
// strategy is to add up the byte points. This is best understood by
|
||||
// expressing k in base-256 which it already sort of is. Each "digit" in
|
||||
// the 8-bit window can be looked up using bytePoints and added together.
|
||||
var pt JacobianPoint
|
||||
for i, byteVal := range k.Bytes() {
|
||||
p := bytePoints[i][byteVal]
|
||||
pt.X.Set(&p[0])
|
||||
pt.Y.Set(&p[1])
|
||||
pt.Z.SetInt(1)
|
||||
AddNonConst(&q, &pt, &q)
|
||||
// bytePoints has all 256 byte points for each 8-bit window. The strategy
|
||||
// is to add up the byte points. This is best understood by expressing k in
|
||||
// base-256 which it already sort of is. Each "digit" in the 8-bit window
|
||||
// can be looked up using bytePoints and added together.
|
||||
kb := k.Bytes()
|
||||
for i := 0; i < len(kb); i++ {
|
||||
pt := &bytePoints[i][kb[i]]
|
||||
AddNonConst(result, pt, result)
|
||||
}
|
||||
|
||||
result.Set(&q)
|
||||
}
|
||||
|
||||
// isOnCurve returns whether or not the affine point (x,y) is on the curve.
|
||||
|
|
|
@ -1,10 +1,11 @@
|
|||
// Copyright (c) 2013-2014 The btcsuite developers
|
||||
// Copyright (c) 2015-2019 The Decred developers
|
||||
// Copyright (c) 2015-2022 The Decred developers
|
||||
// Use of this source code is governed by an ISC
|
||||
// license that can be found in the LICENSE file.
|
||||
|
||||
/*
|
||||
Package secp256k1 implements optimized secp256k1 elliptic curve operations.
|
||||
Package secp256k1 implements optimized secp256k1 elliptic curve operations in
|
||||
pure Go.
|
||||
|
||||
This package provides an optimized pure Go implementation of elliptic curve
|
||||
cryptography operations over the secp256k1 curve as well as data structures and
|
||||
|
@ -18,22 +19,22 @@ packages for more details about those aspects.
|
|||
|
||||
An overview of the features provided by this package are as follows:
|
||||
|
||||
- Private key generation, serialization, and parsing
|
||||
- Public key generation, serialization and parsing per ANSI X9.62-1998
|
||||
- Parses uncompressed, compressed, and hybrid public keys
|
||||
- Serializes uncompressed and compressed public keys
|
||||
- Specialized types for performing optimized and constant time field operations
|
||||
- FieldVal type for working modulo the secp256k1 field prime
|
||||
- ModNScalar type for working modulo the secp256k1 group order
|
||||
- Elliptic curve operations in Jacobian projective coordinates
|
||||
- Point addition
|
||||
- Point doubling
|
||||
- Scalar multiplication with an arbitrary point
|
||||
- Scalar multiplication with the base point (group generator)
|
||||
- Point decompression from a given x coordinate
|
||||
- Nonce generation via RFC6979 with support for extra data and version
|
||||
information that can be used to prevent nonce reuse between signing
|
||||
algorithms
|
||||
- Private key generation, serialization, and parsing
|
||||
- Public key generation, serialization and parsing per ANSI X9.62-1998
|
||||
- Parses uncompressed, compressed, and hybrid public keys
|
||||
- Serializes uncompressed and compressed public keys
|
||||
- Specialized types for performing optimized and constant time field operations
|
||||
- FieldVal type for working modulo the secp256k1 field prime
|
||||
- ModNScalar type for working modulo the secp256k1 group order
|
||||
- Elliptic curve operations in Jacobian projective coordinates
|
||||
- Point addition
|
||||
- Point doubling
|
||||
- Scalar multiplication with an arbitrary point
|
||||
- Scalar multiplication with the base point (group generator)
|
||||
- Point decompression from a given x coordinate
|
||||
- Nonce generation via RFC6979 with support for extra data and version
|
||||
information that can be used to prevent nonce reuse between signing
|
||||
algorithms
|
||||
|
||||
It also provides an implementation of the Go standard library crypto/elliptic
|
||||
Curve interface via the S256 function so that it may be used with other packages
|
||||
|
@ -49,7 +50,7 @@ use optimized secp256k1 elliptic curve cryptography.
|
|||
Finally, a comprehensive suite of tests is provided to provide a high level of
|
||||
quality assurance.
|
||||
|
||||
Use of secp256k1 in Decred
|
||||
# Use of secp256k1 in Decred
|
||||
|
||||
At the time of this writing, the primary public key cryptography in widespread
|
||||
use on the Decred network used to secure coins is based on elliptic curves
|
||||
|
|
|
@ -21,7 +21,7 @@ hash combination.
|
|||
|
||||
A comprehensive suite of tests is provided to ensure proper functionality.
|
||||
|
||||
ECDSA use in Decred
|
||||
# ECDSA use in Decred
|
||||
|
||||
At the time of this writing, ECDSA signatures are heavily used for proving coin
|
||||
ownership in Decred as the vast majority of transactions consist of what is
|
||||
|
@ -30,7 +30,7 @@ private key only known to the recipient of the coins along with an encumbrance
|
|||
that requires an ECDSA signature that proves the new owner possesses the private
|
||||
key without actually revealing it.
|
||||
|
||||
Errors
|
||||
# Errors
|
||||
|
||||
Errors returned by this package are of type ecdsa.Error and fully support the
|
||||
standard library errors.Is and errors.As functions. This allows the caller to
|
||||
|
|
|
@ -1,4 +1,4 @@
|
|||
// Copyright (c) 2020 The Decred developers
|
||||
// Copyright (c) 2020-2022 The Decred developers
|
||||
// Use of this source code is governed by an ISC
|
||||
// license that can be found in the LICENSE file.
|
||||
|
||||
|
@ -85,6 +85,24 @@ const (
|
|||
// ErrSigSTooBig is returned when a signature has S with a value that is
|
||||
// greater than or equal to the group order.
|
||||
ErrSigSTooBig = ErrorKind("ErrSigSTooBig")
|
||||
|
||||
// ErrSigInvalidLen is returned when a signature that should be a compact
|
||||
// signature is not the required length.
|
||||
ErrSigInvalidLen = ErrorKind("ErrSigInvalidLen")
|
||||
|
||||
// ErrSigInvalidRecoveryCode is returned when a signature that should be a
|
||||
// compact signature has an invalid value for the public key recovery code.
|
||||
ErrSigInvalidRecoveryCode = ErrorKind("ErrSigInvalidRecoveryCode")
|
||||
|
||||
// ErrSigOverflowsPrime is returned when a signature that should be a
|
||||
// compact signature has the overflow bit set but adding the order to it
|
||||
// would overflow the underlying field prime.
|
||||
ErrSigOverflowsPrime = ErrorKind("ErrSigOverflowsPrime")
|
||||
|
||||
// ErrPointNotOnCurve is returned when attempting to recover a public key
|
||||
// from a compact signature results in a point that is not on the elliptic
|
||||
// curve.
|
||||
ErrPointNotOnCurve = ErrorKind("ErrPointNotOnCurve")
|
||||
)
|
||||
|
||||
// Error satisfies the error interface and prints human-readable errors.
|
||||
|
|
|
@ -1,12 +1,11 @@
|
|||
// Copyright (c) 2013-2014 The btcsuite developers
|
||||
// Copyright (c) 2015-2020 The Decred developers
|
||||
// Copyright (c) 2015-2022 The Decred developers
|
||||
// Use of this source code is governed by an ISC
|
||||
// license that can be found in the LICENSE file.
|
||||
|
||||
package ecdsa
|
||||
|
||||
import (
|
||||
"errors"
|
||||
"fmt"
|
||||
|
||||
"github.com/decred/dcrd/dcrec/secp256k1/v4"
|
||||
|
@ -30,10 +29,10 @@ var (
|
|||
// orderAsFieldVal is the order of the secp256k1 curve group stored as a
|
||||
// field value. It is provided here to avoid the need to create it multiple
|
||||
// times.
|
||||
orderAsFieldVal = func() *secp256k1.FieldVal {
|
||||
orderAsFieldVal = func() secp256k1.FieldVal {
|
||||
var f secp256k1.FieldVal
|
||||
f.SetByteSlice(secp256k1.Params().N.Bytes())
|
||||
return &f
|
||||
return f
|
||||
}()
|
||||
)
|
||||
|
||||
|
@ -283,7 +282,7 @@ func (sig *Signature) Verify(hash []byte, pubKey *secp256k1.PublicKey) bool {
|
|||
// Step 10.
|
||||
//
|
||||
// Verified if (R + N) * z == X.x (mod P)
|
||||
sigRModP.Add(orderAsFieldVal)
|
||||
sigRModP.Add(&orderAsFieldVal)
|
||||
result.Mul2(&sigRModP, z).Normalize()
|
||||
return result.Equals(&X.X)
|
||||
}
|
||||
|
@ -535,8 +534,130 @@ func ParseDERSignature(sig []byte) (*Signature, error) {
|
|||
return NewSignature(&r, &s), nil
|
||||
}
|
||||
|
||||
// sign generates an ECDSA signature over the secp256k1 curve for the provided
|
||||
// hash (which should be the result of hashing a larger message) using the given
|
||||
// nonce and private key and returns it along with an additional public key
|
||||
// recovery code and success indicator. Upon success, the produced signature is
|
||||
// deterministic (same message, nonce, and key yield the same signature) and
|
||||
// canonical in accordance with BIP0062.
|
||||
//
|
||||
// Note that signRFC6979 makes use of this function as it is the primary ECDSA
|
||||
// signing logic. It differs in that it accepts a nonce to use when signing and
|
||||
// may not successfully produce a valid signature for the given nonce. It is
|
||||
// primarily separated for testing purposes.
|
||||
func sign(privKey, nonce *secp256k1.ModNScalar, hash []byte) (*Signature, byte, bool) {
|
||||
// The algorithm for producing an ECDSA signature is given as algorithm 4.29
|
||||
// in [GECC].
|
||||
//
|
||||
// The following is a paraphrased version for reference:
|
||||
//
|
||||
// G = curve generator
|
||||
// N = curve order
|
||||
// d = private key
|
||||
// m = message
|
||||
// r, s = signature
|
||||
//
|
||||
// 1. Select random nonce k in [1, N-1]
|
||||
// 2. Compute kG
|
||||
// 3. r = kG.x mod N (kG.x is the x coordinate of the point kG)
|
||||
// Repeat from step 1 if r = 0
|
||||
// 4. e = H(m)
|
||||
// 5. s = k^-1(e + dr) mod N
|
||||
// Repeat from step 1 if s = 0
|
||||
// 6. Return (r,s)
|
||||
//
|
||||
// This is slightly modified here to conform to RFC6979 and BIP 62 as
|
||||
// follows:
|
||||
//
|
||||
// A. Instead of selecting a random nonce in step 1, use RFC6979 to generate
|
||||
// a deterministic nonce in [1, N-1] parameterized by the private key,
|
||||
// message being signed, and an iteration count for the repeat cases
|
||||
// B. Negate s calculated in step 5 if it is > N/2
|
||||
// This is done because both s and its negation are valid signatures
|
||||
// modulo the curve order N, so it forces a consistent choice to reduce
|
||||
// signature malleability
|
||||
|
||||
// NOTE: Step 1 is performed by the caller.
|
||||
//
|
||||
// Step 2.
|
||||
//
|
||||
// Compute kG
|
||||
//
|
||||
// Note that the point must be in affine coordinates.
|
||||
k := nonce
|
||||
var kG secp256k1.JacobianPoint
|
||||
secp256k1.ScalarBaseMultNonConst(k, &kG)
|
||||
kG.ToAffine()
|
||||
|
||||
// Step 3.
|
||||
//
|
||||
// r = kG.x mod N
|
||||
// Repeat from step 1 if r = 0
|
||||
r, overflow := fieldToModNScalar(&kG.X)
|
||||
if r.IsZero() {
|
||||
return nil, 0, false
|
||||
}
|
||||
|
||||
// Since the secp256k1 curve has a cofactor of 1, when recovering a
|
||||
// public key from an ECDSA signature over it, there are four possible
|
||||
// candidates corresponding to the following cases:
|
||||
//
|
||||
// 1) The X coord of the random point is < N and its Y coord even
|
||||
// 2) The X coord of the random point is < N and its Y coord is odd
|
||||
// 3) The X coord of the random point is >= N and its Y coord is even
|
||||
// 4) The X coord of the random point is >= N and its Y coord is odd
|
||||
//
|
||||
// Rather than forcing the recovery procedure to check all possible
|
||||
// cases, this creates a recovery code that uniquely identifies which of
|
||||
// the cases apply by making use of 2 bits. Bit 0 identifies the
|
||||
// oddness case and Bit 1 identifies the overflow case (aka when the X
|
||||
// coord >= N).
|
||||
//
|
||||
// It is also worth noting that making use of Hasse's theorem shows
|
||||
// there are around log_2((p-n)/p) ~= -127.65 ~= 1 in 2^127 points where
|
||||
// the X coordinate is >= N. It is not possible to calculate these
|
||||
// points since that would require breaking the ECDLP, but, in practice
|
||||
// this strongly implies with extremely high probability that there are
|
||||
// only a few actual points for which this case is true.
|
||||
pubKeyRecoveryCode := byte(overflow<<1) | byte(kG.Y.IsOddBit())
|
||||
|
||||
// Step 4.
|
||||
//
|
||||
// e = H(m)
|
||||
//
|
||||
// Note that this actually sets e = H(m) mod N which is correct since
|
||||
// it is only used in step 5 which itself is mod N.
|
||||
var e secp256k1.ModNScalar
|
||||
e.SetByteSlice(hash)
|
||||
|
||||
// Step 5 with modification B.
|
||||
//
|
||||
// s = k^-1(e + dr) mod N
|
||||
// Repeat from step 1 if s = 0
|
||||
// s = -s if s > N/2
|
||||
kinv := new(secp256k1.ModNScalar).InverseValNonConst(k)
|
||||
s := new(secp256k1.ModNScalar).Mul2(privKey, &r).Add(&e).Mul(kinv)
|
||||
if s.IsZero() {
|
||||
return nil, 0, false
|
||||
}
|
||||
if s.IsOverHalfOrder() {
|
||||
s.Negate()
|
||||
|
||||
// Negating s corresponds to the random point that would have been
|
||||
// generated by -k (mod N), which necessarily has the opposite
|
||||
// oddness since N is prime, thus flip the pubkey recovery code
|
||||
// oddness bit accordingly.
|
||||
pubKeyRecoveryCode ^= 0x01
|
||||
}
|
||||
|
||||
// Step 6.
|
||||
//
|
||||
// Return (r,s)
|
||||
return NewSignature(&r, s), pubKeyRecoveryCode, true
|
||||
}
|
||||
|
||||
// signRFC6979 generates a deterministic ECDSA signature according to RFC 6979
|
||||
// and BIP 62 and returns it along with an additional public key recovery code
|
||||
// and BIP0062 and returns it along with an additional public key recovery code
|
||||
// for efficiently recovering the public key from the signature.
|
||||
func signRFC6979(privKey *secp256k1.PrivateKey, hash []byte) (*Signature, byte) {
|
||||
// The algorithm for producing an ECDSA signature is given as algorithm 4.29
|
||||
|
@ -581,82 +702,14 @@ func signRFC6979(privKey *secp256k1.PrivateKey, hash []byte) (*Signature, byte)
|
|||
// private key, message being signed, and iteration count.
|
||||
k := secp256k1.NonceRFC6979(privKeyBytes[:], hash, nil, nil, iteration)
|
||||
|
||||
// Step 2.
|
||||
//
|
||||
// Compute kG
|
||||
//
|
||||
// Note that the point must be in affine coordinates.
|
||||
var kG secp256k1.JacobianPoint
|
||||
secp256k1.ScalarBaseMultNonConst(k, &kG)
|
||||
kG.ToAffine()
|
||||
|
||||
// Step 3.
|
||||
//
|
||||
// r = kG.x mod N
|
||||
// Repeat from step 1 if r = 0
|
||||
r, overflow := fieldToModNScalar(&kG.X)
|
||||
if r.IsZero() {
|
||||
k.Zero()
|
||||
continue
|
||||
}
|
||||
|
||||
// Since the secp256k1 curve has a cofactor of 1, when recovering a
|
||||
// public key from an ECDSA signature over it, there are four possible
|
||||
// candidates corresponding to the following cases:
|
||||
//
|
||||
// 1) The X coord of the random point is < N and its Y coord even
|
||||
// 2) The X coord of the random point is < N and its Y coord is odd
|
||||
// 3) The X coord of the random point is >= N and its Y coord is even
|
||||
// 4) The X coord of the random point is >= N and its Y coord is odd
|
||||
//
|
||||
// Rather than forcing the recovery procedure to check all possible
|
||||
// cases, this creates a recovery code that uniquely identifies which of
|
||||
// the cases apply by making use of 2 bits. Bit 0 identifies the
|
||||
// oddness case and Bit 1 identifies the overflow case (aka when the X
|
||||
// coord >= N).
|
||||
//
|
||||
// It is also worth noting that making use of Hasse's theorem shows
|
||||
// there are around log_2((p-n)/p) ~= -127.65 ~= 1 in 2^127 points where
|
||||
// the X coordinate is >= N. It is not possible to calculate these
|
||||
// points since that would require breaking the ECDLP, but, in practice
|
||||
// this strongly implies with extremely high probability that there are
|
||||
// only a few actual points for which this case is true.
|
||||
pubKeyRecoveryCode := byte(overflow<<1) | byte(kG.Y.IsOddBit())
|
||||
|
||||
// Step 4.
|
||||
//
|
||||
// e = H(m)
|
||||
//
|
||||
// Note that this actually sets e = H(m) mod N which is correct since
|
||||
// it is only used in step 5 which itself is mod N.
|
||||
var e secp256k1.ModNScalar
|
||||
e.SetByteSlice(hash)
|
||||
|
||||
// Step 5 with modification B.
|
||||
//
|
||||
// s = k^-1(e + dr) mod N
|
||||
// Repeat from step 1 if s = 0
|
||||
// s = -s if s > N/2
|
||||
kInv := new(secp256k1.ModNScalar).InverseValNonConst(k)
|
||||
// Steps 2-6.
|
||||
sig, pubKeyRecoveryCode, success := sign(privKeyScalar, k, hash)
|
||||
k.Zero()
|
||||
s := new(secp256k1.ModNScalar).Mul2(privKeyScalar, &r).Add(&e).Mul(kInv)
|
||||
if s.IsZero() {
|
||||
if !success {
|
||||
continue
|
||||
}
|
||||
if s.IsOverHalfOrder() {
|
||||
s.Negate()
|
||||
|
||||
// Negating s corresponds to the random point that would have been
|
||||
// generated by -k (mod N), which necessarily has the opposite
|
||||
// oddness since N is prime, thus flip the pubkey recovery code
|
||||
// oddness bit accordingly.
|
||||
pubKeyRecoveryCode ^= 0x01
|
||||
}
|
||||
|
||||
// Step 6.
|
||||
//
|
||||
// Return (r,s)
|
||||
return NewSignature(&r, s), pubKeyRecoveryCode
|
||||
return sig, pubKeyRecoveryCode
|
||||
}
|
||||
}
|
||||
|
||||
|
@ -769,7 +822,7 @@ func RecoverCompact(signature, hash []byte) (*secp256k1.PublicKey, bool, error)
|
|||
// The method described by section 4.1.6 of [SEC1] to determine which one is
|
||||
// the correct one involves calculating each possibility as a candidate
|
||||
// public key and comparing the candidate to the authentic public key. It
|
||||
// also hints that is is possible to generate the signature in a such a
|
||||
// also hints that it is possible to generate the signature in a such a
|
||||
// way that only one of the candidate public keys is viable.
|
||||
//
|
||||
// A more efficient approach that is specific to the secp256k1 curve is used
|
||||
|
@ -797,7 +850,9 @@ func RecoverCompact(signature, hash []byte) (*secp256k1.PublicKey, bool, error)
|
|||
// A compact signature consists of a recovery byte followed by the R and
|
||||
// S components serialized as 32-byte big-endian values.
|
||||
if len(signature) != compactSigSize {
|
||||
return nil, false, errors.New("invalid compact signature size")
|
||||
str := fmt.Sprintf("malformed signature: wrong size: %d != %d",
|
||||
len(signature), compactSigSize)
|
||||
return nil, false, signatureError(ErrSigInvalidLen, str)
|
||||
}
|
||||
|
||||
// Parse and validate the compact signature recovery code.
|
||||
|
@ -807,7 +862,10 @@ func RecoverCompact(signature, hash []byte) (*secp256k1.PublicKey, bool, error)
|
|||
)
|
||||
sigRecoveryCode := signature[0]
|
||||
if sigRecoveryCode < minValidCode || sigRecoveryCode > maxValidCode {
|
||||
return nil, false, errors.New("invalid compact signature recovery code")
|
||||
str := fmt.Sprintf("invalid signature: public key recovery code %d is "+
|
||||
"not in the valid range [%d, %d]", sigRecoveryCode, minValidCode,
|
||||
maxValidCode)
|
||||
return nil, false, signatureError(ErrSigInvalidRecoveryCode, str)
|
||||
}
|
||||
sigRecoveryCode -= compactSigMagicOffset
|
||||
wasCompressed := sigRecoveryCode&compactSigCompPubKey != 0
|
||||
|
@ -820,16 +878,20 @@ func RecoverCompact(signature, hash []byte) (*secp256k1.PublicKey, bool, error)
|
|||
// Fail if r and s are not in [1, N-1].
|
||||
var r, s secp256k1.ModNScalar
|
||||
if overflow := r.SetByteSlice(signature[1:33]); overflow {
|
||||
return nil, false, errors.New("signature R is >= curve order")
|
||||
str := "invalid signature: R >= group order"
|
||||
return nil, false, signatureError(ErrSigRTooBig, str)
|
||||
}
|
||||
if r.IsZero() {
|
||||
return nil, false, errors.New("signature R is 0")
|
||||
str := "invalid signature: R is 0"
|
||||
return nil, false, signatureError(ErrSigRIsZero, str)
|
||||
}
|
||||
if overflow := s.SetByteSlice(signature[33:]); overflow {
|
||||
return nil, false, errors.New("signature S is >= curve order")
|
||||
str := "invalid signature: S >= group order"
|
||||
return nil, false, signatureError(ErrSigSTooBig, str)
|
||||
}
|
||||
if s.IsZero() {
|
||||
return nil, false, errors.New("signature S is 0")
|
||||
str := "invalid signature: S is 0"
|
||||
return nil, false, signatureError(ErrSigSIsZero, str)
|
||||
}
|
||||
|
||||
// Step 2.
|
||||
|
@ -850,13 +912,14 @@ func RecoverCompact(signature, hash []byte) (*secp256k1.PublicKey, bool, error)
|
|||
// would exceed the field prime since R originally came from the X
|
||||
// coordinate of a random point on the curve.
|
||||
if fieldR.IsGtOrEqPrimeMinusOrder() {
|
||||
return nil, false, errors.New("signature R + N >= P")
|
||||
str := "invalid signature: signature R + N >= P"
|
||||
return nil, false, signatureError(ErrSigOverflowsPrime, str)
|
||||
}
|
||||
|
||||
// Step 3.2.
|
||||
//
|
||||
// r = r + N (mod P)
|
||||
fieldR.Add(orderAsFieldVal)
|
||||
fieldR.Add(&orderAsFieldVal)
|
||||
}
|
||||
|
||||
// Step 4.
|
||||
|
@ -871,15 +934,16 @@ func RecoverCompact(signature, hash []byte) (*secp256k1.PublicKey, bool, error)
|
|||
oddY := pubKeyRecoveryCode&pubKeyRecoveryCodeOddnessBit != 0
|
||||
var y secp256k1.FieldVal
|
||||
if valid := secp256k1.DecompressY(&fieldR, oddY, &y); !valid {
|
||||
return nil, false, errors.New("signature is not for a valid curve point")
|
||||
str := "invalid signature: not for a valid curve point"
|
||||
return nil, false, signatureError(ErrPointNotOnCurve, str)
|
||||
}
|
||||
|
||||
// Step 5.
|
||||
//
|
||||
// X = (r, y)
|
||||
var X secp256k1.JacobianPoint
|
||||
X.X.Set(&fieldR)
|
||||
X.Y.Set(&y)
|
||||
X.X.Set(fieldR.Normalize())
|
||||
X.Y.Set(y.Normalize())
|
||||
X.Z.SetInt(1)
|
||||
|
||||
// Step 6.
|
||||
|
@ -915,7 +979,8 @@ func RecoverCompact(signature, hash []byte) (*secp256k1.PublicKey, bool, error)
|
|||
// Either the signature or the pubkey recovery code must be invalid if the
|
||||
// recovered pubkey is the point at infinity.
|
||||
if (Q.X.IsZero() && Q.Y.IsZero()) || Q.Z.IsZero() {
|
||||
return nil, false, errors.New("recovered pubkey is the point at infinity")
|
||||
str := "invalid signature: recovered pubkey is the point at infinity"
|
||||
return nil, false, signatureError(ErrPointNotOnCurve, str)
|
||||
}
|
||||
|
||||
// Notice that the public key is in affine coordinates.
|
||||
|
|
|
@ -1,4 +1,4 @@
|
|||
// Copyright 2020-2021 The Decred developers
|
||||
// Copyright 2020-2022 The Decred developers
|
||||
// Use of this source code is governed by an ISC
|
||||
// license that can be found in the LICENSE file.
|
||||
|
||||
|
@ -102,7 +102,7 @@ func (curve *KoblitzCurve) IsOnCurve(x, y *big.Int) bool {
|
|||
//
|
||||
// This is part of the elliptic.Curve interface implementation.
|
||||
func (curve *KoblitzCurve) Add(x1, y1, x2, y2 *big.Int) (*big.Int, *big.Int) {
|
||||
// A point at infinity is the identity according to the group law for
|
||||
// The point at infinity is the identity according to the group law for
|
||||
// elliptic curve cryptography. Thus, ∞ + P = P and P + ∞ = P.
|
||||
if x1.Sign() == 0 && y1.Sign() == 0 {
|
||||
return x2, y2
|
||||
|
@ -249,7 +249,7 @@ var secp256k1 = &KoblitzCurve{
|
|||
},
|
||||
}
|
||||
|
||||
// S256 returns a Curve which implements secp256k1.
|
||||
// S256 returns an elliptic.Curve which implements secp256k1.
|
||||
func S256() *KoblitzCurve {
|
||||
return secp256k1
|
||||
}
|
||||
|
|
|
@ -1,6 +1,6 @@
|
|||
// Copyright (c) 2013-2014 The btcsuite developers
|
||||
// Copyright (c) 2015-2021 The Decred developers
|
||||
// Copyright (c) 2013-2021 Dave Collins
|
||||
// Copyright (c) 2015-2022 The Decred developers
|
||||
// Copyright (c) 2013-2022 Dave Collins
|
||||
// Use of this source code is governed by an ISC
|
||||
// license that can be found in the LICENSE file.
|
||||
|
||||
|
@ -103,7 +103,8 @@ const (
|
|||
|
||||
// FieldVal implements optimized fixed-precision arithmetic over the
|
||||
// secp256k1 finite field. This means all arithmetic is performed modulo
|
||||
// 0xfffffffffffffffffffffffffffffffffffffffffffffffffffffffefffffc2f.
|
||||
//
|
||||
// 0xfffffffffffffffffffffffffffffffffffffffffffffffffffffffefffffc2f.
|
||||
//
|
||||
// WARNING: Since it is so important for the field arithmetic to be extremely
|
||||
// fast for high performance crypto, this type does not perform any validation
|
||||
|
@ -172,9 +173,9 @@ type FieldVal struct {
|
|||
|
||||
// String returns the field value as a normalized human-readable hex string.
|
||||
//
|
||||
// Preconditions: None
|
||||
// Output Normalized: Field is not modified -- same as input value
|
||||
// Output Max Magnitude: Field is not modified -- same as input value
|
||||
// Preconditions: None
|
||||
// Output Normalized: Field is not modified -- same as input value
|
||||
// Output Max Magnitude: Field is not modified -- same as input value
|
||||
func (f FieldVal) String() string {
|
||||
// f is a copy, so it's safe to normalize it without mutating the original.
|
||||
f.Normalize()
|
||||
|
@ -185,9 +186,9 @@ func (f FieldVal) String() string {
|
|||
// value is already set to zero. This function can be useful to clear an
|
||||
// existing field value for reuse.
|
||||
//
|
||||
// Preconditions: None
|
||||
// Output Normalized: Yes
|
||||
// Output Max Magnitude: 1
|
||||
// Preconditions: None
|
||||
// Output Normalized: Yes
|
||||
// Output Max Magnitude: 1
|
||||
func (f *FieldVal) Zero() {
|
||||
f.n[0] = 0
|
||||
f.n[1] = 0
|
||||
|
@ -208,9 +209,9 @@ func (f *FieldVal) Zero() {
|
|||
// f := new(FieldVal).Set(f2).Add(1) so that f = f2 + 1 where f2 is not
|
||||
// modified.
|
||||
//
|
||||
// Preconditions: None
|
||||
// Output Normalized: Same as input value
|
||||
// Output Max Magnitude: Same as input value
|
||||
// Preconditions: None
|
||||
// Output Normalized: Same as input value
|
||||
// Output Max Magnitude: Same as input value
|
||||
func (f *FieldVal) Set(val *FieldVal) *FieldVal {
|
||||
*f = *val
|
||||
return f
|
||||
|
@ -223,9 +224,9 @@ func (f *FieldVal) Set(val *FieldVal) *FieldVal {
|
|||
// The field value is returned to support chaining. This enables syntax such
|
||||
// as f := new(FieldVal).SetInt(2).Mul(f2) so that f = 2 * f2.
|
||||
//
|
||||
// Preconditions: None
|
||||
// Output Normalized: Yes
|
||||
// Output Max Magnitude: 1
|
||||
// Preconditions: None
|
||||
// Output Normalized: Yes
|
||||
// Output Max Magnitude: 1
|
||||
func (f *FieldVal) SetInt(ui uint16) *FieldVal {
|
||||
f.Zero()
|
||||
f.n[0] = uint32(ui)
|
||||
|
@ -242,9 +243,9 @@ func (f *FieldVal) SetInt(ui uint16) *FieldVal {
|
|||
// from a bool to numeric value in constant time and many constant-time
|
||||
// operations require a numeric value.
|
||||
//
|
||||
// Preconditions: None
|
||||
// Output Normalized: Yes if no overflow, no otherwise
|
||||
// Output Max Magnitude: 1
|
||||
// Preconditions: None
|
||||
// Output Normalized: Yes if no overflow, no otherwise
|
||||
// Output Max Magnitude: 1
|
||||
func (f *FieldVal) SetBytes(b *[32]byte) uint32 {
|
||||
// Pack the 256 total bits across the 10 uint32 words with a max of
|
||||
// 26-bits per word. This could be done with a couple of for loops,
|
||||
|
@ -311,9 +312,9 @@ func (f *FieldVal) SetBytes(b *[32]byte) uint32 {
|
|||
// or it if is acceptable to use this function with the described truncation and
|
||||
// overflow behavior.
|
||||
//
|
||||
// Preconditions: None
|
||||
// Output Normalized: Yes if no overflow, no otherwise
|
||||
// Output Max Magnitude: 1
|
||||
// Preconditions: None
|
||||
// Output Normalized: Yes if no overflow, no otherwise
|
||||
// Output Max Magnitude: 1
|
||||
func (f *FieldVal) SetByteSlice(b []byte) bool {
|
||||
var b32 [32]byte
|
||||
b = b[:constantTimeMin(uint32(len(b)), 32)]
|
||||
|
@ -328,9 +329,9 @@ func (f *FieldVal) SetByteSlice(b []byte) bool {
|
|||
// performs fast modular reduction over the secp256k1 prime by making use of the
|
||||
// special form of the prime in constant time.
|
||||
//
|
||||
// Preconditions: None
|
||||
// Output Normalized: Yes
|
||||
// Output Max Magnitude: 1
|
||||
// Preconditions: None
|
||||
// Output Normalized: Yes
|
||||
// Output Max Magnitude: 1
|
||||
func (f *FieldVal) Normalize() *FieldVal {
|
||||
// The field representation leaves 6 bits of overflow in each word so
|
||||
// intermediate calculations can be performed without needing to
|
||||
|
@ -441,9 +442,9 @@ func (f *FieldVal) Normalize() *FieldVal {
|
|||
// to write directly into part of a larger buffer without needing a separate
|
||||
// allocation.
|
||||
//
|
||||
// Preconditions:
|
||||
// - The field value MUST be normalized
|
||||
// - The target slice MUST have at least 32 bytes available
|
||||
// Preconditions:
|
||||
// - The field value MUST be normalized
|
||||
// - The target slice MUST have at least 32 bytes available
|
||||
func (f *FieldVal) PutBytesUnchecked(b []byte) {
|
||||
// Unpack the 256 total bits from the 10 uint32 words with a max of
|
||||
// 26-bits per word. This could be done with a couple of for loops,
|
||||
|
@ -495,8 +496,8 @@ func (f *FieldVal) PutBytesUnchecked(b []byte) {
|
|||
// array and returns that which can sometimes be more ergonomic in applications
|
||||
// that aren't concerned about an additional copy.
|
||||
//
|
||||
// Preconditions:
|
||||
// - The field value MUST be normalized
|
||||
// Preconditions:
|
||||
// - The field value MUST be normalized
|
||||
func (f *FieldVal) PutBytes(b *[32]byte) {
|
||||
f.PutBytesUnchecked(b[:])
|
||||
}
|
||||
|
@ -508,8 +509,8 @@ func (f *FieldVal) PutBytes(b *[32]byte) {
|
|||
// allowing the caller to reuse a buffer or write directly into part of a larger
|
||||
// buffer.
|
||||
//
|
||||
// Preconditions:
|
||||
// - The field value MUST be normalized
|
||||
// Preconditions:
|
||||
// - The field value MUST be normalized
|
||||
func (f *FieldVal) Bytes() *[32]byte {
|
||||
b := new([32]byte)
|
||||
f.PutBytesUnchecked(b[:])
|
||||
|
@ -524,8 +525,8 @@ func (f *FieldVal) Bytes() *[32]byte {
|
|||
// operations require a numeric value. See IsZero for the version that returns
|
||||
// a bool.
|
||||
//
|
||||
// Preconditions:
|
||||
// - The field value MUST be normalized
|
||||
// Preconditions:
|
||||
// - The field value MUST be normalized
|
||||
func (f *FieldVal) IsZeroBit() uint32 {
|
||||
// The value can only be zero if no bits are set in any of the words.
|
||||
// This is a constant time implementation.
|
||||
|
@ -538,8 +539,8 @@ func (f *FieldVal) IsZeroBit() uint32 {
|
|||
// IsZero returns whether or not the field value is equal to zero in constant
|
||||
// time.
|
||||
//
|
||||
// Preconditions:
|
||||
// - The field value MUST be normalized
|
||||
// Preconditions:
|
||||
// - The field value MUST be normalized
|
||||
func (f *FieldVal) IsZero() bool {
|
||||
// The value can only be zero if no bits are set in any of the words.
|
||||
// This is a constant time implementation.
|
||||
|
@ -557,8 +558,8 @@ func (f *FieldVal) IsZero() bool {
|
|||
// operations require a numeric value. See IsOne for the version that returns a
|
||||
// bool.
|
||||
//
|
||||
// Preconditions:
|
||||
// - The field value MUST be normalized
|
||||
// Preconditions:
|
||||
// - The field value MUST be normalized
|
||||
func (f *FieldVal) IsOneBit() uint32 {
|
||||
// The value can only be one if the single lowest significant bit is set in
|
||||
// the first word and no other bits are set in any of the other words.
|
||||
|
@ -572,8 +573,8 @@ func (f *FieldVal) IsOneBit() uint32 {
|
|||
// IsOne returns whether or not the field value is equal to one in constant
|
||||
// time.
|
||||
//
|
||||
// Preconditions:
|
||||
// - The field value MUST be normalized
|
||||
// Preconditions:
|
||||
// - The field value MUST be normalized
|
||||
func (f *FieldVal) IsOne() bool {
|
||||
// The value can only be one if the single lowest significant bit is set in
|
||||
// the first word and no other bits are set in any of the other words.
|
||||
|
@ -592,8 +593,8 @@ func (f *FieldVal) IsOne() bool {
|
|||
// operations require a numeric value. See IsOdd for the version that returns a
|
||||
// bool.
|
||||
//
|
||||
// Preconditions:
|
||||
// - The field value MUST be normalized
|
||||
// Preconditions:
|
||||
// - The field value MUST be normalized
|
||||
func (f *FieldVal) IsOddBit() uint32 {
|
||||
// Only odd numbers have the bottom bit set.
|
||||
return f.n[0] & 1
|
||||
|
@ -602,8 +603,8 @@ func (f *FieldVal) IsOddBit() uint32 {
|
|||
// IsOdd returns whether or not the field value is an odd number in constant
|
||||
// time.
|
||||
//
|
||||
// Preconditions:
|
||||
// - The field value MUST be normalized
|
||||
// Preconditions:
|
||||
// - The field value MUST be normalized
|
||||
func (f *FieldVal) IsOdd() bool {
|
||||
// Only odd numbers have the bottom bit set.
|
||||
return f.n[0]&1 == 1
|
||||
|
@ -612,8 +613,8 @@ func (f *FieldVal) IsOdd() bool {
|
|||
// Equals returns whether or not the two field values are the same in constant
|
||||
// time.
|
||||
//
|
||||
// Preconditions:
|
||||
// - Both field values being compared MUST be normalized
|
||||
// Preconditions:
|
||||
// - Both field values being compared MUST be normalized
|
||||
func (f *FieldVal) Equals(val *FieldVal) bool {
|
||||
// Xor only sets bits when they are different, so the two field values
|
||||
// can only be the same if no bits are set after xoring each word.
|
||||
|
@ -633,10 +634,10 @@ func (f *FieldVal) Equals(val *FieldVal) bool {
|
|||
// The field value is returned to support chaining. This enables syntax like:
|
||||
// f.NegateVal(f2).AddInt(1) so that f = -f2 + 1.
|
||||
//
|
||||
// Preconditions:
|
||||
// - The max magnitude MUST be 63
|
||||
// Output Normalized: No
|
||||
// Output Max Magnitude: Input magnitude + 1
|
||||
// Preconditions:
|
||||
// - The max magnitude MUST be 63
|
||||
// Output Normalized: No
|
||||
// Output Max Magnitude: Input magnitude + 1
|
||||
func (f *FieldVal) NegateVal(val *FieldVal, magnitude uint32) *FieldVal {
|
||||
// Negation in the field is just the prime minus the value. However,
|
||||
// in order to allow negation against a field value without having to
|
||||
|
@ -677,10 +678,10 @@ func (f *FieldVal) NegateVal(val *FieldVal, magnitude uint32) *FieldVal {
|
|||
// The field value is returned to support chaining. This enables syntax like:
|
||||
// f.Negate().AddInt(1) so that f = -f + 1.
|
||||
//
|
||||
// Preconditions:
|
||||
// - The max magnitude MUST be 63
|
||||
// Output Normalized: No
|
||||
// Output Max Magnitude: Input magnitude + 1
|
||||
// Preconditions:
|
||||
// - The max magnitude MUST be 63
|
||||
// Output Normalized: No
|
||||
// Output Max Magnitude: Input magnitude + 1
|
||||
func (f *FieldVal) Negate(magnitude uint32) *FieldVal {
|
||||
return f.NegateVal(f, magnitude)
|
||||
}
|
||||
|
@ -692,10 +693,10 @@ func (f *FieldVal) Negate(magnitude uint32) *FieldVal {
|
|||
// The field value is returned to support chaining. This enables syntax like:
|
||||
// f.AddInt(1).Add(f2) so that f = f + 1 + f2.
|
||||
//
|
||||
// Preconditions:
|
||||
// - The field value MUST have a max magnitude of 63
|
||||
// Output Normalized: No
|
||||
// Output Max Magnitude: Existing field magnitude + 1
|
||||
// Preconditions:
|
||||
// - The field value MUST have a max magnitude of 63
|
||||
// Output Normalized: No
|
||||
// Output Max Magnitude: Existing field magnitude + 1
|
||||
func (f *FieldVal) AddInt(ui uint16) *FieldVal {
|
||||
// Since the field representation intentionally provides overflow bits,
|
||||
// it's ok to use carryless addition as the carry bit is safely part of
|
||||
|
@ -711,10 +712,10 @@ func (f *FieldVal) AddInt(ui uint16) *FieldVal {
|
|||
// The field value is returned to support chaining. This enables syntax like:
|
||||
// f.Add(f2).AddInt(1) so that f = f + f2 + 1.
|
||||
//
|
||||
// Preconditions:
|
||||
// - The sum of the magnitudes of the two field values MUST be a max of 64
|
||||
// Output Normalized: No
|
||||
// Output Max Magnitude: Sum of the magnitude of the two individual field values
|
||||
// Preconditions:
|
||||
// - The sum of the magnitudes of the two field values MUST be a max of 64
|
||||
// Output Normalized: No
|
||||
// Output Max Magnitude: Sum of the magnitude of the two individual field values
|
||||
func (f *FieldVal) Add(val *FieldVal) *FieldVal {
|
||||
// Since the field representation intentionally provides overflow bits,
|
||||
// it's ok to use carryless addition as the carry bit is safely part of
|
||||
|
@ -740,10 +741,10 @@ func (f *FieldVal) Add(val *FieldVal) *FieldVal {
|
|||
// The field value is returned to support chaining. This enables syntax like:
|
||||
// f3.Add2(f, f2).AddInt(1) so that f3 = f + f2 + 1.
|
||||
//
|
||||
// Preconditions:
|
||||
// - The sum of the magnitudes of the two field values MUST be a max of 64
|
||||
// Output Normalized: No
|
||||
// Output Max Magnitude: Sum of the magnitude of the two field values
|
||||
// Preconditions:
|
||||
// - The sum of the magnitudes of the two field values MUST be a max of 64
|
||||
// Output Normalized: No
|
||||
// Output Max Magnitude: Sum of the magnitude of the two field values
|
||||
func (f *FieldVal) Add2(val *FieldVal, val2 *FieldVal) *FieldVal {
|
||||
// Since the field representation intentionally provides overflow bits,
|
||||
// it's ok to use carryless addition as the carry bit is safely part of
|
||||
|
@ -772,10 +773,10 @@ func (f *FieldVal) Add2(val *FieldVal, val2 *FieldVal) *FieldVal {
|
|||
// The field value is returned to support chaining. This enables syntax like:
|
||||
// f.MulInt(2).Add(f2) so that f = 2 * f + f2.
|
||||
//
|
||||
// Preconditions:
|
||||
// - The field value magnitude multiplied by given val MUST be a max of 64
|
||||
// Output Normalized: No
|
||||
// Output Max Magnitude: Existing field magnitude times the provided integer val
|
||||
// Preconditions:
|
||||
// - The field value magnitude multiplied by given val MUST be a max of 64
|
||||
// Output Normalized: No
|
||||
// Output Max Magnitude: Existing field magnitude times the provided integer val
|
||||
func (f *FieldVal) MulInt(val uint8) *FieldVal {
|
||||
// Since each word of the field representation can hold up to
|
||||
// 32 - fieldBase extra bits which will be normalized out, it's safe
|
||||
|
@ -807,10 +808,10 @@ func (f *FieldVal) MulInt(val uint8) *FieldVal {
|
|||
// The field value is returned to support chaining. This enables syntax like:
|
||||
// f.Mul(f2).AddInt(1) so that f = (f * f2) + 1.
|
||||
//
|
||||
// Preconditions:
|
||||
// - Both field values MUST have a max magnitude of 8
|
||||
// Output Normalized: No
|
||||
// Output Max Magnitude: 1
|
||||
// Preconditions:
|
||||
// - Both field values MUST have a max magnitude of 8
|
||||
// Output Normalized: No
|
||||
// Output Max Magnitude: 1
|
||||
func (f *FieldVal) Mul(val *FieldVal) *FieldVal {
|
||||
return f.Mul2(f, val)
|
||||
}
|
||||
|
@ -824,10 +825,10 @@ func (f *FieldVal) Mul(val *FieldVal) *FieldVal {
|
|||
// The field value is returned to support chaining. This enables syntax like:
|
||||
// f3.Mul2(f, f2).AddInt(1) so that f3 = (f * f2) + 1.
|
||||
//
|
||||
// Preconditions:
|
||||
// - Both input field values MUST have a max magnitude of 8
|
||||
// Output Normalized: No
|
||||
// Output Max Magnitude: 1
|
||||
// Preconditions:
|
||||
// - Both input field values MUST have a max magnitude of 8
|
||||
// Output Normalized: No
|
||||
// Output Max Magnitude: 1
|
||||
func (f *FieldVal) Mul2(val *FieldVal, val2 *FieldVal) *FieldVal {
|
||||
// This could be done with a couple of for loops and an array to store
|
||||
// the intermediate terms, but this unrolled version is significantly
|
||||
|
@ -1101,10 +1102,10 @@ func (f *FieldVal) Mul2(val *FieldVal, val2 *FieldVal) *FieldVal {
|
|||
// field must be a max of 8 to prevent overflow. The magnitude of the result
|
||||
// will be 1.
|
||||
//
|
||||
// Preconditions:
|
||||
// - The input field value MUST have a max magnitude of 8
|
||||
// Output Normalized: No
|
||||
// Output Max Magnitude: 1
|
||||
// Preconditions:
|
||||
// - The input field value MUST have a max magnitude of 8
|
||||
// Output Normalized: No
|
||||
// Output Max Magnitude: 1
|
||||
func (f *FieldVal) SquareRootVal(val *FieldVal) bool {
|
||||
// This uses the Tonelli-Shanks method for calculating the square root of
|
||||
// the value when it exists. The key principles of the method follow.
|
||||
|
@ -1156,7 +1157,7 @@ func (f *FieldVal) SquareRootVal(val *FieldVal) bool {
|
|||
// 10111111 11111111 11111111 00001100
|
||||
//
|
||||
// Notice that can be broken up into three windows of consecutive 1s (in
|
||||
// order of least to most signifcant) as:
|
||||
// order of least to most significant) as:
|
||||
//
|
||||
// 6-bit window with two bits set (bits 4, 5, 6, 7 unset)
|
||||
// 23-bit window with 22 bits set (bit 30 unset)
|
||||
|
@ -1262,10 +1263,10 @@ func (f *FieldVal) SquareRootVal(val *FieldVal) bool {
|
|||
// The field value is returned to support chaining. This enables syntax like:
|
||||
// f.Square().Mul(f2) so that f = f^2 * f2.
|
||||
//
|
||||
// Preconditions:
|
||||
// - The field value MUST have a max magnitude of 8
|
||||
// Output Normalized: No
|
||||
// Output Max Magnitude: 1
|
||||
// Preconditions:
|
||||
// - The field value MUST have a max magnitude of 8
|
||||
// Output Normalized: No
|
||||
// Output Max Magnitude: 1
|
||||
func (f *FieldVal) Square() *FieldVal {
|
||||
return f.SquareVal(f)
|
||||
}
|
||||
|
@ -1278,10 +1279,10 @@ func (f *FieldVal) Square() *FieldVal {
|
|||
// The field value is returned to support chaining. This enables syntax like:
|
||||
// f3.SquareVal(f).Mul(f) so that f3 = f^2 * f = f^3.
|
||||
//
|
||||
// Preconditions:
|
||||
// - The input field value MUST have a max magnitude of 8
|
||||
// Output Normalized: No
|
||||
// Output Max Magnitude: 1
|
||||
// Preconditions:
|
||||
// - The input field value MUST have a max magnitude of 8
|
||||
// Output Normalized: No
|
||||
// Output Max Magnitude: 1
|
||||
func (f *FieldVal) SquareVal(val *FieldVal) *FieldVal {
|
||||
// This could be done with a couple of for loops and an array to store
|
||||
// the intermediate terms, but this unrolled version is significantly
|
||||
|
@ -1503,10 +1504,10 @@ func (f *FieldVal) SquareVal(val *FieldVal) *FieldVal {
|
|||
// The field value is returned to support chaining. This enables syntax like:
|
||||
// f.Inverse().Mul(f2) so that f = f^-1 * f2.
|
||||
//
|
||||
// Preconditions:
|
||||
// - The field value MUST have a max magnitude of 8
|
||||
// Output Normalized: No
|
||||
// Output Max Magnitude: 1
|
||||
// Preconditions:
|
||||
// - The field value MUST have a max magnitude of 8
|
||||
// Output Normalized: No
|
||||
// Output Max Magnitude: 1
|
||||
func (f *FieldVal) Inverse() *FieldVal {
|
||||
// Fermat's little theorem states that for a nonzero number a and prime
|
||||
// prime p, a^(p-1) = 1 (mod p). Since the multiplicative inverse is
|
||||
|
@ -1614,8 +1615,8 @@ func (f *FieldVal) Inverse() *FieldVal {
|
|||
// IsGtOrEqPrimeMinusOrder returns whether or not the field value exceeds the
|
||||
// group order divided by 2 in constant time.
|
||||
//
|
||||
// Preconditions:
|
||||
// - The field value MUST be normalized
|
||||
// Preconditions:
|
||||
// - The field value MUST be normalized
|
||||
func (f *FieldVal) IsGtOrEqPrimeMinusOrder() bool {
|
||||
// The secp256k1 prime is equivalent to 2^256 - 4294968273 and the group
|
||||
// order is 2^256 - 432420386565659656852420866394968145599. Thus,
|
||||
|
|
|
@ -1,196 +0,0 @@
|
|||
// Copyright (c) 2014-2015 The btcsuite developers
|
||||
// Copyright (c) 2015-2021 The Decred developers
|
||||
// Use of this source code is governed by an ISC
|
||||
// license that can be found in the LICENSE file.
|
||||
|
||||
// This file is ignored during the regular build due to the following build tag.
|
||||
// This build tag is set during go generate.
|
||||
//go:build gensecp256k1
|
||||
// +build gensecp256k1
|
||||
|
||||
package secp256k1
|
||||
|
||||
// References:
|
||||
// [GECC]: Guide to Elliptic Curve Cryptography (Hankerson, Menezes, Vanstone)
|
||||
|
||||
import (
|
||||
"encoding/binary"
|
||||
"math/big"
|
||||
)
|
||||
|
||||
// compressedBytePoints are dummy points used so the code which generates the
|
||||
// real values can compile.
|
||||
var compressedBytePoints = ""
|
||||
|
||||
// SerializedBytePoints returns a serialized byte slice which contains all of
|
||||
// the possible points per 8-bit window. This is used to when generating
|
||||
// compressedbytepoints.go.
|
||||
func SerializedBytePoints() []byte {
|
||||
// Calculate G^(2^i) for i in 0..255. These are used to avoid recomputing
|
||||
// them for each digit of the 8-bit windows.
|
||||
doublingPoints := make([]JacobianPoint, curveParams.BitSize)
|
||||
var q JacobianPoint
|
||||
bigAffineToJacobian(curveParams.Gx, curveParams.Gy, &q)
|
||||
for i := 0; i < curveParams.BitSize; i++ {
|
||||
// Q = 2*Q.
|
||||
doublingPoints[i] = q
|
||||
DoubleNonConst(&q, &q)
|
||||
}
|
||||
|
||||
// Separate the bits into byte-sized windows.
|
||||
curveByteSize := curveParams.BitSize / 8
|
||||
serialized := make([]byte, curveByteSize*256*2*10*4)
|
||||
offset := 0
|
||||
for byteNum := 0; byteNum < curveByteSize; byteNum++ {
|
||||
// Grab the 8 bits that make up this byte from doubling points.
|
||||
startingBit := 8 * (curveByteSize - byteNum - 1)
|
||||
windowPoints := doublingPoints[startingBit : startingBit+8]
|
||||
|
||||
// Compute all points in this window, convert them to affine, and
|
||||
// serialize them.
|
||||
for i := 0; i < 256; i++ {
|
||||
var point JacobianPoint
|
||||
for bit := 0; bit < 8; bit++ {
|
||||
if i>>uint(bit)&1 == 1 {
|
||||
AddNonConst(&point, &windowPoints[bit], &point)
|
||||
}
|
||||
}
|
||||
point.ToAffine()
|
||||
|
||||
for i := 0; i < len(point.X.n); i++ {
|
||||
binary.LittleEndian.PutUint32(serialized[offset:], point.X.n[i])
|
||||
offset += 4
|
||||
}
|
||||
for i := 0; i < len(point.Y.n); i++ {
|
||||
binary.LittleEndian.PutUint32(serialized[offset:], point.Y.n[i])
|
||||
offset += 4
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
return serialized
|
||||
}
|
||||
|
||||
// sqrt returns the square root of the provided big integer using Newton's
|
||||
// method. It's only compiled and used during generation of pre-computed
|
||||
// values, so speed is not a huge concern.
|
||||
func sqrt(n *big.Int) *big.Int {
|
||||
// Initial guess = 2^(log_2(n)/2)
|
||||
guess := big.NewInt(2)
|
||||
guess.Exp(guess, big.NewInt(int64(n.BitLen()/2)), nil)
|
||||
|
||||
// Now refine using Newton's method.
|
||||
big2 := big.NewInt(2)
|
||||
prevGuess := big.NewInt(0)
|
||||
for {
|
||||
prevGuess.Set(guess)
|
||||
guess.Add(guess, new(big.Int).Div(n, guess))
|
||||
guess.Div(guess, big2)
|
||||
if guess.Cmp(prevGuess) == 0 {
|
||||
break
|
||||
}
|
||||
}
|
||||
return guess
|
||||
}
|
||||
|
||||
// EndomorphismVectors runs the first 3 steps of algorithm 3.74 from [GECC] to
|
||||
// generate the linearly independent vectors needed to generate a balanced
|
||||
// length-two representation of a multiplier such that k = k1 + k2λ (mod N) and
|
||||
// returns them. Since the values will always be the same given the fact that N
|
||||
// and λ are fixed, the final results can be accelerated by storing the
|
||||
// precomputed values.
|
||||
func EndomorphismVectors() (a1, b1, a2, b2 *big.Int) {
|
||||
bigMinus1 := big.NewInt(-1)
|
||||
|
||||
// This section uses an extended Euclidean algorithm to generate a
|
||||
// sequence of equations:
|
||||
// s[i] * N + t[i] * λ = r[i]
|
||||
|
||||
nSqrt := sqrt(curveParams.N)
|
||||
u, v := new(big.Int).Set(curveParams.N), new(big.Int).Set(endomorphismLambda)
|
||||
x1, y1 := big.NewInt(1), big.NewInt(0)
|
||||
x2, y2 := big.NewInt(0), big.NewInt(1)
|
||||
q, r := new(big.Int), new(big.Int)
|
||||
qu, qx1, qy1 := new(big.Int), new(big.Int), new(big.Int)
|
||||
s, t := new(big.Int), new(big.Int)
|
||||
ri, ti := new(big.Int), new(big.Int)
|
||||
a1, b1, a2, b2 = new(big.Int), new(big.Int), new(big.Int), new(big.Int)
|
||||
found, oneMore := false, false
|
||||
for u.Sign() != 0 {
|
||||
// q = v/u
|
||||
q.Div(v, u)
|
||||
|
||||
// r = v - q*u
|
||||
qu.Mul(q, u)
|
||||
r.Sub(v, qu)
|
||||
|
||||
// s = x2 - q*x1
|
||||
qx1.Mul(q, x1)
|
||||
s.Sub(x2, qx1)
|
||||
|
||||
// t = y2 - q*y1
|
||||
qy1.Mul(q, y1)
|
||||
t.Sub(y2, qy1)
|
||||
|
||||
// v = u, u = r, x2 = x1, x1 = s, y2 = y1, y1 = t
|
||||
v.Set(u)
|
||||
u.Set(r)
|
||||
x2.Set(x1)
|
||||
x1.Set(s)
|
||||
y2.Set(y1)
|
||||
y1.Set(t)
|
||||
|
||||
// As soon as the remainder is less than the sqrt of n, the
|
||||
// values of a1 and b1 are known.
|
||||
if !found && r.Cmp(nSqrt) < 0 {
|
||||
// When this condition executes ri and ti represent the
|
||||
// r[i] and t[i] values such that i is the greatest
|
||||
// index for which r >= sqrt(n). Meanwhile, the current
|
||||
// r and t values are r[i+1] and t[i+1], respectively.
|
||||
|
||||
// a1 = r[i+1], b1 = -t[i+1]
|
||||
a1.Set(r)
|
||||
b1.Mul(t, bigMinus1)
|
||||
found = true
|
||||
oneMore = true
|
||||
|
||||
// Skip to the next iteration so ri and ti are not
|
||||
// modified.
|
||||
continue
|
||||
|
||||
} else if oneMore {
|
||||
// When this condition executes ri and ti still
|
||||
// represent the r[i] and t[i] values while the current
|
||||
// r and t are r[i+2] and t[i+2], respectively.
|
||||
|
||||
// sum1 = r[i]^2 + t[i]^2
|
||||
rSquared := new(big.Int).Mul(ri, ri)
|
||||
tSquared := new(big.Int).Mul(ti, ti)
|
||||
sum1 := new(big.Int).Add(rSquared, tSquared)
|
||||
|
||||
// sum2 = r[i+2]^2 + t[i+2]^2
|
||||
r2Squared := new(big.Int).Mul(r, r)
|
||||
t2Squared := new(big.Int).Mul(t, t)
|
||||
sum2 := new(big.Int).Add(r2Squared, t2Squared)
|
||||
|
||||
// if (r[i]^2 + t[i]^2) <= (r[i+2]^2 + t[i+2]^2)
|
||||
if sum1.Cmp(sum2) <= 0 {
|
||||
// a2 = r[i], b2 = -t[i]
|
||||
a2.Set(ri)
|
||||
b2.Mul(ti, bigMinus1)
|
||||
} else {
|
||||
// a2 = r[i+2], b2 = -t[i+2]
|
||||
a2.Set(r)
|
||||
b2.Mul(t, bigMinus1)
|
||||
}
|
||||
|
||||
// All done.
|
||||
break
|
||||
}
|
||||
|
||||
ri.Set(r)
|
||||
ti.Set(t)
|
||||
}
|
||||
|
||||
return a1, b1, a2, b2
|
||||
}
|
|
@ -1,5 +1,5 @@
|
|||
// Copyright 2015 The btcsuite developers
|
||||
// Copyright (c) 2015-2021 The Decred developers
|
||||
// Copyright (c) 2015-2022 The Decred developers
|
||||
// Use of this source code is governed by an ISC
|
||||
// license that can be found in the LICENSE file.
|
||||
|
||||
|
@ -8,17 +8,21 @@ package secp256k1
|
|||
import (
|
||||
"compress/zlib"
|
||||
"encoding/base64"
|
||||
"encoding/binary"
|
||||
"io"
|
||||
"strings"
|
||||
"sync"
|
||||
)
|
||||
|
||||
//go:generate go run -tags gensecp256k1 genprecomps.go
|
||||
//go:generate go run genprecomps.go
|
||||
|
||||
// bytePointTable describes a table used to house pre-computed values for
|
||||
// accelerating scalar base multiplication.
|
||||
type bytePointTable [32][256][2]FieldVal
|
||||
type bytePointTable [32][256]JacobianPoint
|
||||
|
||||
// compressedBytePointsFn is set to a real function by the code generation to
|
||||
// return the compressed pre-computed values for accelerating scalar base
|
||||
// multiplication.
|
||||
var compressedBytePointsFn func() string
|
||||
|
||||
// s256BytePoints houses pre-computed values used to accelerate scalar base
|
||||
// multiplication such that they are only loaded on first use.
|
||||
|
@ -38,10 +42,10 @@ var s256BytePoints = func() func() *bytePointTable {
|
|||
var data *bytePointTable
|
||||
mustLoadBytePoints := func() {
|
||||
// There will be no byte points to load when generating them.
|
||||
bp := compressedBytePoints
|
||||
if len(bp) == 0 {
|
||||
if compressedBytePointsFn == nil {
|
||||
return
|
||||
}
|
||||
bp := compressedBytePointsFn()
|
||||
|
||||
// Decompress the pre-computed table used to accelerate scalar base
|
||||
// multiplication.
|
||||
|
@ -62,16 +66,12 @@ var s256BytePoints = func() func() *bytePointTable {
|
|||
for byteNum := 0; byteNum < len(bytePoints); byteNum++ {
|
||||
// All points in this window.
|
||||
for i := 0; i < len(bytePoints[byteNum]); i++ {
|
||||
px := &bytePoints[byteNum][i][0]
|
||||
py := &bytePoints[byteNum][i][1]
|
||||
for i := 0; i < len(px.n); i++ {
|
||||
px.n[i] = binary.LittleEndian.Uint32(serialized[offset:])
|
||||
offset += 4
|
||||
}
|
||||
for i := 0; i < len(py.n); i++ {
|
||||
py.n[i] = binary.LittleEndian.Uint32(serialized[offset:])
|
||||
offset += 4
|
||||
}
|
||||
p := &bytePoints[byteNum][i]
|
||||
p.X.SetByteSlice(serialized[offset:])
|
||||
offset += 32
|
||||
p.Y.SetByteSlice(serialized[offset:])
|
||||
offset += 32
|
||||
p.Z.SetInt(1)
|
||||
}
|
||||
}
|
||||
data = &bytePoints
|
||||
|
|
|
@ -1,4 +1,4 @@
|
|||
// Copyright (c) 2020 The Decred developers
|
||||
// Copyright (c) 2020-2022 The Decred developers
|
||||
// Use of this source code is governed by an ISC
|
||||
// license that can be found in the LICENSE file.
|
||||
|
||||
|
@ -100,7 +100,7 @@ var (
|
|||
// arithmetic over the secp256k1 group order. This means all arithmetic is
|
||||
// performed modulo:
|
||||
//
|
||||
// 0xfffffffffffffffffffffffffffffffebaaedce6af48a03bbfd25e8cd0364141
|
||||
// 0xfffffffffffffffffffffffffffffffebaaedce6af48a03bbfd25e8cd0364141
|
||||
//
|
||||
// It only implements the arithmetic needed for elliptic curve operations,
|
||||
// however, the operations that are not implemented can typically be worked
|
||||
|
@ -174,6 +174,19 @@ func (s *ModNScalar) Zero() {
|
|||
s.n[7] = 0
|
||||
}
|
||||
|
||||
// IsZeroBit returns 1 when the scalar is equal to zero or 0 otherwise in
|
||||
// constant time.
|
||||
//
|
||||
// Note that a bool is not used here because it is not possible in Go to convert
|
||||
// from a bool to numeric value in constant time and many constant-time
|
||||
// operations require a numeric value. See IsZero for the version that returns
|
||||
// a bool.
|
||||
func (s *ModNScalar) IsZeroBit() uint32 {
|
||||
// The scalar can only be zero if no bits are set in any of the words.
|
||||
bits := s.n[0] | s.n[1] | s.n[2] | s.n[3] | s.n[4] | s.n[5] | s.n[6] | s.n[7]
|
||||
return constantTimeEq(bits, 0)
|
||||
}
|
||||
|
||||
// IsZero returns whether or not the scalar is equal to zero in constant time.
|
||||
func (s *ModNScalar) IsZero() bool {
|
||||
// The scalar can only be zero if no bits are set in any of the words.
|
||||
|
|
|
@ -1,13 +1,12 @@
|
|||
// Copyright (c) 2013-2014 The btcsuite developers
|
||||
// Copyright (c) 2015-2020 The Decred developers
|
||||
// Copyright (c) 2015-2022 The Decred developers
|
||||
// Use of this source code is governed by an ISC
|
||||
// license that can be found in the LICENSE file.
|
||||
|
||||
package secp256k1
|
||||
|
||||
import (
|
||||
"crypto/ecdsa"
|
||||
"crypto/rand"
|
||||
csprng "crypto/rand"
|
||||
)
|
||||
|
||||
// PrivateKey provides facilities for working with secp256k1 private keys within
|
||||
|
@ -40,14 +39,27 @@ func PrivKeyFromBytes(privKeyBytes []byte) *PrivateKey {
|
|||
return &privKey
|
||||
}
|
||||
|
||||
// GeneratePrivateKey returns a private key that is suitable for use with
|
||||
// secp256k1.
|
||||
// GeneratePrivateKey generates and returns a new cryptographically secure
|
||||
// private key that is suitable for use with secp256k1.
|
||||
func GeneratePrivateKey() (*PrivateKey, error) {
|
||||
key, err := ecdsa.GenerateKey(S256(), rand.Reader)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
// The group order is close enough to 2^256 that there is only roughly a 1
|
||||
// in 2^128 chance of generating an invalid private key, so this loop will
|
||||
// virtually never run more than a single iteration in practice.
|
||||
var key PrivateKey
|
||||
var b32 [32]byte
|
||||
for valid := false; !valid; {
|
||||
if _, err := csprng.Read(b32[:]); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
// The private key is only valid when it is in the range [1, N-1], where
|
||||
// N is the order of the curve.
|
||||
overflow := key.Key.SetBytes(&b32)
|
||||
valid = (key.Key.IsZeroBit() | overflow) == 0
|
||||
}
|
||||
return PrivKeyFromBytes(key.D.Bytes()), nil
|
||||
zeroArray32(&b32)
|
||||
|
||||
return &key, nil
|
||||
}
|
||||
|
||||
// PubKey computes and returns the public key corresponding to this private key.
|
||||
|
|
|
@ -1,5 +1,5 @@
|
|||
// Copyright (c) 2013-2014 The btcsuite developers
|
||||
// Copyright (c) 2015-2021 The Decred developers
|
||||
// Copyright (c) 2015-2022 The Decred developers
|
||||
// Use of this source code is governed by an ISC
|
||||
// license that can be found in the LICENSE file.
|
||||
|
||||
|
@ -92,11 +92,16 @@ func NewPublicKey(x, y *FieldVal) *PublicKey {
|
|||
// hybrid formats as follows:
|
||||
//
|
||||
// Compressed:
|
||||
// <format byte = 0x02/0x03><32-byte X coordinate>
|
||||
//
|
||||
// <format byte = 0x02/0x03><32-byte X coordinate>
|
||||
//
|
||||
// Uncompressed:
|
||||
// <format byte = 0x04><32-byte X coordinate><32-byte Y coordinate>
|
||||
//
|
||||
// <format byte = 0x04><32-byte X coordinate><32-byte Y coordinate>
|
||||
//
|
||||
// Hybrid:
|
||||
// <format byte = 0x05/0x06><32-byte X coordinate><32-byte Y coordinate>
|
||||
//
|
||||
// <format byte = 0x05/0x06><32-byte X coordinate><32-byte Y coordinate>
|
||||
//
|
||||
// NOTE: The hybrid format makes little sense in practice an therefore this
|
||||
// package will not produce public keys serialized in this format. However,
|
||||
|
@ -209,9 +214,9 @@ func (p PublicKey) SerializeCompressed() []byte {
|
|||
return b[:]
|
||||
}
|
||||
|
||||
// IsEqual compares this PublicKey instance to the one passed, returning true if
|
||||
// both PublicKeys are equivalent. A PublicKey is equivalent to another, if they
|
||||
// both have the same X and Y coordinate.
|
||||
// IsEqual compares this public key instance to the one passed, returning true
|
||||
// if both public keys are equivalent. A public key is equivalent to another,
|
||||
// if they both have the same X and Y coordinates.
|
||||
func (p *PublicKey) IsEqual(otherPubKey *PublicKey) bool {
|
||||
return p.x.Equals(&otherPubKey.x) && p.y.Equals(&otherPubKey.y)
|
||||
}
|
||||
|
|
|
@ -2,7 +2,6 @@ package units
|
|||
|
||||
import (
|
||||
"fmt"
|
||||
"regexp"
|
||||
"strconv"
|
||||
"strings"
|
||||
)
|
||||
|
@ -26,16 +25,17 @@ const (
|
|||
PiB = 1024 * TiB
|
||||
)
|
||||
|
||||
type unitMap map[string]int64
|
||||
type unitMap map[byte]int64
|
||||
|
||||
var (
|
||||
decimalMap = unitMap{"k": KB, "m": MB, "g": GB, "t": TB, "p": PB}
|
||||
binaryMap = unitMap{"k": KiB, "m": MiB, "g": GiB, "t": TiB, "p": PiB}
|
||||
sizeRegex = regexp.MustCompile(`^(\d+(\.\d+)*) ?([kKmMgGtTpP])?[iI]?[bB]?$`)
|
||||
decimalMap = unitMap{'k': KB, 'm': MB, 'g': GB, 't': TB, 'p': PB}
|
||||
binaryMap = unitMap{'k': KiB, 'm': MiB, 'g': GiB, 't': TiB, 'p': PiB}
|
||||
)
|
||||
|
||||
var decimapAbbrs = []string{"B", "kB", "MB", "GB", "TB", "PB", "EB", "ZB", "YB"}
|
||||
var binaryAbbrs = []string{"B", "KiB", "MiB", "GiB", "TiB", "PiB", "EiB", "ZiB", "YiB"}
|
||||
var (
|
||||
decimapAbbrs = []string{"B", "kB", "MB", "GB", "TB", "PB", "EB", "ZB", "YB"}
|
||||
binaryAbbrs = []string{"B", "KiB", "MiB", "GiB", "TiB", "PiB", "EiB", "ZiB", "YiB"}
|
||||
)
|
||||
|
||||
func getSizeAndUnit(size float64, base float64, _map []string) (float64, string) {
|
||||
i := 0
|
||||
|
@ -89,20 +89,66 @@ func RAMInBytes(size string) (int64, error) {
|
|||
|
||||
// Parses the human-readable size string into the amount it represents.
|
||||
func parseSize(sizeStr string, uMap unitMap) (int64, error) {
|
||||
matches := sizeRegex.FindStringSubmatch(sizeStr)
|
||||
if len(matches) != 4 {
|
||||
// TODO: rewrite to use strings.Cut if there's a space
|
||||
// once Go < 1.18 is deprecated.
|
||||
sep := strings.LastIndexAny(sizeStr, "01234567890. ")
|
||||
if sep == -1 {
|
||||
// There should be at least a digit.
|
||||
return -1, fmt.Errorf("invalid size: '%s'", sizeStr)
|
||||
}
|
||||
var num, sfx string
|
||||
if sizeStr[sep] != ' ' {
|
||||
num = sizeStr[:sep+1]
|
||||
sfx = sizeStr[sep+1:]
|
||||
} else {
|
||||
// Omit the space separator.
|
||||
num = sizeStr[:sep]
|
||||
sfx = sizeStr[sep+1:]
|
||||
}
|
||||
|
||||
size, err := strconv.ParseFloat(matches[1], 64)
|
||||
size, err := strconv.ParseFloat(num, 64)
|
||||
if err != nil {
|
||||
return -1, err
|
||||
}
|
||||
// Backward compatibility: reject negative sizes.
|
||||
if size < 0 {
|
||||
return -1, fmt.Errorf("invalid size: '%s'", sizeStr)
|
||||
}
|
||||
|
||||
unitPrefix := strings.ToLower(matches[3])
|
||||
if mul, ok := uMap[unitPrefix]; ok {
|
||||
if len(sfx) == 0 {
|
||||
return int64(size), nil
|
||||
}
|
||||
|
||||
// Process the suffix.
|
||||
|
||||
if len(sfx) > 3 { // Too long.
|
||||
goto badSuffix
|
||||
}
|
||||
sfx = strings.ToLower(sfx)
|
||||
// Trivial case: b suffix.
|
||||
if sfx[0] == 'b' {
|
||||
if len(sfx) > 1 { // no extra characters allowed after b.
|
||||
goto badSuffix
|
||||
}
|
||||
return int64(size), nil
|
||||
}
|
||||
// A suffix from the map.
|
||||
if mul, ok := uMap[sfx[0]]; ok {
|
||||
size *= float64(mul)
|
||||
} else {
|
||||
goto badSuffix
|
||||
}
|
||||
|
||||
// The suffix may have extra "b" or "ib" (e.g. KiB or MB).
|
||||
switch {
|
||||
case len(sfx) == 2 && sfx[1] != 'b':
|
||||
goto badSuffix
|
||||
case len(sfx) == 3 && sfx[1:] != "ib":
|
||||
goto badSuffix
|
||||
}
|
||||
|
||||
return int64(size), nil
|
||||
|
||||
badSuffix:
|
||||
return -1, fmt.Errorf("invalid suffix: '%s'", sfx)
|
||||
}
|
||||
|
|
|
@ -1,102 +0,0 @@
|
|||
// Protocol Buffers for Go with Gadgets
|
||||
//
|
||||
// Copyright (c) 2013, The GoGo Authors. All rights reserved.
|
||||
// http://github.com/gogo/protobuf
|
||||
//
|
||||
// Redistribution and use in source and binary forms, with or without
|
||||
// modification, are permitted provided that the following conditions are
|
||||
// met:
|
||||
//
|
||||
// * Redistributions of source code must retain the above copyright
|
||||
// notice, this list of conditions and the following disclaimer.
|
||||
// * Redistributions in binary form must reproduce the above
|
||||
// copyright notice, this list of conditions and the following disclaimer
|
||||
// in the documentation and/or other materials provided with the
|
||||
// distribution.
|
||||
//
|
||||
// THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS
|
||||
// "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT
|
||||
// LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR
|
||||
// A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT
|
||||
// OWNER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL,
|
||||
// SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT
|
||||
// LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE,
|
||||
// DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY
|
||||
// THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT
|
||||
// (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
|
||||
// OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
|
||||
|
||||
package io
|
||||
|
||||
import (
|
||||
"github.com/gogo/protobuf/proto"
|
||||
"io"
|
||||
)
|
||||
|
||||
func NewFullWriter(w io.Writer) WriteCloser {
|
||||
return &fullWriter{w, nil}
|
||||
}
|
||||
|
||||
type fullWriter struct {
|
||||
w io.Writer
|
||||
buffer []byte
|
||||
}
|
||||
|
||||
func (this *fullWriter) WriteMsg(msg proto.Message) (err error) {
|
||||
var data []byte
|
||||
if m, ok := msg.(marshaler); ok {
|
||||
n, ok := getSize(m)
|
||||
if !ok {
|
||||
data, err = proto.Marshal(msg)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
}
|
||||
if n >= len(this.buffer) {
|
||||
this.buffer = make([]byte, n)
|
||||
}
|
||||
_, err = m.MarshalTo(this.buffer)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
data = this.buffer[:n]
|
||||
} else {
|
||||
data, err = proto.Marshal(msg)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
}
|
||||
_, err = this.w.Write(data)
|
||||
return err
|
||||
}
|
||||
|
||||
func (this *fullWriter) Close() error {
|
||||
if closer, ok := this.w.(io.Closer); ok {
|
||||
return closer.Close()
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
type fullReader struct {
|
||||
r io.Reader
|
||||
buf []byte
|
||||
}
|
||||
|
||||
func NewFullReader(r io.Reader, maxSize int) ReadCloser {
|
||||
return &fullReader{r, make([]byte, maxSize)}
|
||||
}
|
||||
|
||||
func (this *fullReader) ReadMsg(msg proto.Message) error {
|
||||
length, err := this.r.Read(this.buf)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
return proto.Unmarshal(this.buf[:length], msg)
|
||||
}
|
||||
|
||||
func (this *fullReader) Close() error {
|
||||
if closer, ok := this.r.(io.Closer); ok {
|
||||
return closer.Close()
|
||||
}
|
||||
return nil
|
||||
}
|
|
@ -1,70 +0,0 @@
|
|||
// Protocol Buffers for Go with Gadgets
|
||||
//
|
||||
// Copyright (c) 2013, The GoGo Authors. All rights reserved.
|
||||
// http://github.com/gogo/protobuf
|
||||
//
|
||||
// Redistribution and use in source and binary forms, with or without
|
||||
// modification, are permitted provided that the following conditions are
|
||||
// met:
|
||||
//
|
||||
// * Redistributions of source code must retain the above copyright
|
||||
// notice, this list of conditions and the following disclaimer.
|
||||
// * Redistributions in binary form must reproduce the above
|
||||
// copyright notice, this list of conditions and the following disclaimer
|
||||
// in the documentation and/or other materials provided with the
|
||||
// distribution.
|
||||
//
|
||||
// THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS
|
||||
// "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT
|
||||
// LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR
|
||||
// A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT
|
||||
// OWNER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL,
|
||||
// SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT
|
||||
// LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE,
|
||||
// DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY
|
||||
// THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT
|
||||
// (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
|
||||
// OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
|
||||
|
||||
package io
|
||||
|
||||
import (
|
||||
"github.com/gogo/protobuf/proto"
|
||||
"io"
|
||||
)
|
||||
|
||||
type Writer interface {
|
||||
WriteMsg(proto.Message) error
|
||||
}
|
||||
|
||||
type WriteCloser interface {
|
||||
Writer
|
||||
io.Closer
|
||||
}
|
||||
|
||||
type Reader interface {
|
||||
ReadMsg(msg proto.Message) error
|
||||
}
|
||||
|
||||
type ReadCloser interface {
|
||||
Reader
|
||||
io.Closer
|
||||
}
|
||||
|
||||
type marshaler interface {
|
||||
MarshalTo(data []byte) (n int, err error)
|
||||
}
|
||||
|
||||
func getSize(v interface{}) (int, bool) {
|
||||
if sz, ok := v.(interface {
|
||||
Size() (n int)
|
||||
}); ok {
|
||||
return sz.Size(), true
|
||||
} else if sz, ok := v.(interface {
|
||||
ProtoSize() (n int)
|
||||
}); ok {
|
||||
return sz.ProtoSize(), true
|
||||
} else {
|
||||
return 0, false
|
||||
}
|
||||
}
|
|
@ -1,138 +0,0 @@
|
|||
// Protocol Buffers for Go with Gadgets
|
||||
//
|
||||
// Copyright (c) 2013, The GoGo Authors. All rights reserved.
|
||||
// http://github.com/gogo/protobuf
|
||||
//
|
||||
// Redistribution and use in source and binary forms, with or without
|
||||
// modification, are permitted provided that the following conditions are
|
||||
// met:
|
||||
//
|
||||
// * Redistributions of source code must retain the above copyright
|
||||
// notice, this list of conditions and the following disclaimer.
|
||||
// * Redistributions in binary form must reproduce the above
|
||||
// copyright notice, this list of conditions and the following disclaimer
|
||||
// in the documentation and/or other materials provided with the
|
||||
// distribution.
|
||||
//
|
||||
// THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS
|
||||
// "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT
|
||||
// LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR
|
||||
// A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT
|
||||
// OWNER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL,
|
||||
// SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT
|
||||
// LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE,
|
||||
// DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY
|
||||
// THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT
|
||||
// (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
|
||||
// OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
|
||||
|
||||
package io
|
||||
|
||||
import (
|
||||
"encoding/binary"
|
||||
"io"
|
||||
|
||||
"github.com/gogo/protobuf/proto"
|
||||
)
|
||||
|
||||
const uint32BinaryLen = 4
|
||||
|
||||
func NewUint32DelimitedWriter(w io.Writer, byteOrder binary.ByteOrder) WriteCloser {
|
||||
return &uint32Writer{w, byteOrder, nil, make([]byte, uint32BinaryLen)}
|
||||
}
|
||||
|
||||
func NewSizeUint32DelimitedWriter(w io.Writer, byteOrder binary.ByteOrder, size int) WriteCloser {
|
||||
return &uint32Writer{w, byteOrder, make([]byte, size), make([]byte, uint32BinaryLen)}
|
||||
}
|
||||
|
||||
type uint32Writer struct {
|
||||
w io.Writer
|
||||
byteOrder binary.ByteOrder
|
||||
buffer []byte
|
||||
lenBuf []byte
|
||||
}
|
||||
|
||||
func (this *uint32Writer) writeFallback(msg proto.Message) error {
|
||||
data, err := proto.Marshal(msg)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
length := uint32(len(data))
|
||||
this.byteOrder.PutUint32(this.lenBuf, length)
|
||||
if _, err = this.w.Write(this.lenBuf); err != nil {
|
||||
return err
|
||||
}
|
||||
_, err = this.w.Write(data)
|
||||
return err
|
||||
}
|
||||
|
||||
func (this *uint32Writer) WriteMsg(msg proto.Message) error {
|
||||
m, ok := msg.(marshaler)
|
||||
if !ok {
|
||||
return this.writeFallback(msg)
|
||||
}
|
||||
|
||||
n, ok := getSize(m)
|
||||
if !ok {
|
||||
return this.writeFallback(msg)
|
||||
}
|
||||
|
||||
size := n + uint32BinaryLen
|
||||
if size > len(this.buffer) {
|
||||
this.buffer = make([]byte, size)
|
||||
}
|
||||
|
||||
this.byteOrder.PutUint32(this.buffer, uint32(n))
|
||||
if _, err := m.MarshalTo(this.buffer[uint32BinaryLen:]); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
_, err := this.w.Write(this.buffer[:size])
|
||||
return err
|
||||
}
|
||||
|
||||
func (this *uint32Writer) Close() error {
|
||||
if closer, ok := this.w.(io.Closer); ok {
|
||||
return closer.Close()
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
type uint32Reader struct {
|
||||
r io.Reader
|
||||
byteOrder binary.ByteOrder
|
||||
lenBuf []byte
|
||||
buf []byte
|
||||
maxSize int
|
||||
}
|
||||
|
||||
func NewUint32DelimitedReader(r io.Reader, byteOrder binary.ByteOrder, maxSize int) ReadCloser {
|
||||
return &uint32Reader{r, byteOrder, make([]byte, 4), nil, maxSize}
|
||||
}
|
||||
|
||||
func (this *uint32Reader) ReadMsg(msg proto.Message) error {
|
||||
if _, err := io.ReadFull(this.r, this.lenBuf); err != nil {
|
||||
return err
|
||||
}
|
||||
length32 := this.byteOrder.Uint32(this.lenBuf)
|
||||
length := int(length32)
|
||||
if length < 0 || length > this.maxSize {
|
||||
return io.ErrShortBuffer
|
||||
}
|
||||
if length > len(this.buf) {
|
||||
this.buf = make([]byte, length)
|
||||
}
|
||||
_, err := io.ReadFull(this.r, this.buf[:length])
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
return proto.Unmarshal(this.buf[:length], msg)
|
||||
}
|
||||
|
||||
func (this *uint32Reader) Close() error {
|
||||
if closer, ok := this.r.(io.Closer); ok {
|
||||
return closer.Close()
|
||||
}
|
||||
return nil
|
||||
}
|
|
@ -1,133 +0,0 @@
|
|||
// Protocol Buffers for Go with Gadgets
|
||||
//
|
||||
// Copyright (c) 2013, The GoGo Authors. All rights reserved.
|
||||
// http://github.com/gogo/protobuf
|
||||
//
|
||||
// Redistribution and use in source and binary forms, with or without
|
||||
// modification, are permitted provided that the following conditions are
|
||||
// met:
|
||||
//
|
||||
// * Redistributions of source code must retain the above copyright
|
||||
// notice, this list of conditions and the following disclaimer.
|
||||
// * Redistributions in binary form must reproduce the above
|
||||
// copyright notice, this list of conditions and the following disclaimer
|
||||
// in the documentation and/or other materials provided with the
|
||||
// distribution.
|
||||
//
|
||||
// THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS
|
||||
// "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT
|
||||
// LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR
|
||||
// A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT
|
||||
// OWNER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL,
|
||||
// SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT
|
||||
// LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE,
|
||||
// DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY
|
||||
// THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT
|
||||
// (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
|
||||
// OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
|
||||
|
||||
package io
|
||||
|
||||
import (
|
||||
"bufio"
|
||||
"encoding/binary"
|
||||
"errors"
|
||||
"github.com/gogo/protobuf/proto"
|
||||
"io"
|
||||
)
|
||||
|
||||
var (
|
||||
errSmallBuffer = errors.New("Buffer Too Small")
|
||||
errLargeValue = errors.New("Value is Larger than 64 bits")
|
||||
)
|
||||
|
||||
func NewDelimitedWriter(w io.Writer) WriteCloser {
|
||||
return &varintWriter{w, make([]byte, binary.MaxVarintLen64), nil}
|
||||
}
|
||||
|
||||
type varintWriter struct {
|
||||
w io.Writer
|
||||
lenBuf []byte
|
||||
buffer []byte
|
||||
}
|
||||
|
||||
func (this *varintWriter) WriteMsg(msg proto.Message) (err error) {
|
||||
var data []byte
|
||||
if m, ok := msg.(marshaler); ok {
|
||||
n, ok := getSize(m)
|
||||
if ok {
|
||||
if n+binary.MaxVarintLen64 >= len(this.buffer) {
|
||||
this.buffer = make([]byte, n+binary.MaxVarintLen64)
|
||||
}
|
||||
lenOff := binary.PutUvarint(this.buffer, uint64(n))
|
||||
_, err = m.MarshalTo(this.buffer[lenOff:])
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
_, err = this.w.Write(this.buffer[:lenOff+n])
|
||||
return err
|
||||
}
|
||||
}
|
||||
|
||||
// fallback
|
||||
data, err = proto.Marshal(msg)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
length := uint64(len(data))
|
||||
n := binary.PutUvarint(this.lenBuf, length)
|
||||
_, err = this.w.Write(this.lenBuf[:n])
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
_, err = this.w.Write(data)
|
||||
return err
|
||||
}
|
||||
|
||||
func (this *varintWriter) Close() error {
|
||||
if closer, ok := this.w.(io.Closer); ok {
|
||||
return closer.Close()
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func NewDelimitedReader(r io.Reader, maxSize int) ReadCloser {
|
||||
var closer io.Closer
|
||||
if c, ok := r.(io.Closer); ok {
|
||||
closer = c
|
||||
}
|
||||
return &varintReader{bufio.NewReader(r), nil, maxSize, closer}
|
||||
}
|
||||
|
||||
type varintReader struct {
|
||||
r *bufio.Reader
|
||||
buf []byte
|
||||
maxSize int
|
||||
closer io.Closer
|
||||
}
|
||||
|
||||
func (this *varintReader) ReadMsg(msg proto.Message) error {
|
||||
length64, err := binary.ReadUvarint(this.r)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
length := int(length64)
|
||||
if length < 0 || length > this.maxSize {
|
||||
return io.ErrShortBuffer
|
||||
}
|
||||
if len(this.buf) < length {
|
||||
this.buf = make([]byte, length)
|
||||
}
|
||||
buf := this.buf[:length]
|
||||
if _, err := io.ReadFull(this.r, buf); err != nil {
|
||||
return err
|
||||
}
|
||||
return proto.Unmarshal(buf, msg)
|
||||
}
|
||||
|
||||
func (this *varintReader) Close() error {
|
||||
if this.closer != nil {
|
||||
return this.closer.Close()
|
||||
}
|
||||
return nil
|
||||
}
|
|
@ -0,0 +1,701 @@
|
|||
// Copyright 2010 Google Inc.
|
||||
//
|
||||
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||
// you may not use this file except in compliance with the License.
|
||||
// You may obtain a copy of the License at
|
||||
//
|
||||
// http://www.apache.org/licenses/LICENSE-2.0
|
||||
//
|
||||
// Unless required by applicable law or agreed to in writing, software
|
||||
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
|
||||
// MockGen generates mock implementations of Go interfaces.
|
||||
package main
|
||||
|
||||
// TODO: This does not support recursive embedded interfaces.
|
||||
// TODO: This does not support embedding package-local interfaces in a separate file.
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"encoding/json"
|
||||
"flag"
|
||||
"fmt"
|
||||
"go/token"
|
||||
"io"
|
||||
"io/ioutil"
|
||||
"log"
|
||||
"os"
|
||||
"os/exec"
|
||||
"path"
|
||||
"path/filepath"
|
||||
"sort"
|
||||
"strconv"
|
||||
"strings"
|
||||
"unicode"
|
||||
|
||||
"github.com/golang/mock/mockgen/model"
|
||||
|
||||
"golang.org/x/mod/modfile"
|
||||
toolsimports "golang.org/x/tools/imports"
|
||||
)
|
||||
|
||||
const (
|
||||
gomockImportPath = "github.com/golang/mock/gomock"
|
||||
)
|
||||
|
||||
var (
|
||||
version = ""
|
||||
commit = "none"
|
||||
date = "unknown"
|
||||
)
|
||||
|
||||
var (
|
||||
source = flag.String("source", "", "(source mode) Input Go source file; enables source mode.")
|
||||
destination = flag.String("destination", "", "Output file; defaults to stdout.")
|
||||
mockNames = flag.String("mock_names", "", "Comma-separated interfaceName=mockName pairs of explicit mock names to use. Mock names default to 'Mock'+ interfaceName suffix.")
|
||||
packageOut = flag.String("package", "", "Package of the generated code; defaults to the package of the input with a 'mock_' prefix.")
|
||||
selfPackage = flag.String("self_package", "", "The full package import path for the generated code. The purpose of this flag is to prevent import cycles in the generated code by trying to include its own package. This can happen if the mock's package is set to one of its inputs (usually the main one) and the output is stdio so mockgen cannot detect the final output package. Setting this flag will then tell mockgen which import to exclude.")
|
||||
writePkgComment = flag.Bool("write_package_comment", true, "Writes package documentation comment (godoc) if true.")
|
||||
copyrightFile = flag.String("copyright_file", "", "Copyright file used to add copyright header")
|
||||
|
||||
debugParser = flag.Bool("debug_parser", false, "Print out parser results only.")
|
||||
showVersion = flag.Bool("version", false, "Print version.")
|
||||
)
|
||||
|
||||
func main() {
|
||||
flag.Usage = usage
|
||||
flag.Parse()
|
||||
|
||||
if *showVersion {
|
||||
printVersion()
|
||||
return
|
||||
}
|
||||
|
||||
var pkg *model.Package
|
||||
var err error
|
||||
var packageName string
|
||||
if *source != "" {
|
||||
pkg, err = sourceMode(*source)
|
||||
} else {
|
||||
if flag.NArg() != 2 {
|
||||
usage()
|
||||
log.Fatal("Expected exactly two arguments")
|
||||
}
|
||||
packageName = flag.Arg(0)
|
||||
interfaces := strings.Split(flag.Arg(1), ",")
|
||||
if packageName == "." {
|
||||
dir, err := os.Getwd()
|
||||
if err != nil {
|
||||
log.Fatalf("Get current directory failed: %v", err)
|
||||
}
|
||||
packageName, err = packageNameOfDir(dir)
|
||||
if err != nil {
|
||||
log.Fatalf("Parse package name failed: %v", err)
|
||||
}
|
||||
}
|
||||
pkg, err = reflectMode(packageName, interfaces)
|
||||
}
|
||||
if err != nil {
|
||||
log.Fatalf("Loading input failed: %v", err)
|
||||
}
|
||||
|
||||
if *debugParser {
|
||||
pkg.Print(os.Stdout)
|
||||
return
|
||||
}
|
||||
|
||||
dst := os.Stdout
|
||||
if len(*destination) > 0 {
|
||||
if err := os.MkdirAll(filepath.Dir(*destination), os.ModePerm); err != nil {
|
||||
log.Fatalf("Unable to create directory: %v", err)
|
||||
}
|
||||
f, err := os.Create(*destination)
|
||||
if err != nil {
|
||||
log.Fatalf("Failed opening destination file: %v", err)
|
||||
}
|
||||
defer f.Close()
|
||||
dst = f
|
||||
}
|
||||
|
||||
outputPackageName := *packageOut
|
||||
if outputPackageName == "" {
|
||||
// pkg.Name in reflect mode is the base name of the import path,
|
||||
// which might have characters that are illegal to have in package names.
|
||||
outputPackageName = "mock_" + sanitize(pkg.Name)
|
||||
}
|
||||
|
||||
// outputPackagePath represents the fully qualified name of the package of
|
||||
// the generated code. Its purposes are to prevent the module from importing
|
||||
// itself and to prevent qualifying type names that come from its own
|
||||
// package (i.e. if there is a type called X then we want to print "X" not
|
||||
// "package.X" since "package" is this package). This can happen if the mock
|
||||
// is output into an already existing package.
|
||||
outputPackagePath := *selfPackage
|
||||
if outputPackagePath == "" && *destination != "" {
|
||||
dstPath, err := filepath.Abs(filepath.Dir(*destination))
|
||||
if err == nil {
|
||||
pkgPath, err := parsePackageImport(dstPath)
|
||||
if err == nil {
|
||||
outputPackagePath = pkgPath
|
||||
} else {
|
||||
log.Println("Unable to infer -self_package from destination file path:", err)
|
||||
}
|
||||
} else {
|
||||
log.Println("Unable to determine destination file path:", err)
|
||||
}
|
||||
}
|
||||
|
||||
g := new(generator)
|
||||
if *source != "" {
|
||||
g.filename = *source
|
||||
} else {
|
||||
g.srcPackage = packageName
|
||||
g.srcInterfaces = flag.Arg(1)
|
||||
}
|
||||
g.destination = *destination
|
||||
|
||||
if *mockNames != "" {
|
||||
g.mockNames = parseMockNames(*mockNames)
|
||||
}
|
||||
if *copyrightFile != "" {
|
||||
header, err := ioutil.ReadFile(*copyrightFile)
|
||||
if err != nil {
|
||||
log.Fatalf("Failed reading copyright file: %v", err)
|
||||
}
|
||||
|
||||
g.copyrightHeader = string(header)
|
||||
}
|
||||
if err := g.Generate(pkg, outputPackageName, outputPackagePath); err != nil {
|
||||
log.Fatalf("Failed generating mock: %v", err)
|
||||
}
|
||||
if _, err := dst.Write(g.Output()); err != nil {
|
||||
log.Fatalf("Failed writing to destination: %v", err)
|
||||
}
|
||||
}
|
||||
|
||||
func parseMockNames(names string) map[string]string {
|
||||
mocksMap := make(map[string]string)
|
||||
for _, kv := range strings.Split(names, ",") {
|
||||
parts := strings.SplitN(kv, "=", 2)
|
||||
if len(parts) != 2 || parts[1] == "" {
|
||||
log.Fatalf("bad mock names spec: %v", kv)
|
||||
}
|
||||
mocksMap[parts[0]] = parts[1]
|
||||
}
|
||||
return mocksMap
|
||||
}
|
||||
|
||||
func usage() {
|
||||
_, _ = io.WriteString(os.Stderr, usageText)
|
||||
flag.PrintDefaults()
|
||||
}
|
||||
|
||||
const usageText = `mockgen has two modes of operation: source and reflect.
|
||||
|
||||
Source mode generates mock interfaces from a source file.
|
||||
It is enabled by using the -source flag. Other flags that
|
||||
may be useful in this mode are -imports and -aux_files.
|
||||
Example:
|
||||
mockgen -source=foo.go [other options]
|
||||
|
||||
Reflect mode generates mock interfaces by building a program
|
||||
that uses reflection to understand interfaces. It is enabled
|
||||
by passing two non-flag arguments: an import path, and a
|
||||
comma-separated list of symbols.
|
||||
Example:
|
||||
mockgen database/sql/driver Conn,Driver
|
||||
|
||||
`
|
||||
|
||||
type generator struct {
|
||||
buf bytes.Buffer
|
||||
indent string
|
||||
mockNames map[string]string // may be empty
|
||||
filename string // may be empty
|
||||
destination string // may be empty
|
||||
srcPackage, srcInterfaces string // may be empty
|
||||
copyrightHeader string
|
||||
|
||||
packageMap map[string]string // map from import path to package name
|
||||
}
|
||||
|
||||
func (g *generator) p(format string, args ...interface{}) {
|
||||
fmt.Fprintf(&g.buf, g.indent+format+"\n", args...)
|
||||
}
|
||||
|
||||
func (g *generator) in() {
|
||||
g.indent += "\t"
|
||||
}
|
||||
|
||||
func (g *generator) out() {
|
||||
if len(g.indent) > 0 {
|
||||
g.indent = g.indent[0 : len(g.indent)-1]
|
||||
}
|
||||
}
|
||||
|
||||
// sanitize cleans up a string to make a suitable package name.
|
||||
func sanitize(s string) string {
|
||||
t := ""
|
||||
for _, r := range s {
|
||||
if t == "" {
|
||||
if unicode.IsLetter(r) || r == '_' {
|
||||
t += string(r)
|
||||
continue
|
||||
}
|
||||
} else {
|
||||
if unicode.IsLetter(r) || unicode.IsDigit(r) || r == '_' {
|
||||
t += string(r)
|
||||
continue
|
||||
}
|
||||
}
|
||||
t += "_"
|
||||
}
|
||||
if t == "_" {
|
||||
t = "x"
|
||||
}
|
||||
return t
|
||||
}
|
||||
|
||||
func (g *generator) Generate(pkg *model.Package, outputPkgName string, outputPackagePath string) error {
|
||||
if outputPkgName != pkg.Name && *selfPackage == "" {
|
||||
// reset outputPackagePath if it's not passed in through -self_package
|
||||
outputPackagePath = ""
|
||||
}
|
||||
|
||||
if g.copyrightHeader != "" {
|
||||
lines := strings.Split(g.copyrightHeader, "\n")
|
||||
for _, line := range lines {
|
||||
g.p("// %s", line)
|
||||
}
|
||||
g.p("")
|
||||
}
|
||||
|
||||
g.p("// Code generated by MockGen. DO NOT EDIT.")
|
||||
if g.filename != "" {
|
||||
g.p("// Source: %v", g.filename)
|
||||
} else {
|
||||
g.p("// Source: %v (interfaces: %v)", g.srcPackage, g.srcInterfaces)
|
||||
}
|
||||
g.p("")
|
||||
|
||||
// Get all required imports, and generate unique names for them all.
|
||||
im := pkg.Imports()
|
||||
im[gomockImportPath] = true
|
||||
|
||||
// Only import reflect if it's used. We only use reflect in mocked methods
|
||||
// so only import if any of the mocked interfaces have methods.
|
||||
for _, intf := range pkg.Interfaces {
|
||||
if len(intf.Methods) > 0 {
|
||||
im["reflect"] = true
|
||||
break
|
||||
}
|
||||
}
|
||||
|
||||
// Sort keys to make import alias generation predictable
|
||||
sortedPaths := make([]string, len(im))
|
||||
x := 0
|
||||
for pth := range im {
|
||||
sortedPaths[x] = pth
|
||||
x++
|
||||
}
|
||||
sort.Strings(sortedPaths)
|
||||
|
||||
packagesName := createPackageMap(sortedPaths)
|
||||
|
||||
g.packageMap = make(map[string]string, len(im))
|
||||
localNames := make(map[string]bool, len(im))
|
||||
for _, pth := range sortedPaths {
|
||||
base, ok := packagesName[pth]
|
||||
if !ok {
|
||||
base = sanitize(path.Base(pth))
|
||||
}
|
||||
|
||||
// Local names for an imported package can usually be the basename of the import path.
|
||||
// A couple of situations don't permit that, such as duplicate local names
|
||||
// (e.g. importing "html/template" and "text/template"), or where the basename is
|
||||
// a keyword (e.g. "foo/case").
|
||||
// try base0, base1, ...
|
||||
pkgName := base
|
||||
i := 0
|
||||
for localNames[pkgName] || token.Lookup(pkgName).IsKeyword() {
|
||||
pkgName = base + strconv.Itoa(i)
|
||||
i++
|
||||
}
|
||||
|
||||
// Avoid importing package if source pkg == output pkg
|
||||
if pth == pkg.PkgPath && outputPackagePath == pkg.PkgPath {
|
||||
continue
|
||||
}
|
||||
|
||||
g.packageMap[pth] = pkgName
|
||||
localNames[pkgName] = true
|
||||
}
|
||||
|
||||
if *writePkgComment {
|
||||
g.p("// Package %v is a generated GoMock package.", outputPkgName)
|
||||
}
|
||||
g.p("package %v", outputPkgName)
|
||||
g.p("")
|
||||
g.p("import (")
|
||||
g.in()
|
||||
for pkgPath, pkgName := range g.packageMap {
|
||||
if pkgPath == outputPackagePath {
|
||||
continue
|
||||
}
|
||||
g.p("%v %q", pkgName, pkgPath)
|
||||
}
|
||||
for _, pkgPath := range pkg.DotImports {
|
||||
g.p(". %q", pkgPath)
|
||||
}
|
||||
g.out()
|
||||
g.p(")")
|
||||
|
||||
for _, intf := range pkg.Interfaces {
|
||||
if err := g.GenerateMockInterface(intf, outputPackagePath); err != nil {
|
||||
return err
|
||||
}
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// The name of the mock type to use for the given interface identifier.
|
||||
func (g *generator) mockName(typeName string) string {
|
||||
if mockName, ok := g.mockNames[typeName]; ok {
|
||||
return mockName
|
||||
}
|
||||
|
||||
return "Mock" + typeName
|
||||
}
|
||||
|
||||
func (g *generator) GenerateMockInterface(intf *model.Interface, outputPackagePath string) error {
|
||||
mockType := g.mockName(intf.Name)
|
||||
|
||||
g.p("")
|
||||
g.p("// %v is a mock of %v interface.", mockType, intf.Name)
|
||||
g.p("type %v struct {", mockType)
|
||||
g.in()
|
||||
g.p("ctrl *gomock.Controller")
|
||||
g.p("recorder *%vMockRecorder", mockType)
|
||||
g.out()
|
||||
g.p("}")
|
||||
g.p("")
|
||||
|
||||
g.p("// %vMockRecorder is the mock recorder for %v.", mockType, mockType)
|
||||
g.p("type %vMockRecorder struct {", mockType)
|
||||
g.in()
|
||||
g.p("mock *%v", mockType)
|
||||
g.out()
|
||||
g.p("}")
|
||||
g.p("")
|
||||
|
||||
g.p("// New%v creates a new mock instance.", mockType)
|
||||
g.p("func New%v(ctrl *gomock.Controller) *%v {", mockType, mockType)
|
||||
g.in()
|
||||
g.p("mock := &%v{ctrl: ctrl}", mockType)
|
||||
g.p("mock.recorder = &%vMockRecorder{mock}", mockType)
|
||||
g.p("return mock")
|
||||
g.out()
|
||||
g.p("}")
|
||||
g.p("")
|
||||
|
||||
// XXX: possible name collision here if someone has EXPECT in their interface.
|
||||
g.p("// EXPECT returns an object that allows the caller to indicate expected use.")
|
||||
g.p("func (m *%v) EXPECT() *%vMockRecorder {", mockType, mockType)
|
||||
g.in()
|
||||
g.p("return m.recorder")
|
||||
g.out()
|
||||
g.p("}")
|
||||
|
||||
g.GenerateMockMethods(mockType, intf, outputPackagePath)
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
type byMethodName []*model.Method
|
||||
|
||||
func (b byMethodName) Len() int { return len(b) }
|
||||
func (b byMethodName) Swap(i, j int) { b[i], b[j] = b[j], b[i] }
|
||||
func (b byMethodName) Less(i, j int) bool { return b[i].Name < b[j].Name }
|
||||
|
||||
func (g *generator) GenerateMockMethods(mockType string, intf *model.Interface, pkgOverride string) {
|
||||
sort.Sort(byMethodName(intf.Methods))
|
||||
for _, m := range intf.Methods {
|
||||
g.p("")
|
||||
_ = g.GenerateMockMethod(mockType, m, pkgOverride)
|
||||
g.p("")
|
||||
_ = g.GenerateMockRecorderMethod(mockType, m)
|
||||
}
|
||||
}
|
||||
|
||||
func makeArgString(argNames, argTypes []string) string {
|
||||
args := make([]string, len(argNames))
|
||||
for i, name := range argNames {
|
||||
// specify the type only once for consecutive args of the same type
|
||||
if i+1 < len(argTypes) && argTypes[i] == argTypes[i+1] {
|
||||
args[i] = name
|
||||
} else {
|
||||
args[i] = name + " " + argTypes[i]
|
||||
}
|
||||
}
|
||||
return strings.Join(args, ", ")
|
||||
}
|
||||
|
||||
// GenerateMockMethod generates a mock method implementation.
|
||||
// If non-empty, pkgOverride is the package in which unqualified types reside.
|
||||
func (g *generator) GenerateMockMethod(mockType string, m *model.Method, pkgOverride string) error {
|
||||
argNames := g.getArgNames(m)
|
||||
argTypes := g.getArgTypes(m, pkgOverride)
|
||||
argString := makeArgString(argNames, argTypes)
|
||||
|
||||
rets := make([]string, len(m.Out))
|
||||
for i, p := range m.Out {
|
||||
rets[i] = p.Type.String(g.packageMap, pkgOverride)
|
||||
}
|
||||
retString := strings.Join(rets, ", ")
|
||||
if len(rets) > 1 {
|
||||
retString = "(" + retString + ")"
|
||||
}
|
||||
if retString != "" {
|
||||
retString = " " + retString
|
||||
}
|
||||
|
||||
ia := newIdentifierAllocator(argNames)
|
||||
idRecv := ia.allocateIdentifier("m")
|
||||
|
||||
g.p("// %v mocks base method.", m.Name)
|
||||
g.p("func (%v *%v) %v(%v)%v {", idRecv, mockType, m.Name, argString, retString)
|
||||
g.in()
|
||||
g.p("%s.ctrl.T.Helper()", idRecv)
|
||||
|
||||
var callArgs string
|
||||
if m.Variadic == nil {
|
||||
if len(argNames) > 0 {
|
||||
callArgs = ", " + strings.Join(argNames, ", ")
|
||||
}
|
||||
} else {
|
||||
// Non-trivial. The generated code must build a []interface{},
|
||||
// but the variadic argument may be any type.
|
||||
idVarArgs := ia.allocateIdentifier("varargs")
|
||||
idVArg := ia.allocateIdentifier("a")
|
||||
g.p("%s := []interface{}{%s}", idVarArgs, strings.Join(argNames[:len(argNames)-1], ", "))
|
||||
g.p("for _, %s := range %s {", idVArg, argNames[len(argNames)-1])
|
||||
g.in()
|
||||
g.p("%s = append(%s, %s)", idVarArgs, idVarArgs, idVArg)
|
||||
g.out()
|
||||
g.p("}")
|
||||
callArgs = ", " + idVarArgs + "..."
|
||||
}
|
||||
if len(m.Out) == 0 {
|
||||
g.p(`%v.ctrl.Call(%v, %q%v)`, idRecv, idRecv, m.Name, callArgs)
|
||||
} else {
|
||||
idRet := ia.allocateIdentifier("ret")
|
||||
g.p(`%v := %v.ctrl.Call(%v, %q%v)`, idRet, idRecv, idRecv, m.Name, callArgs)
|
||||
|
||||
// Go does not allow "naked" type assertions on nil values, so we use the two-value form here.
|
||||
// The value of that is either (x.(T), true) or (Z, false), where Z is the zero value for T.
|
||||
// Happily, this coincides with the semantics we want here.
|
||||
retNames := make([]string, len(rets))
|
||||
for i, t := range rets {
|
||||
retNames[i] = ia.allocateIdentifier(fmt.Sprintf("ret%d", i))
|
||||
g.p("%s, _ := %s[%d].(%s)", retNames[i], idRet, i, t)
|
||||
}
|
||||
g.p("return " + strings.Join(retNames, ", "))
|
||||
}
|
||||
|
||||
g.out()
|
||||
g.p("}")
|
||||
return nil
|
||||
}
|
||||
|
||||
func (g *generator) GenerateMockRecorderMethod(mockType string, m *model.Method) error {
|
||||
argNames := g.getArgNames(m)
|
||||
|
||||
var argString string
|
||||
if m.Variadic == nil {
|
||||
argString = strings.Join(argNames, ", ")
|
||||
} else {
|
||||
argString = strings.Join(argNames[:len(argNames)-1], ", ")
|
||||
}
|
||||
if argString != "" {
|
||||
argString += " interface{}"
|
||||
}
|
||||
|
||||
if m.Variadic != nil {
|
||||
if argString != "" {
|
||||
argString += ", "
|
||||
}
|
||||
argString += fmt.Sprintf("%s ...interface{}", argNames[len(argNames)-1])
|
||||
}
|
||||
|
||||
ia := newIdentifierAllocator(argNames)
|
||||
idRecv := ia.allocateIdentifier("mr")
|
||||
|
||||
g.p("// %v indicates an expected call of %v.", m.Name, m.Name)
|
||||
g.p("func (%s *%vMockRecorder) %v(%v) *gomock.Call {", idRecv, mockType, m.Name, argString)
|
||||
g.in()
|
||||
g.p("%s.mock.ctrl.T.Helper()", idRecv)
|
||||
|
||||
var callArgs string
|
||||
if m.Variadic == nil {
|
||||
if len(argNames) > 0 {
|
||||
callArgs = ", " + strings.Join(argNames, ", ")
|
||||
}
|
||||
} else {
|
||||
if len(argNames) == 1 {
|
||||
// Easy: just use ... to push the arguments through.
|
||||
callArgs = ", " + argNames[0] + "..."
|
||||
} else {
|
||||
// Hard: create a temporary slice.
|
||||
idVarArgs := ia.allocateIdentifier("varargs")
|
||||
g.p("%s := append([]interface{}{%s}, %s...)",
|
||||
idVarArgs,
|
||||
strings.Join(argNames[:len(argNames)-1], ", "),
|
||||
argNames[len(argNames)-1])
|
||||
callArgs = ", " + idVarArgs + "..."
|
||||
}
|
||||
}
|
||||
g.p(`return %s.mock.ctrl.RecordCallWithMethodType(%s.mock, "%s", reflect.TypeOf((*%s)(nil).%s)%s)`, idRecv, idRecv, m.Name, mockType, m.Name, callArgs)
|
||||
|
||||
g.out()
|
||||
g.p("}")
|
||||
return nil
|
||||
}
|
||||
|
||||
func (g *generator) getArgNames(m *model.Method) []string {
|
||||
argNames := make([]string, len(m.In))
|
||||
for i, p := range m.In {
|
||||
name := p.Name
|
||||
if name == "" || name == "_" {
|
||||
name = fmt.Sprintf("arg%d", i)
|
||||
}
|
||||
argNames[i] = name
|
||||
}
|
||||
if m.Variadic != nil {
|
||||
name := m.Variadic.Name
|
||||
if name == "" {
|
||||
name = fmt.Sprintf("arg%d", len(m.In))
|
||||
}
|
||||
argNames = append(argNames, name)
|
||||
}
|
||||
return argNames
|
||||
}
|
||||
|
||||
func (g *generator) getArgTypes(m *model.Method, pkgOverride string) []string {
|
||||
argTypes := make([]string, len(m.In))
|
||||
for i, p := range m.In {
|
||||
argTypes[i] = p.Type.String(g.packageMap, pkgOverride)
|
||||
}
|
||||
if m.Variadic != nil {
|
||||
argTypes = append(argTypes, "..."+m.Variadic.Type.String(g.packageMap, pkgOverride))
|
||||
}
|
||||
return argTypes
|
||||
}
|
||||
|
||||
type identifierAllocator map[string]struct{}
|
||||
|
||||
func newIdentifierAllocator(taken []string) identifierAllocator {
|
||||
a := make(identifierAllocator, len(taken))
|
||||
for _, s := range taken {
|
||||
a[s] = struct{}{}
|
||||
}
|
||||
return a
|
||||
}
|
||||
|
||||
func (o identifierAllocator) allocateIdentifier(want string) string {
|
||||
id := want
|
||||
for i := 2; ; i++ {
|
||||
if _, ok := o[id]; !ok {
|
||||
o[id] = struct{}{}
|
||||
return id
|
||||
}
|
||||
id = want + "_" + strconv.Itoa(i)
|
||||
}
|
||||
}
|
||||
|
||||
// Output returns the generator's output, formatted in the standard Go style.
|
||||
func (g *generator) Output() []byte {
|
||||
src, err := toolsimports.Process(g.destination, g.buf.Bytes(), nil)
|
||||
if err != nil {
|
||||
log.Fatalf("Failed to format generated source code: %s\n%s", err, g.buf.String())
|
||||
}
|
||||
return src
|
||||
}
|
||||
|
||||
// createPackageMap returns a map of import path to package name
|
||||
// for specified importPaths.
|
||||
func createPackageMap(importPaths []string) map[string]string {
|
||||
var pkg struct {
|
||||
Name string
|
||||
ImportPath string
|
||||
}
|
||||
pkgMap := make(map[string]string)
|
||||
b := bytes.NewBuffer(nil)
|
||||
args := []string{"list", "-json"}
|
||||
args = append(args, importPaths...)
|
||||
cmd := exec.Command("go", args...)
|
||||
cmd.Stdout = b
|
||||
cmd.Run()
|
||||
dec := json.NewDecoder(b)
|
||||
for dec.More() {
|
||||
err := dec.Decode(&pkg)
|
||||
if err != nil {
|
||||
log.Printf("failed to decode 'go list' output: %v", err)
|
||||
continue
|
||||
}
|
||||
pkgMap[pkg.ImportPath] = pkg.Name
|
||||
}
|
||||
return pkgMap
|
||||
}
|
||||
|
||||
func printVersion() {
|
||||
if version != "" {
|
||||
fmt.Printf("v%s\nCommit: %s\nDate: %s\n", version, commit, date)
|
||||
} else {
|
||||
printModuleVersion()
|
||||
}
|
||||
}
|
||||
|
||||
// parseImportPackage get package import path via source file
|
||||
// an alternative implementation is to use:
|
||||
// cfg := &packages.Config{Mode: packages.NeedName, Tests: true, Dir: srcDir}
|
||||
// pkgs, err := packages.Load(cfg, "file="+source)
|
||||
// However, it will call "go list" and slow down the performance
|
||||
func parsePackageImport(srcDir string) (string, error) {
|
||||
moduleMode := os.Getenv("GO111MODULE")
|
||||
// trying to find the module
|
||||
if moduleMode != "off" {
|
||||
currentDir := srcDir
|
||||
for {
|
||||
dat, err := ioutil.ReadFile(filepath.Join(currentDir, "go.mod"))
|
||||
if os.IsNotExist(err) {
|
||||
if currentDir == filepath.Dir(currentDir) {
|
||||
// at the root
|
||||
break
|
||||
}
|
||||
currentDir = filepath.Dir(currentDir)
|
||||
continue
|
||||
} else if err != nil {
|
||||
return "", err
|
||||
}
|
||||
modulePath := modfile.ModulePath(dat)
|
||||
return filepath.ToSlash(filepath.Join(modulePath, strings.TrimPrefix(srcDir, currentDir))), nil
|
||||
}
|
||||
}
|
||||
// fall back to GOPATH mode
|
||||
goPaths := os.Getenv("GOPATH")
|
||||
if goPaths == "" {
|
||||
return "", fmt.Errorf("GOPATH is not set")
|
||||
}
|
||||
goPathList := strings.Split(goPaths, string(os.PathListSeparator))
|
||||
for _, goPath := range goPathList {
|
||||
sourceRoot := filepath.Join(goPath, "src") + string(os.PathSeparator)
|
||||
if strings.HasPrefix(srcDir, sourceRoot) {
|
||||
return filepath.ToSlash(strings.TrimPrefix(srcDir, sourceRoot)), nil
|
||||
}
|
||||
}
|
||||
return "", errOutsideGoPath
|
||||
}
|
|
@ -0,0 +1,495 @@
|
|||
// Copyright 2012 Google Inc.
|
||||
//
|
||||
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||
// you may not use this file except in compliance with the License.
|
||||
// You may obtain a copy of the License at
|
||||
//
|
||||
// http://www.apache.org/licenses/LICENSE-2.0
|
||||
//
|
||||
// Unless required by applicable law or agreed to in writing, software
|
||||
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
|
||||
// Package model contains the data model necessary for generating mock implementations.
|
||||
package model
|
||||
|
||||
import (
|
||||
"encoding/gob"
|
||||
"fmt"
|
||||
"io"
|
||||
"reflect"
|
||||
"strings"
|
||||
)
|
||||
|
||||
// pkgPath is the importable path for package model
|
||||
const pkgPath = "github.com/golang/mock/mockgen/model"
|
||||
|
||||
// Package is a Go package. It may be a subset.
|
||||
type Package struct {
|
||||
Name string
|
||||
PkgPath string
|
||||
Interfaces []*Interface
|
||||
DotImports []string
|
||||
}
|
||||
|
||||
// Print writes the package name and its exported interfaces.
|
||||
func (pkg *Package) Print(w io.Writer) {
|
||||
_, _ = fmt.Fprintf(w, "package %s\n", pkg.Name)
|
||||
for _, intf := range pkg.Interfaces {
|
||||
intf.Print(w)
|
||||
}
|
||||
}
|
||||
|
||||
// Imports returns the imports needed by the Package as a set of import paths.
|
||||
func (pkg *Package) Imports() map[string]bool {
|
||||
im := make(map[string]bool)
|
||||
for _, intf := range pkg.Interfaces {
|
||||
intf.addImports(im)
|
||||
}
|
||||
return im
|
||||
}
|
||||
|
||||
// Interface is a Go interface.
|
||||
type Interface struct {
|
||||
Name string
|
||||
Methods []*Method
|
||||
}
|
||||
|
||||
// Print writes the interface name and its methods.
|
||||
func (intf *Interface) Print(w io.Writer) {
|
||||
_, _ = fmt.Fprintf(w, "interface %s\n", intf.Name)
|
||||
for _, m := range intf.Methods {
|
||||
m.Print(w)
|
||||
}
|
||||
}
|
||||
|
||||
func (intf *Interface) addImports(im map[string]bool) {
|
||||
for _, m := range intf.Methods {
|
||||
m.addImports(im)
|
||||
}
|
||||
}
|
||||
|
||||
// AddMethod adds a new method, de-duplicating by method name.
|
||||
func (intf *Interface) AddMethod(m *Method) {
|
||||
for _, me := range intf.Methods {
|
||||
if me.Name == m.Name {
|
||||
return
|
||||
}
|
||||
}
|
||||
intf.Methods = append(intf.Methods, m)
|
||||
}
|
||||
|
||||
// Method is a single method of an interface.
|
||||
type Method struct {
|
||||
Name string
|
||||
In, Out []*Parameter
|
||||
Variadic *Parameter // may be nil
|
||||
}
|
||||
|
||||
// Print writes the method name and its signature.
|
||||
func (m *Method) Print(w io.Writer) {
|
||||
_, _ = fmt.Fprintf(w, " - method %s\n", m.Name)
|
||||
if len(m.In) > 0 {
|
||||
_, _ = fmt.Fprintf(w, " in:\n")
|
||||
for _, p := range m.In {
|
||||
p.Print(w)
|
||||
}
|
||||
}
|
||||
if m.Variadic != nil {
|
||||
_, _ = fmt.Fprintf(w, " ...:\n")
|
||||
m.Variadic.Print(w)
|
||||
}
|
||||
if len(m.Out) > 0 {
|
||||
_, _ = fmt.Fprintf(w, " out:\n")
|
||||
for _, p := range m.Out {
|
||||
p.Print(w)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func (m *Method) addImports(im map[string]bool) {
|
||||
for _, p := range m.In {
|
||||
p.Type.addImports(im)
|
||||
}
|
||||
if m.Variadic != nil {
|
||||
m.Variadic.Type.addImports(im)
|
||||
}
|
||||
for _, p := range m.Out {
|
||||
p.Type.addImports(im)
|
||||
}
|
||||
}
|
||||
|
||||
// Parameter is an argument or return parameter of a method.
|
||||
type Parameter struct {
|
||||
Name string // may be empty
|
||||
Type Type
|
||||
}
|
||||
|
||||
// Print writes a method parameter.
|
||||
func (p *Parameter) Print(w io.Writer) {
|
||||
n := p.Name
|
||||
if n == "" {
|
||||
n = `""`
|
||||
}
|
||||
_, _ = fmt.Fprintf(w, " - %v: %v\n", n, p.Type.String(nil, ""))
|
||||
}
|
||||
|
||||
// Type is a Go type.
|
||||
type Type interface {
|
||||
String(pm map[string]string, pkgOverride string) string
|
||||
addImports(im map[string]bool)
|
||||
}
|
||||
|
||||
func init() {
|
||||
gob.Register(&ArrayType{})
|
||||
gob.Register(&ChanType{})
|
||||
gob.Register(&FuncType{})
|
||||
gob.Register(&MapType{})
|
||||
gob.Register(&NamedType{})
|
||||
gob.Register(&PointerType{})
|
||||
|
||||
// Call gob.RegisterName to make sure it has the consistent name registered
|
||||
// for both gob decoder and encoder.
|
||||
//
|
||||
// For a non-pointer type, gob.Register will try to get package full path by
|
||||
// calling rt.PkgPath() for a name to register. If your project has vendor
|
||||
// directory, it is possible that PkgPath will get a path like this:
|
||||
// ../../../vendor/github.com/golang/mock/mockgen/model
|
||||
gob.RegisterName(pkgPath+".PredeclaredType", PredeclaredType(""))
|
||||
}
|
||||
|
||||
// ArrayType is an array or slice type.
|
||||
type ArrayType struct {
|
||||
Len int // -1 for slices, >= 0 for arrays
|
||||
Type Type
|
||||
}
|
||||
|
||||
func (at *ArrayType) String(pm map[string]string, pkgOverride string) string {
|
||||
s := "[]"
|
||||
if at.Len > -1 {
|
||||
s = fmt.Sprintf("[%d]", at.Len)
|
||||
}
|
||||
return s + at.Type.String(pm, pkgOverride)
|
||||
}
|
||||
|
||||
func (at *ArrayType) addImports(im map[string]bool) { at.Type.addImports(im) }
|
||||
|
||||
// ChanType is a channel type.
|
||||
type ChanType struct {
|
||||
Dir ChanDir // 0, 1 or 2
|
||||
Type Type
|
||||
}
|
||||
|
||||
func (ct *ChanType) String(pm map[string]string, pkgOverride string) string {
|
||||
s := ct.Type.String(pm, pkgOverride)
|
||||
if ct.Dir == RecvDir {
|
||||
return "<-chan " + s
|
||||
}
|
||||
if ct.Dir == SendDir {
|
||||
return "chan<- " + s
|
||||
}
|
||||
return "chan " + s
|
||||
}
|
||||
|
||||
func (ct *ChanType) addImports(im map[string]bool) { ct.Type.addImports(im) }
|
||||
|
||||
// ChanDir is a channel direction.
|
||||
type ChanDir int
|
||||
|
||||
// Constants for channel directions.
|
||||
const (
|
||||
RecvDir ChanDir = 1
|
||||
SendDir ChanDir = 2
|
||||
)
|
||||
|
||||
// FuncType is a function type.
|
||||
type FuncType struct {
|
||||
In, Out []*Parameter
|
||||
Variadic *Parameter // may be nil
|
||||
}
|
||||
|
||||
func (ft *FuncType) String(pm map[string]string, pkgOverride string) string {
|
||||
args := make([]string, len(ft.In))
|
||||
for i, p := range ft.In {
|
||||
args[i] = p.Type.String(pm, pkgOverride)
|
||||
}
|
||||
if ft.Variadic != nil {
|
||||
args = append(args, "..."+ft.Variadic.Type.String(pm, pkgOverride))
|
||||
}
|
||||
rets := make([]string, len(ft.Out))
|
||||
for i, p := range ft.Out {
|
||||
rets[i] = p.Type.String(pm, pkgOverride)
|
||||
}
|
||||
retString := strings.Join(rets, ", ")
|
||||
if nOut := len(ft.Out); nOut == 1 {
|
||||
retString = " " + retString
|
||||
} else if nOut > 1 {
|
||||
retString = " (" + retString + ")"
|
||||
}
|
||||
return "func(" + strings.Join(args, ", ") + ")" + retString
|
||||
}
|
||||
|
||||
func (ft *FuncType) addImports(im map[string]bool) {
|
||||
for _, p := range ft.In {
|
||||
p.Type.addImports(im)
|
||||
}
|
||||
if ft.Variadic != nil {
|
||||
ft.Variadic.Type.addImports(im)
|
||||
}
|
||||
for _, p := range ft.Out {
|
||||
p.Type.addImports(im)
|
||||
}
|
||||
}
|
||||
|
||||
// MapType is a map type.
|
||||
type MapType struct {
|
||||
Key, Value Type
|
||||
}
|
||||
|
||||
func (mt *MapType) String(pm map[string]string, pkgOverride string) string {
|
||||
return "map[" + mt.Key.String(pm, pkgOverride) + "]" + mt.Value.String(pm, pkgOverride)
|
||||
}
|
||||
|
||||
func (mt *MapType) addImports(im map[string]bool) {
|
||||
mt.Key.addImports(im)
|
||||
mt.Value.addImports(im)
|
||||
}
|
||||
|
||||
// NamedType is an exported type in a package.
|
||||
type NamedType struct {
|
||||
Package string // may be empty
|
||||
Type string
|
||||
}
|
||||
|
||||
func (nt *NamedType) String(pm map[string]string, pkgOverride string) string {
|
||||
if pkgOverride == nt.Package {
|
||||
return nt.Type
|
||||
}
|
||||
prefix := pm[nt.Package]
|
||||
if prefix != "" {
|
||||
return prefix + "." + nt.Type
|
||||
}
|
||||
|
||||
return nt.Type
|
||||
}
|
||||
|
||||
func (nt *NamedType) addImports(im map[string]bool) {
|
||||
if nt.Package != "" {
|
||||
im[nt.Package] = true
|
||||
}
|
||||
}
|
||||
|
||||
// PointerType is a pointer to another type.
|
||||
type PointerType struct {
|
||||
Type Type
|
||||
}
|
||||
|
||||
func (pt *PointerType) String(pm map[string]string, pkgOverride string) string {
|
||||
return "*" + pt.Type.String(pm, pkgOverride)
|
||||
}
|
||||
func (pt *PointerType) addImports(im map[string]bool) { pt.Type.addImports(im) }
|
||||
|
||||
// PredeclaredType is a predeclared type such as "int".
|
||||
type PredeclaredType string
|
||||
|
||||
func (pt PredeclaredType) String(map[string]string, string) string { return string(pt) }
|
||||
func (pt PredeclaredType) addImports(map[string]bool) {}
|
||||
|
||||
// The following code is intended to be called by the program generated by ../reflect.go.
|
||||
|
||||
// InterfaceFromInterfaceType returns a pointer to an interface for the
|
||||
// given reflection interface type.
|
||||
func InterfaceFromInterfaceType(it reflect.Type) (*Interface, error) {
|
||||
if it.Kind() != reflect.Interface {
|
||||
return nil, fmt.Errorf("%v is not an interface", it)
|
||||
}
|
||||
intf := &Interface{}
|
||||
|
||||
for i := 0; i < it.NumMethod(); i++ {
|
||||
mt := it.Method(i)
|
||||
// TODO: need to skip unexported methods? or just raise an error?
|
||||
m := &Method{
|
||||
Name: mt.Name,
|
||||
}
|
||||
|
||||
var err error
|
||||
m.In, m.Variadic, m.Out, err = funcArgsFromType(mt.Type)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
intf.AddMethod(m)
|
||||
}
|
||||
|
||||
return intf, nil
|
||||
}
|
||||
|
||||
// t's Kind must be a reflect.Func.
|
||||
func funcArgsFromType(t reflect.Type) (in []*Parameter, variadic *Parameter, out []*Parameter, err error) {
|
||||
nin := t.NumIn()
|
||||
if t.IsVariadic() {
|
||||
nin--
|
||||
}
|
||||
var p *Parameter
|
||||
for i := 0; i < nin; i++ {
|
||||
p, err = parameterFromType(t.In(i))
|
||||
if err != nil {
|
||||
return
|
||||
}
|
||||
in = append(in, p)
|
||||
}
|
||||
if t.IsVariadic() {
|
||||
p, err = parameterFromType(t.In(nin).Elem())
|
||||
if err != nil {
|
||||
return
|
||||
}
|
||||
variadic = p
|
||||
}
|
||||
for i := 0; i < t.NumOut(); i++ {
|
||||
p, err = parameterFromType(t.Out(i))
|
||||
if err != nil {
|
||||
return
|
||||
}
|
||||
out = append(out, p)
|
||||
}
|
||||
return
|
||||
}
|
||||
|
||||
func parameterFromType(t reflect.Type) (*Parameter, error) {
|
||||
tt, err := typeFromType(t)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return &Parameter{Type: tt}, nil
|
||||
}
|
||||
|
||||
var errorType = reflect.TypeOf((*error)(nil)).Elem()
|
||||
|
||||
var byteType = reflect.TypeOf(byte(0))
|
||||
|
||||
func typeFromType(t reflect.Type) (Type, error) {
|
||||
// Hack workaround for https://golang.org/issue/3853.
|
||||
// This explicit check should not be necessary.
|
||||
if t == byteType {
|
||||
return PredeclaredType("byte"), nil
|
||||
}
|
||||
|
||||
if imp := t.PkgPath(); imp != "" {
|
||||
return &NamedType{
|
||||
Package: impPath(imp),
|
||||
Type: t.Name(),
|
||||
}, nil
|
||||
}
|
||||
|
||||
// only unnamed or predeclared types after here
|
||||
|
||||
// Lots of types have element types. Let's do the parsing and error checking for all of them.
|
||||
var elemType Type
|
||||
switch t.Kind() {
|
||||
case reflect.Array, reflect.Chan, reflect.Map, reflect.Ptr, reflect.Slice:
|
||||
var err error
|
||||
elemType, err = typeFromType(t.Elem())
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
}
|
||||
|
||||
switch t.Kind() {
|
||||
case reflect.Array:
|
||||
return &ArrayType{
|
||||
Len: t.Len(),
|
||||
Type: elemType,
|
||||
}, nil
|
||||
case reflect.Bool, reflect.Int, reflect.Int8, reflect.Int16, reflect.Int32, reflect.Int64,
|
||||
reflect.Uint, reflect.Uint8, reflect.Uint16, reflect.Uint32, reflect.Uint64, reflect.Uintptr,
|
||||
reflect.Float32, reflect.Float64, reflect.Complex64, reflect.Complex128, reflect.String:
|
||||
return PredeclaredType(t.Kind().String()), nil
|
||||
case reflect.Chan:
|
||||
var dir ChanDir
|
||||
switch t.ChanDir() {
|
||||
case reflect.RecvDir:
|
||||
dir = RecvDir
|
||||
case reflect.SendDir:
|
||||
dir = SendDir
|
||||
}
|
||||
return &ChanType{
|
||||
Dir: dir,
|
||||
Type: elemType,
|
||||
}, nil
|
||||
case reflect.Func:
|
||||
in, variadic, out, err := funcArgsFromType(t)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return &FuncType{
|
||||
In: in,
|
||||
Out: out,
|
||||
Variadic: variadic,
|
||||
}, nil
|
||||
case reflect.Interface:
|
||||
// Two special interfaces.
|
||||
if t.NumMethod() == 0 {
|
||||
return PredeclaredType("interface{}"), nil
|
||||
}
|
||||
if t == errorType {
|
||||
return PredeclaredType("error"), nil
|
||||
}
|
||||
case reflect.Map:
|
||||
kt, err := typeFromType(t.Key())
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return &MapType{
|
||||
Key: kt,
|
||||
Value: elemType,
|
||||
}, nil
|
||||
case reflect.Ptr:
|
||||
return &PointerType{
|
||||
Type: elemType,
|
||||
}, nil
|
||||
case reflect.Slice:
|
||||
return &ArrayType{
|
||||
Len: -1,
|
||||
Type: elemType,
|
||||
}, nil
|
||||
case reflect.Struct:
|
||||
if t.NumField() == 0 {
|
||||
return PredeclaredType("struct{}"), nil
|
||||
}
|
||||
}
|
||||
|
||||
// TODO: Struct, UnsafePointer
|
||||
return nil, fmt.Errorf("can't yet turn %v (%v) into a model.Type", t, t.Kind())
|
||||
}
|
||||
|
||||
// impPath sanitizes the package path returned by `PkgPath` method of a reflect Type so that
|
||||
// it is importable. PkgPath might return a path that includes "vendor". These paths do not
|
||||
// compile, so we need to remove everything up to and including "/vendor/".
|
||||
// See https://github.com/golang/go/issues/12019.
|
||||
func impPath(imp string) string {
|
||||
if strings.HasPrefix(imp, "vendor/") {
|
||||
imp = "/" + imp
|
||||
}
|
||||
if i := strings.LastIndex(imp, "/vendor/"); i != -1 {
|
||||
imp = imp[i+len("/vendor/"):]
|
||||
}
|
||||
return imp
|
||||
}
|
||||
|
||||
// ErrorInterface represent built-in error interface.
|
||||
var ErrorInterface = Interface{
|
||||
Name: "error",
|
||||
Methods: []*Method{
|
||||
{
|
||||
Name: "Error",
|
||||
Out: []*Parameter{
|
||||
{
|
||||
Name: "",
|
||||
Type: PredeclaredType("string"),
|
||||
},
|
||||
},
|
||||
},
|
||||
},
|
||||
}
|
|
@ -0,0 +1,644 @@
|
|||
// Copyright 2012 Google Inc.
|
||||
//
|
||||
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||
// you may not use this file except in compliance with the License.
|
||||
// You may obtain a copy of the License at
|
||||
//
|
||||
// http://www.apache.org/licenses/LICENSE-2.0
|
||||
//
|
||||
// Unless required by applicable law or agreed to in writing, software
|
||||
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
|
||||
package main
|
||||
|
||||
// This file contains the model construction by parsing source files.
|
||||
|
||||
import (
|
||||
"errors"
|
||||
"flag"
|
||||
"fmt"
|
||||
"go/ast"
|
||||
"go/build"
|
||||
"go/importer"
|
||||
"go/parser"
|
||||
"go/token"
|
||||
"go/types"
|
||||
"io/ioutil"
|
||||
"log"
|
||||
"path"
|
||||
"path/filepath"
|
||||
"strconv"
|
||||
"strings"
|
||||
|
||||
"github.com/golang/mock/mockgen/model"
|
||||
)
|
||||
|
||||
var (
|
||||
imports = flag.String("imports", "", "(source mode) Comma-separated name=path pairs of explicit imports to use.")
|
||||
auxFiles = flag.String("aux_files", "", "(source mode) Comma-separated pkg=path pairs of auxiliary Go source files.")
|
||||
)
|
||||
|
||||
// sourceMode generates mocks via source file.
|
||||
func sourceMode(source string) (*model.Package, error) {
|
||||
srcDir, err := filepath.Abs(filepath.Dir(source))
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed getting source directory: %v", err)
|
||||
}
|
||||
|
||||
packageImport, err := parsePackageImport(srcDir)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
fs := token.NewFileSet()
|
||||
file, err := parser.ParseFile(fs, source, nil, 0)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed parsing source file %v: %v", source, err)
|
||||
}
|
||||
|
||||
p := &fileParser{
|
||||
fileSet: fs,
|
||||
imports: make(map[string]importedPackage),
|
||||
importedInterfaces: make(map[string]map[string]*ast.InterfaceType),
|
||||
auxInterfaces: make(map[string]map[string]*ast.InterfaceType),
|
||||
srcDir: srcDir,
|
||||
}
|
||||
|
||||
// Handle -imports.
|
||||
dotImports := make(map[string]bool)
|
||||
if *imports != "" {
|
||||
for _, kv := range strings.Split(*imports, ",") {
|
||||
eq := strings.Index(kv, "=")
|
||||
k, v := kv[:eq], kv[eq+1:]
|
||||
if k == "." {
|
||||
dotImports[v] = true
|
||||
} else {
|
||||
p.imports[k] = importedPkg{path: v}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// Handle -aux_files.
|
||||
if err := p.parseAuxFiles(*auxFiles); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
p.addAuxInterfacesFromFile(packageImport, file) // this file
|
||||
|
||||
pkg, err := p.parseFile(packageImport, file)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
for pkgPath := range dotImports {
|
||||
pkg.DotImports = append(pkg.DotImports, pkgPath)
|
||||
}
|
||||
return pkg, nil
|
||||
}
|
||||
|
||||
type importedPackage interface {
|
||||
Path() string
|
||||
Parser() *fileParser
|
||||
}
|
||||
|
||||
type importedPkg struct {
|
||||
path string
|
||||
parser *fileParser
|
||||
}
|
||||
|
||||
func (i importedPkg) Path() string { return i.path }
|
||||
func (i importedPkg) Parser() *fileParser { return i.parser }
|
||||
|
||||
// duplicateImport is a bit of a misnomer. Currently the parser can't
|
||||
// handle cases of multi-file packages importing different packages
|
||||
// under the same name. Often these imports would not be problematic,
|
||||
// so this type lets us defer raising an error unless the package name
|
||||
// is actually used.
|
||||
type duplicateImport struct {
|
||||
name string
|
||||
duplicates []string
|
||||
}
|
||||
|
||||
func (d duplicateImport) Error() string {
|
||||
return fmt.Sprintf("%q is ambiguous because of duplicate imports: %v", d.name, d.duplicates)
|
||||
}
|
||||
|
||||
func (d duplicateImport) Path() string { log.Fatal(d.Error()); return "" }
|
||||
func (d duplicateImport) Parser() *fileParser { log.Fatal(d.Error()); return nil }
|
||||
|
||||
type fileParser struct {
|
||||
fileSet *token.FileSet
|
||||
imports map[string]importedPackage // package name => imported package
|
||||
importedInterfaces map[string]map[string]*ast.InterfaceType // package (or "") => name => interface
|
||||
|
||||
auxFiles []*ast.File
|
||||
auxInterfaces map[string]map[string]*ast.InterfaceType // package (or "") => name => interface
|
||||
|
||||
srcDir string
|
||||
}
|
||||
|
||||
func (p *fileParser) errorf(pos token.Pos, format string, args ...interface{}) error {
|
||||
ps := p.fileSet.Position(pos)
|
||||
format = "%s:%d:%d: " + format
|
||||
args = append([]interface{}{ps.Filename, ps.Line, ps.Column}, args...)
|
||||
return fmt.Errorf(format, args...)
|
||||
}
|
||||
|
||||
func (p *fileParser) parseAuxFiles(auxFiles string) error {
|
||||
auxFiles = strings.TrimSpace(auxFiles)
|
||||
if auxFiles == "" {
|
||||
return nil
|
||||
}
|
||||
for _, kv := range strings.Split(auxFiles, ",") {
|
||||
parts := strings.SplitN(kv, "=", 2)
|
||||
if len(parts) != 2 {
|
||||
return fmt.Errorf("bad aux file spec: %v", kv)
|
||||
}
|
||||
pkg, fpath := parts[0], parts[1]
|
||||
|
||||
file, err := parser.ParseFile(p.fileSet, fpath, nil, 0)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
p.auxFiles = append(p.auxFiles, file)
|
||||
p.addAuxInterfacesFromFile(pkg, file)
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func (p *fileParser) addAuxInterfacesFromFile(pkg string, file *ast.File) {
|
||||
if _, ok := p.auxInterfaces[pkg]; !ok {
|
||||
p.auxInterfaces[pkg] = make(map[string]*ast.InterfaceType)
|
||||
}
|
||||
for ni := range iterInterfaces(file) {
|
||||
p.auxInterfaces[pkg][ni.name.Name] = ni.it
|
||||
}
|
||||
}
|
||||
|
||||
// parseFile loads all file imports and auxiliary files import into the
|
||||
// fileParser, parses all file interfaces and returns package model.
|
||||
func (p *fileParser) parseFile(importPath string, file *ast.File) (*model.Package, error) {
|
||||
allImports, dotImports := importsOfFile(file)
|
||||
// Don't stomp imports provided by -imports. Those should take precedence.
|
||||
for pkg, pkgI := range allImports {
|
||||
if _, ok := p.imports[pkg]; !ok {
|
||||
p.imports[pkg] = pkgI
|
||||
}
|
||||
}
|
||||
// Add imports from auxiliary files, which might be needed for embedded interfaces.
|
||||
// Don't stomp any other imports.
|
||||
for _, f := range p.auxFiles {
|
||||
auxImports, _ := importsOfFile(f)
|
||||
for pkg, pkgI := range auxImports {
|
||||
if _, ok := p.imports[pkg]; !ok {
|
||||
p.imports[pkg] = pkgI
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
var is []*model.Interface
|
||||
for ni := range iterInterfaces(file) {
|
||||
i, err := p.parseInterface(ni.name.String(), importPath, ni.it)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
is = append(is, i)
|
||||
}
|
||||
return &model.Package{
|
||||
Name: file.Name.String(),
|
||||
PkgPath: importPath,
|
||||
Interfaces: is,
|
||||
DotImports: dotImports,
|
||||
}, nil
|
||||
}
|
||||
|
||||
// parsePackage loads package specified by path, parses it and returns
|
||||
// a new fileParser with the parsed imports and interfaces.
|
||||
func (p *fileParser) parsePackage(path string) (*fileParser, error) {
|
||||
newP := &fileParser{
|
||||
fileSet: token.NewFileSet(),
|
||||
imports: make(map[string]importedPackage),
|
||||
importedInterfaces: make(map[string]map[string]*ast.InterfaceType),
|
||||
auxInterfaces: make(map[string]map[string]*ast.InterfaceType),
|
||||
srcDir: p.srcDir,
|
||||
}
|
||||
|
||||
var pkgs map[string]*ast.Package
|
||||
if imp, err := build.Import(path, newP.srcDir, build.FindOnly); err != nil {
|
||||
return nil, err
|
||||
} else if pkgs, err = parser.ParseDir(newP.fileSet, imp.Dir, nil, 0); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
for _, pkg := range pkgs {
|
||||
file := ast.MergePackageFiles(pkg, ast.FilterFuncDuplicates|ast.FilterUnassociatedComments|ast.FilterImportDuplicates)
|
||||
if _, ok := newP.importedInterfaces[path]; !ok {
|
||||
newP.importedInterfaces[path] = make(map[string]*ast.InterfaceType)
|
||||
}
|
||||
for ni := range iterInterfaces(file) {
|
||||
newP.importedInterfaces[path][ni.name.Name] = ni.it
|
||||
}
|
||||
imports, _ := importsOfFile(file)
|
||||
for pkgName, pkgI := range imports {
|
||||
newP.imports[pkgName] = pkgI
|
||||
}
|
||||
}
|
||||
return newP, nil
|
||||
}
|
||||
|
||||
func (p *fileParser) parseInterface(name, pkg string, it *ast.InterfaceType) (*model.Interface, error) {
|
||||
iface := &model.Interface{Name: name}
|
||||
for _, field := range it.Methods.List {
|
||||
switch v := field.Type.(type) {
|
||||
case *ast.FuncType:
|
||||
if nn := len(field.Names); nn != 1 {
|
||||
return nil, fmt.Errorf("expected one name for interface %v, got %d", iface.Name, nn)
|
||||
}
|
||||
m := &model.Method{
|
||||
Name: field.Names[0].String(),
|
||||
}
|
||||
var err error
|
||||
m.In, m.Variadic, m.Out, err = p.parseFunc(pkg, v)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
iface.AddMethod(m)
|
||||
case *ast.Ident:
|
||||
// Embedded interface in this package.
|
||||
embeddedIfaceType := p.auxInterfaces[pkg][v.String()]
|
||||
if embeddedIfaceType == nil {
|
||||
embeddedIfaceType = p.importedInterfaces[pkg][v.String()]
|
||||
}
|
||||
|
||||
var embeddedIface *model.Interface
|
||||
if embeddedIfaceType != nil {
|
||||
var err error
|
||||
embeddedIface, err = p.parseInterface(v.String(), pkg, embeddedIfaceType)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
} else {
|
||||
// This is built-in error interface.
|
||||
if v.String() == model.ErrorInterface.Name {
|
||||
embeddedIface = &model.ErrorInterface
|
||||
} else {
|
||||
return nil, p.errorf(v.Pos(), "unknown embedded interface %s", v.String())
|
||||
}
|
||||
}
|
||||
// Copy the methods.
|
||||
for _, m := range embeddedIface.Methods {
|
||||
iface.AddMethod(m)
|
||||
}
|
||||
case *ast.SelectorExpr:
|
||||
// Embedded interface in another package.
|
||||
filePkg, sel := v.X.(*ast.Ident).String(), v.Sel.String()
|
||||
embeddedPkg, ok := p.imports[filePkg]
|
||||
if !ok {
|
||||
return nil, p.errorf(v.X.Pos(), "unknown package %s", filePkg)
|
||||
}
|
||||
|
||||
var embeddedIface *model.Interface
|
||||
var err error
|
||||
embeddedIfaceType := p.auxInterfaces[filePkg][sel]
|
||||
if embeddedIfaceType != nil {
|
||||
embeddedIface, err = p.parseInterface(sel, filePkg, embeddedIfaceType)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
} else {
|
||||
path := embeddedPkg.Path()
|
||||
parser := embeddedPkg.Parser()
|
||||
if parser == nil {
|
||||
ip, err := p.parsePackage(path)
|
||||
if err != nil {
|
||||
return nil, p.errorf(v.Pos(), "could not parse package %s: %v", path, err)
|
||||
}
|
||||
parser = ip
|
||||
p.imports[filePkg] = importedPkg{
|
||||
path: embeddedPkg.Path(),
|
||||
parser: parser,
|
||||
}
|
||||
}
|
||||
if embeddedIfaceType = parser.importedInterfaces[path][sel]; embeddedIfaceType == nil {
|
||||
return nil, p.errorf(v.Pos(), "unknown embedded interface %s.%s", path, sel)
|
||||
}
|
||||
embeddedIface, err = parser.parseInterface(sel, path, embeddedIfaceType)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
}
|
||||
// Copy the methods.
|
||||
// TODO: apply shadowing rules.
|
||||
for _, m := range embeddedIface.Methods {
|
||||
iface.AddMethod(m)
|
||||
}
|
||||
default:
|
||||
return nil, fmt.Errorf("don't know how to mock method of type %T", field.Type)
|
||||
}
|
||||
}
|
||||
return iface, nil
|
||||
}
|
||||
|
||||
func (p *fileParser) parseFunc(pkg string, f *ast.FuncType) (inParam []*model.Parameter, variadic *model.Parameter, outParam []*model.Parameter, err error) {
|
||||
if f.Params != nil {
|
||||
regParams := f.Params.List
|
||||
if isVariadic(f) {
|
||||
n := len(regParams)
|
||||
varParams := regParams[n-1:]
|
||||
regParams = regParams[:n-1]
|
||||
vp, err := p.parseFieldList(pkg, varParams)
|
||||
if err != nil {
|
||||
return nil, nil, nil, p.errorf(varParams[0].Pos(), "failed parsing variadic argument: %v", err)
|
||||
}
|
||||
variadic = vp[0]
|
||||
}
|
||||
inParam, err = p.parseFieldList(pkg, regParams)
|
||||
if err != nil {
|
||||
return nil, nil, nil, p.errorf(f.Pos(), "failed parsing arguments: %v", err)
|
||||
}
|
||||
}
|
||||
if f.Results != nil {
|
||||
outParam, err = p.parseFieldList(pkg, f.Results.List)
|
||||
if err != nil {
|
||||
return nil, nil, nil, p.errorf(f.Pos(), "failed parsing returns: %v", err)
|
||||
}
|
||||
}
|
||||
return
|
||||
}
|
||||
|
||||
func (p *fileParser) parseFieldList(pkg string, fields []*ast.Field) ([]*model.Parameter, error) {
|
||||
nf := 0
|
||||
for _, f := range fields {
|
||||
nn := len(f.Names)
|
||||
if nn == 0 {
|
||||
nn = 1 // anonymous parameter
|
||||
}
|
||||
nf += nn
|
||||
}
|
||||
if nf == 0 {
|
||||
return nil, nil
|
||||
}
|
||||
ps := make([]*model.Parameter, nf)
|
||||
i := 0 // destination index
|
||||
for _, f := range fields {
|
||||
t, err := p.parseType(pkg, f.Type)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
if len(f.Names) == 0 {
|
||||
// anonymous arg
|
||||
ps[i] = &model.Parameter{Type: t}
|
||||
i++
|
||||
continue
|
||||
}
|
||||
for _, name := range f.Names {
|
||||
ps[i] = &model.Parameter{Name: name.Name, Type: t}
|
||||
i++
|
||||
}
|
||||
}
|
||||
return ps, nil
|
||||
}
|
||||
|
||||
func (p *fileParser) parseType(pkg string, typ ast.Expr) (model.Type, error) {
|
||||
switch v := typ.(type) {
|
||||
case *ast.ArrayType:
|
||||
ln := -1
|
||||
if v.Len != nil {
|
||||
var value string
|
||||
switch val := v.Len.(type) {
|
||||
case (*ast.BasicLit):
|
||||
value = val.Value
|
||||
case (*ast.Ident):
|
||||
// when the length is a const defined locally
|
||||
value = val.Obj.Decl.(*ast.ValueSpec).Values[0].(*ast.BasicLit).Value
|
||||
case (*ast.SelectorExpr):
|
||||
// when the length is a const defined in an external package
|
||||
usedPkg, err := importer.Default().Import(fmt.Sprintf("%s", val.X))
|
||||
if err != nil {
|
||||
return nil, p.errorf(v.Len.Pos(), "unknown package in array length: %v", err)
|
||||
}
|
||||
ev, err := types.Eval(token.NewFileSet(), usedPkg, token.NoPos, val.Sel.Name)
|
||||
if err != nil {
|
||||
return nil, p.errorf(v.Len.Pos(), "unknown constant in array length: %v", err)
|
||||
}
|
||||
value = ev.Value.String()
|
||||
}
|
||||
|
||||
x, err := strconv.Atoi(value)
|
||||
if err != nil {
|
||||
return nil, p.errorf(v.Len.Pos(), "bad array size: %v", err)
|
||||
}
|
||||
ln = x
|
||||
}
|
||||
t, err := p.parseType(pkg, v.Elt)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return &model.ArrayType{Len: ln, Type: t}, nil
|
||||
case *ast.ChanType:
|
||||
t, err := p.parseType(pkg, v.Value)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
var dir model.ChanDir
|
||||
if v.Dir == ast.SEND {
|
||||
dir = model.SendDir
|
||||
}
|
||||
if v.Dir == ast.RECV {
|
||||
dir = model.RecvDir
|
||||
}
|
||||
return &model.ChanType{Dir: dir, Type: t}, nil
|
||||
case *ast.Ellipsis:
|
||||
// assume we're parsing a variadic argument
|
||||
return p.parseType(pkg, v.Elt)
|
||||
case *ast.FuncType:
|
||||
in, variadic, out, err := p.parseFunc(pkg, v)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return &model.FuncType{In: in, Out: out, Variadic: variadic}, nil
|
||||
case *ast.Ident:
|
||||
if v.IsExported() {
|
||||
// `pkg` may be an aliased imported pkg
|
||||
// if so, patch the import w/ the fully qualified import
|
||||
maybeImportedPkg, ok := p.imports[pkg]
|
||||
if ok {
|
||||
pkg = maybeImportedPkg.Path()
|
||||
}
|
||||
// assume type in this package
|
||||
return &model.NamedType{Package: pkg, Type: v.Name}, nil
|
||||
}
|
||||
|
||||
// assume predeclared type
|
||||
return model.PredeclaredType(v.Name), nil
|
||||
case *ast.InterfaceType:
|
||||
if v.Methods != nil && len(v.Methods.List) > 0 {
|
||||
return nil, p.errorf(v.Pos(), "can't handle non-empty unnamed interface types")
|
||||
}
|
||||
return model.PredeclaredType("interface{}"), nil
|
||||
case *ast.MapType:
|
||||
key, err := p.parseType(pkg, v.Key)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
value, err := p.parseType(pkg, v.Value)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return &model.MapType{Key: key, Value: value}, nil
|
||||
case *ast.SelectorExpr:
|
||||
pkgName := v.X.(*ast.Ident).String()
|
||||
pkg, ok := p.imports[pkgName]
|
||||
if !ok {
|
||||
return nil, p.errorf(v.Pos(), "unknown package %q", pkgName)
|
||||
}
|
||||
return &model.NamedType{Package: pkg.Path(), Type: v.Sel.String()}, nil
|
||||
case *ast.StarExpr:
|
||||
t, err := p.parseType(pkg, v.X)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return &model.PointerType{Type: t}, nil
|
||||
case *ast.StructType:
|
||||
if v.Fields != nil && len(v.Fields.List) > 0 {
|
||||
return nil, p.errorf(v.Pos(), "can't handle non-empty unnamed struct types")
|
||||
}
|
||||
return model.PredeclaredType("struct{}"), nil
|
||||
case *ast.ParenExpr:
|
||||
return p.parseType(pkg, v.X)
|
||||
}
|
||||
|
||||
return nil, fmt.Errorf("don't know how to parse type %T", typ)
|
||||
}
|
||||
|
||||
// importsOfFile returns a map of package name to import path
|
||||
// of the imports in file.
|
||||
func importsOfFile(file *ast.File) (normalImports map[string]importedPackage, dotImports []string) {
|
||||
var importPaths []string
|
||||
for _, is := range file.Imports {
|
||||
if is.Name != nil {
|
||||
continue
|
||||
}
|
||||
importPath := is.Path.Value[1 : len(is.Path.Value)-1] // remove quotes
|
||||
importPaths = append(importPaths, importPath)
|
||||
}
|
||||
packagesName := createPackageMap(importPaths)
|
||||
normalImports = make(map[string]importedPackage)
|
||||
dotImports = make([]string, 0)
|
||||
for _, is := range file.Imports {
|
||||
var pkgName string
|
||||
importPath := is.Path.Value[1 : len(is.Path.Value)-1] // remove quotes
|
||||
|
||||
if is.Name != nil {
|
||||
// Named imports are always certain.
|
||||
if is.Name.Name == "_" {
|
||||
continue
|
||||
}
|
||||
pkgName = is.Name.Name
|
||||
} else {
|
||||
pkg, ok := packagesName[importPath]
|
||||
if !ok {
|
||||
// Fallback to import path suffix. Note that this is uncertain.
|
||||
_, last := path.Split(importPath)
|
||||
// If the last path component has dots, the first dot-delimited
|
||||
// field is used as the name.
|
||||
pkgName = strings.SplitN(last, ".", 2)[0]
|
||||
} else {
|
||||
pkgName = pkg
|
||||
}
|
||||
}
|
||||
|
||||
if pkgName == "." {
|
||||
dotImports = append(dotImports, importPath)
|
||||
} else {
|
||||
if pkg, ok := normalImports[pkgName]; ok {
|
||||
switch p := pkg.(type) {
|
||||
case duplicateImport:
|
||||
normalImports[pkgName] = duplicateImport{
|
||||
name: p.name,
|
||||
duplicates: append([]string{importPath}, p.duplicates...),
|
||||
}
|
||||
case importedPkg:
|
||||
normalImports[pkgName] = duplicateImport{
|
||||
name: pkgName,
|
||||
duplicates: []string{p.path, importPath},
|
||||
}
|
||||
}
|
||||
} else {
|
||||
normalImports[pkgName] = importedPkg{path: importPath}
|
||||
}
|
||||
}
|
||||
}
|
||||
return
|
||||
}
|
||||
|
||||
type namedInterface struct {
|
||||
name *ast.Ident
|
||||
it *ast.InterfaceType
|
||||
}
|
||||
|
||||
// Create an iterator over all interfaces in file.
|
||||
func iterInterfaces(file *ast.File) <-chan namedInterface {
|
||||
ch := make(chan namedInterface)
|
||||
go func() {
|
||||
for _, decl := range file.Decls {
|
||||
gd, ok := decl.(*ast.GenDecl)
|
||||
if !ok || gd.Tok != token.TYPE {
|
||||
continue
|
||||
}
|
||||
for _, spec := range gd.Specs {
|
||||
ts, ok := spec.(*ast.TypeSpec)
|
||||
if !ok {
|
||||
continue
|
||||
}
|
||||
it, ok := ts.Type.(*ast.InterfaceType)
|
||||
if !ok {
|
||||
continue
|
||||
}
|
||||
|
||||
ch <- namedInterface{ts.Name, it}
|
||||
}
|
||||
}
|
||||
close(ch)
|
||||
}()
|
||||
return ch
|
||||
}
|
||||
|
||||
// isVariadic returns whether the function is variadic.
|
||||
func isVariadic(f *ast.FuncType) bool {
|
||||
nargs := len(f.Params.List)
|
||||
if nargs == 0 {
|
||||
return false
|
||||
}
|
||||
_, ok := f.Params.List[nargs-1].Type.(*ast.Ellipsis)
|
||||
return ok
|
||||
}
|
||||
|
||||
// packageNameOfDir get package import path via dir
|
||||
func packageNameOfDir(srcDir string) (string, error) {
|
||||
files, err := ioutil.ReadDir(srcDir)
|
||||
if err != nil {
|
||||
log.Fatal(err)
|
||||
}
|
||||
|
||||
var goFilePath string
|
||||
for _, file := range files {
|
||||
if !file.IsDir() && strings.HasSuffix(file.Name(), ".go") {
|
||||
goFilePath = file.Name()
|
||||
break
|
||||
}
|
||||
}
|
||||
if goFilePath == "" {
|
||||
return "", fmt.Errorf("go source file not found %s", srcDir)
|
||||
}
|
||||
|
||||
packageImport, err := parsePackageImport(srcDir)
|
||||
if err != nil {
|
||||
return "", err
|
||||
}
|
||||
return packageImport, nil
|
||||
}
|
||||
|
||||
var errOutsideGoPath = errors.New("source directory is outside GOPATH")
|
|
@ -0,0 +1,256 @@
|
|||
// Copyright 2012 Google Inc.
|
||||
//
|
||||
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||
// you may not use this file except in compliance with the License.
|
||||
// You may obtain a copy of the License at
|
||||
//
|
||||
// http://www.apache.org/licenses/LICENSE-2.0
|
||||
//
|
||||
// Unless required by applicable law or agreed to in writing, software
|
||||
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
|
||||
package main
|
||||
|
||||
// This file contains the model construction by reflection.
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"encoding/gob"
|
||||
"flag"
|
||||
"fmt"
|
||||
"go/build"
|
||||
"io"
|
||||
"io/ioutil"
|
||||
"log"
|
||||
"os"
|
||||
"os/exec"
|
||||
"path/filepath"
|
||||
"runtime"
|
||||
"strings"
|
||||
"text/template"
|
||||
|
||||
"github.com/golang/mock/mockgen/model"
|
||||
)
|
||||
|
||||
var (
|
||||
progOnly = flag.Bool("prog_only", false, "(reflect mode) Only generate the reflection program; write it to stdout and exit.")
|
||||
execOnly = flag.String("exec_only", "", "(reflect mode) If set, execute this reflection program.")
|
||||
buildFlags = flag.String("build_flags", "", "(reflect mode) Additional flags for go build.")
|
||||
)
|
||||
|
||||
// reflectMode generates mocks via reflection on an interface.
|
||||
func reflectMode(importPath string, symbols []string) (*model.Package, error) {
|
||||
if *execOnly != "" {
|
||||
return run(*execOnly)
|
||||
}
|
||||
|
||||
program, err := writeProgram(importPath, symbols)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
if *progOnly {
|
||||
if _, err := os.Stdout.Write(program); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
os.Exit(0)
|
||||
}
|
||||
|
||||
wd, _ := os.Getwd()
|
||||
|
||||
// Try to run the reflection program in the current working directory.
|
||||
if p, err := runInDir(program, wd); err == nil {
|
||||
return p, nil
|
||||
}
|
||||
|
||||
// Try to run the program in the same directory as the input package.
|
||||
if p, err := build.Import(importPath, wd, build.FindOnly); err == nil {
|
||||
dir := p.Dir
|
||||
if p, err := runInDir(program, dir); err == nil {
|
||||
return p, nil
|
||||
}
|
||||
}
|
||||
|
||||
// Try to run it in a standard temp directory.
|
||||
return runInDir(program, "")
|
||||
}
|
||||
|
||||
func writeProgram(importPath string, symbols []string) ([]byte, error) {
|
||||
var program bytes.Buffer
|
||||
data := reflectData{
|
||||
ImportPath: importPath,
|
||||
Symbols: symbols,
|
||||
}
|
||||
if err := reflectProgram.Execute(&program, &data); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return program.Bytes(), nil
|
||||
}
|
||||
|
||||
// run the given program and parse the output as a model.Package.
|
||||
func run(program string) (*model.Package, error) {
|
||||
f, err := ioutil.TempFile("", "")
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
filename := f.Name()
|
||||
defer os.Remove(filename)
|
||||
if err := f.Close(); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
// Run the program.
|
||||
cmd := exec.Command(program, "-output", filename)
|
||||
cmd.Stdout = os.Stdout
|
||||
cmd.Stderr = os.Stderr
|
||||
if err := cmd.Run(); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
f, err = os.Open(filename)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
// Process output.
|
||||
var pkg model.Package
|
||||
if err := gob.NewDecoder(f).Decode(&pkg); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
if err := f.Close(); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
return &pkg, nil
|
||||
}
|
||||
|
||||
// runInDir writes the given program into the given dir, runs it there, and
|
||||
// parses the output as a model.Package.
|
||||
func runInDir(program []byte, dir string) (*model.Package, error) {
|
||||
// We use TempDir instead of TempFile so we can control the filename.
|
||||
tmpDir, err := ioutil.TempDir(dir, "gomock_reflect_")
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
defer func() {
|
||||
if err := os.RemoveAll(tmpDir); err != nil {
|
||||
log.Printf("failed to remove temp directory: %s", err)
|
||||
}
|
||||
}()
|
||||
const progSource = "prog.go"
|
||||
var progBinary = "prog.bin"
|
||||
if runtime.GOOS == "windows" {
|
||||
// Windows won't execute a program unless it has a ".exe" suffix.
|
||||
progBinary += ".exe"
|
||||
}
|
||||
|
||||
if err := ioutil.WriteFile(filepath.Join(tmpDir, progSource), program, 0600); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
cmdArgs := []string{}
|
||||
cmdArgs = append(cmdArgs, "build")
|
||||
if *buildFlags != "" {
|
||||
cmdArgs = append(cmdArgs, strings.Split(*buildFlags, " ")...)
|
||||
}
|
||||
cmdArgs = append(cmdArgs, "-o", progBinary, progSource)
|
||||
|
||||
// Build the program.
|
||||
buf := bytes.NewBuffer(nil)
|
||||
cmd := exec.Command("go", cmdArgs...)
|
||||
cmd.Dir = tmpDir
|
||||
cmd.Stdout = os.Stdout
|
||||
cmd.Stderr = io.MultiWriter(os.Stderr, buf)
|
||||
if err := cmd.Run(); err != nil {
|
||||
sErr := buf.String()
|
||||
if strings.Contains(sErr, `cannot find package "."`) &&
|
||||
strings.Contains(sErr, "github.com/golang/mock/mockgen/model") {
|
||||
fmt.Fprint(os.Stderr, "Please reference the steps in the README to fix this error:\n\thttps://github.com/golang/mock#reflect-vendoring-error.")
|
||||
return nil, err
|
||||
}
|
||||
return nil, err
|
||||
}
|
||||
|
||||
return run(filepath.Join(tmpDir, progBinary))
|
||||
}
|
||||
|
||||
type reflectData struct {
|
||||
ImportPath string
|
||||
Symbols []string
|
||||
}
|
||||
|
||||
// This program reflects on an interface value, and prints the
|
||||
// gob encoding of a model.Package to standard output.
|
||||
// JSON doesn't work because of the model.Type interface.
|
||||
var reflectProgram = template.Must(template.New("program").Parse(`
|
||||
package main
|
||||
|
||||
import (
|
||||
"encoding/gob"
|
||||
"flag"
|
||||
"fmt"
|
||||
"os"
|
||||
"path"
|
||||
"reflect"
|
||||
|
||||
"github.com/golang/mock/mockgen/model"
|
||||
|
||||
pkg_ {{printf "%q" .ImportPath}}
|
||||
)
|
||||
|
||||
var output = flag.String("output", "", "The output file name, or empty to use stdout.")
|
||||
|
||||
func main() {
|
||||
flag.Parse()
|
||||
|
||||
its := []struct{
|
||||
sym string
|
||||
typ reflect.Type
|
||||
}{
|
||||
{{range .Symbols}}
|
||||
{ {{printf "%q" .}}, reflect.TypeOf((*pkg_.{{.}})(nil)).Elem()},
|
||||
{{end}}
|
||||
}
|
||||
pkg := &model.Package{
|
||||
// NOTE: This behaves contrary to documented behaviour if the
|
||||
// package name is not the final component of the import path.
|
||||
// The reflect package doesn't expose the package name, though.
|
||||
Name: path.Base({{printf "%q" .ImportPath}}),
|
||||
}
|
||||
|
||||
for _, it := range its {
|
||||
intf, err := model.InterfaceFromInterfaceType(it.typ)
|
||||
if err != nil {
|
||||
fmt.Fprintf(os.Stderr, "Reflection: %v\n", err)
|
||||
os.Exit(1)
|
||||
}
|
||||
intf.Name = it.sym
|
||||
pkg.Interfaces = append(pkg.Interfaces, intf)
|
||||
}
|
||||
|
||||
outfile := os.Stdout
|
||||
if len(*output) != 0 {
|
||||
var err error
|
||||
outfile, err = os.Create(*output)
|
||||
if err != nil {
|
||||
fmt.Fprintf(os.Stderr, "failed to open output file %q", *output)
|
||||
}
|
||||
defer func() {
|
||||
if err := outfile.Close(); err != nil {
|
||||
fmt.Fprintf(os.Stderr, "failed to close output file %q", *output)
|
||||
os.Exit(1)
|
||||
}
|
||||
}()
|
||||
}
|
||||
|
||||
if err := gob.NewEncoder(outfile).Encode(pkg); err != nil {
|
||||
fmt.Fprintf(os.Stderr, "gob encode: %v\n", err)
|
||||
os.Exit(1)
|
||||
}
|
||||
}
|
||||
`))
|
|
@ -0,0 +1,26 @@
|
|||
// Copyright 2019 Google LLC
|
||||
//
|
||||
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||
// you may not use this file except in compliance with the License.
|
||||
// You may obtain a copy of the License at
|
||||
//
|
||||
// http://www.apache.org/licenses/LICENSE-2.0
|
||||
//
|
||||
// Unless required by applicable law or agreed to in writing, software
|
||||
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
|
||||
// +build !go1.12
|
||||
|
||||
package main
|
||||
|
||||
import (
|
||||
"log"
|
||||
)
|
||||
|
||||
func printModuleVersion() {
|
||||
log.Printf("No version information is available for Mockgen compiled with " +
|
||||
"version 1.11")
|
||||
}
|
|
@ -0,0 +1,35 @@
|
|||
// Copyright 2019 Google LLC
|
||||
//
|
||||
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||
// you may not use this file except in compliance with the License.
|
||||
// You may obtain a copy of the License at
|
||||
//
|
||||
// http://www.apache.org/licenses/LICENSE-2.0
|
||||
//
|
||||
// Unless required by applicable law or agreed to in writing, software
|
||||
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
//
|
||||
|
||||
// +build go1.12
|
||||
|
||||
package main
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"log"
|
||||
"runtime/debug"
|
||||
)
|
||||
|
||||
func printModuleVersion() {
|
||||
if bi, exists := debug.ReadBuildInfo(); exists {
|
||||
fmt.Println(bi.Main.Version)
|
||||
} else {
|
||||
log.Printf("No version information found. Make sure to use " +
|
||||
"GO111MODULE=on when running 'go get' in order to use specific " +
|
||||
"version of the binary.")
|
||||
}
|
||||
|
||||
}
|
|
@ -10,7 +10,7 @@
|
|||
//
|
||||
// A CIDv1 has four parts:
|
||||
//
|
||||
// <cidv1> ::= <multibase-prefix><cid-version><multicodec-packed-content-type><multihash-content-address>
|
||||
// <cidv1> ::= <multibase-prefix><cid-version><multicodec-packed-content-type><multihash-content-address>
|
||||
//
|
||||
// As shown above, the CID implementation relies heavily on Multiformats,
|
||||
// particularly Multibase
|
||||
|
@ -181,10 +181,19 @@ func Parse(v interface{}) (Cid, error) {
|
|||
}
|
||||
}
|
||||
|
||||
// MustParse calls Parse but will panic on error.
|
||||
func MustParse(v interface{}) Cid {
|
||||
c, err := Parse(v)
|
||||
if err != nil {
|
||||
panic(err)
|
||||
}
|
||||
return c
|
||||
}
|
||||
|
||||
// Decode parses a Cid-encoded string and returns a Cid object.
|
||||
// For CidV1, a Cid-encoded string is primarily a multibase string:
|
||||
//
|
||||
// <multibase-type-code><base-encoded-string>
|
||||
// <multibase-type-code><base-encoded-string>
|
||||
//
|
||||
// The base-encoded string represents a:
|
||||
//
|
||||
|
@ -240,7 +249,7 @@ func ExtractEncoding(v string) (mbase.Encoding, error) {
|
|||
// Cast takes a Cid data slice, parses it and returns a Cid.
|
||||
// For CidV1, the data buffer is in the form:
|
||||
//
|
||||
// <version><codec-type><multihash>
|
||||
// <version><codec-type><multihash>
|
||||
//
|
||||
// CidV0 are also supported. In particular, data buffers starting
|
||||
// with length 34 bytes, which starts with bytes [18,32...] are considered
|
||||
|
@ -369,7 +378,13 @@ func (c Cid) Hash() mh.Multihash {
|
|||
// Bytes returns the byte representation of a Cid.
|
||||
// The output of bytes can be parsed back into a Cid
|
||||
// with Cast().
|
||||
//
|
||||
// If c.Defined() == false, it return a nil slice and may not
|
||||
// be parsable with Cast().
|
||||
func (c Cid) Bytes() []byte {
|
||||
if !c.Defined() {
|
||||
return nil
|
||||
}
|
||||
return []byte(c.str)
|
||||
}
|
||||
|
||||
|
@ -450,7 +465,7 @@ func (c *Cid) UnmarshalJSON(b []byte) error {
|
|||
|
||||
// MarshalJSON procudes a JSON representation of a Cid, which looks as follows:
|
||||
//
|
||||
// { "/": "<cid-string>" }
|
||||
// { "/": "<cid-string>" }
|
||||
//
|
||||
// Note that this formatting comes from the IPLD specification
|
||||
// (https://github.com/ipld/specs/tree/master/ipld)
|
||||
|
@ -507,7 +522,8 @@ func (c Cid) Prefix() Prefix {
|
|||
// and the Multihash length. It does not contains
|
||||
// any actual content information.
|
||||
// NOTE: The use -1 in MhLength to mean default length is deprecated,
|
||||
// use the V0Builder or V1Builder structures instead
|
||||
//
|
||||
// use the V0Builder or V1Builder structures instead
|
||||
type Prefix struct {
|
||||
Version uint64
|
||||
Codec uint64
|
||||
|
@ -546,7 +562,7 @@ func (p Prefix) Sum(data []byte) (Cid, error) {
|
|||
|
||||
// Bytes returns a byte representation of a Prefix. It looks like:
|
||||
//
|
||||
// <version><codec><mh-type><mh-length>
|
||||
// <version><codec><mh-type><mh-length>
|
||||
func (p Prefix) Bytes() []byte {
|
||||
size := varint.UvarintSize(p.Version)
|
||||
size += varint.UvarintSize(p.Codec)
|
||||
|
|
|
@ -1,5 +1,4 @@
|
|||
//go:build gofuzz
|
||||
// +build gofuzz
|
||||
|
||||
package cid
|
||||
|
||||
|
|
|
@ -1,3 +1,3 @@
|
|||
{
|
||||
"version": "v0.2.0"
|
||||
"version": "v0.3.2"
|
||||
}
|
||||
|
|
|
@ -1 +0,0 @@
|
|||
*.swp
|
|
@ -1,21 +0,0 @@
|
|||
The MIT License
|
||||
|
||||
Copyright (c) 2016 Juan Batiz-Benet
|
||||
|
||||
Permission is hereby granted, free of charge, to any person obtaining a copy
|
||||
of this software and associated documentation files (the "Software"), to deal
|
||||
in the Software without restriction, including without limitation the rights
|
||||
to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
|
||||
copies of the Software, and to permit persons to whom the Software is
|
||||
furnished to do so, subject to the following conditions:
|
||||
|
||||
The above copyright notice and this permission notice shall be included in
|
||||
all copies or substantial portions of the Software.
|
||||
|
||||
THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
|
||||
IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
|
||||
FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
|
||||
AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
|
||||
LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
|
||||
OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN
|
||||
THE SOFTWARE.
|
|
@ -1,47 +0,0 @@
|
|||
# go-datastore
|
||||
|
||||
[![](https://img.shields.io/badge/made%20by-Protocol%20Labs-blue.svg?style=flat-square)](http://ipn.io)
|
||||
[![](https://img.shields.io/badge/project-IPFS-blue.svg?style=flat-square)](http://ipfs.io/)
|
||||
[![](https://img.shields.io/badge/freenode-%23ipfs-blue.svg?style=flat-square)](http://webchat.freenode.net/?channels=%23ipfs)
|
||||
[![standard-readme compliant](https://img.shields.io/badge/standard--readme-OK-green.svg?style=flat-square)](https://github.com/RichardLitt/standard-readme)
|
||||
[![GoDoc](https://godoc.org/github.com/ipfs/go-datastore?status.svg)](https://godoc.org/github.com/ipfs/go-datastore)
|
||||
|
||||
> key-value datastore interfaces
|
||||
|
||||
## Lead Maintainer
|
||||
|
||||
[Steven Allen](https://github.com/Stebalien)
|
||||
|
||||
## Table of Contents
|
||||
|
||||
- [Background](#background)
|
||||
- [Documentation](#documentation)
|
||||
- [Contribute](#contribute)
|
||||
- [License](#license)
|
||||
|
||||
## Background
|
||||
|
||||
Datastore is a generic layer of abstraction for data store and database access. It is a simple API with the aim to enable application development in a datastore-agnostic way, allowing datastores to be swapped seamlessly without changing application code. Thus, one can leverage different datastores with different strengths without committing the application to one datastore throughout its lifetime.
|
||||
|
||||
In addition, grouped datastores significantly simplify interesting data access patterns (such as caching and sharding).
|
||||
|
||||
Based on [datastore.py](https://github.com/datastore/datastore).
|
||||
|
||||
## Documentation
|
||||
|
||||
https://godoc.org/github.com/ipfs/go-datastore
|
||||
|
||||
## Contribute
|
||||
|
||||
Feel free to join in. All welcome. Open an [issue](https://github.com/ipfs/go-datastore/issues)!
|
||||
|
||||
This repository falls under the IPFS [Code of Conduct](https://github.com/ipfs/community/blob/master/code-of-conduct.md).
|
||||
|
||||
### Want to hack on IPFS?
|
||||
|
||||
[![](https://cdn.rawgit.com/jbenet/contribute-ipfs-gif/master/img/contribute.gif)](https://github.com/ipfs/community/blob/master/contributing.md)
|
||||
|
||||
## License
|
||||
|
||||
MIT
|
||||
|
|
@ -1,304 +0,0 @@
|
|||
package datastore
|
||||
|
||||
import (
|
||||
"context"
|
||||
"log"
|
||||
|
||||
dsq "github.com/ipfs/go-datastore/query"
|
||||
)
|
||||
|
||||
// Here are some basic datastore implementations.
|
||||
|
||||
// MapDatastore uses a standard Go map for internal storage.
|
||||
type MapDatastore struct {
|
||||
values map[Key][]byte
|
||||
}
|
||||
|
||||
var _ Datastore = (*MapDatastore)(nil)
|
||||
var _ Batching = (*MapDatastore)(nil)
|
||||
|
||||
// NewMapDatastore constructs a MapDatastore. It is _not_ thread-safe by
|
||||
// default, wrap using sync.MutexWrap if you need thread safety (the answer here
|
||||
// is usually yes).
|
||||
func NewMapDatastore() (d *MapDatastore) {
|
||||
return &MapDatastore{
|
||||
values: make(map[Key][]byte),
|
||||
}
|
||||
}
|
||||
|
||||
// Put implements Datastore.Put
|
||||
func (d *MapDatastore) Put(ctx context.Context, key Key, value []byte) (err error) {
|
||||
d.values[key] = value
|
||||
return nil
|
||||
}
|
||||
|
||||
// Sync implements Datastore.Sync
|
||||
func (d *MapDatastore) Sync(ctx context.Context, prefix Key) error {
|
||||
return nil
|
||||
}
|
||||
|
||||
// Get implements Datastore.Get
|
||||
func (d *MapDatastore) Get(ctx context.Context, key Key) (value []byte, err error) {
|
||||
val, found := d.values[key]
|
||||
if !found {
|
||||
return nil, ErrNotFound
|
||||
}
|
||||
return val, nil
|
||||
}
|
||||
|
||||
// Has implements Datastore.Has
|
||||
func (d *MapDatastore) Has(ctx context.Context, key Key) (exists bool, err error) {
|
||||
_, found := d.values[key]
|
||||
return found, nil
|
||||
}
|
||||
|
||||
// GetSize implements Datastore.GetSize
|
||||
func (d *MapDatastore) GetSize(ctx context.Context, key Key) (size int, err error) {
|
||||
if v, found := d.values[key]; found {
|
||||
return len(v), nil
|
||||
}
|
||||
return -1, ErrNotFound
|
||||
}
|
||||
|
||||
// Delete implements Datastore.Delete
|
||||
func (d *MapDatastore) Delete(ctx context.Context, key Key) (err error) {
|
||||
delete(d.values, key)
|
||||
return nil
|
||||
}
|
||||
|
||||
// Query implements Datastore.Query
|
||||
func (d *MapDatastore) Query(ctx context.Context, q dsq.Query) (dsq.Results, error) {
|
||||
re := make([]dsq.Entry, 0, len(d.values))
|
||||
for k, v := range d.values {
|
||||
e := dsq.Entry{Key: k.String(), Size: len(v)}
|
||||
if !q.KeysOnly {
|
||||
e.Value = v
|
||||
}
|
||||
re = append(re, e)
|
||||
}
|
||||
r := dsq.ResultsWithEntries(q, re)
|
||||
r = dsq.NaiveQueryApply(q, r)
|
||||
return r, nil
|
||||
}
|
||||
|
||||
func (d *MapDatastore) Batch(ctx context.Context) (Batch, error) {
|
||||
return NewBasicBatch(d), nil
|
||||
}
|
||||
|
||||
func (d *MapDatastore) Close() error {
|
||||
return nil
|
||||
}
|
||||
|
||||
// NullDatastore stores nothing, but conforms to the API.
|
||||
// Useful to test with.
|
||||
type NullDatastore struct {
|
||||
}
|
||||
|
||||
var _ Datastore = (*NullDatastore)(nil)
|
||||
var _ Batching = (*NullDatastore)(nil)
|
||||
|
||||
// NewNullDatastore constructs a null datastoe
|
||||
func NewNullDatastore() *NullDatastore {
|
||||
return &NullDatastore{}
|
||||
}
|
||||
|
||||
// Put implements Datastore.Put
|
||||
func (d *NullDatastore) Put(ctx context.Context, key Key, value []byte) (err error) {
|
||||
return nil
|
||||
}
|
||||
|
||||
// Sync implements Datastore.Sync
|
||||
func (d *NullDatastore) Sync(ctx context.Context, prefix Key) error {
|
||||
return nil
|
||||
}
|
||||
|
||||
// Get implements Datastore.Get
|
||||
func (d *NullDatastore) Get(ctx context.Context, key Key) (value []byte, err error) {
|
||||
return nil, ErrNotFound
|
||||
}
|
||||
|
||||
// Has implements Datastore.Has
|
||||
func (d *NullDatastore) Has(ctx context.Context, key Key) (exists bool, err error) {
|
||||
return false, nil
|
||||
}
|
||||
|
||||
// Has implements Datastore.GetSize
|
||||
func (d *NullDatastore) GetSize(ctx context.Context, key Key) (size int, err error) {
|
||||
return -1, ErrNotFound
|
||||
}
|
||||
|
||||
// Delete implements Datastore.Delete
|
||||
func (d *NullDatastore) Delete(ctx context.Context, key Key) (err error) {
|
||||
return nil
|
||||
}
|
||||
|
||||
// Query implements Datastore.Query
|
||||
func (d *NullDatastore) Query(ctx context.Context, q dsq.Query) (dsq.Results, error) {
|
||||
return dsq.ResultsWithEntries(q, nil), nil
|
||||
}
|
||||
|
||||
func (d *NullDatastore) Batch(ctx context.Context) (Batch, error) {
|
||||
return NewBasicBatch(d), nil
|
||||
}
|
||||
|
||||
func (d *NullDatastore) Close() error {
|
||||
return nil
|
||||
}
|
||||
|
||||
// LogDatastore logs all accesses through the datastore.
|
||||
type LogDatastore struct {
|
||||
Name string
|
||||
child Datastore
|
||||
}
|
||||
|
||||
var _ Datastore = (*LogDatastore)(nil)
|
||||
var _ Batching = (*LogDatastore)(nil)
|
||||
var _ GCDatastore = (*LogDatastore)(nil)
|
||||
var _ PersistentDatastore = (*LogDatastore)(nil)
|
||||
var _ ScrubbedDatastore = (*LogDatastore)(nil)
|
||||
var _ CheckedDatastore = (*LogDatastore)(nil)
|
||||
var _ Shim = (*LogDatastore)(nil)
|
||||
|
||||
// Shim is a datastore which has a child.
|
||||
type Shim interface {
|
||||
Datastore
|
||||
|
||||
Children() []Datastore
|
||||
}
|
||||
|
||||
// NewLogDatastore constructs a log datastore.
|
||||
func NewLogDatastore(ds Datastore, name string) *LogDatastore {
|
||||
if len(name) < 1 {
|
||||
name = "LogDatastore"
|
||||
}
|
||||
return &LogDatastore{Name: name, child: ds}
|
||||
}
|
||||
|
||||
// Children implements Shim
|
||||
func (d *LogDatastore) Children() []Datastore {
|
||||
return []Datastore{d.child}
|
||||
}
|
||||
|
||||
// Put implements Datastore.Put
|
||||
func (d *LogDatastore) Put(ctx context.Context, key Key, value []byte) (err error) {
|
||||
log.Printf("%s: Put %s\n", d.Name, key)
|
||||
// log.Printf("%s: Put %s ```%s```", d.Name, key, value)
|
||||
return d.child.Put(ctx, key, value)
|
||||
}
|
||||
|
||||
// Sync implements Datastore.Sync
|
||||
func (d *LogDatastore) Sync(ctx context.Context, prefix Key) error {
|
||||
log.Printf("%s: Sync %s\n", d.Name, prefix)
|
||||
return d.child.Sync(ctx, prefix)
|
||||
}
|
||||
|
||||
// Get implements Datastore.Get
|
||||
func (d *LogDatastore) Get(ctx context.Context, key Key) (value []byte, err error) {
|
||||
log.Printf("%s: Get %s\n", d.Name, key)
|
||||
return d.child.Get(ctx, key)
|
||||
}
|
||||
|
||||
// Has implements Datastore.Has
|
||||
func (d *LogDatastore) Has(ctx context.Context, key Key) (exists bool, err error) {
|
||||
log.Printf("%s: Has %s\n", d.Name, key)
|
||||
return d.child.Has(ctx, key)
|
||||
}
|
||||
|
||||
// GetSize implements Datastore.GetSize
|
||||
func (d *LogDatastore) GetSize(ctx context.Context, key Key) (size int, err error) {
|
||||
log.Printf("%s: GetSize %s\n", d.Name, key)
|
||||
return d.child.GetSize(ctx, key)
|
||||
}
|
||||
|
||||
// Delete implements Datastore.Delete
|
||||
func (d *LogDatastore) Delete(ctx context.Context, key Key) (err error) {
|
||||
log.Printf("%s: Delete %s\n", d.Name, key)
|
||||
return d.child.Delete(ctx, key)
|
||||
}
|
||||
|
||||
// DiskUsage implements the PersistentDatastore interface.
|
||||
func (d *LogDatastore) DiskUsage(ctx context.Context) (uint64, error) {
|
||||
log.Printf("%s: DiskUsage\n", d.Name)
|
||||
return DiskUsage(ctx, d.child)
|
||||
}
|
||||
|
||||
// Query implements Datastore.Query
|
||||
func (d *LogDatastore) Query(ctx context.Context, q dsq.Query) (dsq.Results, error) {
|
||||
log.Printf("%s: Query\n", d.Name)
|
||||
log.Printf("%s: q.Prefix: %s\n", d.Name, q.Prefix)
|
||||
log.Printf("%s: q.KeysOnly: %v\n", d.Name, q.KeysOnly)
|
||||
log.Printf("%s: q.Filters: %d\n", d.Name, len(q.Filters))
|
||||
log.Printf("%s: q.Orders: %d\n", d.Name, len(q.Orders))
|
||||
log.Printf("%s: q.Offset: %d\n", d.Name, q.Offset)
|
||||
|
||||
return d.child.Query(ctx, q)
|
||||
}
|
||||
|
||||
// LogBatch logs all accesses through the batch.
|
||||
type LogBatch struct {
|
||||
Name string
|
||||
child Batch
|
||||
}
|
||||
|
||||
var _ Batch = (*LogBatch)(nil)
|
||||
|
||||
func (d *LogDatastore) Batch(ctx context.Context) (Batch, error) {
|
||||
log.Printf("%s: Batch\n", d.Name)
|
||||
if bds, ok := d.child.(Batching); ok {
|
||||
b, err := bds.Batch(ctx)
|
||||
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return &LogBatch{
|
||||
Name: d.Name,
|
||||
child: b,
|
||||
}, nil
|
||||
}
|
||||
return nil, ErrBatchUnsupported
|
||||
}
|
||||
|
||||
// Put implements Batch.Put
|
||||
func (d *LogBatch) Put(ctx context.Context, key Key, value []byte) (err error) {
|
||||
log.Printf("%s: BatchPut %s\n", d.Name, key)
|
||||
// log.Printf("%s: Put %s ```%s```", d.Name, key, value)
|
||||
return d.child.Put(ctx, key, value)
|
||||
}
|
||||
|
||||
// Delete implements Batch.Delete
|
||||
func (d *LogBatch) Delete(ctx context.Context, key Key) (err error) {
|
||||
log.Printf("%s: BatchDelete %s\n", d.Name, key)
|
||||
return d.child.Delete(ctx, key)
|
||||
}
|
||||
|
||||
// Commit implements Batch.Commit
|
||||
func (d *LogBatch) Commit(ctx context.Context) (err error) {
|
||||
log.Printf("%s: BatchCommit\n", d.Name)
|
||||
return d.child.Commit(ctx)
|
||||
}
|
||||
|
||||
func (d *LogDatastore) Close() error {
|
||||
log.Printf("%s: Close\n", d.Name)
|
||||
return d.child.Close()
|
||||
}
|
||||
|
||||
func (d *LogDatastore) Check(ctx context.Context) error {
|
||||
if c, ok := d.child.(CheckedDatastore); ok {
|
||||
return c.Check(ctx)
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func (d *LogDatastore) Scrub(ctx context.Context) error {
|
||||
if c, ok := d.child.(ScrubbedDatastore); ok {
|
||||
return c.Scrub(ctx)
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func (d *LogDatastore) CollectGarbage(ctx context.Context) error {
|
||||
if c, ok := d.child.(GCDatastore); ok {
|
||||
return c.CollectGarbage(ctx)
|
||||
}
|
||||
return nil
|
||||
}
|
|
@ -1,53 +0,0 @@
|
|||
package datastore
|
||||
|
||||
import (
|
||||
"context"
|
||||
)
|
||||
|
||||
type op struct {
|
||||
delete bool
|
||||
value []byte
|
||||
}
|
||||
|
||||
// basicBatch implements the transaction interface for datastores who do
|
||||
// not have any sort of underlying transactional support
|
||||
type basicBatch struct {
|
||||
ops map[Key]op
|
||||
|
||||
target Datastore
|
||||
}
|
||||
|
||||
var _ Batch = (*basicBatch)(nil)
|
||||
|
||||
func NewBasicBatch(ds Datastore) Batch {
|
||||
return &basicBatch{
|
||||
ops: make(map[Key]op),
|
||||
target: ds,
|
||||
}
|
||||
}
|
||||
|
||||
func (bt *basicBatch) Put(ctx context.Context, key Key, val []byte) error {
|
||||
bt.ops[key] = op{value: val}
|
||||
return nil
|
||||
}
|
||||
|
||||
func (bt *basicBatch) Delete(ctx context.Context, key Key) error {
|
||||
bt.ops[key] = op{delete: true}
|
||||
return nil
|
||||
}
|
||||
|
||||
func (bt *basicBatch) Commit(ctx context.Context) error {
|
||||
var err error
|
||||
for k, op := range bt.ops {
|
||||
if op.delete {
|
||||
err = bt.target.Delete(ctx, k)
|
||||
} else {
|
||||
err = bt.target.Put(ctx, k, op.value)
|
||||
}
|
||||
if err != nil {
|
||||
break
|
||||
}
|
||||
}
|
||||
|
||||
return err
|
||||
}
|
|
@ -1,252 +0,0 @@
|
|||
package datastore
|
||||
|
||||
import (
|
||||
"context"
|
||||
"errors"
|
||||
"io"
|
||||
"time"
|
||||
|
||||
query "github.com/ipfs/go-datastore/query"
|
||||
)
|
||||
|
||||
/*
|
||||
Datastore represents storage for any key-value pair.
|
||||
|
||||
Datastores are general enough to be backed by all kinds of different storage:
|
||||
in-memory caches, databases, a remote datastore, flat files on disk, etc.
|
||||
|
||||
The general idea is to wrap a more complicated storage facility in a simple,
|
||||
uniform interface, keeping the freedom of using the right tools for the job.
|
||||
In particular, a Datastore can aggregate other datastores in interesting ways,
|
||||
like sharded (to distribute load) or tiered access (caches before databases).
|
||||
|
||||
While Datastores should be written general enough to accept all sorts of
|
||||
values, some implementations will undoubtedly have to be specific (e.g. SQL
|
||||
databases where fields should be decomposed into columns), particularly to
|
||||
support queries efficiently. Moreover, certain datastores may enforce certain
|
||||
types of values (e.g. requiring an io.Reader, a specific struct, etc) or
|
||||
serialization formats (JSON, Protobufs, etc).
|
||||
|
||||
IMPORTANT: No Datastore should ever Panic! This is a cross-module interface,
|
||||
and thus it should behave predictably and handle exceptional conditions with
|
||||
proper error reporting. Thus, all Datastore calls may return errors, which
|
||||
should be checked by callers.
|
||||
*/
|
||||
type Datastore interface {
|
||||
Read
|
||||
Write
|
||||
// Sync guarantees that any Put or Delete calls under prefix that returned
|
||||
// before Sync(prefix) was called will be observed after Sync(prefix)
|
||||
// returns, even if the program crashes. If Put/Delete operations already
|
||||
// satisfy these requirements then Sync may be a no-op.
|
||||
//
|
||||
// If the prefix fails to Sync this method returns an error.
|
||||
Sync(ctx context.Context, prefix Key) error
|
||||
io.Closer
|
||||
}
|
||||
|
||||
// Write is the write-side of the Datastore interface.
|
||||
type Write interface {
|
||||
// Put stores the object `value` named by `key`.
|
||||
//
|
||||
// The generalized Datastore interface does not impose a value type,
|
||||
// allowing various datastore middleware implementations (which do not
|
||||
// handle the values directly) to be composed together.
|
||||
//
|
||||
// Ultimately, the lowest-level datastore will need to do some value checking
|
||||
// or risk getting incorrect values. It may also be useful to expose a more
|
||||
// type-safe interface to your application, and do the checking up-front.
|
||||
Put(ctx context.Context, key Key, value []byte) error
|
||||
|
||||
// Delete removes the value for given `key`. If the key is not in the
|
||||
// datastore, this method returns no error.
|
||||
Delete(ctx context.Context, key Key) error
|
||||
}
|
||||
|
||||
// Read is the read-side of the Datastore interface.
|
||||
type Read interface {
|
||||
// Get retrieves the object `value` named by `key`.
|
||||
// Get will return ErrNotFound if the key is not mapped to a value.
|
||||
Get(ctx context.Context, key Key) (value []byte, err error)
|
||||
|
||||
// Has returns whether the `key` is mapped to a `value`.
|
||||
// In some contexts, it may be much cheaper only to check for existence of
|
||||
// a value, rather than retrieving the value itself. (e.g. HTTP HEAD).
|
||||
// The default implementation is found in `GetBackedHas`.
|
||||
Has(ctx context.Context, key Key) (exists bool, err error)
|
||||
|
||||
// GetSize returns the size of the `value` named by `key`.
|
||||
// In some contexts, it may be much cheaper to only get the size of the
|
||||
// value rather than retrieving the value itself.
|
||||
GetSize(ctx context.Context, key Key) (size int, err error)
|
||||
|
||||
// Query searches the datastore and returns a query result. This function
|
||||
// may return before the query actually runs. To wait for the query:
|
||||
//
|
||||
// result, _ := ds.Query(q)
|
||||
//
|
||||
// // use the channel interface; result may come in at different times
|
||||
// for entry := range result.Next() { ... }
|
||||
//
|
||||
// // or wait for the query to be completely done
|
||||
// entries, _ := result.Rest()
|
||||
// for entry := range entries { ... }
|
||||
//
|
||||
Query(ctx context.Context, q query.Query) (query.Results, error)
|
||||
}
|
||||
|
||||
// Batching datastores support deferred, grouped updates to the database.
|
||||
// `Batch`es do NOT have transactional semantics: updates to the underlying
|
||||
// datastore are not guaranteed to occur in the same iota of time. Similarly,
|
||||
// batched updates will not be flushed to the underlying datastore until
|
||||
// `Commit` has been called. `Txn`s from a `TxnDatastore` have all the
|
||||
// capabilities of a `Batch`, but the reverse is NOT true.
|
||||
type Batching interface {
|
||||
Datastore
|
||||
|
||||
Batch(ctx context.Context) (Batch, error)
|
||||
}
|
||||
|
||||
// ErrBatchUnsupported is returned if the by Batch if the Datastore doesn't
|
||||
// actually support batching.
|
||||
var ErrBatchUnsupported = errors.New("this datastore does not support batching")
|
||||
|
||||
// CheckedDatastore is an interface that should be implemented by datastores
|
||||
// which may need checking on-disk data integrity.
|
||||
type CheckedDatastore interface {
|
||||
Datastore
|
||||
|
||||
Check(ctx context.Context) error
|
||||
}
|
||||
|
||||
// ScrubbedDatastore is an interface that should be implemented by datastores
|
||||
// which want to provide a mechanism to check data integrity and/or
|
||||
// error correction.
|
||||
type ScrubbedDatastore interface {
|
||||
Datastore
|
||||
|
||||
Scrub(ctx context.Context) error
|
||||
}
|
||||
|
||||
// GCDatastore is an interface that should be implemented by datastores which
|
||||
// don't free disk space by just removing data from them.
|
||||
type GCDatastore interface {
|
||||
Datastore
|
||||
|
||||
CollectGarbage(ctx context.Context) error
|
||||
}
|
||||
|
||||
// PersistentDatastore is an interface that should be implemented by datastores
|
||||
// which can report disk usage.
|
||||
type PersistentDatastore interface {
|
||||
Datastore
|
||||
|
||||
// DiskUsage returns the space used by a datastore, in bytes.
|
||||
DiskUsage(ctx context.Context) (uint64, error)
|
||||
}
|
||||
|
||||
// DiskUsage checks if a Datastore is a
|
||||
// PersistentDatastore and returns its DiskUsage(),
|
||||
// otherwise returns 0.
|
||||
func DiskUsage(ctx context.Context, d Datastore) (uint64, error) {
|
||||
persDs, ok := d.(PersistentDatastore)
|
||||
if !ok {
|
||||
return 0, nil
|
||||
}
|
||||
return persDs.DiskUsage(ctx)
|
||||
}
|
||||
|
||||
// TTLDatastore is an interface that should be implemented by datastores that
|
||||
// support expiring entries.
|
||||
type TTLDatastore interface {
|
||||
Datastore
|
||||
TTL
|
||||
}
|
||||
|
||||
// TTL encapulates the methods that deal with entries with time-to-live.
|
||||
type TTL interface {
|
||||
PutWithTTL(ctx context.Context, key Key, value []byte, ttl time.Duration) error
|
||||
SetTTL(ctx context.Context, key Key, ttl time.Duration) error
|
||||
GetExpiration(ctx context.Context, key Key) (time.Time, error)
|
||||
}
|
||||
|
||||
// Txn extends the Datastore type. Txns allow users to batch queries and
|
||||
// mutations to the Datastore into atomic groups, or transactions. Actions
|
||||
// performed on a transaction will not take hold until a successful call to
|
||||
// Commit has been made. Likewise, transactions can be aborted by calling
|
||||
// Discard before a successful Commit has been made.
|
||||
type Txn interface {
|
||||
Read
|
||||
Write
|
||||
|
||||
// Commit finalizes a transaction, attempting to commit it to the Datastore.
|
||||
// May return an error if the transaction has gone stale. The presence of an
|
||||
// error is an indication that the data was not committed to the Datastore.
|
||||
Commit(ctx context.Context) error
|
||||
// Discard throws away changes recorded in a transaction without committing
|
||||
// them to the underlying Datastore. Any calls made to Discard after Commit
|
||||
// has been successfully called will have no effect on the transaction and
|
||||
// state of the Datastore, making it safe to defer.
|
||||
Discard(ctx context.Context)
|
||||
}
|
||||
|
||||
// TxnDatastore is an interface that should be implemented by datastores that
|
||||
// support transactions.
|
||||
type TxnDatastore interface {
|
||||
Datastore
|
||||
|
||||
NewTransaction(ctx context.Context, readOnly bool) (Txn, error)
|
||||
}
|
||||
|
||||
// Errors
|
||||
|
||||
type dsError struct {
|
||||
error
|
||||
isNotFound bool
|
||||
}
|
||||
|
||||
func (e *dsError) NotFound() bool {
|
||||
return e.isNotFound
|
||||
}
|
||||
|
||||
// ErrNotFound is returned by Get and GetSize when a datastore does not map the
|
||||
// given key to a value.
|
||||
var ErrNotFound error = &dsError{error: errors.New("datastore: key not found"), isNotFound: true}
|
||||
|
||||
// GetBackedHas provides a default Datastore.Has implementation.
|
||||
// It exists so Datastore.Has implementations can use it, like so:
|
||||
//
|
||||
// func (*d SomeDatastore) Has(key Key) (exists bool, err error) {
|
||||
// return GetBackedHas(d, key)
|
||||
// }
|
||||
func GetBackedHas(ctx context.Context, ds Read, key Key) (bool, error) {
|
||||
_, err := ds.Get(ctx, key)
|
||||
switch err {
|
||||
case nil:
|
||||
return true, nil
|
||||
case ErrNotFound:
|
||||
return false, nil
|
||||
default:
|
||||
return false, err
|
||||
}
|
||||
}
|
||||
|
||||
// GetBackedSize provides a default Datastore.GetSize implementation.
|
||||
// It exists so Datastore.GetSize implementations can use it, like so:
|
||||
//
|
||||
// func (*d SomeDatastore) GetSize(key Key) (size int, err error) {
|
||||
// return GetBackedSize(d, key)
|
||||
// }
|
||||
func GetBackedSize(ctx context.Context, ds Read, key Key) (int, error) {
|
||||
value, err := ds.Get(ctx, key)
|
||||
if err == nil {
|
||||
return len(value), nil
|
||||
}
|
||||
return -1, err
|
||||
}
|
||||
|
||||
type Batch interface {
|
||||
Write
|
||||
|
||||
Commit(ctx context.Context) error
|
||||
}
|
|
@ -1,309 +0,0 @@
|
|||
package datastore
|
||||
|
||||
import (
|
||||
"encoding/json"
|
||||
"path"
|
||||
"strings"
|
||||
|
||||
dsq "github.com/ipfs/go-datastore/query"
|
||||
|
||||
"github.com/google/uuid"
|
||||
)
|
||||
|
||||
/*
|
||||
A Key represents the unique identifier of an object.
|
||||
Our Key scheme is inspired by file systems and Google App Engine key model.
|
||||
|
||||
Keys are meant to be unique across a system. Keys are hierarchical,
|
||||
incorporating more and more specific namespaces. Thus keys can be deemed
|
||||
'children' or 'ancestors' of other keys::
|
||||
|
||||
Key("/Comedy")
|
||||
Key("/Comedy/MontyPython")
|
||||
|
||||
Also, every namespace can be parametrized to embed relevant object
|
||||
information. For example, the Key `name` (most specific namespace) could
|
||||
include the object type::
|
||||
|
||||
Key("/Comedy/MontyPython/Actor:JohnCleese")
|
||||
Key("/Comedy/MontyPython/Sketch:CheeseShop")
|
||||
Key("/Comedy/MontyPython/Sketch:CheeseShop/Character:Mousebender")
|
||||
|
||||
*/
|
||||
type Key struct {
|
||||
string
|
||||
}
|
||||
|
||||
// NewKey constructs a key from string. it will clean the value.
|
||||
func NewKey(s string) Key {
|
||||
k := Key{s}
|
||||
k.Clean()
|
||||
return k
|
||||
}
|
||||
|
||||
// RawKey creates a new Key without safety checking the input. Use with care.
|
||||
func RawKey(s string) Key {
|
||||
// accept an empty string and fix it to avoid special cases
|
||||
// elsewhere
|
||||
if len(s) == 0 {
|
||||
return Key{"/"}
|
||||
}
|
||||
|
||||
// perform a quick sanity check that the key is in the correct
|
||||
// format, if it is not then it is a programmer error and it is
|
||||
// okay to panic
|
||||
if len(s) == 0 || s[0] != '/' || (len(s) > 1 && s[len(s)-1] == '/') {
|
||||
panic("invalid datastore key: " + s)
|
||||
}
|
||||
|
||||
return Key{s}
|
||||
}
|
||||
|
||||
// KeyWithNamespaces constructs a key out of a namespace slice.
|
||||
func KeyWithNamespaces(ns []string) Key {
|
||||
return NewKey(strings.Join(ns, "/"))
|
||||
}
|
||||
|
||||
// Clean up a Key, using path.Clean.
|
||||
func (k *Key) Clean() {
|
||||
switch {
|
||||
case len(k.string) == 0:
|
||||
k.string = "/"
|
||||
case k.string[0] == '/':
|
||||
k.string = path.Clean(k.string)
|
||||
default:
|
||||
k.string = path.Clean("/" + k.string)
|
||||
}
|
||||
}
|
||||
|
||||
// Strings is the string value of Key
|
||||
func (k Key) String() string {
|
||||
return k.string
|
||||
}
|
||||
|
||||
// Bytes returns the string value of Key as a []byte
|
||||
func (k Key) Bytes() []byte {
|
||||
return []byte(k.string)
|
||||
}
|
||||
|
||||
// Equal checks equality of two keys
|
||||
func (k Key) Equal(k2 Key) bool {
|
||||
return k.string == k2.string
|
||||
}
|
||||
|
||||
// Less checks whether this key is sorted lower than another.
|
||||
func (k Key) Less(k2 Key) bool {
|
||||
list1 := k.List()
|
||||
list2 := k2.List()
|
||||
for i, c1 := range list1 {
|
||||
if len(list2) < (i + 1) {
|
||||
return false
|
||||
}
|
||||
|
||||
c2 := list2[i]
|
||||
if c1 < c2 {
|
||||
return true
|
||||
} else if c1 > c2 {
|
||||
return false
|
||||
}
|
||||
// c1 == c2, continue
|
||||
}
|
||||
|
||||
// list1 is shorter or exactly the same.
|
||||
return len(list1) < len(list2)
|
||||
}
|
||||
|
||||
// List returns the `list` representation of this Key.
|
||||
// NewKey("/Comedy/MontyPython/Actor:JohnCleese").List()
|
||||
// ["Comedy", "MontyPythong", "Actor:JohnCleese"]
|
||||
func (k Key) List() []string {
|
||||
return strings.Split(k.string, "/")[1:]
|
||||
}
|
||||
|
||||
// Reverse returns the reverse of this Key.
|
||||
// NewKey("/Comedy/MontyPython/Actor:JohnCleese").Reverse()
|
||||
// NewKey("/Actor:JohnCleese/MontyPython/Comedy")
|
||||
func (k Key) Reverse() Key {
|
||||
l := k.List()
|
||||
r := make([]string, len(l))
|
||||
for i, e := range l {
|
||||
r[len(l)-i-1] = e
|
||||
}
|
||||
return KeyWithNamespaces(r)
|
||||
}
|
||||
|
||||
// Namespaces returns the `namespaces` making up this Key.
|
||||
// NewKey("/Comedy/MontyPython/Actor:JohnCleese").Namespaces()
|
||||
// ["Comedy", "MontyPython", "Actor:JohnCleese"]
|
||||
func (k Key) Namespaces() []string {
|
||||
return k.List()
|
||||
}
|
||||
|
||||
// BaseNamespace returns the "base" namespace of this key (path.Base(filename))
|
||||
// NewKey("/Comedy/MontyPython/Actor:JohnCleese").BaseNamespace()
|
||||
// "Actor:JohnCleese"
|
||||
func (k Key) BaseNamespace() string {
|
||||
n := k.Namespaces()
|
||||
return n[len(n)-1]
|
||||
}
|
||||
|
||||
// Type returns the "type" of this key (value of last namespace).
|
||||
// NewKey("/Comedy/MontyPython/Actor:JohnCleese").Type()
|
||||
// "Actor"
|
||||
func (k Key) Type() string {
|
||||
return NamespaceType(k.BaseNamespace())
|
||||
}
|
||||
|
||||
// Name returns the "name" of this key (field of last namespace).
|
||||
// NewKey("/Comedy/MontyPython/Actor:JohnCleese").Name()
|
||||
// "JohnCleese"
|
||||
func (k Key) Name() string {
|
||||
return NamespaceValue(k.BaseNamespace())
|
||||
}
|
||||
|
||||
// Instance returns an "instance" of this type key (appends value to namespace).
|
||||
// NewKey("/Comedy/MontyPython/Actor").Instance("JohnClesse")
|
||||
// NewKey("/Comedy/MontyPython/Actor:JohnCleese")
|
||||
func (k Key) Instance(s string) Key {
|
||||
return NewKey(k.string + ":" + s)
|
||||
}
|
||||
|
||||
// Path returns the "path" of this key (parent + type).
|
||||
// NewKey("/Comedy/MontyPython/Actor:JohnCleese").Path()
|
||||
// NewKey("/Comedy/MontyPython/Actor")
|
||||
func (k Key) Path() Key {
|
||||
s := k.Parent().string + "/" + NamespaceType(k.BaseNamespace())
|
||||
return NewKey(s)
|
||||
}
|
||||
|
||||
// Parent returns the `parent` Key of this Key.
|
||||
// NewKey("/Comedy/MontyPython/Actor:JohnCleese").Parent()
|
||||
// NewKey("/Comedy/MontyPython")
|
||||
func (k Key) Parent() Key {
|
||||
n := k.List()
|
||||
if len(n) == 1 {
|
||||
return RawKey("/")
|
||||
}
|
||||
return NewKey(strings.Join(n[:len(n)-1], "/"))
|
||||
}
|
||||
|
||||
// Child returns the `child` Key of this Key.
|
||||
// NewKey("/Comedy/MontyPython").Child(NewKey("Actor:JohnCleese"))
|
||||
// NewKey("/Comedy/MontyPython/Actor:JohnCleese")
|
||||
func (k Key) Child(k2 Key) Key {
|
||||
switch {
|
||||
case k.string == "/":
|
||||
return k2
|
||||
case k2.string == "/":
|
||||
return k
|
||||
default:
|
||||
return RawKey(k.string + k2.string)
|
||||
}
|
||||
}
|
||||
|
||||
// ChildString returns the `child` Key of this Key -- string helper.
|
||||
// NewKey("/Comedy/MontyPython").ChildString("Actor:JohnCleese")
|
||||
// NewKey("/Comedy/MontyPython/Actor:JohnCleese")
|
||||
func (k Key) ChildString(s string) Key {
|
||||
return NewKey(k.string + "/" + s)
|
||||
}
|
||||
|
||||
// IsAncestorOf returns whether this key is a prefix of `other`
|
||||
// NewKey("/Comedy").IsAncestorOf("/Comedy/MontyPython")
|
||||
// true
|
||||
func (k Key) IsAncestorOf(other Key) bool {
|
||||
// equivalent to HasPrefix(other, k.string + "/")
|
||||
|
||||
if len(other.string) <= len(k.string) {
|
||||
// We're not long enough to be a child.
|
||||
return false
|
||||
}
|
||||
|
||||
if k.string == "/" {
|
||||
// We're the root and the other key is longer.
|
||||
return true
|
||||
}
|
||||
|
||||
// "other" starts with /k.string/
|
||||
return other.string[len(k.string)] == '/' && other.string[:len(k.string)] == k.string
|
||||
}
|
||||
|
||||
// IsDescendantOf returns whether this key contains another as a prefix.
|
||||
// NewKey("/Comedy/MontyPython").IsDescendantOf("/Comedy")
|
||||
// true
|
||||
func (k Key) IsDescendantOf(other Key) bool {
|
||||
return other.IsAncestorOf(k)
|
||||
}
|
||||
|
||||
// IsTopLevel returns whether this key has only one namespace.
|
||||
func (k Key) IsTopLevel() bool {
|
||||
return len(k.List()) == 1
|
||||
}
|
||||
|
||||
// MarshalJSON implements the json.Marshaler interface,
|
||||
// keys are represented as JSON strings
|
||||
func (k Key) MarshalJSON() ([]byte, error) {
|
||||
return json.Marshal(k.String())
|
||||
}
|
||||
|
||||
// UnmarshalJSON implements the json.Unmarshaler interface,
|
||||
// keys will parse any value specified as a key to a string
|
||||
func (k *Key) UnmarshalJSON(data []byte) error {
|
||||
var key string
|
||||
if err := json.Unmarshal(data, &key); err != nil {
|
||||
return err
|
||||
}
|
||||
*k = NewKey(key)
|
||||
return nil
|
||||
}
|
||||
|
||||
// RandomKey returns a randomly (uuid) generated key.
|
||||
// RandomKey()
|
||||
// NewKey("/f98719ea086343f7b71f32ea9d9d521d")
|
||||
func RandomKey() Key {
|
||||
return NewKey(strings.Replace(uuid.New().String(), "-", "", -1))
|
||||
}
|
||||
|
||||
/*
|
||||
A Key Namespace is like a path element.
|
||||
A namespace can optionally include a type (delimited by ':')
|
||||
|
||||
> NamespaceValue("Song:PhilosopherSong")
|
||||
PhilosopherSong
|
||||
> NamespaceType("Song:PhilosopherSong")
|
||||
Song
|
||||
> NamespaceType("Music:Song:PhilosopherSong")
|
||||
Music:Song
|
||||
*/
|
||||
|
||||
// NamespaceType is the first component of a namespace. `foo` in `foo:bar`
|
||||
func NamespaceType(namespace string) string {
|
||||
parts := strings.Split(namespace, ":")
|
||||
if len(parts) < 2 {
|
||||
return ""
|
||||
}
|
||||
return strings.Join(parts[0:len(parts)-1], ":")
|
||||
}
|
||||
|
||||
// NamespaceValue returns the last component of a namespace. `baz` in `f:b:baz`
|
||||
func NamespaceValue(namespace string) string {
|
||||
parts := strings.Split(namespace, ":")
|
||||
return parts[len(parts)-1]
|
||||
}
|
||||
|
||||
// KeySlice attaches the methods of sort.Interface to []Key,
|
||||
// sorting in increasing order.
|
||||
type KeySlice []Key
|
||||
|
||||
func (p KeySlice) Len() int { return len(p) }
|
||||
func (p KeySlice) Less(i, j int) bool { return p[i].Less(p[j]) }
|
||||
func (p KeySlice) Swap(i, j int) { p[i], p[j] = p[j], p[i] }
|
||||
|
||||
// EntryKeys
|
||||
func EntryKeys(e []dsq.Entry) []Key {
|
||||
ks := make([]Key, len(e))
|
||||
for i, e := range e {
|
||||
ks[i] = NewKey(e.Key)
|
||||
}
|
||||
return ks
|
||||
}
|
|
@ -1,102 +0,0 @@
|
|||
package query
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"fmt"
|
||||
"strings"
|
||||
)
|
||||
|
||||
// Filter is an object that tests ResultEntries
|
||||
type Filter interface {
|
||||
// Filter returns whether an entry passes the filter
|
||||
Filter(e Entry) bool
|
||||
}
|
||||
|
||||
// Op is a comparison operator
|
||||
type Op string
|
||||
|
||||
var (
|
||||
Equal = Op("==")
|
||||
NotEqual = Op("!=")
|
||||
GreaterThan = Op(">")
|
||||
GreaterThanOrEqual = Op(">=")
|
||||
LessThan = Op("<")
|
||||
LessThanOrEqual = Op("<=")
|
||||
)
|
||||
|
||||
// FilterValueCompare is used to signal to datastores they
|
||||
// should apply internal comparisons. unfortunately, there
|
||||
// is no way to apply comparisons* to interface{} types in
|
||||
// Go, so if the datastore doesnt have a special way to
|
||||
// handle these comparisons, you must provided the
|
||||
// TypedFilter to actually do filtering.
|
||||
//
|
||||
// [*] other than == and !=, which use reflect.DeepEqual.
|
||||
type FilterValueCompare struct {
|
||||
Op Op
|
||||
Value []byte
|
||||
}
|
||||
|
||||
func (f FilterValueCompare) Filter(e Entry) bool {
|
||||
cmp := bytes.Compare(e.Value, f.Value)
|
||||
switch f.Op {
|
||||
case Equal:
|
||||
return cmp == 0
|
||||
case NotEqual:
|
||||
return cmp != 0
|
||||
case LessThan:
|
||||
return cmp < 0
|
||||
case LessThanOrEqual:
|
||||
return cmp <= 0
|
||||
case GreaterThan:
|
||||
return cmp > 0
|
||||
case GreaterThanOrEqual:
|
||||
return cmp >= 0
|
||||
default:
|
||||
panic(fmt.Errorf("unknown operation: %s", f.Op))
|
||||
}
|
||||
}
|
||||
|
||||
func (f FilterValueCompare) String() string {
|
||||
return fmt.Sprintf("VALUE %s %q", f.Op, string(f.Value))
|
||||
}
|
||||
|
||||
type FilterKeyCompare struct {
|
||||
Op Op
|
||||
Key string
|
||||
}
|
||||
|
||||
func (f FilterKeyCompare) Filter(e Entry) bool {
|
||||
switch f.Op {
|
||||
case Equal:
|
||||
return e.Key == f.Key
|
||||
case NotEqual:
|
||||
return e.Key != f.Key
|
||||
case GreaterThan:
|
||||
return e.Key > f.Key
|
||||
case GreaterThanOrEqual:
|
||||
return e.Key >= f.Key
|
||||
case LessThan:
|
||||
return e.Key < f.Key
|
||||
case LessThanOrEqual:
|
||||
return e.Key <= f.Key
|
||||
default:
|
||||
panic(fmt.Errorf("unknown op '%s'", f.Op))
|
||||
}
|
||||
}
|
||||
|
||||
func (f FilterKeyCompare) String() string {
|
||||
return fmt.Sprintf("KEY %s %q", f.Op, f.Key)
|
||||
}
|
||||
|
||||
type FilterKeyPrefix struct {
|
||||
Prefix string
|
||||
}
|
||||
|
||||
func (f FilterKeyPrefix) Filter(e Entry) bool {
|
||||
return strings.HasPrefix(e.Key, f.Prefix)
|
||||
}
|
||||
|
||||
func (f FilterKeyPrefix) String() string {
|
||||
return fmt.Sprintf("PREFIX(%q)", f.Prefix)
|
||||
}
|
|
@ -1,94 +0,0 @@
|
|||
package query
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"sort"
|
||||
"strings"
|
||||
)
|
||||
|
||||
// Order is an object used to order objects
|
||||
type Order interface {
|
||||
Compare(a, b Entry) int
|
||||
}
|
||||
|
||||
// OrderByFunction orders the results based on the result of the given function.
|
||||
type OrderByFunction func(a, b Entry) int
|
||||
|
||||
func (o OrderByFunction) Compare(a, b Entry) int {
|
||||
return o(a, b)
|
||||
}
|
||||
|
||||
func (OrderByFunction) String() string {
|
||||
return "FN"
|
||||
}
|
||||
|
||||
// OrderByValue is used to signal to datastores they should apply internal
|
||||
// orderings.
|
||||
type OrderByValue struct{}
|
||||
|
||||
func (o OrderByValue) Compare(a, b Entry) int {
|
||||
return bytes.Compare(a.Value, b.Value)
|
||||
}
|
||||
|
||||
func (OrderByValue) String() string {
|
||||
return "VALUE"
|
||||
}
|
||||
|
||||
// OrderByValueDescending is used to signal to datastores they
|
||||
// should apply internal orderings.
|
||||
type OrderByValueDescending struct{}
|
||||
|
||||
func (o OrderByValueDescending) Compare(a, b Entry) int {
|
||||
return -bytes.Compare(a.Value, b.Value)
|
||||
}
|
||||
|
||||
func (OrderByValueDescending) String() string {
|
||||
return "desc(VALUE)"
|
||||
}
|
||||
|
||||
// OrderByKey
|
||||
type OrderByKey struct{}
|
||||
|
||||
func (o OrderByKey) Compare(a, b Entry) int {
|
||||
return strings.Compare(a.Key, b.Key)
|
||||
}
|
||||
|
||||
func (OrderByKey) String() string {
|
||||
return "KEY"
|
||||
}
|
||||
|
||||
// OrderByKeyDescending
|
||||
type OrderByKeyDescending struct{}
|
||||
|
||||
func (o OrderByKeyDescending) Compare(a, b Entry) int {
|
||||
return -strings.Compare(a.Key, b.Key)
|
||||
}
|
||||
|
||||
func (OrderByKeyDescending) String() string {
|
||||
return "desc(KEY)"
|
||||
}
|
||||
|
||||
// Less returns true if a comes before b with the requested orderings.
|
||||
func Less(orders []Order, a, b Entry) bool {
|
||||
for _, cmp := range orders {
|
||||
switch cmp.Compare(a, b) {
|
||||
case 0:
|
||||
case -1:
|
||||
return true
|
||||
case 1:
|
||||
return false
|
||||
}
|
||||
}
|
||||
|
||||
// This gives us a *stable* sort for free. We don't care
|
||||
// preserving the order from the underlying datastore
|
||||
// because it's undefined.
|
||||
return a.Key < b.Key
|
||||
}
|
||||
|
||||
// Sort sorts the given entries using the given orders.
|
||||
func Sort(orders []Order, entries []Entry) {
|
||||
sort.Slice(entries, func(i int, j int) bool {
|
||||
return Less(orders, entries[i], entries[j])
|
||||
})
|
||||
}
|
|
@ -1,426 +0,0 @@
|
|||
package query
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"time"
|
||||
|
||||
goprocess "github.com/jbenet/goprocess"
|
||||
)
|
||||
|
||||
/*
|
||||
Query represents storage for any key-value pair.
|
||||
|
||||
tl;dr:
|
||||
|
||||
queries are supported across datastores.
|
||||
Cheap on top of relational dbs, and expensive otherwise.
|
||||
Pick the right tool for the job!
|
||||
|
||||
In addition to the key-value store get and set semantics, datastore
|
||||
provides an interface to retrieve multiple records at a time through
|
||||
the use of queries. The datastore Query model gleans a common set of
|
||||
operations performed when querying. To avoid pasting here years of
|
||||
database research, let’s summarize the operations datastore supports.
|
||||
|
||||
Query Operations, applied in-order:
|
||||
|
||||
* prefix - scope the query to a given path prefix
|
||||
* filters - select a subset of values by applying constraints
|
||||
* orders - sort the results by applying sort conditions, hierarchically.
|
||||
* offset - skip a number of results (for efficient pagination)
|
||||
* limit - impose a numeric limit on the number of results
|
||||
|
||||
Datastore combines these operations into a simple Query class that allows
|
||||
applications to define their constraints in a simple, generic, way without
|
||||
introducing datastore specific calls, languages, etc.
|
||||
|
||||
However, take heed: not all datastores support efficiently performing these
|
||||
operations. Pick a datastore based on your needs. If you need efficient look-ups,
|
||||
go for a simple key/value store. If you need efficient queries, consider an SQL
|
||||
backed datastore.
|
||||
|
||||
Notes:
|
||||
|
||||
* Prefix: When a query filters by prefix, it selects keys that are strict
|
||||
children of the prefix. For example, a prefix "/foo" would select "/foo/bar"
|
||||
but not "/foobar" or "/foo",
|
||||
* Orders: Orders are applied hierarchically. Results are sorted by the first
|
||||
ordering, then entries equal under the first ordering are sorted with the
|
||||
second ordering, etc.
|
||||
* Limits & Offset: Limits and offsets are applied after everything else.
|
||||
*/
|
||||
type Query struct {
|
||||
Prefix string // namespaces the query to results whose keys have Prefix
|
||||
Filters []Filter // filter results. apply sequentially
|
||||
Orders []Order // order results. apply hierarchically
|
||||
Limit int // maximum number of results
|
||||
Offset int // skip given number of results
|
||||
KeysOnly bool // return only keys.
|
||||
ReturnExpirations bool // return expirations (see TTLDatastore)
|
||||
ReturnsSizes bool // always return sizes. If not set, datastore impl can return
|
||||
// // it anyway if it doesn't involve a performance cost. If KeysOnly
|
||||
// // is not set, Size should always be set.
|
||||
}
|
||||
|
||||
// String returns a string representation of the Query for debugging/validation
|
||||
// purposes. Do not use it for SQL queries.
|
||||
func (q Query) String() string {
|
||||
s := "SELECT keys"
|
||||
if !q.KeysOnly {
|
||||
s += ",vals"
|
||||
}
|
||||
if q.ReturnExpirations {
|
||||
s += ",exps"
|
||||
}
|
||||
|
||||
s += " "
|
||||
|
||||
if q.Prefix != "" {
|
||||
s += fmt.Sprintf("FROM %q ", q.Prefix)
|
||||
}
|
||||
|
||||
if len(q.Filters) > 0 {
|
||||
s += fmt.Sprintf("FILTER [%s", q.Filters[0])
|
||||
for _, f := range q.Filters[1:] {
|
||||
s += fmt.Sprintf(", %s", f)
|
||||
}
|
||||
s += "] "
|
||||
}
|
||||
|
||||
if len(q.Orders) > 0 {
|
||||
s += fmt.Sprintf("ORDER [%s", q.Orders[0])
|
||||
for _, f := range q.Orders[1:] {
|
||||
s += fmt.Sprintf(", %s", f)
|
||||
}
|
||||
s += "] "
|
||||
}
|
||||
|
||||
if q.Offset > 0 {
|
||||
s += fmt.Sprintf("OFFSET %d ", q.Offset)
|
||||
}
|
||||
|
||||
if q.Limit > 0 {
|
||||
s += fmt.Sprintf("LIMIT %d ", q.Limit)
|
||||
}
|
||||
// Will always end with a space, strip it.
|
||||
return s[:len(s)-1]
|
||||
}
|
||||
|
||||
// Entry is a query result entry.
|
||||
type Entry struct {
|
||||
Key string // cant be ds.Key because circular imports ...!!!
|
||||
Value []byte // Will be nil if KeysOnly has been passed.
|
||||
Expiration time.Time // Entry expiration timestamp if requested and supported (see TTLDatastore).
|
||||
Size int // Might be -1 if the datastore doesn't support listing the size with KeysOnly
|
||||
// // or if ReturnsSizes is not set
|
||||
}
|
||||
|
||||
// Result is a special entry that includes an error, so that the client
|
||||
// may be warned about internal errors. If Error is non-nil, Entry must be
|
||||
// empty.
|
||||
type Result struct {
|
||||
Entry
|
||||
|
||||
Error error
|
||||
}
|
||||
|
||||
// Results is a set of Query results. This is the interface for clients.
|
||||
// Example:
|
||||
//
|
||||
// qr, _ := myds.Query(q)
|
||||
// for r := range qr.Next() {
|
||||
// if r.Error != nil {
|
||||
// // handle.
|
||||
// break
|
||||
// }
|
||||
//
|
||||
// fmt.Println(r.Entry.Key, r.Entry.Value)
|
||||
// }
|
||||
//
|
||||
// or, wait on all results at once:
|
||||
//
|
||||
// qr, _ := myds.Query(q)
|
||||
// es, _ := qr.Rest()
|
||||
// for _, e := range es {
|
||||
// fmt.Println(e.Key, e.Value)
|
||||
// }
|
||||
//
|
||||
type Results interface {
|
||||
Query() Query // the query these Results correspond to
|
||||
Next() <-chan Result // returns a channel to wait for the next result
|
||||
NextSync() (Result, bool) // blocks and waits to return the next result, second parameter returns false when results are exhausted
|
||||
Rest() ([]Entry, error) // waits till processing finishes, returns all entries at once.
|
||||
Close() error // client may call Close to signal early exit
|
||||
|
||||
// Process returns a goprocess.Process associated with these results.
|
||||
// most users will not need this function (Close is all they want),
|
||||
// but it's here in case you want to connect the results to other
|
||||
// goprocess-friendly things.
|
||||
Process() goprocess.Process
|
||||
}
|
||||
|
||||
// results implements Results
|
||||
type results struct {
|
||||
query Query
|
||||
proc goprocess.Process
|
||||
res <-chan Result
|
||||
}
|
||||
|
||||
func (r *results) Next() <-chan Result {
|
||||
return r.res
|
||||
}
|
||||
|
||||
func (r *results) NextSync() (Result, bool) {
|
||||
val, ok := <-r.res
|
||||
return val, ok
|
||||
}
|
||||
|
||||
func (r *results) Rest() ([]Entry, error) {
|
||||
var es []Entry
|
||||
for e := range r.res {
|
||||
if e.Error != nil {
|
||||
return es, e.Error
|
||||
}
|
||||
es = append(es, e.Entry)
|
||||
}
|
||||
<-r.proc.Closed() // wait till the processing finishes.
|
||||
return es, nil
|
||||
}
|
||||
|
||||
func (r *results) Process() goprocess.Process {
|
||||
return r.proc
|
||||
}
|
||||
|
||||
func (r *results) Close() error {
|
||||
return r.proc.Close()
|
||||
}
|
||||
|
||||
func (r *results) Query() Query {
|
||||
return r.query
|
||||
}
|
||||
|
||||
// ResultBuilder is what implementors use to construct results
|
||||
// Implementors of datastores and their clients must respect the
|
||||
// Process of the Request:
|
||||
//
|
||||
// * clients must call r.Process().Close() on an early exit, so
|
||||
// implementations can reclaim resources.
|
||||
// * if the Entries are read to completion (channel closed), Process
|
||||
// should be closed automatically.
|
||||
// * datastores must respect <-Process.Closing(), which intermediates
|
||||
// an early close signal from the client.
|
||||
//
|
||||
type ResultBuilder struct {
|
||||
Query Query
|
||||
Process goprocess.Process
|
||||
Output chan Result
|
||||
}
|
||||
|
||||
// Results returns a Results to to this builder.
|
||||
func (rb *ResultBuilder) Results() Results {
|
||||
return &results{
|
||||
query: rb.Query,
|
||||
proc: rb.Process,
|
||||
res: rb.Output,
|
||||
}
|
||||
}
|
||||
|
||||
const NormalBufSize = 1
|
||||
const KeysOnlyBufSize = 128
|
||||
|
||||
func NewResultBuilder(q Query) *ResultBuilder {
|
||||
bufSize := NormalBufSize
|
||||
if q.KeysOnly {
|
||||
bufSize = KeysOnlyBufSize
|
||||
}
|
||||
b := &ResultBuilder{
|
||||
Query: q,
|
||||
Output: make(chan Result, bufSize),
|
||||
}
|
||||
b.Process = goprocess.WithTeardown(func() error {
|
||||
close(b.Output)
|
||||
return nil
|
||||
})
|
||||
return b
|
||||
}
|
||||
|
||||
// ResultsWithChan returns a Results object from a channel
|
||||
// of Result entries.
|
||||
//
|
||||
// DEPRECATED: This iterator is impossible to cancel correctly. Canceling it
|
||||
// will leave anything trying to write to the result channel hanging.
|
||||
func ResultsWithChan(q Query, res <-chan Result) Results {
|
||||
return ResultsWithProcess(q, func(worker goprocess.Process, out chan<- Result) {
|
||||
for {
|
||||
select {
|
||||
case <-worker.Closing(): // client told us to close early
|
||||
return
|
||||
case e, more := <-res:
|
||||
if !more {
|
||||
return
|
||||
}
|
||||
|
||||
select {
|
||||
case out <- e:
|
||||
case <-worker.Closing(): // client told us to close early
|
||||
return
|
||||
}
|
||||
}
|
||||
}
|
||||
})
|
||||
}
|
||||
|
||||
// ResultsWithProcess returns a Results object with the results generated by the
|
||||
// passed subprocess.
|
||||
func ResultsWithProcess(q Query, proc func(goprocess.Process, chan<- Result)) Results {
|
||||
b := NewResultBuilder(q)
|
||||
|
||||
// go consume all the entries and add them to the results.
|
||||
b.Process.Go(func(worker goprocess.Process) {
|
||||
proc(worker, b.Output)
|
||||
})
|
||||
|
||||
go b.Process.CloseAfterChildren() //nolint
|
||||
return b.Results()
|
||||
}
|
||||
|
||||
// ResultsWithEntries returns a Results object from a list of entries
|
||||
func ResultsWithEntries(q Query, res []Entry) Results {
|
||||
i := 0
|
||||
return ResultsFromIterator(q, Iterator{
|
||||
Next: func() (Result, bool) {
|
||||
if i >= len(res) {
|
||||
return Result{}, false
|
||||
}
|
||||
next := res[i]
|
||||
i++
|
||||
return Result{Entry: next}, true
|
||||
},
|
||||
})
|
||||
}
|
||||
|
||||
func ResultsReplaceQuery(r Results, q Query) Results {
|
||||
switch r := r.(type) {
|
||||
case *results:
|
||||
// note: not using field names to make sure all fields are copied
|
||||
return &results{q, r.proc, r.res}
|
||||
case *resultsIter:
|
||||
// note: not using field names to make sure all fields are copied
|
||||
lr := r.legacyResults
|
||||
if lr != nil {
|
||||
lr = &results{q, lr.proc, lr.res}
|
||||
}
|
||||
return &resultsIter{q, r.next, r.close, lr}
|
||||
default:
|
||||
panic("unknown results type")
|
||||
}
|
||||
}
|
||||
|
||||
//
|
||||
// ResultFromIterator provides an alternative way to to construct
|
||||
// results without the use of channels.
|
||||
//
|
||||
|
||||
func ResultsFromIterator(q Query, iter Iterator) Results {
|
||||
if iter.Close == nil {
|
||||
iter.Close = noopClose
|
||||
}
|
||||
return &resultsIter{
|
||||
query: q,
|
||||
next: iter.Next,
|
||||
close: iter.Close,
|
||||
}
|
||||
}
|
||||
|
||||
func noopClose() error {
|
||||
return nil
|
||||
}
|
||||
|
||||
type Iterator struct {
|
||||
Next func() (Result, bool)
|
||||
Close func() error // note: might be called more than once
|
||||
}
|
||||
|
||||
type resultsIter struct {
|
||||
query Query
|
||||
next func() (Result, bool)
|
||||
close func() error
|
||||
legacyResults *results
|
||||
}
|
||||
|
||||
func (r *resultsIter) Next() <-chan Result {
|
||||
r.useLegacyResults()
|
||||
return r.legacyResults.Next()
|
||||
}
|
||||
|
||||
func (r *resultsIter) NextSync() (Result, bool) {
|
||||
if r.legacyResults != nil {
|
||||
return r.legacyResults.NextSync()
|
||||
} else {
|
||||
res, ok := r.next()
|
||||
if !ok {
|
||||
r.close()
|
||||
}
|
||||
return res, ok
|
||||
}
|
||||
}
|
||||
|
||||
func (r *resultsIter) Rest() ([]Entry, error) {
|
||||
var es []Entry
|
||||
for {
|
||||
e, ok := r.NextSync()
|
||||
if !ok {
|
||||
break
|
||||
}
|
||||
if e.Error != nil {
|
||||
return es, e.Error
|
||||
}
|
||||
es = append(es, e.Entry)
|
||||
}
|
||||
return es, nil
|
||||
}
|
||||
|
||||
func (r *resultsIter) Process() goprocess.Process {
|
||||
r.useLegacyResults()
|
||||
return r.legacyResults.Process()
|
||||
}
|
||||
|
||||
func (r *resultsIter) Close() error {
|
||||
if r.legacyResults != nil {
|
||||
return r.legacyResults.Close()
|
||||
} else {
|
||||
return r.close()
|
||||
}
|
||||
}
|
||||
|
||||
func (r *resultsIter) Query() Query {
|
||||
return r.query
|
||||
}
|
||||
|
||||
func (r *resultsIter) useLegacyResults() {
|
||||
if r.legacyResults != nil {
|
||||
return
|
||||
}
|
||||
|
||||
b := NewResultBuilder(r.query)
|
||||
|
||||
// go consume all the entries and add them to the results.
|
||||
b.Process.Go(func(worker goprocess.Process) {
|
||||
defer r.close()
|
||||
for {
|
||||
e, ok := r.next()
|
||||
if !ok {
|
||||
break
|
||||
}
|
||||
select {
|
||||
case b.Output <- e:
|
||||
case <-worker.Closing(): // client told us to close early
|
||||
return
|
||||
}
|
||||
}
|
||||
})
|
||||
|
||||
go b.Process.CloseAfterChildren() //nolint
|
||||
|
||||
r.legacyResults = b.Results().(*results)
|
||||
}
|
|
@ -1,158 +0,0 @@
|
|||
package query
|
||||
|
||||
import (
|
||||
"path"
|
||||
|
||||
goprocess "github.com/jbenet/goprocess"
|
||||
)
|
||||
|
||||
// NaiveFilter applies a filter to the results.
|
||||
func NaiveFilter(qr Results, filter Filter) Results {
|
||||
return ResultsFromIterator(qr.Query(), Iterator{
|
||||
Next: func() (Result, bool) {
|
||||
for {
|
||||
e, ok := qr.NextSync()
|
||||
if !ok {
|
||||
return Result{}, false
|
||||
}
|
||||
if e.Error != nil || filter.Filter(e.Entry) {
|
||||
return e, true
|
||||
}
|
||||
}
|
||||
},
|
||||
Close: func() error {
|
||||
return qr.Close()
|
||||
},
|
||||
})
|
||||
}
|
||||
|
||||
// NaiveLimit truncates the results to a given int limit
|
||||
func NaiveLimit(qr Results, limit int) Results {
|
||||
if limit == 0 {
|
||||
// 0 means no limit
|
||||
return qr
|
||||
}
|
||||
closed := false
|
||||
return ResultsFromIterator(qr.Query(), Iterator{
|
||||
Next: func() (Result, bool) {
|
||||
if limit == 0 {
|
||||
if !closed {
|
||||
closed = true
|
||||
err := qr.Close()
|
||||
if err != nil {
|
||||
return Result{Error: err}, true
|
||||
}
|
||||
}
|
||||
return Result{}, false
|
||||
}
|
||||
limit--
|
||||
return qr.NextSync()
|
||||
},
|
||||
Close: func() error {
|
||||
if closed {
|
||||
return nil
|
||||
}
|
||||
closed = true
|
||||
return qr.Close()
|
||||
},
|
||||
})
|
||||
}
|
||||
|
||||
// NaiveOffset skips a given number of results
|
||||
func NaiveOffset(qr Results, offset int) Results {
|
||||
return ResultsFromIterator(qr.Query(), Iterator{
|
||||
Next: func() (Result, bool) {
|
||||
for ; offset > 0; offset-- {
|
||||
res, ok := qr.NextSync()
|
||||
if !ok || res.Error != nil {
|
||||
return res, ok
|
||||
}
|
||||
}
|
||||
return qr.NextSync()
|
||||
},
|
||||
Close: func() error {
|
||||
return qr.Close()
|
||||
},
|
||||
})
|
||||
}
|
||||
|
||||
// NaiveOrder reorders results according to given orders.
|
||||
// WARNING: this is the only non-stream friendly operation!
|
||||
func NaiveOrder(qr Results, orders ...Order) Results {
|
||||
// Short circuit.
|
||||
if len(orders) == 0 {
|
||||
return qr
|
||||
}
|
||||
|
||||
return ResultsWithProcess(qr.Query(), func(worker goprocess.Process, out chan<- Result) {
|
||||
defer qr.Close()
|
||||
var entries []Entry
|
||||
collect:
|
||||
for {
|
||||
select {
|
||||
case <-worker.Closing():
|
||||
return
|
||||
case e, ok := <-qr.Next():
|
||||
if !ok {
|
||||
break collect
|
||||
}
|
||||
if e.Error != nil {
|
||||
out <- e
|
||||
continue
|
||||
}
|
||||
entries = append(entries, e.Entry)
|
||||
}
|
||||
}
|
||||
|
||||
Sort(orders, entries)
|
||||
|
||||
for _, e := range entries {
|
||||
select {
|
||||
case <-worker.Closing():
|
||||
return
|
||||
case out <- Result{Entry: e}:
|
||||
}
|
||||
}
|
||||
})
|
||||
}
|
||||
|
||||
func NaiveQueryApply(q Query, qr Results) Results {
|
||||
if q.Prefix != "" {
|
||||
// Clean the prefix as a key and append / so a prefix of /bar
|
||||
// only finds /bar/baz, not /barbaz.
|
||||
prefix := q.Prefix
|
||||
if len(prefix) == 0 {
|
||||
prefix = "/"
|
||||
} else {
|
||||
if prefix[0] != '/' {
|
||||
prefix = "/" + prefix
|
||||
}
|
||||
prefix = path.Clean(prefix)
|
||||
}
|
||||
// If the prefix is empty, ignore it.
|
||||
if prefix != "/" {
|
||||
qr = NaiveFilter(qr, FilterKeyPrefix{prefix + "/"})
|
||||
}
|
||||
}
|
||||
for _, f := range q.Filters {
|
||||
qr = NaiveFilter(qr, f)
|
||||
}
|
||||
if len(q.Orders) > 0 {
|
||||
qr = NaiveOrder(qr, q.Orders...)
|
||||
}
|
||||
if q.Offset != 0 {
|
||||
qr = NaiveOffset(qr, q.Offset)
|
||||
}
|
||||
if q.Limit != 0 {
|
||||
qr = NaiveLimit(qr, q.Limit)
|
||||
}
|
||||
return qr
|
||||
}
|
||||
|
||||
func ResultEntriesFrom(keys []string, vals [][]byte) []Entry {
|
||||
re := make([]Entry, len(keys))
|
||||
for i, k := range keys {
|
||||
re[i] = Entry{Key: k, Size: len(vals[i]), Value: vals[i]}
|
||||
}
|
||||
return re
|
||||
}
|
|
@ -1,3 +0,0 @@
|
|||
{
|
||||
"version": "v0.5.1"
|
||||
}
|
|
@ -1,21 +0,0 @@
|
|||
The MIT License (MIT)
|
||||
|
||||
Copyright (c) 2016 Jeromy Johnson
|
||||
|
||||
Permission is hereby granted, free of charge, to any person obtaining a copy
|
||||
of this software and associated documentation files (the "Software"), to deal
|
||||
in the Software without restriction, including without limitation the rights
|
||||
to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
|
||||
copies of the Software, and to permit persons to whom the Software is
|
||||
furnished to do so, subject to the following conditions:
|
||||
|
||||
The above copyright notice and this permission notice shall be included in
|
||||
all copies or substantial portions of the Software.
|
||||
|
||||
THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
|
||||
IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
|
||||
FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
|
||||
AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
|
||||
LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
|
||||
OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN
|
||||
THE SOFTWARE.
|
|
@ -1,130 +0,0 @@
|
|||
# SQL Datastore
|
||||
|
||||
[![CircleCI](https://circleci.com/gh/ipfs/go-ds-sql.svg?style=shield)](https://circleci.com/gh/ipfs/go-ds-sql)
|
||||
[![Coverage](https://codecov.io/gh/ipfs/go-ds-sql/branch/master/graph/badge.svg)](https://codecov.io/gh/ipfs/go-ds-sql)
|
||||
[![Standard README](https://img.shields.io/badge/readme%20style-standard-brightgreen.svg)](https://github.com/RichardLitt/standard-readme)
|
||||
[![GoDoc](http://img.shields.io/badge/godoc-reference-5272B4.svg)](https://godoc.org/github.com/ipfs/go-ds-sql)
|
||||
[![golang version](https://img.shields.io/badge/golang-%3E%3D1.14.0-orange.svg)](https://golang.org/)
|
||||
[![Go Report Card](https://goreportcard.com/badge/github.com/ipfs/go-ds-sql)](https://goreportcard.com/report/github.com/ipfs/go-ds-sql)
|
||||
|
||||
An implementation of [the datastore interface](https://github.com/ipfs/go-datastore)
|
||||
that can be backed by any sql database.
|
||||
|
||||
## Install
|
||||
|
||||
```sh
|
||||
go get github.com/ipfs/go-ds-sql
|
||||
```
|
||||
|
||||
## Usage
|
||||
|
||||
### PostgreSQL
|
||||
|
||||
Ensure a database is created and a table exists with `key` and `data` columns. For example, in PostgreSQL you can create a table with the following structure (replacing `table_name` with the name of the table the datastore will use - by default this is `blocks`):
|
||||
|
||||
```sql
|
||||
CREATE TABLE IF NOT EXISTS table_name (key TEXT NOT NULL UNIQUE, data BYTEA)
|
||||
```
|
||||
|
||||
It's recommended to create an index on the `key` column that is optimised for prefix scans. For example, in PostgreSQL you can create a `text_pattern_ops` index on the table:
|
||||
|
||||
```sql
|
||||
CREATE INDEX IF NOT EXISTS table_name_key_text_pattern_ops_idx ON table_name (key text_pattern_ops)
|
||||
```
|
||||
|
||||
Import and use in your application:
|
||||
|
||||
```go
|
||||
import (
|
||||
"database/sql"
|
||||
"github.com/ipfs/go-ds-sql"
|
||||
pg "github.com/ipfs/go-ds-sql/postgres"
|
||||
)
|
||||
|
||||
mydb, _ := sql.Open("yourdb", "yourdbparameters")
|
||||
|
||||
// Implement the Queries interface for your SQL impl.
|
||||
// ...or use the provided PostgreSQL queries
|
||||
queries := pg.NewQueries("blocks")
|
||||
|
||||
ds := sqlds.NewDatastore(mydb, queries)
|
||||
```
|
||||
|
||||
### SQLite
|
||||
|
||||
The [SQLite](https://sqlite.org) wrapper tries to create the table automatically
|
||||
|
||||
Prefix scans are optimized by using GLOB
|
||||
|
||||
Import and use in your application:
|
||||
|
||||
```go
|
||||
package main
|
||||
|
||||
import (
|
||||
sqliteds "github.com/ipfs/go-ds-sql/sqlite"
|
||||
_ "github.com/mattn/go-sqlite3"
|
||||
)
|
||||
|
||||
func main() {
|
||||
opts := &sqliteds.Options{
|
||||
DSN: "db.sqlite",
|
||||
}
|
||||
|
||||
ds, err := opts.Create()
|
||||
if err != nil {
|
||||
panic(err)
|
||||
}
|
||||
defer func() {
|
||||
if err := ds.Close(); err != nil {
|
||||
panic(err)
|
||||
}
|
||||
}()
|
||||
}
|
||||
```
|
||||
|
||||
If no `DSN` is specified, an unique in-memory database will be created
|
||||
|
||||
### SQLCipher
|
||||
|
||||
The SQLite wrapper also supports the [SQLCipher](https://www.zetetic.net/sqlcipher/) extension
|
||||
|
||||
Import and use in your application:
|
||||
|
||||
```go
|
||||
package main
|
||||
|
||||
import (
|
||||
sqliteds "github.com/ipfs/go-ds-sql/sqlite"
|
||||
_ "github.com/mutecomm/go-sqlcipher/v4"
|
||||
)
|
||||
|
||||
func main() {
|
||||
opts := &sqliteds.Options{
|
||||
DSN: "encdb.sqlite",
|
||||
Key: ([]byte)("32_very_secure_bytes_0123456789a"),
|
||||
}
|
||||
|
||||
ds, err := opts.Create()
|
||||
if err != nil {
|
||||
panic(err)
|
||||
}
|
||||
defer func() {
|
||||
if err := ds.Close(); err != nil {
|
||||
panic(err)
|
||||
}
|
||||
}()
|
||||
}
|
||||
```
|
||||
|
||||
## API
|
||||
|
||||
[GoDoc Reference](https://godoc.org/github.com/ipfs/go-ds-sql)
|
||||
|
||||
## Contribute
|
||||
|
||||
Feel free to dive in! [Open an issue](https://github.com/ipfs/go-ds-sql/issues/new) or submit PRs.
|
||||
|
||||
## License
|
||||
|
||||
[MIT](LICENSE)
|
|
@ -1,65 +0,0 @@
|
|||
package sqlds
|
||||
|
||||
import (
|
||||
"context"
|
||||
|
||||
ds "github.com/ipfs/go-datastore"
|
||||
)
|
||||
|
||||
type op struct {
|
||||
delete bool
|
||||
value []byte
|
||||
}
|
||||
|
||||
type batch struct {
|
||||
ds *Datastore
|
||||
ops map[ds.Key]op
|
||||
}
|
||||
|
||||
// Batch creates a set of deferred updates to the database.
|
||||
// Since SQL does not support a true batch of updates,
|
||||
// operations are buffered and then executed sequentially
|
||||
// over a single connection when Commit is called.
|
||||
func (d *Datastore) Batch(ctx context.Context) (ds.Batch, error) {
|
||||
return &batch{
|
||||
ds: d,
|
||||
ops: make(map[ds.Key]op),
|
||||
}, nil
|
||||
}
|
||||
|
||||
func (bt *batch) Put(ctx context.Context, key ds.Key, val []byte) error {
|
||||
bt.ops[key] = op{value: val}
|
||||
return nil
|
||||
}
|
||||
|
||||
func (bt *batch) Delete(ctx context.Context, key ds.Key) error {
|
||||
bt.ops[key] = op{delete: true}
|
||||
return nil
|
||||
}
|
||||
|
||||
func (bt *batch) Commit(ctx context.Context) error {
|
||||
return bt.CommitContext(ctx)
|
||||
}
|
||||
|
||||
func (bt *batch) CommitContext(ctx context.Context) error {
|
||||
conn, err := bt.ds.db.Conn(ctx)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
defer conn.Close()
|
||||
|
||||
for k, op := range bt.ops {
|
||||
if op.delete {
|
||||
_, err = conn.ExecContext(ctx, bt.ds.queries.Delete(), k.String())
|
||||
} else {
|
||||
_, err = conn.ExecContext(ctx, bt.ds.queries.Put(), k.String(), op.value)
|
||||
}
|
||||
if err != nil {
|
||||
break
|
||||
}
|
||||
}
|
||||
|
||||
return err
|
||||
}
|
||||
|
||||
var _ ds.Batching = (*Datastore)(nil)
|
|
@ -1,204 +0,0 @@
|
|||
package sqlds
|
||||
|
||||
import (
|
||||
"context"
|
||||
"database/sql"
|
||||
"fmt"
|
||||
|
||||
ds "github.com/ipfs/go-datastore"
|
||||
dsq "github.com/ipfs/go-datastore/query"
|
||||
)
|
||||
|
||||
// Queries generates SQL queries for datastore operations.
|
||||
type Queries interface {
|
||||
Delete() string
|
||||
Exists() string
|
||||
Get() string
|
||||
Put() string
|
||||
Query() string
|
||||
Prefix() string
|
||||
Limit() string
|
||||
Offset() string
|
||||
GetSize() string
|
||||
}
|
||||
|
||||
// Datastore is a SQL backed datastore.
|
||||
type Datastore struct {
|
||||
db *sql.DB
|
||||
queries Queries
|
||||
}
|
||||
|
||||
// NewDatastore returns a new SQL datastore.
|
||||
func NewDatastore(db *sql.DB, queries Queries) *Datastore {
|
||||
return &Datastore{db: db, queries: queries}
|
||||
}
|
||||
|
||||
// Close closes the underying SQL database.
|
||||
func (d *Datastore) Close() error {
|
||||
return d.db.Close()
|
||||
}
|
||||
|
||||
// Delete removes a row from the SQL database by the given key.
|
||||
func (d *Datastore) Delete(ctx context.Context, key ds.Key) error {
|
||||
_, err := d.db.ExecContext(ctx, d.queries.Delete(), key.String())
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// Get retrieves a value from the SQL database by the given key.
|
||||
func (d *Datastore) Get(ctx context.Context, key ds.Key) (value []byte, err error) {
|
||||
row := d.db.QueryRowContext(ctx, d.queries.Get(), key.String())
|
||||
var out []byte
|
||||
|
||||
switch err := row.Scan(&out); err {
|
||||
case sql.ErrNoRows:
|
||||
return nil, ds.ErrNotFound
|
||||
case nil:
|
||||
return out, nil
|
||||
default:
|
||||
return nil, err
|
||||
}
|
||||
}
|
||||
|
||||
// Has determines if a value for the given key exists in the SQL database.
|
||||
func (d *Datastore) Has(ctx context.Context, key ds.Key) (exists bool, err error) {
|
||||
row := d.db.QueryRowContext(ctx, d.queries.Exists(), key.String())
|
||||
|
||||
switch err := row.Scan(&exists); err {
|
||||
case sql.ErrNoRows:
|
||||
return exists, nil
|
||||
case nil:
|
||||
return exists, nil
|
||||
default:
|
||||
return exists, err
|
||||
}
|
||||
}
|
||||
|
||||
// Put "upserts" a row into the SQL database.
|
||||
func (d *Datastore) Put(ctx context.Context, key ds.Key, value []byte) error {
|
||||
_, err := d.db.ExecContext(ctx, d.queries.Put(), key.String(), value)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// Query returns multiple rows from the SQL database based on the passed query parameters.
|
||||
func (d *Datastore) Query(ctx context.Context, q dsq.Query) (dsq.Results, error) {
|
||||
raw, err := d.rawQuery(ctx, q)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
for _, f := range q.Filters {
|
||||
raw = dsq.NaiveFilter(raw, f)
|
||||
}
|
||||
|
||||
raw = dsq.NaiveOrder(raw, q.Orders...)
|
||||
|
||||
// if we have filters or orders, offset and limit won't have been applied in the query
|
||||
if len(q.Filters) > 0 || len(q.Orders) > 0 {
|
||||
if q.Offset != 0 {
|
||||
raw = dsq.NaiveOffset(raw, q.Offset)
|
||||
}
|
||||
if q.Limit != 0 {
|
||||
raw = dsq.NaiveLimit(raw, q.Limit)
|
||||
}
|
||||
}
|
||||
|
||||
return raw, nil
|
||||
}
|
||||
|
||||
func (d *Datastore) rawQuery(ctx context.Context, q dsq.Query) (dsq.Results, error) {
|
||||
var rows *sql.Rows
|
||||
var err error
|
||||
|
||||
rows, err = queryWithParams(ctx, d, q)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
it := dsq.Iterator{
|
||||
Next: func() (dsq.Result, bool) {
|
||||
if !rows.Next() {
|
||||
return dsq.Result{}, false
|
||||
}
|
||||
|
||||
var key string
|
||||
var out []byte
|
||||
|
||||
err := rows.Scan(&key, &out)
|
||||
if err != nil {
|
||||
return dsq.Result{Error: err}, false
|
||||
}
|
||||
|
||||
entry := dsq.Entry{Key: key}
|
||||
|
||||
if !q.KeysOnly {
|
||||
entry.Value = out
|
||||
}
|
||||
if q.ReturnsSizes {
|
||||
entry.Size = len(out)
|
||||
}
|
||||
|
||||
return dsq.Result{Entry: entry}, true
|
||||
},
|
||||
Close: func() error {
|
||||
return rows.Close()
|
||||
},
|
||||
}
|
||||
|
||||
return dsq.ResultsFromIterator(q, it), nil
|
||||
}
|
||||
|
||||
// Sync is noop for SQL databases.
|
||||
func (d *Datastore) Sync(ctx context.Context, key ds.Key) error {
|
||||
return nil
|
||||
}
|
||||
|
||||
// GetSize determines the size in bytes of the value for a given key.
|
||||
func (d *Datastore) GetSize(ctx context.Context, key ds.Key) (int, error) {
|
||||
row := d.db.QueryRowContext(ctx, d.queries.GetSize(), key.String())
|
||||
var size int
|
||||
|
||||
switch err := row.Scan(&size); err {
|
||||
case sql.ErrNoRows:
|
||||
return -1, ds.ErrNotFound
|
||||
case nil:
|
||||
return size, nil
|
||||
default:
|
||||
return 0, err
|
||||
}
|
||||
}
|
||||
|
||||
// queryWithParams applies prefix, limit, and offset params in pg query
|
||||
func queryWithParams(ctx context.Context, d *Datastore, q dsq.Query) (*sql.Rows, error) {
|
||||
var qNew = d.queries.Query()
|
||||
|
||||
if q.Prefix != "" {
|
||||
// normalize
|
||||
prefix := ds.NewKey(q.Prefix).String()
|
||||
if prefix != "/" {
|
||||
qNew += fmt.Sprintf(d.queries.Prefix(), prefix+"/")
|
||||
}
|
||||
}
|
||||
|
||||
// only apply limit and offset if we do not have to naive filter/order the results
|
||||
if len(q.Filters) == 0 && len(q.Orders) == 0 {
|
||||
if q.Limit != 0 {
|
||||
qNew += fmt.Sprintf(d.queries.Limit(), q.Limit)
|
||||
}
|
||||
if q.Offset != 0 {
|
||||
qNew += fmt.Sprintf(d.queries.Offset(), q.Offset)
|
||||
}
|
||||
}
|
||||
|
||||
return d.db.QueryContext(ctx, qNew)
|
||||
|
||||
}
|
||||
|
||||
var _ ds.Datastore = (*Datastore)(nil)
|
|
@ -1,30 +0,0 @@
|
|||
{
|
||||
"author": "whyrusleeping",
|
||||
"bugs": {
|
||||
"url": "https://github.com/whyrusleeping/sql-datastore"
|
||||
},
|
||||
"gx": {
|
||||
"dvcsimport": "github.com/whyrusleeping/sql-datastore"
|
||||
},
|
||||
"gxDependencies": [
|
||||
{
|
||||
"author": "magik6k",
|
||||
"hash": "QmfJhaxwzBCorUmZNRmY87z4mD6roRrHFMqddhiS4D4XWr",
|
||||
"name": "pq",
|
||||
"version": "1.0.0"
|
||||
},
|
||||
{
|
||||
"author": "jbenet",
|
||||
"hash": "QmPGYyi1DtuWyUkG3PtvLz1xb4ScjnUvwJMCoX3cxeyxNr",
|
||||
"name": "go-datastore",
|
||||
"version": "3.5.0"
|
||||
}
|
||||
],
|
||||
"gxVersion": "0.14.0",
|
||||
"language": "go",
|
||||
"license": "",
|
||||
"name": "sql-datastore",
|
||||
"releaseCmd": "git commit -a -m \"gx publish $VERSION\"",
|
||||
"version": "1.0.2"
|
||||
}
|
||||
|
|
@ -1,122 +0,0 @@
|
|||
package sqlds
|
||||
|
||||
import (
|
||||
"context"
|
||||
"database/sql"
|
||||
"fmt"
|
||||
|
||||
datastore "github.com/ipfs/go-datastore"
|
||||
dsq "github.com/ipfs/go-datastore/query"
|
||||
)
|
||||
|
||||
// ErrNotImplemented is returned when the SQL datastore does not yet implement the function call.
|
||||
var ErrNotImplemented = fmt.Errorf("not implemented")
|
||||
|
||||
type txn struct {
|
||||
db *sql.DB
|
||||
queries Queries
|
||||
txn *sql.Tx
|
||||
}
|
||||
|
||||
// NewTransaction creates a new database transaction, note the readOnly parameter is ignored by this implementation.
|
||||
func (ds *Datastore) NewTransaction(ctx context.Context, _ bool) (datastore.Txn, error) {
|
||||
sqlTxn, err := ds.db.BeginTx(ctx, nil)
|
||||
if err != nil {
|
||||
if sqlTxn != nil {
|
||||
// nothing we can do about this error.
|
||||
_ = sqlTxn.Rollback()
|
||||
}
|
||||
|
||||
return nil, err
|
||||
}
|
||||
|
||||
return &txn{
|
||||
db: ds.db,
|
||||
queries: ds.queries,
|
||||
txn: sqlTxn,
|
||||
}, nil
|
||||
}
|
||||
|
||||
func (t *txn) Get(ctx context.Context, key datastore.Key) ([]byte, error) {
|
||||
row := t.txn.QueryRowContext(ctx, t.queries.Get(), key.String())
|
||||
var out []byte
|
||||
|
||||
switch err := row.Scan(&out); err {
|
||||
case sql.ErrNoRows:
|
||||
return nil, datastore.ErrNotFound
|
||||
case nil:
|
||||
return out, nil
|
||||
default:
|
||||
return nil, err
|
||||
}
|
||||
}
|
||||
|
||||
func (t *txn) Has(ctx context.Context, key datastore.Key) (bool, error) {
|
||||
row := t.txn.QueryRowContext(ctx, t.queries.Exists(), key.String())
|
||||
var exists bool
|
||||
|
||||
switch err := row.Scan(&exists); err {
|
||||
case sql.ErrNoRows:
|
||||
return exists, nil
|
||||
case nil:
|
||||
return exists, nil
|
||||
default:
|
||||
return exists, err
|
||||
}
|
||||
}
|
||||
|
||||
func (t *txn) GetSize(ctx context.Context, key datastore.Key) (int, error) {
|
||||
row := t.txn.QueryRowContext(ctx, t.queries.GetSize(), key.String())
|
||||
var size int
|
||||
|
||||
switch err := row.Scan(&size); err {
|
||||
case sql.ErrNoRows:
|
||||
return -1, datastore.ErrNotFound
|
||||
case nil:
|
||||
return size, nil
|
||||
default:
|
||||
return 0, err
|
||||
}
|
||||
}
|
||||
|
||||
func (t *txn) Query(ctx context.Context, q dsq.Query) (dsq.Results, error) {
|
||||
return nil, ErrNotImplemented
|
||||
}
|
||||
|
||||
// Put adds a value to the datastore identified by the given key.
|
||||
func (t *txn) Put(ctx context.Context, key datastore.Key, val []byte) error {
|
||||
_, err := t.txn.ExecContext(ctx, t.queries.Put(), key.String(), val)
|
||||
if err != nil {
|
||||
_ = t.txn.Rollback()
|
||||
return err
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
// Delete removes a value from the datastore that matches the given key.
|
||||
func (t *txn) Delete(ctx context.Context, key datastore.Key) error {
|
||||
_, err := t.txn.ExecContext(ctx, t.queries.Delete(), key.String())
|
||||
if err != nil {
|
||||
_ = t.txn.Rollback()
|
||||
return err
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
// Commit finalizes a transaction.
|
||||
func (t *txn) Commit(ctx context.Context) error {
|
||||
err := t.txn.Commit()
|
||||
if err != nil {
|
||||
_ = t.txn.Rollback()
|
||||
return err
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
// Discard throws away changes recorded in a transaction without committing
|
||||
// them to the underlying Datastore.
|
||||
func (t *txn) Discard(ctx context.Context) {
|
||||
_ = t.txn.Rollback()
|
||||
}
|
||||
|
||||
var _ datastore.TxnDatastore = (*Datastore)(nil)
|
|
@ -1,3 +0,0 @@
|
|||
{
|
||||
"version": "v0.3.0"
|
||||
}
|
|
@ -1,9 +0,0 @@
|
|||
sudo: false
|
||||
|
||||
language: go
|
||||
|
||||
go:
|
||||
- 1.12
|
||||
|
||||
script:
|
||||
- go test -race -v ./...
|
|
@ -1,21 +0,0 @@
|
|||
The MIT License (MIT)
|
||||
|
||||
Copyright (c) 2014 Juan Batiz-Benet
|
||||
|
||||
Permission is hereby granted, free of charge, to any person obtaining a copy
|
||||
of this software and associated documentation files (the "Software"), to deal
|
||||
in the Software without restriction, including without limitation the rights
|
||||
to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
|
||||
copies of the Software, and to permit persons to whom the Software is
|
||||
furnished to do so, subject to the following conditions:
|
||||
|
||||
The above copyright notice and this permission notice shall be included in
|
||||
all copies or substantial portions of the Software.
|
||||
|
||||
THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
|
||||
IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
|
||||
FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
|
||||
AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
|
||||
LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
|
||||
OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN
|
||||
THE SOFTWARE.
|
|
@ -1,132 +0,0 @@
|
|||
# goprocess - lifecycles in go
|
||||
|
||||
[![travisbadge](https://travis-ci.org/jbenet/goprocess.svg)](https://travis-ci.org/jbenet/goprocess)
|
||||
|
||||
(Based on https://github.com/jbenet/go-ctxgroup)
|
||||
|
||||
- Godoc: https://godoc.org/github.com/jbenet/goprocess
|
||||
|
||||
`goprocess` introduces a way to manage process lifecycles in go. It is
|
||||
much like [go.net/context](https://godoc.org/code.google.com/p/go.net/context)
|
||||
(it actually uses a Context), but it is more like a Context-WaitGroup hybrid.
|
||||
`goprocess` is about being able to start and stop units of work, which may
|
||||
receive `Close` signals from many clients. Think of it like a UNIX process
|
||||
tree, but inside go.
|
||||
|
||||
`goprocess` seeks to minimally affect your objects, so you can use it
|
||||
with both embedding or composition. At the heart of `goprocess` is the
|
||||
`Process` interface:
|
||||
|
||||
```Go
|
||||
// Process is the basic unit of work in goprocess. It defines a computation
|
||||
// with a lifecycle:
|
||||
// - running (before calling Close),
|
||||
// - closing (after calling Close at least once),
|
||||
// - closed (after Close returns, and all teardown has _completed_).
|
||||
//
|
||||
// More specifically, it fits this:
|
||||
//
|
||||
// p := WithTeardown(tf) // new process is created, it is now running.
|
||||
// p.AddChild(q) // can register children **before** Closing.
|
||||
// go p.Close() // blocks until done running teardown func.
|
||||
// <-p.Closing() // would now return true.
|
||||
// <-p.childrenDone() // wait on all children to be done
|
||||
// p.teardown() // runs the user's teardown function tf.
|
||||
// p.Close() // now returns, with error teardown returned.
|
||||
// <-p.Closed() // would now return true.
|
||||
//
|
||||
// Processes can be arranged in a process "tree", where children are
|
||||
// automatically Closed if their parents are closed. (Note, it is actually
|
||||
// a Process DAG, children may have multiple parents). A process may also
|
||||
// optionally wait for another to fully Close before beginning to Close.
|
||||
// This makes it easy to ensure order of operations and proper sequential
|
||||
// teardown of resurces. For example:
|
||||
//
|
||||
// p1 := goprocess.WithTeardown(func() error {
|
||||
// fmt.Println("closing 1")
|
||||
// })
|
||||
// p2 := goprocess.WithTeardown(func() error {
|
||||
// fmt.Println("closing 2")
|
||||
// })
|
||||
// p3 := goprocess.WithTeardown(func() error {
|
||||
// fmt.Println("closing 3")
|
||||
// })
|
||||
//
|
||||
// p1.AddChild(p2)
|
||||
// p2.AddChild(p3)
|
||||
//
|
||||
//
|
||||
// go p1.Close()
|
||||
// go p2.Close()
|
||||
// go p3.Close()
|
||||
//
|
||||
// // Output:
|
||||
// // closing 3
|
||||
// // closing 2
|
||||
// // closing 1
|
||||
//
|
||||
// Process is modelled after the UNIX processes group idea, and heavily
|
||||
// informed by sync.WaitGroup and go.net/context.Context.
|
||||
//
|
||||
// In the function documentation of this interface, `p` always refers to
|
||||
// the self Process.
|
||||
type Process interface {
|
||||
|
||||
// WaitFor makes p wait for q before exiting. Thus, p will _always_ close
|
||||
// _after_ q. Note well: a waiting cycle is deadlock.
|
||||
//
|
||||
// If q is already Closed, WaitFor calls p.Close()
|
||||
// If p is already Closing or Closed, WaitFor panics. This is the same thing
|
||||
// as calling Add(1) _after_ calling Done() on a wait group. Calling WaitFor
|
||||
// on an already-closed process is a programming error likely due to bad
|
||||
// synchronization
|
||||
WaitFor(q Process)
|
||||
|
||||
// AddChildNoWait registers child as a "child" of Process. As in UNIX,
|
||||
// when parent is Closed, child is Closed -- child may Close beforehand.
|
||||
// This is the equivalent of calling:
|
||||
//
|
||||
// go func(parent, child Process) {
|
||||
// <-parent.Closing()
|
||||
// child.Close()
|
||||
// }(p, q)
|
||||
//
|
||||
// Note: the naming of functions is `AddChildNoWait` and `AddChild` (instead
|
||||
// of `AddChild` and `AddChildWaitFor`) because:
|
||||
// - it is the more common operation,
|
||||
// - explicitness is helpful in the less common case (no waiting), and
|
||||
// - usual "child" semantics imply parent Processes should wait for children.
|
||||
AddChildNoWait(q Process)
|
||||
|
||||
// AddChild is the equivalent of calling:
|
||||
// parent.AddChildNoWait(q)
|
||||
// parent.WaitFor(q)
|
||||
AddChild(q Process)
|
||||
|
||||
// Go creates a new process, adds it as a child, and spawns the ProcessFunc f
|
||||
// in its own goroutine. It is equivalent to:
|
||||
//
|
||||
// GoChild(p, f)
|
||||
//
|
||||
// It is useful to construct simple asynchronous workers, children of p.
|
||||
Go(f ProcessFunc) Process
|
||||
|
||||
// Close ends the process. Close blocks until the process has completely
|
||||
// shut down, and any teardown has run _exactly once_. The returned error
|
||||
// is available indefinitely: calling Close twice returns the same error.
|
||||
// If the process has already been closed, Close returns immediately.
|
||||
Close() error
|
||||
|
||||
// Closing is a signal to wait upon. The returned channel is closed
|
||||
// _after_ Close has been called at least once, but teardown may or may
|
||||
// not be done yet. The primary use case of Closing is for children who
|
||||
// need to know when a parent is shutting down, and therefore also shut
|
||||
// down.
|
||||
Closing() <-chan struct{}
|
||||
|
||||
// Closed is a signal to wait upon. The returned channel is closed
|
||||
// _after_ Close has completed; teardown has finished. The primary use case
|
||||
// of Closed is waiting for a Process to Close without _causing_ the Close.
|
||||
Closed() <-chan struct{}
|
||||
}
|
||||
```
|
|
@ -1,33 +0,0 @@
|
|||
package goprocess
|
||||
|
||||
// Background returns the "bgProcess" Process: a statically allocated
|
||||
// process that can _never_ close. It also never enters Closing() state.
|
||||
// Calling Background().Close() will hang indefinitely.
|
||||
func Background() Process {
|
||||
return background
|
||||
}
|
||||
|
||||
var background = new(bgProcess)
|
||||
|
||||
type bgProcess struct{}
|
||||
|
||||
func (*bgProcess) WaitFor(q Process) {}
|
||||
func (*bgProcess) AddChildNoWait(q Process) {}
|
||||
func (*bgProcess) AddChild(q Process) {}
|
||||
func (*bgProcess) Close() error { select {} }
|
||||
func (*bgProcess) CloseAfterChildren() error { select {} }
|
||||
func (*bgProcess) Closing() <-chan struct{} { return nil }
|
||||
func (*bgProcess) Closed() <-chan struct{} { return nil }
|
||||
func (*bgProcess) Err() error { select {} }
|
||||
|
||||
func (*bgProcess) SetTeardown(tf TeardownFunc) {
|
||||
panic("can't set teardown on bgProcess process")
|
||||
}
|
||||
func (*bgProcess) Go(f ProcessFunc) Process {
|
||||
child := newProcess(nil)
|
||||
go func() {
|
||||
f(child)
|
||||
child.Close()
|
||||
}()
|
||||
return child
|
||||
}
|
|
@ -1,263 +0,0 @@
|
|||
// Package goprocess introduces a Process abstraction that allows simple
|
||||
// organization, and orchestration of work. It is much like a WaitGroup,
|
||||
// and much like a context.Context, but also ensures safe **exactly-once**,
|
||||
// and well-ordered teardown semantics.
|
||||
package goprocess
|
||||
|
||||
import (
|
||||
"os"
|
||||
"os/signal"
|
||||
)
|
||||
|
||||
// Process is the basic unit of work in goprocess. It defines a computation
|
||||
// with a lifecycle:
|
||||
// - running (before calling Close),
|
||||
// - closing (after calling Close at least once),
|
||||
// - closed (after Close returns, and all teardown has _completed_).
|
||||
//
|
||||
// More specifically, it fits this:
|
||||
//
|
||||
// p := WithTeardown(tf) // new process is created, it is now running.
|
||||
// p.AddChild(q) // can register children **before** Closed().
|
||||
// go p.Close() // blocks until done running teardown func.
|
||||
// <-p.Closing() // would now return true.
|
||||
// <-p.childrenDone() // wait on all children to be done
|
||||
// p.teardown() // runs the user's teardown function tf.
|
||||
// p.Close() // now returns, with error teardown returned.
|
||||
// <-p.Closed() // would now return true.
|
||||
//
|
||||
// Processes can be arranged in a process "tree", where children are
|
||||
// automatically Closed if their parents are closed. (Note, it is actually
|
||||
// a Process DAG, children may have multiple parents). A process may also
|
||||
// optionally wait for another to fully Close before beginning to Close.
|
||||
// This makes it easy to ensure order of operations and proper sequential
|
||||
// teardown of resurces. For example:
|
||||
//
|
||||
// p1 := goprocess.WithTeardown(func() error {
|
||||
// fmt.Println("closing 1")
|
||||
// })
|
||||
// p2 := goprocess.WithTeardown(func() error {
|
||||
// fmt.Println("closing 2")
|
||||
// })
|
||||
// p3 := goprocess.WithTeardown(func() error {
|
||||
// fmt.Println("closing 3")
|
||||
// })
|
||||
//
|
||||
// p1.AddChild(p2)
|
||||
// p2.AddChild(p3)
|
||||
//
|
||||
//
|
||||
// go p1.Close()
|
||||
// go p2.Close()
|
||||
// go p3.Close()
|
||||
//
|
||||
// // Output:
|
||||
// // closing 3
|
||||
// // closing 2
|
||||
// // closing 1
|
||||
//
|
||||
// Process is modelled after the UNIX processes group idea, and heavily
|
||||
// informed by sync.WaitGroup and go.net/context.Context.
|
||||
//
|
||||
// In the function documentation of this interface, `p` always refers to
|
||||
// the self Process.
|
||||
type Process interface {
|
||||
|
||||
// WaitFor makes p wait for q before exiting. Thus, p will _always_ close
|
||||
// _after_ q. Note well: a waiting cycle is deadlock.
|
||||
//
|
||||
// If p is already Closed, WaitFor panics. This is the same thing as
|
||||
// calling Add(1) _after_ calling Done() on a wait group. Calling
|
||||
// WaitFor on an already-closed process is a programming error likely
|
||||
// due to bad synchronization
|
||||
WaitFor(q Process)
|
||||
|
||||
// AddChildNoWait registers child as a "child" of Process. As in UNIX,
|
||||
// when parent is Closed, child is Closed -- child may Close beforehand.
|
||||
// This is the equivalent of calling:
|
||||
//
|
||||
// go func(parent, child Process) {
|
||||
// <-parent.Closing()
|
||||
// child.Close()
|
||||
// }(p, q)
|
||||
//
|
||||
// Note: the naming of functions is `AddChildNoWait` and `AddChild` (instead
|
||||
// of `AddChild` and `AddChildWaitFor`) because:
|
||||
// - it is the more common operation,
|
||||
// - explicitness is helpful in the less common case (no waiting), and
|
||||
// - usual "child" semantics imply parent Processes should wait for children.
|
||||
AddChildNoWait(q Process)
|
||||
|
||||
// AddChild is the equivalent of calling:
|
||||
// parent.AddChildNoWait(q)
|
||||
// parent.WaitFor(q)
|
||||
//
|
||||
// It will _panic_ if the parent is already closed.
|
||||
AddChild(q Process)
|
||||
|
||||
// Go is much like `go`, as it runs a function in a newly spawned goroutine.
|
||||
// The neat part of Process.Go is that the Process object you call it on will:
|
||||
// * construct a child Process, and call AddChild(child) on it
|
||||
// * spawn a goroutine, and call the given function
|
||||
// * Close the child when the function exits.
|
||||
// This way, you can rest assured each goroutine you spawn has its very own
|
||||
// Process context, and that it will be closed when the function exits.
|
||||
// It is the function's responsibility to respect the Closing of its Process,
|
||||
// namely it should exit (return) when <-Closing() is ready. It is basically:
|
||||
//
|
||||
// func (p Process) Go(f ProcessFunc) Process {
|
||||
// child := WithParent(p)
|
||||
// go func () {
|
||||
// f(child)
|
||||
// child.Close()
|
||||
// }()
|
||||
// }
|
||||
//
|
||||
// It is useful to construct simple asynchronous workers, children of p.
|
||||
Go(f ProcessFunc) Process
|
||||
|
||||
// SetTeardown sets the process's teardown to tf.
|
||||
SetTeardown(tf TeardownFunc)
|
||||
|
||||
// Close ends the process. Close blocks until the process has completely
|
||||
// shut down, and any teardown has run _exactly once_. The returned error
|
||||
// is available indefinitely: calling Close twice returns the same error.
|
||||
// If the process has already been closed, Close returns immediately.
|
||||
Close() error
|
||||
|
||||
// CloseAfterChildren calls Close _after_ its children have Closed
|
||||
// normally (i.e. it _does not_ attempt to close them).
|
||||
CloseAfterChildren() error
|
||||
|
||||
// Closing is a signal to wait upon. The returned channel is closed
|
||||
// _after_ Close has been called at least once, but teardown may or may
|
||||
// not be done yet. The primary use case of Closing is for children who
|
||||
// need to know when a parent is shutting down, and therefore also shut
|
||||
// down.
|
||||
Closing() <-chan struct{}
|
||||
|
||||
// Closed is a signal to wait upon. The returned channel is closed
|
||||
// _after_ Close has completed; teardown has finished. The primary use case
|
||||
// of Closed is waiting for a Process to Close without _causing_ the Close.
|
||||
Closed() <-chan struct{}
|
||||
|
||||
// Err waits until the process is closed, and then returns any error that
|
||||
// occurred during shutdown.
|
||||
Err() error
|
||||
}
|
||||
|
||||
// TeardownFunc is a function used to cleanup state at the end of the
|
||||
// lifecycle of a Process.
|
||||
type TeardownFunc func() error
|
||||
|
||||
// ProcessFunc is a function that takes a process. Its main use case is goprocess.Go,
|
||||
// which spawns a ProcessFunc in its own goroutine, and returns a corresponding
|
||||
// Process object.
|
||||
type ProcessFunc func(proc Process)
|
||||
|
||||
var nilProcessFunc = func(Process) {}
|
||||
|
||||
// Go is much like `go`: it runs a function in a newly spawned goroutine. The neat
|
||||
// part of Go is that it provides Process object to communicate between the
|
||||
// function and the outside world. Thus, callers can easily WaitFor, or Close the
|
||||
// function. It is the function's responsibility to respect the Closing of its Process,
|
||||
// namely it should exit (return) when <-Closing() is ready. It is simply:
|
||||
//
|
||||
// func Go(f ProcessFunc) Process {
|
||||
// p := WithParent(Background())
|
||||
// p.Go(f)
|
||||
// return p
|
||||
// }
|
||||
//
|
||||
// Note that a naive implementation of Go like the following would not work:
|
||||
//
|
||||
// func Go(f ProcessFunc) Process {
|
||||
// return Background().Go(f)
|
||||
// }
|
||||
//
|
||||
// This is because having the process you
|
||||
func Go(f ProcessFunc) Process {
|
||||
// return GoChild(Background(), f)
|
||||
|
||||
// we use two processes, one for communication, and
|
||||
// one for ensuring we wait on the function (unclosable from the outside).
|
||||
p := newProcess(nil)
|
||||
waitFor := newProcess(nil)
|
||||
p.WaitFor(waitFor) // prevent p from closing
|
||||
go func() {
|
||||
f(p)
|
||||
waitFor.Close() // allow p to close.
|
||||
p.Close() // ensure p closes.
|
||||
}()
|
||||
return p
|
||||
}
|
||||
|
||||
// GoChild is like Go, but it registers the returned Process as a child of parent,
|
||||
// **before** spawning the goroutine, which ensures proper synchronization with parent.
|
||||
// It is somewhat like
|
||||
//
|
||||
// func GoChild(parent Process, f ProcessFunc) Process {
|
||||
// p := WithParent(parent)
|
||||
// p.Go(f)
|
||||
// return p
|
||||
// }
|
||||
//
|
||||
// And it is similar to the classic WaitGroup use case:
|
||||
//
|
||||
// func WaitGroupGo(wg sync.WaitGroup, child func()) {
|
||||
// wg.Add(1)
|
||||
// go func() {
|
||||
// child()
|
||||
// wg.Done()
|
||||
// }()
|
||||
// }
|
||||
//
|
||||
func GoChild(parent Process, f ProcessFunc) Process {
|
||||
p := WithParent(parent)
|
||||
p.Go(f)
|
||||
return p
|
||||
}
|
||||
|
||||
// Spawn is an alias of `Go`. In many contexts, Spawn is a
|
||||
// well-known Process launching word, which fits our use case.
|
||||
var Spawn = Go
|
||||
|
||||
// SpawnChild is an alias of `GoChild`. In many contexts, Spawn is a
|
||||
// well-known Process launching word, which fits our use case.
|
||||
var SpawnChild = GoChild
|
||||
|
||||
// WithTeardown constructs and returns a Process with a TeardownFunc.
|
||||
// TeardownFunc tf will be called **exactly-once** when Process is
|
||||
// Closing, after all Children have fully closed, and before p is Closed.
|
||||
// In fact, Process p will not be Closed until tf runs and exits.
|
||||
// See lifecycle in Process doc.
|
||||
func WithTeardown(tf TeardownFunc) Process {
|
||||
if tf == nil {
|
||||
panic("nil tf TeardownFunc")
|
||||
}
|
||||
return newProcess(tf)
|
||||
}
|
||||
|
||||
// WithParent constructs and returns a Process with a given parent.
|
||||
func WithParent(parent Process) Process {
|
||||
if parent == nil {
|
||||
panic("nil parent Process")
|
||||
}
|
||||
q := newProcess(nil)
|
||||
parent.AddChild(q)
|
||||
return q
|
||||
}
|
||||
|
||||
// WithSignals returns a Process that will Close() when any given signal fires.
|
||||
// This is useful to bind Process trees to syscall.SIGTERM, SIGKILL, etc.
|
||||
func WithSignals(sig ...os.Signal) Process {
|
||||
p := WithParent(Background())
|
||||
c := make(chan os.Signal, 1)
|
||||
signal.Notify(c, sig...)
|
||||
go func() {
|
||||
<-c
|
||||
signal.Stop(c)
|
||||
p.Close()
|
||||
}()
|
||||
return p
|
||||
}
|
|
@ -1,299 +0,0 @@
|
|||
package goprocess
|
||||
|
||||
import (
|
||||
"sync"
|
||||
)
|
||||
|
||||
// process implements Process
|
||||
type process struct {
|
||||
children map[*processLink]struct{} // process to close with us
|
||||
waitfors map[*processLink]struct{} // process to only wait for
|
||||
waiters []*processLink // processes that wait for us. for gc.
|
||||
|
||||
teardown TeardownFunc // called to run the teardown logic.
|
||||
closing chan struct{} // closed once close starts.
|
||||
closed chan struct{} // closed once close is done.
|
||||
closeErr error // error to return to clients of Close()
|
||||
|
||||
sync.Mutex
|
||||
}
|
||||
|
||||
// newProcess constructs and returns a Process.
|
||||
// It will call tf TeardownFunc exactly once:
|
||||
// **after** all children have fully Closed,
|
||||
// **after** entering <-Closing(), and
|
||||
// **before** <-Closed().
|
||||
func newProcess(tf TeardownFunc) *process {
|
||||
return &process{
|
||||
teardown: tf,
|
||||
closed: make(chan struct{}),
|
||||
closing: make(chan struct{}),
|
||||
waitfors: make(map[*processLink]struct{}),
|
||||
children: make(map[*processLink]struct{}),
|
||||
}
|
||||
}
|
||||
|
||||
func (p *process) WaitFor(q Process) {
|
||||
if q == nil {
|
||||
panic("waiting for nil process")
|
||||
}
|
||||
|
||||
p.Lock()
|
||||
defer p.Unlock()
|
||||
|
||||
select {
|
||||
case <-p.Closed():
|
||||
panic("Process cannot wait after being closed")
|
||||
default:
|
||||
}
|
||||
|
||||
pl := newProcessLink(p, q)
|
||||
if p.waitfors == nil {
|
||||
// This may be nil when we're closing. In close, we'll keep
|
||||
// reading this map till it stays nil.
|
||||
p.waitfors = make(map[*processLink]struct{}, 1)
|
||||
}
|
||||
p.waitfors[pl] = struct{}{}
|
||||
go pl.AddToChild()
|
||||
}
|
||||
|
||||
func (p *process) AddChildNoWait(child Process) {
|
||||
if child == nil {
|
||||
panic("adding nil child process")
|
||||
}
|
||||
|
||||
p.Lock()
|
||||
defer p.Unlock()
|
||||
|
||||
select {
|
||||
case <-p.Closing():
|
||||
// Either closed or closing, close child immediately. This is
|
||||
// correct because we aren't asked to _wait_ on this child.
|
||||
go child.Close()
|
||||
// Wait for the child to start closing so the child is in the
|
||||
// "correct" state after this function finishes (see #17).
|
||||
<-child.Closing()
|
||||
return
|
||||
default:
|
||||
}
|
||||
|
||||
pl := newProcessLink(p, child)
|
||||
p.children[pl] = struct{}{}
|
||||
go pl.AddToChild()
|
||||
}
|
||||
|
||||
func (p *process) AddChild(child Process) {
|
||||
if child == nil {
|
||||
panic("adding nil child process")
|
||||
}
|
||||
|
||||
p.Lock()
|
||||
defer p.Unlock()
|
||||
|
||||
pl := newProcessLink(p, child)
|
||||
|
||||
select {
|
||||
case <-p.Closed():
|
||||
// AddChild must not be called on a dead process. Maybe that's
|
||||
// too strict?
|
||||
panic("Process cannot add children after being closed")
|
||||
default:
|
||||
}
|
||||
|
||||
select {
|
||||
case <-p.Closing():
|
||||
// Already closing, close child in background.
|
||||
go child.Close()
|
||||
// Wait for the child to start closing so the child is in the
|
||||
// "correct" state after this function finishes (see #17).
|
||||
<-child.Closing()
|
||||
default:
|
||||
// Only add the child when not closing. When closing, just add
|
||||
// it to the "waitfors" list.
|
||||
p.children[pl] = struct{}{}
|
||||
}
|
||||
|
||||
if p.waitfors == nil {
|
||||
// This may be be nil when we're closing. In close, we'll keep
|
||||
// reading this map till it stays nil.
|
||||
p.waitfors = make(map[*processLink]struct{}, 1)
|
||||
}
|
||||
p.waitfors[pl] = struct{}{}
|
||||
go pl.AddToChild()
|
||||
}
|
||||
|
||||
func (p *process) Go(f ProcessFunc) Process {
|
||||
child := newProcess(nil)
|
||||
waitFor := newProcess(nil)
|
||||
child.WaitFor(waitFor) // prevent child from closing
|
||||
|
||||
// add child last, to prevent a closing parent from
|
||||
// closing all of them prematurely, before running the func.
|
||||
p.AddChild(child)
|
||||
go func() {
|
||||
f(child)
|
||||
waitFor.Close() // allow child to close.
|
||||
child.CloseAfterChildren() // close to tear down.
|
||||
}()
|
||||
return child
|
||||
}
|
||||
|
||||
// SetTeardown to assign a teardown function
|
||||
func (p *process) SetTeardown(tf TeardownFunc) {
|
||||
if tf == nil {
|
||||
panic("cannot set nil TeardownFunc")
|
||||
}
|
||||
|
||||
p.Lock()
|
||||
if p.teardown != nil {
|
||||
panic("cannot SetTeardown twice")
|
||||
}
|
||||
|
||||
p.teardown = tf
|
||||
select {
|
||||
case <-p.Closed():
|
||||
// Call the teardown function, but don't set the error. We can't
|
||||
// change that after we shut down.
|
||||
tf()
|
||||
default:
|
||||
}
|
||||
p.Unlock()
|
||||
}
|
||||
|
||||
// Close is the external close function.
|
||||
// it's a wrapper around internalClose that waits on Closed()
|
||||
func (p *process) Close() error {
|
||||
p.Lock()
|
||||
|
||||
// if already closing, or closed, get out. (but wait!)
|
||||
select {
|
||||
case <-p.Closing():
|
||||
p.Unlock()
|
||||
<-p.Closed()
|
||||
return p.closeErr
|
||||
default:
|
||||
}
|
||||
|
||||
p.doClose()
|
||||
p.Unlock()
|
||||
return p.closeErr
|
||||
}
|
||||
|
||||
func (p *process) Closing() <-chan struct{} {
|
||||
return p.closing
|
||||
}
|
||||
|
||||
func (p *process) Closed() <-chan struct{} {
|
||||
return p.closed
|
||||
}
|
||||
|
||||
func (p *process) Err() error {
|
||||
<-p.Closed()
|
||||
return p.closeErr
|
||||
}
|
||||
|
||||
// the _actual_ close process.
|
||||
func (p *process) doClose() {
|
||||
// this function is only be called once (protected by p.Lock()).
|
||||
// and it will panic (on closing channels) otherwise.
|
||||
|
||||
close(p.closing) // signal that we're shutting down (Closing)
|
||||
|
||||
// We won't add any children after we start closing so we can do this
|
||||
// once.
|
||||
for plc, _ := range p.children {
|
||||
child := plc.Child()
|
||||
if child != nil { // check because child may already have been removed.
|
||||
go child.Close() // force all children to shut down
|
||||
}
|
||||
|
||||
// safe to call multiple times per link
|
||||
plc.ParentClear()
|
||||
}
|
||||
p.children = nil // clear them. release memory.
|
||||
|
||||
// We may repeatedly continue to add waiters while we wait to close so
|
||||
// we have to do this in a loop.
|
||||
for len(p.waitfors) > 0 {
|
||||
// we must be careful not to iterate over waitfors directly, as it may
|
||||
// change under our feet.
|
||||
wf := p.waitfors
|
||||
p.waitfors = nil // clear them. release memory.
|
||||
for w, _ := range wf {
|
||||
// Here, we wait UNLOCKED, so that waitfors who are in the middle of
|
||||
// adding a child to us can finish. we will immediately close the child.
|
||||
p.Unlock()
|
||||
<-w.ChildClosed() // wait till all waitfors are fully closed (before teardown)
|
||||
p.Lock()
|
||||
|
||||
// safe to call multiple times per link
|
||||
w.ParentClear()
|
||||
}
|
||||
}
|
||||
|
||||
if p.teardown != nil {
|
||||
p.closeErr = p.teardown() // actually run the close logic (ok safe to teardown)
|
||||
}
|
||||
close(p.closed) // signal that we're shut down (Closed)
|
||||
|
||||
// go remove all the parents from the process links. optimization.
|
||||
go func(waiters []*processLink) {
|
||||
for _, pl := range waiters {
|
||||
pl.ClearChild()
|
||||
pr, ok := pl.Parent().(*process)
|
||||
if !ok {
|
||||
// parent has already been called to close
|
||||
continue
|
||||
}
|
||||
pr.Lock()
|
||||
delete(pr.waitfors, pl)
|
||||
delete(pr.children, pl)
|
||||
pr.Unlock()
|
||||
}
|
||||
}(p.waiters) // pass in so
|
||||
p.waiters = nil // clear them. release memory.
|
||||
}
|
||||
|
||||
// We will only wait on the children we have now.
|
||||
// We will not wait on children added subsequently.
|
||||
// this may change in the future.
|
||||
func (p *process) CloseAfterChildren() error {
|
||||
p.Lock()
|
||||
select {
|
||||
case <-p.Closed():
|
||||
p.Unlock()
|
||||
return p.Close() // get error. safe, after p.Closed()
|
||||
default:
|
||||
}
|
||||
p.Unlock()
|
||||
|
||||
// here only from one goroutine.
|
||||
|
||||
nextToWaitFor := func() Process {
|
||||
p.Lock()
|
||||
defer p.Unlock()
|
||||
for e, _ := range p.waitfors {
|
||||
c := e.Child()
|
||||
if c == nil {
|
||||
continue
|
||||
}
|
||||
|
||||
select {
|
||||
case <-c.Closed():
|
||||
default:
|
||||
return c
|
||||
}
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
// wait for all processes we're waiting for are closed.
|
||||
// the semantics here are simple: we will _only_ close
|
||||
// if there are no processes currently waiting for.
|
||||
for next := nextToWaitFor(); next != nil; next = nextToWaitFor() {
|
||||
<-next.Closed()
|
||||
}
|
||||
|
||||
// YAY! we're done. close
|
||||
return p.Close()
|
||||
}
|
|
@ -1,128 +0,0 @@
|
|||
package goprocess
|
||||
|
||||
import (
|
||||
"sync"
|
||||
)
|
||||
|
||||
// closedCh is an alread-closed channel. used to return
|
||||
// in cases where we already know a channel is closed.
|
||||
var closedCh chan struct{}
|
||||
|
||||
func init() {
|
||||
closedCh = make(chan struct{})
|
||||
close(closedCh)
|
||||
}
|
||||
|
||||
// a processLink is an internal bookkeeping datastructure.
|
||||
// it's used to form a relationship between two processes.
|
||||
// It is mostly for keeping memory usage down (letting
|
||||
// children close and be garbage-collected).
|
||||
type processLink struct {
|
||||
// guards all fields.
|
||||
// DO NOT HOLD while holding process locks.
|
||||
// it may be slow, and could deadlock if not careful.
|
||||
sync.Mutex
|
||||
parent Process
|
||||
child Process
|
||||
}
|
||||
|
||||
func newProcessLink(p, c Process) *processLink {
|
||||
return &processLink{
|
||||
parent: p,
|
||||
child: c,
|
||||
}
|
||||
}
|
||||
|
||||
// Closing returns whether the child is closing
|
||||
func (pl *processLink) ChildClosing() <-chan struct{} {
|
||||
// grab a hold of it, and unlock, as .Closing may block.
|
||||
pl.Lock()
|
||||
child := pl.child
|
||||
pl.Unlock()
|
||||
|
||||
if child == nil { // already closed? memory optimization.
|
||||
return closedCh
|
||||
}
|
||||
return child.Closing()
|
||||
}
|
||||
|
||||
func (pl *processLink) ChildClosed() <-chan struct{} {
|
||||
// grab a hold of it, and unlock, as .Closed may block.
|
||||
pl.Lock()
|
||||
child := pl.child
|
||||
pl.Unlock()
|
||||
|
||||
if child == nil { // already closed? memory optimization.
|
||||
return closedCh
|
||||
}
|
||||
return child.Closed()
|
||||
}
|
||||
|
||||
func (pl *processLink) ChildClose() {
|
||||
// grab a hold of it, and unlock, as .Closed may block.
|
||||
pl.Lock()
|
||||
child := pl.child
|
||||
pl.Unlock()
|
||||
|
||||
if child != nil { // already closed? memory optimization.
|
||||
child.Close()
|
||||
}
|
||||
}
|
||||
|
||||
func (pl *processLink) ClearChild() {
|
||||
pl.Lock()
|
||||
pl.child = nil
|
||||
pl.Unlock()
|
||||
}
|
||||
|
||||
func (pl *processLink) ParentClear() {
|
||||
pl.Lock()
|
||||
pl.parent = nil
|
||||
pl.Unlock()
|
||||
}
|
||||
|
||||
func (pl *processLink) Child() Process {
|
||||
pl.Lock()
|
||||
defer pl.Unlock()
|
||||
return pl.child
|
||||
}
|
||||
|
||||
func (pl *processLink) Parent() Process {
|
||||
pl.Lock()
|
||||
defer pl.Unlock()
|
||||
return pl.parent
|
||||
}
|
||||
|
||||
func (pl *processLink) AddToChild() {
|
||||
cp := pl.Child()
|
||||
|
||||
// is it a *process ? if not... panic.
|
||||
var c *process
|
||||
switch cp := cp.(type) {
|
||||
case *process:
|
||||
c = cp
|
||||
case *bgProcess:
|
||||
// Background process never closes so we don't need to do
|
||||
// anything.
|
||||
return
|
||||
default:
|
||||
panic("goprocess does not yet support other process impls.")
|
||||
}
|
||||
|
||||
// first, is it Closed?
|
||||
c.Lock()
|
||||
select {
|
||||
case <-c.Closed():
|
||||
c.Unlock()
|
||||
|
||||
// already closed. must not add.
|
||||
// we must clear it, though. do so without the lock.
|
||||
pl.ClearChild()
|
||||
return
|
||||
|
||||
default:
|
||||
// put the process link into q's waiters
|
||||
c.waiters = append(c.waiters, pl)
|
||||
c.Unlock()
|
||||
}
|
||||
}
|
|
@ -1,14 +0,0 @@
|
|||
{
|
||||
"author": "whyrusleeping",
|
||||
"bugs": {
|
||||
"url": "https://github.com/jbenet/goprocess"
|
||||
},
|
||||
"gx": {
|
||||
"dvcsimport": "github.com/jbenet/goprocess"
|
||||
},
|
||||
"gxVersion": "0.8.0",
|
||||
"language": "go",
|
||||
"license": "",
|
||||
"name": "goprocess",
|
||||
"version": "1.0.0"
|
||||
}
|
|
@ -23,3 +23,10 @@ _testmain.go
|
|||
*.test
|
||||
*.prof
|
||||
/s2/cmd/_s2sx/sfx-exe
|
||||
|
||||
# Linux perf files
|
||||
perf.data
|
||||
perf.data.old
|
||||
|
||||
# gdb history
|
||||
.gdb_history
|
||||
|
|
|
@ -17,6 +17,78 @@ This package provides various compression algorithms.
|
|||
|
||||
# changelog
|
||||
|
||||
* July 21, 2022 (v1.15.9)
|
||||
|
||||
* zstd: Fix decoder crash on amd64 (no BMI) on invalid input https://github.com/klauspost/compress/pull/645
|
||||
* zstd: Disable decoder extended memory copies (amd64) due to possible crashes https://github.com/klauspost/compress/pull/644
|
||||
* zstd: Allow single segments up to "max decoded size" by @klauspost in https://github.com/klauspost/compress/pull/643
|
||||
|
||||
* July 13, 2022 (v1.15.8)
|
||||
|
||||
* gzip: fix stack exhaustion bug in Reader.Read https://github.com/klauspost/compress/pull/641
|
||||
* s2: Add Index header trim/restore https://github.com/klauspost/compress/pull/638
|
||||
* zstd: Optimize seqdeq amd64 asm by @greatroar in https://github.com/klauspost/compress/pull/636
|
||||
* zstd: Improve decoder memcopy https://github.com/klauspost/compress/pull/637
|
||||
* huff0: Pass a single bitReader pointer to asm by @greatroar in https://github.com/klauspost/compress/pull/634
|
||||
* zstd: Branchless getBits for amd64 w/o BMI2 by @greatroar in https://github.com/klauspost/compress/pull/640
|
||||
* gzhttp: Remove header before writing https://github.com/klauspost/compress/pull/639
|
||||
|
||||
* June 29, 2022 (v1.15.7)
|
||||
|
||||
* s2: Fix absolute forward seeks https://github.com/klauspost/compress/pull/633
|
||||
* zip: Merge upstream https://github.com/klauspost/compress/pull/631
|
||||
* zip: Re-add zip64 fix https://github.com/klauspost/compress/pull/624
|
||||
* zstd: translate fseDecoder.buildDtable into asm by @WojciechMula in https://github.com/klauspost/compress/pull/598
|
||||
* flate: Faster histograms https://github.com/klauspost/compress/pull/620
|
||||
* deflate: Use compound hcode https://github.com/klauspost/compress/pull/622
|
||||
|
||||
* June 3, 2022 (v1.15.6)
|
||||
* s2: Improve coding for long, close matches https://github.com/klauspost/compress/pull/613
|
||||
* s2c: Add Snappy/S2 stream recompression https://github.com/klauspost/compress/pull/611
|
||||
* zstd: Always use configured block size https://github.com/klauspost/compress/pull/605
|
||||
* zstd: Fix incorrect hash table placement for dict encoding in default https://github.com/klauspost/compress/pull/606
|
||||
* zstd: Apply default config to ZipDecompressor without options https://github.com/klauspost/compress/pull/608
|
||||
* gzhttp: Exclude more common archive formats https://github.com/klauspost/compress/pull/612
|
||||
* s2: Add ReaderIgnoreCRC https://github.com/klauspost/compress/pull/609
|
||||
* s2: Remove sanity load on index creation https://github.com/klauspost/compress/pull/607
|
||||
* snappy: Use dedicated function for scoring https://github.com/klauspost/compress/pull/614
|
||||
* s2c+s2d: Use official snappy framed extension https://github.com/klauspost/compress/pull/610
|
||||
|
||||
* May 25, 2022 (v1.15.5)
|
||||
* s2: Add concurrent stream decompression https://github.com/klauspost/compress/pull/602
|
||||
* s2: Fix final emit oob read crash on amd64 https://github.com/klauspost/compress/pull/601
|
||||
* huff0: asm implementation of Decompress1X by @WojciechMula https://github.com/klauspost/compress/pull/596
|
||||
* zstd: Use 1 less goroutine for stream decoding https://github.com/klauspost/compress/pull/588
|
||||
* zstd: Copy literal in 16 byte blocks when possible https://github.com/klauspost/compress/pull/592
|
||||
* zstd: Speed up when WithDecoderLowmem(false) https://github.com/klauspost/compress/pull/599
|
||||
* zstd: faster next state update in BMI2 version of decode by @WojciechMula in https://github.com/klauspost/compress/pull/593
|
||||
* huff0: Do not check max size when reading table. https://github.com/klauspost/compress/pull/586
|
||||
* flate: Inplace hashing for level 7-9 by @klauspost in https://github.com/klauspost/compress/pull/590
|
||||
|
||||
|
||||
* May 11, 2022 (v1.15.4)
|
||||
* huff0: decompress directly into output by @WojciechMula in [#577](https://github.com/klauspost/compress/pull/577)
|
||||
* inflate: Keep dict on stack [#581](https://github.com/klauspost/compress/pull/581)
|
||||
* zstd: Faster decoding memcopy in asm [#583](https://github.com/klauspost/compress/pull/583)
|
||||
* zstd: Fix ignored crc [#580](https://github.com/klauspost/compress/pull/580)
|
||||
|
||||
* May 5, 2022 (v1.15.3)
|
||||
* zstd: Allow to ignore checksum checking by @WojciechMula [#572](https://github.com/klauspost/compress/pull/572)
|
||||
* s2: Fix incorrect seek for io.SeekEnd in [#575](https://github.com/klauspost/compress/pull/575)
|
||||
|
||||
* Apr 26, 2022 (v1.15.2)
|
||||
* zstd: Add x86-64 assembly for decompression on streams and blocks. Contributed by [@WojciechMula](https://github.com/WojciechMula). Typically 2x faster. [#528](https://github.com/klauspost/compress/pull/528) [#531](https://github.com/klauspost/compress/pull/531) [#545](https://github.com/klauspost/compress/pull/545) [#537](https://github.com/klauspost/compress/pull/537)
|
||||
* zstd: Add options to ZipDecompressor and fixes [#539](https://github.com/klauspost/compress/pull/539)
|
||||
* s2: Use sorted search for index [#555](https://github.com/klauspost/compress/pull/555)
|
||||
* Minimum version is Go 1.16, added CI test on 1.18.
|
||||
|
||||
* Mar 11, 2022 (v1.15.1)
|
||||
* huff0: Add x86 assembly of Decode4X by @WojciechMula in [#512](https://github.com/klauspost/compress/pull/512)
|
||||
* zstd: Reuse zip decoders in [#514](https://github.com/klauspost/compress/pull/514)
|
||||
* zstd: Detect extra block data and report as corrupted in [#520](https://github.com/klauspost/compress/pull/520)
|
||||
* zstd: Handle zero sized frame content size stricter in [#521](https://github.com/klauspost/compress/pull/521)
|
||||
* zstd: Add stricter block size checks in [#523](https://github.com/klauspost/compress/pull/523)
|
||||
|
||||
* Mar 3, 2022 (v1.15.0)
|
||||
* zstd: Refactor decoder by @klauspost in [#498](https://github.com/klauspost/compress/pull/498)
|
||||
* zstd: Add stream encoding without goroutines by @klauspost in [#505](https://github.com/klauspost/compress/pull/505)
|
||||
|
@ -60,6 +132,9 @@ While the release has been extensively tested, it is recommended to testing when
|
|||
* zstd: add arm64 xxhash assembly in [#464](https://github.com/klauspost/compress/pull/464)
|
||||
* Add garbled for binaries for s2 in [#445](https://github.com/klauspost/compress/pull/445)
|
||||
|
||||
<details>
|
||||
<summary>See changes to v1.13.x</summary>
|
||||
|
||||
* Aug 30, 2021 (v1.13.5)
|
||||
* gz/zlib/flate: Alias stdlib errors [#425](https://github.com/klauspost/compress/pull/425)
|
||||
* s2: Add block support to commandline tools [#413](https://github.com/klauspost/compress/pull/413)
|
||||
|
@ -88,6 +163,8 @@ While the release has been extensively tested, it is recommended to testing when
|
|||
* Added [gzhttp](https://github.com/klauspost/compress/tree/master/gzhttp#gzip-handler) which allows wrapping HTTP servers and clients with GZIP compressors.
|
||||
* zstd: Detect short invalid signatures [#382](https://github.com/klauspost/compress/pull/382)
|
||||
* zstd: Spawn decoder goroutine only if needed. [#380](https://github.com/klauspost/compress/pull/380)
|
||||
</details>
|
||||
|
||||
|
||||
<details>
|
||||
<summary>See changes to v1.12.x</summary>
|
||||
|
|
|
@ -1,5 +0,0 @@
|
|||
package huff0
|
||||
|
||||
//go:generate go run generate.go
|
||||
//go:generate asmfmt -w decompress_amd64.s
|
||||
//go:generate asmfmt -w decompress_8b_amd64.s
|
|
@ -165,11 +165,6 @@ func (b *bitReaderShifted) peekBitsFast(n uint8) uint16 {
|
|||
return uint16(b.value >> ((64 - n) & 63))
|
||||
}
|
||||
|
||||
// peekTopBits(n) is equvialent to peekBitFast(64 - n)
|
||||
func (b *bitReaderShifted) peekTopBits(n uint8) uint16 {
|
||||
return uint16(b.value >> n)
|
||||
}
|
||||
|
||||
func (b *bitReaderShifted) advance(n uint8) {
|
||||
b.bitsRead += n
|
||||
b.value <<= n & 63
|
||||
|
@ -220,11 +215,6 @@ func (b *bitReaderShifted) fill() {
|
|||
}
|
||||
}
|
||||
|
||||
// finished returns true if all bits have been read from the bit stream.
|
||||
func (b *bitReaderShifted) finished() bool {
|
||||
return b.off == 0 && b.bitsRead >= 64
|
||||
}
|
||||
|
||||
func (b *bitReaderShifted) remaining() uint {
|
||||
return b.off*8 + uint(64-b.bitsRead)
|
||||
}
|
||||
|
|
|
@ -5,8 +5,6 @@
|
|||
|
||||
package huff0
|
||||
|
||||
import "fmt"
|
||||
|
||||
// bitWriter will write bits.
|
||||
// First bit will be LSB of the first byte of output.
|
||||
type bitWriter struct {
|
||||
|
@ -23,14 +21,6 @@ var bitMask16 = [32]uint16{
|
|||
0xFFFF, 0xFFFF, 0xFFFF, 0xFFFF, 0xFFFF, 0xFFFF,
|
||||
0xFFFF, 0xFFFF} /* up to 16 bits */
|
||||
|
||||
// addBits16NC will add up to 16 bits.
|
||||
// It will not check if there is space for them,
|
||||
// so the caller must ensure that it has flushed recently.
|
||||
func (b *bitWriter) addBits16NC(value uint16, bits uint8) {
|
||||
b.bitContainer |= uint64(value&bitMask16[bits&31]) << (b.nBits & 63)
|
||||
b.nBits += bits
|
||||
}
|
||||
|
||||
// addBits16Clean will add up to 16 bits. value may not contain more set bits than indicated.
|
||||
// It will not check if there is space for them, so the caller must ensure that it has flushed recently.
|
||||
func (b *bitWriter) addBits16Clean(value uint16, bits uint8) {
|
||||
|
@ -70,104 +60,6 @@ func (b *bitWriter) encTwoSymbols(ct cTable, av, bv byte) {
|
|||
b.nBits += encA.nBits + encB.nBits
|
||||
}
|
||||
|
||||
// addBits16ZeroNC will add up to 16 bits.
|
||||
// It will not check if there is space for them,
|
||||
// so the caller must ensure that it has flushed recently.
|
||||
// This is fastest if bits can be zero.
|
||||
func (b *bitWriter) addBits16ZeroNC(value uint16, bits uint8) {
|
||||
if bits == 0 {
|
||||
return
|
||||
}
|
||||
value <<= (16 - bits) & 15
|
||||
value >>= (16 - bits) & 15
|
||||
b.bitContainer |= uint64(value) << (b.nBits & 63)
|
||||
b.nBits += bits
|
||||
}
|
||||
|
||||
// flush will flush all pending full bytes.
|
||||
// There will be at least 56 bits available for writing when this has been called.
|
||||
// Using flush32 is faster, but leaves less space for writing.
|
||||
func (b *bitWriter) flush() {
|
||||
v := b.nBits >> 3
|
||||
switch v {
|
||||
case 0:
|
||||
return
|
||||
case 1:
|
||||
b.out = append(b.out,
|
||||
byte(b.bitContainer),
|
||||
)
|
||||
b.bitContainer >>= 1 << 3
|
||||
case 2:
|
||||
b.out = append(b.out,
|
||||
byte(b.bitContainer),
|
||||
byte(b.bitContainer>>8),
|
||||
)
|
||||
b.bitContainer >>= 2 << 3
|
||||
case 3:
|
||||
b.out = append(b.out,
|
||||
byte(b.bitContainer),
|
||||
byte(b.bitContainer>>8),
|
||||
byte(b.bitContainer>>16),
|
||||
)
|
||||
b.bitContainer >>= 3 << 3
|
||||
case 4:
|
||||
b.out = append(b.out,
|
||||
byte(b.bitContainer),
|
||||
byte(b.bitContainer>>8),
|
||||
byte(b.bitContainer>>16),
|
||||
byte(b.bitContainer>>24),
|
||||
)
|
||||
b.bitContainer >>= 4 << 3
|
||||
case 5:
|
||||
b.out = append(b.out,
|
||||
byte(b.bitContainer),
|
||||
byte(b.bitContainer>>8),
|
||||
byte(b.bitContainer>>16),
|
||||
byte(b.bitContainer>>24),
|
||||
byte(b.bitContainer>>32),
|
||||
)
|
||||
b.bitContainer >>= 5 << 3
|
||||
case 6:
|
||||
b.out = append(b.out,
|
||||
byte(b.bitContainer),
|
||||
byte(b.bitContainer>>8),
|
||||
byte(b.bitContainer>>16),
|
||||
byte(b.bitContainer>>24),
|
||||
byte(b.bitContainer>>32),
|
||||
byte(b.bitContainer>>40),
|
||||
)
|
||||
b.bitContainer >>= 6 << 3
|
||||
case 7:
|
||||
b.out = append(b.out,
|
||||
byte(b.bitContainer),
|
||||
byte(b.bitContainer>>8),
|
||||
byte(b.bitContainer>>16),
|
||||
byte(b.bitContainer>>24),
|
||||
byte(b.bitContainer>>32),
|
||||
byte(b.bitContainer>>40),
|
||||
byte(b.bitContainer>>48),
|
||||
)
|
||||
b.bitContainer >>= 7 << 3
|
||||
case 8:
|
||||
b.out = append(b.out,
|
||||
byte(b.bitContainer),
|
||||
byte(b.bitContainer>>8),
|
||||
byte(b.bitContainer>>16),
|
||||
byte(b.bitContainer>>24),
|
||||
byte(b.bitContainer>>32),
|
||||
byte(b.bitContainer>>40),
|
||||
byte(b.bitContainer>>48),
|
||||
byte(b.bitContainer>>56),
|
||||
)
|
||||
b.bitContainer = 0
|
||||
b.nBits = 0
|
||||
return
|
||||
default:
|
||||
panic(fmt.Errorf("bits (%d) > 64", b.nBits))
|
||||
}
|
||||
b.nBits &= 7
|
||||
}
|
||||
|
||||
// flush32 will flush out, so there are at least 32 bits available for writing.
|
||||
func (b *bitWriter) flush32() {
|
||||
if b.nBits < 32 {
|
||||
|
@ -201,10 +93,3 @@ func (b *bitWriter) close() error {
|
|||
b.flushAlign()
|
||||
return nil
|
||||
}
|
||||
|
||||
// reset and continue writing by appending to out.
|
||||
func (b *bitWriter) reset(out []byte) {
|
||||
b.bitContainer = 0
|
||||
b.nBits = 0
|
||||
b.out = out
|
||||
}
|
||||
|
|
|
@ -20,11 +20,6 @@ func (b *byteReader) init(in []byte) {
|
|||
b.off = 0
|
||||
}
|
||||
|
||||
// advance the stream b n bytes.
|
||||
func (b *byteReader) advance(n uint) {
|
||||
b.off += int(n)
|
||||
}
|
||||
|
||||
// Int32 returns a little endian int32 starting at current offset.
|
||||
func (b byteReader) Int32() int32 {
|
||||
v3 := int32(b.b[b.off+3])
|
||||
|
@ -43,11 +38,6 @@ func (b byteReader) Uint32() uint32 {
|
|||
return (v3 << 24) | (v2 << 16) | (v1 << 8) | v0
|
||||
}
|
||||
|
||||
// unread returns the unread portion of the input.
|
||||
func (b byteReader) unread() []byte {
|
||||
return b.b[b.off:]
|
||||
}
|
||||
|
||||
// remain will return the number of bytes remaining.
|
||||
func (b byteReader) remain() int {
|
||||
return len(b.b) - b.off
|
||||
|
|
|
@ -404,6 +404,7 @@ func (s *Scratch) canUseTable(c cTable) bool {
|
|||
return true
|
||||
}
|
||||
|
||||
//lint:ignore U1000 used for debugging
|
||||
func (s *Scratch) validateTable(c cTable) bool {
|
||||
if len(c) < int(s.symbolLen) {
|
||||
return false
|
||||
|
|
|
@ -11,7 +11,6 @@ import (
|
|||
|
||||
type dTable struct {
|
||||
single []dEntrySingle
|
||||
double []dEntryDouble
|
||||
}
|
||||
|
||||
// single-symbols decoding
|
||||
|
@ -19,13 +18,6 @@ type dEntrySingle struct {
|
|||
entry uint16
|
||||
}
|
||||
|
||||
// double-symbols decoding
|
||||
type dEntryDouble struct {
|
||||
seq [4]byte
|
||||
nBits uint8
|
||||
len uint8
|
||||
}
|
||||
|
||||
// Uses special code for all tables that are < 8 bits.
|
||||
const use8BitTables = true
|
||||
|
||||
|
@ -35,7 +27,7 @@ const use8BitTables = true
|
|||
// If no Scratch is provided a new one is allocated.
|
||||
// The returned Scratch can be used for encoding or decoding input using this table.
|
||||
func ReadTable(in []byte, s *Scratch) (s2 *Scratch, remain []byte, err error) {
|
||||
s, err = s.prepare(in)
|
||||
s, err = s.prepare(nil)
|
||||
if err != nil {
|
||||
return s, nil, err
|
||||
}
|
||||
|
@ -236,108 +228,6 @@ func (d *Decoder) buffer() *[4][256]byte {
|
|||
return &[4][256]byte{}
|
||||
}
|
||||
|
||||
// Decompress1X will decompress a 1X encoded stream.
|
||||
// The cap of the output buffer will be the maximum decompressed size.
|
||||
// The length of the supplied input must match the end of a block exactly.
|
||||
func (d *Decoder) Decompress1X(dst, src []byte) ([]byte, error) {
|
||||
if len(d.dt.single) == 0 {
|
||||
return nil, errors.New("no table loaded")
|
||||
}
|
||||
if use8BitTables && d.actualTableLog <= 8 {
|
||||
return d.decompress1X8Bit(dst, src)
|
||||
}
|
||||
var br bitReaderShifted
|
||||
err := br.init(src)
|
||||
if err != nil {
|
||||
return dst, err
|
||||
}
|
||||
maxDecodedSize := cap(dst)
|
||||
dst = dst[:0]
|
||||
|
||||
// Avoid bounds check by always having full sized table.
|
||||
const tlSize = 1 << tableLogMax
|
||||
const tlMask = tlSize - 1
|
||||
dt := d.dt.single[:tlSize]
|
||||
|
||||
// Use temp table to avoid bound checks/append penalty.
|
||||
bufs := d.buffer()
|
||||
buf := &bufs[0]
|
||||
var off uint8
|
||||
|
||||
for br.off >= 8 {
|
||||
br.fillFast()
|
||||
v := dt[br.peekBitsFast(d.actualTableLog)&tlMask]
|
||||
br.advance(uint8(v.entry))
|
||||
buf[off+0] = uint8(v.entry >> 8)
|
||||
|
||||
v = dt[br.peekBitsFast(d.actualTableLog)&tlMask]
|
||||
br.advance(uint8(v.entry))
|
||||
buf[off+1] = uint8(v.entry >> 8)
|
||||
|
||||
// Refill
|
||||
br.fillFast()
|
||||
|
||||
v = dt[br.peekBitsFast(d.actualTableLog)&tlMask]
|
||||
br.advance(uint8(v.entry))
|
||||
buf[off+2] = uint8(v.entry >> 8)
|
||||
|
||||
v = dt[br.peekBitsFast(d.actualTableLog)&tlMask]
|
||||
br.advance(uint8(v.entry))
|
||||
buf[off+3] = uint8(v.entry >> 8)
|
||||
|
||||
off += 4
|
||||
if off == 0 {
|
||||
if len(dst)+256 > maxDecodedSize {
|
||||
br.close()
|
||||
d.bufs.Put(bufs)
|
||||
return nil, ErrMaxDecodedSizeExceeded
|
||||
}
|
||||
dst = append(dst, buf[:]...)
|
||||
}
|
||||
}
|
||||
|
||||
if len(dst)+int(off) > maxDecodedSize {
|
||||
d.bufs.Put(bufs)
|
||||
br.close()
|
||||
return nil, ErrMaxDecodedSizeExceeded
|
||||
}
|
||||
dst = append(dst, buf[:off]...)
|
||||
|
||||
// br < 8, so uint8 is fine
|
||||
bitsLeft := uint8(br.off)*8 + 64 - br.bitsRead
|
||||
for bitsLeft > 0 {
|
||||
br.fill()
|
||||
if false && br.bitsRead >= 32 {
|
||||
if br.off >= 4 {
|
||||
v := br.in[br.off-4:]
|
||||
v = v[:4]
|
||||
low := (uint32(v[0])) | (uint32(v[1]) << 8) | (uint32(v[2]) << 16) | (uint32(v[3]) << 24)
|
||||
br.value = (br.value << 32) | uint64(low)
|
||||
br.bitsRead -= 32
|
||||
br.off -= 4
|
||||
} else {
|
||||
for br.off > 0 {
|
||||
br.value = (br.value << 8) | uint64(br.in[br.off-1])
|
||||
br.bitsRead -= 8
|
||||
br.off--
|
||||
}
|
||||
}
|
||||
}
|
||||
if len(dst) >= maxDecodedSize {
|
||||
d.bufs.Put(bufs)
|
||||
br.close()
|
||||
return nil, ErrMaxDecodedSizeExceeded
|
||||
}
|
||||
v := d.dt.single[br.peekBitsFast(d.actualTableLog)&tlMask]
|
||||
nBits := uint8(v.entry)
|
||||
br.advance(nBits)
|
||||
bitsLeft -= nBits
|
||||
dst = append(dst, uint8(v.entry>>8))
|
||||
}
|
||||
d.bufs.Put(bufs)
|
||||
return dst, br.close()
|
||||
}
|
||||
|
||||
// decompress1X8Bit will decompress a 1X encoded stream with tablelog <= 8.
|
||||
// The cap of the output buffer will be the maximum decompressed size.
|
||||
// The length of the supplied input must match the end of a block exactly.
|
||||
|
@ -873,17 +763,20 @@ func (d *Decoder) decompress4X8bit(dst, src []byte) ([]byte, error) {
|
|||
d.bufs.Put(buf)
|
||||
return nil, errors.New("corruption detected: stream overrun 1")
|
||||
}
|
||||
copy(out, buf[0][:])
|
||||
copy(out[dstEvery:], buf[1][:])
|
||||
copy(out[dstEvery*2:], buf[2][:])
|
||||
copy(out[dstEvery*3:], buf[3][:])
|
||||
out = out[bufoff:]
|
||||
decoded += bufoff * 4
|
||||
// There must at least be 3 buffers left.
|
||||
if len(out) < dstEvery*3 {
|
||||
if len(out)-bufoff < dstEvery*3 {
|
||||
d.bufs.Put(buf)
|
||||
return nil, errors.New("corruption detected: stream overrun 2")
|
||||
}
|
||||
//copy(out, buf[0][:])
|
||||
//copy(out[dstEvery:], buf[1][:])
|
||||
//copy(out[dstEvery*2:], buf[2][:])
|
||||
*(*[bufoff]byte)(out) = buf[0]
|
||||
*(*[bufoff]byte)(out[dstEvery:]) = buf[1]
|
||||
*(*[bufoff]byte)(out[dstEvery*2:]) = buf[2]
|
||||
*(*[bufoff]byte)(out[dstEvery*3:]) = buf[3]
|
||||
out = out[bufoff:]
|
||||
decoded += bufoff * 4
|
||||
}
|
||||
}
|
||||
if off > 0 {
|
||||
|
@ -995,7 +888,6 @@ func (d *Decoder) decompress4X8bitExactly(dst, src []byte) ([]byte, error) {
|
|||
|
||||
const shift = 56
|
||||
const tlSize = 1 << 8
|
||||
const tlMask = tlSize - 1
|
||||
single := d.dt.single[:tlSize]
|
||||
|
||||
// Use temp table to avoid bound checks/append penalty.
|
||||
|
@ -1108,17 +1000,22 @@ func (d *Decoder) decompress4X8bitExactly(dst, src []byte) ([]byte, error) {
|
|||
d.bufs.Put(buf)
|
||||
return nil, errors.New("corruption detected: stream overrun 1")
|
||||
}
|
||||
copy(out, buf[0][:])
|
||||
copy(out[dstEvery:], buf[1][:])
|
||||
copy(out[dstEvery*2:], buf[2][:])
|
||||
copy(out[dstEvery*3:], buf[3][:])
|
||||
out = out[bufoff:]
|
||||
decoded += bufoff * 4
|
||||
// There must at least be 3 buffers left.
|
||||
if len(out) < dstEvery*3 {
|
||||
if len(out)-bufoff < dstEvery*3 {
|
||||
d.bufs.Put(buf)
|
||||
return nil, errors.New("corruption detected: stream overrun 2")
|
||||
}
|
||||
|
||||
//copy(out, buf[0][:])
|
||||
//copy(out[dstEvery:], buf[1][:])
|
||||
//copy(out[dstEvery*2:], buf[2][:])
|
||||
// copy(out[dstEvery*3:], buf[3][:])
|
||||
*(*[bufoff]byte)(out) = buf[0]
|
||||
*(*[bufoff]byte)(out[dstEvery:]) = buf[1]
|
||||
*(*[bufoff]byte)(out[dstEvery*2:]) = buf[2]
|
||||
*(*[bufoff]byte)(out[dstEvery*3:]) = buf[3]
|
||||
out = out[bufoff:]
|
||||
decoded += bufoff * 4
|
||||
}
|
||||
}
|
||||
if off > 0 {
|
||||
|
|
|
@ -1,488 +0,0 @@
|
|||
// +build !appengine
|
||||
// +build gc
|
||||
// +build !noasm
|
||||
|
||||
#include "textflag.h"
|
||||
#include "funcdata.h"
|
||||
#include "go_asm.h"
|
||||
|
||||
#define bufoff 256 // see decompress.go, we're using [4][256]byte table
|
||||
|
||||
// func decompress4x_main_loop_x86(pbr0, pbr1, pbr2, pbr3 *bitReaderShifted,
|
||||
// peekBits uint8, buf *byte, tbl *dEntrySingle) (int, bool)
|
||||
TEXT ·decompress4x_8b_loop_x86(SB), NOSPLIT, $8
|
||||
#define off R8
|
||||
#define buffer DI
|
||||
#define table SI
|
||||
|
||||
#define br_bits_read R9
|
||||
#define br_value R10
|
||||
#define br_offset R11
|
||||
#define peek_bits R12
|
||||
#define exhausted DX
|
||||
|
||||
#define br0 R13
|
||||
#define br1 R14
|
||||
#define br2 R15
|
||||
#define br3 BP
|
||||
|
||||
MOVQ BP, 0(SP)
|
||||
|
||||
XORQ exhausted, exhausted // exhausted = false
|
||||
XORQ off, off // off = 0
|
||||
|
||||
MOVBQZX peekBits+32(FP), peek_bits
|
||||
MOVQ buf+40(FP), buffer
|
||||
MOVQ tbl+48(FP), table
|
||||
|
||||
MOVQ pbr0+0(FP), br0
|
||||
MOVQ pbr1+8(FP), br1
|
||||
MOVQ pbr2+16(FP), br2
|
||||
MOVQ pbr3+24(FP), br3
|
||||
|
||||
main_loop:
|
||||
|
||||
// const stream = 0
|
||||
// br0.fillFast()
|
||||
MOVBQZX bitReaderShifted_bitsRead(br0), br_bits_read
|
||||
MOVQ bitReaderShifted_value(br0), br_value
|
||||
MOVQ bitReaderShifted_off(br0), br_offset
|
||||
|
||||
// if b.bitsRead >= 32 {
|
||||
CMPQ br_bits_read, $32
|
||||
JB skip_fill0
|
||||
|
||||
SUBQ $32, br_bits_read // b.bitsRead -= 32
|
||||
SUBQ $4, br_offset // b.off -= 4
|
||||
|
||||
// v := b.in[b.off-4 : b.off]
|
||||
// v = v[:4]
|
||||
// low := (uint32(v[0])) | (uint32(v[1]) << 8) | (uint32(v[2]) << 16) | (uint32(v[3]) << 24)
|
||||
MOVQ bitReaderShifted_in(br0), AX
|
||||
MOVL 0(br_offset)(AX*1), AX // AX = uint32(b.in[b.off:b.off+4])
|
||||
|
||||
// b.value |= uint64(low) << (b.bitsRead & 63)
|
||||
MOVQ br_bits_read, CX
|
||||
SHLQ CL, AX
|
||||
ORQ AX, br_value
|
||||
|
||||
// exhausted = exhausted || (br0.off < 4)
|
||||
CMPQ br_offset, $4
|
||||
SETLT DL
|
||||
ORB DL, DH
|
||||
|
||||
// }
|
||||
skip_fill0:
|
||||
|
||||
// val0 := br0.peekTopBits(peekBits)
|
||||
MOVQ br_value, AX
|
||||
MOVQ peek_bits, CX
|
||||
SHRQ CL, AX // AX = (value >> peek_bits) & mask
|
||||
|
||||
// v0 := table[val0&mask]
|
||||
MOVW 0(table)(AX*2), AX // AX - v0
|
||||
|
||||
// br0.advance(uint8(v0.entry))
|
||||
MOVB AH, BL // BL = uint8(v0.entry >> 8)
|
||||
MOVBQZX AL, CX
|
||||
SHLQ CL, br_value // value <<= n
|
||||
ADDQ CX, br_bits_read // bits_read += n
|
||||
|
||||
// val1 := br0.peekTopBits(peekBits)
|
||||
MOVQ peek_bits, CX
|
||||
MOVQ br_value, AX
|
||||
SHRQ CL, AX // AX = (value >> peek_bits) & mask
|
||||
|
||||
// v1 := table[val1&mask]
|
||||
MOVW 0(table)(AX*2), AX // AX - v1
|
||||
|
||||
// br0.advance(uint8(v1.entry))
|
||||
MOVB AH, BH // BH = uint8(v1.entry >> 8)
|
||||
MOVBQZX AL, CX
|
||||
SHLQ CX, br_value // value <<= n
|
||||
ADDQ CX, br_bits_read // bits_read += n
|
||||
|
||||
// these two writes get coalesced
|
||||
// buf[stream][off] = uint8(v0.entry >> 8)
|
||||
// buf[stream][off+1] = uint8(v1.entry >> 8)
|
||||
MOVW BX, 0(buffer)(off*1)
|
||||
|
||||
// SECOND PART:
|
||||
// val2 := br0.peekTopBits(peekBits)
|
||||
MOVQ br_value, AX
|
||||
MOVQ peek_bits, CX
|
||||
SHRQ CL, AX // AX = (value >> peek_bits) & mask
|
||||
|
||||
// v2 := table[val0&mask]
|
||||
MOVW 0(table)(AX*2), AX // AX - v0
|
||||
|
||||
// br0.advance(uint8(v0.entry))
|
||||
MOVB AH, BL // BL = uint8(v0.entry >> 8)
|
||||
MOVBQZX AL, CX
|
||||
SHLQ CL, br_value // value <<= n
|
||||
ADDQ CX, br_bits_read // bits_read += n
|
||||
|
||||
// val3 := br0.peekTopBits(peekBits)
|
||||
MOVQ peek_bits, CX
|
||||
MOVQ br_value, AX
|
||||
SHRQ CL, AX // AX = (value >> peek_bits) & mask
|
||||
|
||||
// v3 := table[val1&mask]
|
||||
MOVW 0(table)(AX*2), AX // AX - v1
|
||||
|
||||
// br0.advance(uint8(v1.entry))
|
||||
MOVB AH, BH // BH = uint8(v1.entry >> 8)
|
||||
MOVBQZX AL, CX
|
||||
SHLQ CX, br_value // value <<= n
|
||||
ADDQ CX, br_bits_read // bits_read += n
|
||||
|
||||
// these two writes get coalesced
|
||||
// buf[stream][off+2] = uint8(v2.entry >> 8)
|
||||
// buf[stream][off+3] = uint8(v3.entry >> 8)
|
||||
MOVW BX, 0+2(buffer)(off*1)
|
||||
|
||||
// update the bitrader reader structure
|
||||
MOVB br_bits_read, bitReaderShifted_bitsRead(br0)
|
||||
MOVQ br_value, bitReaderShifted_value(br0)
|
||||
MOVQ br_offset, bitReaderShifted_off(br0)
|
||||
|
||||
// const stream = 1
|
||||
// br1.fillFast()
|
||||
MOVBQZX bitReaderShifted_bitsRead(br1), br_bits_read
|
||||
MOVQ bitReaderShifted_value(br1), br_value
|
||||
MOVQ bitReaderShifted_off(br1), br_offset
|
||||
|
||||
// if b.bitsRead >= 32 {
|
||||
CMPQ br_bits_read, $32
|
||||
JB skip_fill1
|
||||
|
||||
SUBQ $32, br_bits_read // b.bitsRead -= 32
|
||||
SUBQ $4, br_offset // b.off -= 4
|
||||
|
||||
// v := b.in[b.off-4 : b.off]
|
||||
// v = v[:4]
|
||||
// low := (uint32(v[0])) | (uint32(v[1]) << 8) | (uint32(v[2]) << 16) | (uint32(v[3]) << 24)
|
||||
MOVQ bitReaderShifted_in(br1), AX
|
||||
MOVL 0(br_offset)(AX*1), AX // AX = uint32(b.in[b.off:b.off+4])
|
||||
|
||||
// b.value |= uint64(low) << (b.bitsRead & 63)
|
||||
MOVQ br_bits_read, CX
|
||||
SHLQ CL, AX
|
||||
ORQ AX, br_value
|
||||
|
||||
// exhausted = exhausted || (br1.off < 4)
|
||||
CMPQ br_offset, $4
|
||||
SETLT DL
|
||||
ORB DL, DH
|
||||
|
||||
// }
|
||||
skip_fill1:
|
||||
|
||||
// val0 := br1.peekTopBits(peekBits)
|
||||
MOVQ br_value, AX
|
||||
MOVQ peek_bits, CX
|
||||
SHRQ CL, AX // AX = (value >> peek_bits) & mask
|
||||
|
||||
// v0 := table[val0&mask]
|
||||
MOVW 0(table)(AX*2), AX // AX - v0
|
||||
|
||||
// br1.advance(uint8(v0.entry))
|
||||
MOVB AH, BL // BL = uint8(v0.entry >> 8)
|
||||
MOVBQZX AL, CX
|
||||
SHLQ CL, br_value // value <<= n
|
||||
ADDQ CX, br_bits_read // bits_read += n
|
||||
|
||||
// val1 := br1.peekTopBits(peekBits)
|
||||
MOVQ peek_bits, CX
|
||||
MOVQ br_value, AX
|
||||
SHRQ CL, AX // AX = (value >> peek_bits) & mask
|
||||
|
||||
// v1 := table[val1&mask]
|
||||
MOVW 0(table)(AX*2), AX // AX - v1
|
||||
|
||||
// br1.advance(uint8(v1.entry))
|
||||
MOVB AH, BH // BH = uint8(v1.entry >> 8)
|
||||
MOVBQZX AL, CX
|
||||
SHLQ CX, br_value // value <<= n
|
||||
ADDQ CX, br_bits_read // bits_read += n
|
||||
|
||||
// these two writes get coalesced
|
||||
// buf[stream][off] = uint8(v0.entry >> 8)
|
||||
// buf[stream][off+1] = uint8(v1.entry >> 8)
|
||||
MOVW BX, 256(buffer)(off*1)
|
||||
|
||||
// SECOND PART:
|
||||
// val2 := br1.peekTopBits(peekBits)
|
||||
MOVQ br_value, AX
|
||||
MOVQ peek_bits, CX
|
||||
SHRQ CL, AX // AX = (value >> peek_bits) & mask
|
||||
|
||||
// v2 := table[val0&mask]
|
||||
MOVW 0(table)(AX*2), AX // AX - v0
|
||||
|
||||
// br1.advance(uint8(v0.entry))
|
||||
MOVB AH, BL // BL = uint8(v0.entry >> 8)
|
||||
MOVBQZX AL, CX
|
||||
SHLQ CL, br_value // value <<= n
|
||||
ADDQ CX, br_bits_read // bits_read += n
|
||||
|
||||
// val3 := br1.peekTopBits(peekBits)
|
||||
MOVQ peek_bits, CX
|
||||
MOVQ br_value, AX
|
||||
SHRQ CL, AX // AX = (value >> peek_bits) & mask
|
||||
|
||||
// v3 := table[val1&mask]
|
||||
MOVW 0(table)(AX*2), AX // AX - v1
|
||||
|
||||
// br1.advance(uint8(v1.entry))
|
||||
MOVB AH, BH // BH = uint8(v1.entry >> 8)
|
||||
MOVBQZX AL, CX
|
||||
SHLQ CX, br_value // value <<= n
|
||||
ADDQ CX, br_bits_read // bits_read += n
|
||||
|
||||
// these two writes get coalesced
|
||||
// buf[stream][off+2] = uint8(v2.entry >> 8)
|
||||
// buf[stream][off+3] = uint8(v3.entry >> 8)
|
||||
MOVW BX, 256+2(buffer)(off*1)
|
||||
|
||||
// update the bitrader reader structure
|
||||
MOVB br_bits_read, bitReaderShifted_bitsRead(br1)
|
||||
MOVQ br_value, bitReaderShifted_value(br1)
|
||||
MOVQ br_offset, bitReaderShifted_off(br1)
|
||||
|
||||
// const stream = 2
|
||||
// br2.fillFast()
|
||||
MOVBQZX bitReaderShifted_bitsRead(br2), br_bits_read
|
||||
MOVQ bitReaderShifted_value(br2), br_value
|
||||
MOVQ bitReaderShifted_off(br2), br_offset
|
||||
|
||||
// if b.bitsRead >= 32 {
|
||||
CMPQ br_bits_read, $32
|
||||
JB skip_fill2
|
||||
|
||||
SUBQ $32, br_bits_read // b.bitsRead -= 32
|
||||
SUBQ $4, br_offset // b.off -= 4
|
||||
|
||||
// v := b.in[b.off-4 : b.off]
|
||||
// v = v[:4]
|
||||
// low := (uint32(v[0])) | (uint32(v[1]) << 8) | (uint32(v[2]) << 16) | (uint32(v[3]) << 24)
|
||||
MOVQ bitReaderShifted_in(br2), AX
|
||||
MOVL 0(br_offset)(AX*1), AX // AX = uint32(b.in[b.off:b.off+4])
|
||||
|
||||
// b.value |= uint64(low) << (b.bitsRead & 63)
|
||||
MOVQ br_bits_read, CX
|
||||
SHLQ CL, AX
|
||||
ORQ AX, br_value
|
||||
|
||||
// exhausted = exhausted || (br2.off < 4)
|
||||
CMPQ br_offset, $4
|
||||
SETLT DL
|
||||
ORB DL, DH
|
||||
|
||||
// }
|
||||
skip_fill2:
|
||||
|
||||
// val0 := br2.peekTopBits(peekBits)
|
||||
MOVQ br_value, AX
|
||||
MOVQ peek_bits, CX
|
||||
SHRQ CL, AX // AX = (value >> peek_bits) & mask
|
||||
|
||||
// v0 := table[val0&mask]
|
||||
MOVW 0(table)(AX*2), AX // AX - v0
|
||||
|
||||
// br2.advance(uint8(v0.entry))
|
||||
MOVB AH, BL // BL = uint8(v0.entry >> 8)
|
||||
MOVBQZX AL, CX
|
||||
SHLQ CL, br_value // value <<= n
|
||||
ADDQ CX, br_bits_read // bits_read += n
|
||||
|
||||
// val1 := br2.peekTopBits(peekBits)
|
||||
MOVQ peek_bits, CX
|
||||
MOVQ br_value, AX
|
||||
SHRQ CL, AX // AX = (value >> peek_bits) & mask
|
||||
|
||||
// v1 := table[val1&mask]
|
||||
MOVW 0(table)(AX*2), AX // AX - v1
|
||||
|
||||
// br2.advance(uint8(v1.entry))
|
||||
MOVB AH, BH // BH = uint8(v1.entry >> 8)
|
||||
MOVBQZX AL, CX
|
||||
SHLQ CX, br_value // value <<= n
|
||||
ADDQ CX, br_bits_read // bits_read += n
|
||||
|
||||
// these two writes get coalesced
|
||||
// buf[stream][off] = uint8(v0.entry >> 8)
|
||||
// buf[stream][off+1] = uint8(v1.entry >> 8)
|
||||
MOVW BX, 512(buffer)(off*1)
|
||||
|
||||
// SECOND PART:
|
||||
// val2 := br2.peekTopBits(peekBits)
|
||||
MOVQ br_value, AX
|
||||
MOVQ peek_bits, CX
|
||||
SHRQ CL, AX // AX = (value >> peek_bits) & mask
|
||||
|
||||
// v2 := table[val0&mask]
|
||||
MOVW 0(table)(AX*2), AX // AX - v0
|
||||
|
||||
// br2.advance(uint8(v0.entry))
|
||||
MOVB AH, BL // BL = uint8(v0.entry >> 8)
|
||||
MOVBQZX AL, CX
|
||||
SHLQ CL, br_value // value <<= n
|
||||
ADDQ CX, br_bits_read // bits_read += n
|
||||
|
||||
// val3 := br2.peekTopBits(peekBits)
|
||||
MOVQ peek_bits, CX
|
||||
MOVQ br_value, AX
|
||||
SHRQ CL, AX // AX = (value >> peek_bits) & mask
|
||||
|
||||
// v3 := table[val1&mask]
|
||||
MOVW 0(table)(AX*2), AX // AX - v1
|
||||
|
||||
// br2.advance(uint8(v1.entry))
|
||||
MOVB AH, BH // BH = uint8(v1.entry >> 8)
|
||||
MOVBQZX AL, CX
|
||||
SHLQ CX, br_value // value <<= n
|
||||
ADDQ CX, br_bits_read // bits_read += n
|
||||
|
||||
// these two writes get coalesced
|
||||
// buf[stream][off+2] = uint8(v2.entry >> 8)
|
||||
// buf[stream][off+3] = uint8(v3.entry >> 8)
|
||||
MOVW BX, 512+2(buffer)(off*1)
|
||||
|
||||
// update the bitrader reader structure
|
||||
MOVB br_bits_read, bitReaderShifted_bitsRead(br2)
|
||||
MOVQ br_value, bitReaderShifted_value(br2)
|
||||
MOVQ br_offset, bitReaderShifted_off(br2)
|
||||
|
||||
// const stream = 3
|
||||
// br3.fillFast()
|
||||
MOVBQZX bitReaderShifted_bitsRead(br3), br_bits_read
|
||||
MOVQ bitReaderShifted_value(br3), br_value
|
||||
MOVQ bitReaderShifted_off(br3), br_offset
|
||||
|
||||
// if b.bitsRead >= 32 {
|
||||
CMPQ br_bits_read, $32
|
||||
JB skip_fill3
|
||||
|
||||
SUBQ $32, br_bits_read // b.bitsRead -= 32
|
||||
SUBQ $4, br_offset // b.off -= 4
|
||||
|
||||
// v := b.in[b.off-4 : b.off]
|
||||
// v = v[:4]
|
||||
// low := (uint32(v[0])) | (uint32(v[1]) << 8) | (uint32(v[2]) << 16) | (uint32(v[3]) << 24)
|
||||
MOVQ bitReaderShifted_in(br3), AX
|
||||
MOVL 0(br_offset)(AX*1), AX // AX = uint32(b.in[b.off:b.off+4])
|
||||
|
||||
// b.value |= uint64(low) << (b.bitsRead & 63)
|
||||
MOVQ br_bits_read, CX
|
||||
SHLQ CL, AX
|
||||
ORQ AX, br_value
|
||||
|
||||
// exhausted = exhausted || (br3.off < 4)
|
||||
CMPQ br_offset, $4
|
||||
SETLT DL
|
||||
ORB DL, DH
|
||||
|
||||
// }
|
||||
skip_fill3:
|
||||
|
||||
// val0 := br3.peekTopBits(peekBits)
|
||||
MOVQ br_value, AX
|
||||
MOVQ peek_bits, CX
|
||||
SHRQ CL, AX // AX = (value >> peek_bits) & mask
|
||||
|
||||
// v0 := table[val0&mask]
|
||||
MOVW 0(table)(AX*2), AX // AX - v0
|
||||
|
||||
// br3.advance(uint8(v0.entry))
|
||||
MOVB AH, BL // BL = uint8(v0.entry >> 8)
|
||||
MOVBQZX AL, CX
|
||||
SHLQ CL, br_value // value <<= n
|
||||
ADDQ CX, br_bits_read // bits_read += n
|
||||
|
||||
// val1 := br3.peekTopBits(peekBits)
|
||||
MOVQ peek_bits, CX
|
||||
MOVQ br_value, AX
|
||||
SHRQ CL, AX // AX = (value >> peek_bits) & mask
|
||||
|
||||
// v1 := table[val1&mask]
|
||||
MOVW 0(table)(AX*2), AX // AX - v1
|
||||
|
||||
// br3.advance(uint8(v1.entry))
|
||||
MOVB AH, BH // BH = uint8(v1.entry >> 8)
|
||||
MOVBQZX AL, CX
|
||||
SHLQ CX, br_value // value <<= n
|
||||
ADDQ CX, br_bits_read // bits_read += n
|
||||
|
||||
// these two writes get coalesced
|
||||
// buf[stream][off] = uint8(v0.entry >> 8)
|
||||
// buf[stream][off+1] = uint8(v1.entry >> 8)
|
||||
MOVW BX, 768(buffer)(off*1)
|
||||
|
||||
// SECOND PART:
|
||||
// val2 := br3.peekTopBits(peekBits)
|
||||
MOVQ br_value, AX
|
||||
MOVQ peek_bits, CX
|
||||
SHRQ CL, AX // AX = (value >> peek_bits) & mask
|
||||
|
||||
// v2 := table[val0&mask]
|
||||
MOVW 0(table)(AX*2), AX // AX - v0
|
||||
|
||||
// br3.advance(uint8(v0.entry))
|
||||
MOVB AH, BL // BL = uint8(v0.entry >> 8)
|
||||
MOVBQZX AL, CX
|
||||
SHLQ CL, br_value // value <<= n
|
||||
ADDQ CX, br_bits_read // bits_read += n
|
||||
|
||||
// val3 := br3.peekTopBits(peekBits)
|
||||
MOVQ peek_bits, CX
|
||||
MOVQ br_value, AX
|
||||
SHRQ CL, AX // AX = (value >> peek_bits) & mask
|
||||
|
||||
// v3 := table[val1&mask]
|
||||
MOVW 0(table)(AX*2), AX // AX - v1
|
||||
|
||||
// br3.advance(uint8(v1.entry))
|
||||
MOVB AH, BH // BH = uint8(v1.entry >> 8)
|
||||
MOVBQZX AL, CX
|
||||
SHLQ CX, br_value // value <<= n
|
||||
ADDQ CX, br_bits_read // bits_read += n
|
||||
|
||||
// these two writes get coalesced
|
||||
// buf[stream][off+2] = uint8(v2.entry >> 8)
|
||||
// buf[stream][off+3] = uint8(v3.entry >> 8)
|
||||
MOVW BX, 768+2(buffer)(off*1)
|
||||
|
||||
// update the bitrader reader structure
|
||||
MOVB br_bits_read, bitReaderShifted_bitsRead(br3)
|
||||
MOVQ br_value, bitReaderShifted_value(br3)
|
||||
MOVQ br_offset, bitReaderShifted_off(br3)
|
||||
|
||||
ADDQ $4, off // off += 2
|
||||
|
||||
TESTB DH, DH // any br[i].ofs < 4?
|
||||
JNZ end
|
||||
|
||||
CMPQ off, $bufoff
|
||||
JL main_loop
|
||||
|
||||
end:
|
||||
MOVQ 0(SP), BP
|
||||
|
||||
MOVB off, ret+56(FP)
|
||||
RET
|
||||
|
||||
#undef off
|
||||
#undef buffer
|
||||
#undef table
|
||||
|
||||
#undef br_bits_read
|
||||
#undef br_value
|
||||
#undef br_offset
|
||||
#undef peek_bits
|
||||
#undef exhausted
|
||||
|
||||
#undef br0
|
||||
#undef br1
|
||||
#undef br2
|
||||
#undef br3
|
|
@ -1,197 +0,0 @@
|
|||
// +build !appengine
|
||||
// +build gc
|
||||
// +build !noasm
|
||||
|
||||
#include "textflag.h"
|
||||
#include "funcdata.h"
|
||||
#include "go_asm.h"
|
||||
|
||||
|
||||
#define bufoff 256 // see decompress.go, we're using [4][256]byte table
|
||||
|
||||
//func decompress4x_main_loop_x86(pbr0, pbr1, pbr2, pbr3 *bitReaderShifted,
|
||||
// peekBits uint8, buf *byte, tbl *dEntrySingle) (int, bool)
|
||||
TEXT ·decompress4x_8b_loop_x86(SB), NOSPLIT, $8
|
||||
#define off R8
|
||||
#define buffer DI
|
||||
#define table SI
|
||||
|
||||
#define br_bits_read R9
|
||||
#define br_value R10
|
||||
#define br_offset R11
|
||||
#define peek_bits R12
|
||||
#define exhausted DX
|
||||
|
||||
#define br0 R13
|
||||
#define br1 R14
|
||||
#define br2 R15
|
||||
#define br3 BP
|
||||
|
||||
MOVQ BP, 0(SP)
|
||||
|
||||
XORQ exhausted, exhausted // exhausted = false
|
||||
XORQ off, off // off = 0
|
||||
|
||||
MOVBQZX peekBits+32(FP), peek_bits
|
||||
MOVQ buf+40(FP), buffer
|
||||
MOVQ tbl+48(FP), table
|
||||
|
||||
MOVQ pbr0+0(FP), br0
|
||||
MOVQ pbr1+8(FP), br1
|
||||
MOVQ pbr2+16(FP), br2
|
||||
MOVQ pbr3+24(FP), br3
|
||||
|
||||
main_loop:
|
||||
{{ define "decode_2_values_x86" }}
|
||||
// const stream = {{ var "id" }}
|
||||
// br{{ var "id"}}.fillFast()
|
||||
MOVBQZX bitReaderShifted_bitsRead(br{{ var "id" }}), br_bits_read
|
||||
MOVQ bitReaderShifted_value(br{{ var "id" }}), br_value
|
||||
MOVQ bitReaderShifted_off(br{{ var "id" }}), br_offset
|
||||
|
||||
// if b.bitsRead >= 32 {
|
||||
CMPQ br_bits_read, $32
|
||||
JB skip_fill{{ var "id" }}
|
||||
|
||||
SUBQ $32, br_bits_read // b.bitsRead -= 32
|
||||
SUBQ $4, br_offset // b.off -= 4
|
||||
|
||||
// v := b.in[b.off-4 : b.off]
|
||||
// v = v[:4]
|
||||
// low := (uint32(v[0])) | (uint32(v[1]) << 8) | (uint32(v[2]) << 16) | (uint32(v[3]) << 24)
|
||||
MOVQ bitReaderShifted_in(br{{ var "id" }}), AX
|
||||
MOVL 0(br_offset)(AX*1), AX // AX = uint32(b.in[b.off:b.off+4])
|
||||
|
||||
// b.value |= uint64(low) << (b.bitsRead & 63)
|
||||
MOVQ br_bits_read, CX
|
||||
SHLQ CL, AX
|
||||
ORQ AX, br_value
|
||||
|
||||
// exhausted = exhausted || (br{{ var "id"}}.off < 4)
|
||||
CMPQ br_offset, $4
|
||||
SETLT DL
|
||||
ORB DL, DH
|
||||
// }
|
||||
skip_fill{{ var "id" }}:
|
||||
|
||||
// val0 := br{{ var "id"}}.peekTopBits(peekBits)
|
||||
MOVQ br_value, AX
|
||||
MOVQ peek_bits, CX
|
||||
SHRQ CL, AX // AX = (value >> peek_bits) & mask
|
||||
|
||||
// v0 := table[val0&mask]
|
||||
MOVW 0(table)(AX*2), AX // AX - v0
|
||||
|
||||
// br{{ var "id"}}.advance(uint8(v0.entry))
|
||||
MOVB AH, BL // BL = uint8(v0.entry >> 8)
|
||||
MOVBQZX AL, CX
|
||||
SHLQ CL, br_value // value <<= n
|
||||
ADDQ CX, br_bits_read // bits_read += n
|
||||
|
||||
// val1 := br{{ var "id"}}.peekTopBits(peekBits)
|
||||
MOVQ peek_bits, CX
|
||||
MOVQ br_value, AX
|
||||
SHRQ CL, AX // AX = (value >> peek_bits) & mask
|
||||
|
||||
// v1 := table[val1&mask]
|
||||
MOVW 0(table)(AX*2), AX // AX - v1
|
||||
|
||||
// br{{ var "id"}}.advance(uint8(v1.entry))
|
||||
MOVB AH, BH // BH = uint8(v1.entry >> 8)
|
||||
MOVBQZX AL, CX
|
||||
SHLQ CX, br_value // value <<= n
|
||||
ADDQ CX, br_bits_read // bits_read += n
|
||||
|
||||
|
||||
// these two writes get coalesced
|
||||
// buf[stream][off] = uint8(v0.entry >> 8)
|
||||
// buf[stream][off+1] = uint8(v1.entry >> 8)
|
||||
MOVW BX, {{ var "bufofs" }}(buffer)(off*1)
|
||||
|
||||
// SECOND PART:
|
||||
// val2 := br{{ var "id"}}.peekTopBits(peekBits)
|
||||
MOVQ br_value, AX
|
||||
MOVQ peek_bits, CX
|
||||
SHRQ CL, AX // AX = (value >> peek_bits) & mask
|
||||
|
||||
// v2 := table[val0&mask]
|
||||
MOVW 0(table)(AX*2), AX // AX - v0
|
||||
|
||||
// br{{ var "id"}}.advance(uint8(v0.entry))
|
||||
MOVB AH, BL // BL = uint8(v0.entry >> 8)
|
||||
MOVBQZX AL, CX
|
||||
SHLQ CL, br_value // value <<= n
|
||||
ADDQ CX, br_bits_read // bits_read += n
|
||||
|
||||
// val3 := br{{ var "id"}}.peekTopBits(peekBits)
|
||||
MOVQ peek_bits, CX
|
||||
MOVQ br_value, AX
|
||||
SHRQ CL, AX // AX = (value >> peek_bits) & mask
|
||||
|
||||
// v3 := table[val1&mask]
|
||||
MOVW 0(table)(AX*2), AX // AX - v1
|
||||
|
||||
// br{{ var "id"}}.advance(uint8(v1.entry))
|
||||
MOVB AH, BH // BH = uint8(v1.entry >> 8)
|
||||
MOVBQZX AL, CX
|
||||
SHLQ CX, br_value // value <<= n
|
||||
ADDQ CX, br_bits_read // bits_read += n
|
||||
|
||||
|
||||
// these two writes get coalesced
|
||||
// buf[stream][off+2] = uint8(v2.entry >> 8)
|
||||
// buf[stream][off+3] = uint8(v3.entry >> 8)
|
||||
MOVW BX, {{ var "bufofs" }}+2(buffer)(off*1)
|
||||
|
||||
// update the bitrader reader structure
|
||||
MOVB br_bits_read, bitReaderShifted_bitsRead(br{{ var "id" }})
|
||||
MOVQ br_value, bitReaderShifted_value(br{{ var "id" }})
|
||||
MOVQ br_offset, bitReaderShifted_off(br{{ var "id" }})
|
||||
{{ end }}
|
||||
|
||||
{{ set "id" "0" }}
|
||||
{{ set "ofs" "0" }}
|
||||
{{ set "bufofs" "0" }} {{/* id * bufoff */}}
|
||||
{{ template "decode_2_values_x86" . }}
|
||||
|
||||
{{ set "id" "1" }}
|
||||
{{ set "ofs" "8" }}
|
||||
{{ set "bufofs" "256" }}
|
||||
{{ template "decode_2_values_x86" . }}
|
||||
|
||||
{{ set "id" "2" }}
|
||||
{{ set "ofs" "16" }}
|
||||
{{ set "bufofs" "512" }}
|
||||
{{ template "decode_2_values_x86" . }}
|
||||
|
||||
{{ set "id" "3" }}
|
||||
{{ set "ofs" "24" }}
|
||||
{{ set "bufofs" "768" }}
|
||||
{{ template "decode_2_values_x86" . }}
|
||||
|
||||
ADDQ $4, off // off += 2
|
||||
|
||||
TESTB DH, DH // any br[i].ofs < 4?
|
||||
JNZ end
|
||||
|
||||
CMPQ off, $bufoff
|
||||
JL main_loop
|
||||
end:
|
||||
MOVQ 0(SP), BP
|
||||
|
||||
MOVB off, ret+56(FP)
|
||||
RET
|
||||
#undef off
|
||||
#undef buffer
|
||||
#undef table
|
||||
|
||||
#undef br_bits_read
|
||||
#undef br_value
|
||||
#undef br_offset
|
||||
#undef peek_bits
|
||||
#undef exhausted
|
||||
|
||||
#undef br0
|
||||
#undef br1
|
||||
#undef br2
|
||||
#undef br3
|
|
@ -2,30 +2,42 @@
|
|||
// +build amd64,!appengine,!noasm,gc
|
||||
|
||||
// This file contains the specialisation of Decoder.Decompress4X
|
||||
// that uses an asm implementation of its main loop.
|
||||
// and Decoder.Decompress1X that use an asm implementation of thir main loops.
|
||||
package huff0
|
||||
|
||||
import (
|
||||
"errors"
|
||||
"fmt"
|
||||
|
||||
"github.com/klauspost/compress/internal/cpuinfo"
|
||||
)
|
||||
|
||||
// decompress4x_main_loop_x86 is an x86 assembler implementation
|
||||
// of Decompress4X when tablelog > 8.
|
||||
// go:noescape
|
||||
func decompress4x_main_loop_x86(pbr0, pbr1, pbr2, pbr3 *bitReaderShifted,
|
||||
peekBits uint8, buf *byte, tbl *dEntrySingle) uint8
|
||||
//
|
||||
//go:noescape
|
||||
func decompress4x_main_loop_amd64(ctx *decompress4xContext)
|
||||
|
||||
// decompress4x_8b_loop_x86 is an x86 assembler implementation
|
||||
// of Decompress4X when tablelog <= 8 which decodes 4 entries
|
||||
// per loop.
|
||||
// go:noescape
|
||||
func decompress4x_8b_loop_x86(pbr0, pbr1, pbr2, pbr3 *bitReaderShifted,
|
||||
peekBits uint8, buf *byte, tbl *dEntrySingle) uint8
|
||||
//
|
||||
//go:noescape
|
||||
func decompress4x_8b_main_loop_amd64(ctx *decompress4xContext)
|
||||
|
||||
// fallback8BitSize is the size where using Go version is faster.
|
||||
const fallback8BitSize = 800
|
||||
|
||||
type decompress4xContext struct {
|
||||
pbr *[4]bitReaderShifted
|
||||
peekBits uint8
|
||||
out *byte
|
||||
dstEvery int
|
||||
tbl *dEntrySingle
|
||||
decoded int
|
||||
limit *byte
|
||||
}
|
||||
|
||||
// Decompress4X will decompress a 4X encoded stream.
|
||||
// The length of the supplied input must match the end of a block exactly.
|
||||
// The *capacity* of the dst slice must match the destination size of
|
||||
|
@ -42,6 +54,7 @@ func (d *Decoder) Decompress4X(dst, src []byte) ([]byte, error) {
|
|||
if cap(dst) < fallback8BitSize && use8BitTables {
|
||||
return d.decompress4X8bit(dst, src)
|
||||
}
|
||||
|
||||
var br [4]bitReaderShifted
|
||||
// Decode "jump table"
|
||||
start := 6
|
||||
|
@ -71,70 +84,25 @@ func (d *Decoder) Decompress4X(dst, src []byte) ([]byte, error) {
|
|||
const tlMask = tlSize - 1
|
||||
single := d.dt.single[:tlSize]
|
||||
|
||||
// Use temp table to avoid bound checks/append penalty.
|
||||
buf := d.buffer()
|
||||
var off uint8
|
||||
var decoded int
|
||||
|
||||
const debug = false
|
||||
|
||||
// see: bitReaderShifted.peekBitsFast()
|
||||
peekBits := uint8((64 - d.actualTableLog) & 63)
|
||||
|
||||
// Decode 2 values from each decoder/loop.
|
||||
const bufoff = 256
|
||||
for {
|
||||
if br[0].off < 4 || br[1].off < 4 || br[2].off < 4 || br[3].off < 4 {
|
||||
break
|
||||
if len(out) > 4*4 && !(br[0].off < 4 || br[1].off < 4 || br[2].off < 4 || br[3].off < 4) {
|
||||
ctx := decompress4xContext{
|
||||
pbr: &br,
|
||||
peekBits: uint8((64 - d.actualTableLog) & 63), // see: bitReaderShifted.peekBitsFast()
|
||||
out: &out[0],
|
||||
dstEvery: dstEvery,
|
||||
tbl: &single[0],
|
||||
limit: &out[dstEvery-4], // Always stop decoding when first buffer gets here to avoid writing OOB on last.
|
||||
}
|
||||
|
||||
if use8BitTables {
|
||||
off = decompress4x_8b_loop_x86(&br[0], &br[1], &br[2], &br[3], peekBits, &buf[0][0], &single[0])
|
||||
decompress4x_8b_main_loop_amd64(&ctx)
|
||||
} else {
|
||||
off = decompress4x_main_loop_x86(&br[0], &br[1], &br[2], &br[3], peekBits, &buf[0][0], &single[0])
|
||||
}
|
||||
if debug {
|
||||
fmt.Print("DEBUG: ")
|
||||
fmt.Printf("off=%d,", off)
|
||||
for i := 0; i < 4; i++ {
|
||||
fmt.Printf(" br[%d]={bitsRead=%d, value=%x, off=%d}",
|
||||
i, br[i].bitsRead, br[i].value, br[i].off)
|
||||
}
|
||||
fmt.Println("")
|
||||
decompress4x_main_loop_amd64(&ctx)
|
||||
}
|
||||
|
||||
if off != 0 {
|
||||
break
|
||||
}
|
||||
|
||||
if bufoff > dstEvery {
|
||||
d.bufs.Put(buf)
|
||||
return nil, errors.New("corruption detected: stream overrun 1")
|
||||
}
|
||||
copy(out, buf[0][:])
|
||||
copy(out[dstEvery:], buf[1][:])
|
||||
copy(out[dstEvery*2:], buf[2][:])
|
||||
copy(out[dstEvery*3:], buf[3][:])
|
||||
out = out[bufoff:]
|
||||
decoded += bufoff * 4
|
||||
// There must at least be 3 buffers left.
|
||||
if len(out) < dstEvery*3 {
|
||||
d.bufs.Put(buf)
|
||||
return nil, errors.New("corruption detected: stream overrun 2")
|
||||
}
|
||||
}
|
||||
if off > 0 {
|
||||
ioff := int(off)
|
||||
if len(out) < dstEvery*3+ioff {
|
||||
d.bufs.Put(buf)
|
||||
return nil, errors.New("corruption detected: stream overrun 3")
|
||||
}
|
||||
copy(out, buf[0][:off])
|
||||
copy(out[dstEvery:], buf[1][:off])
|
||||
copy(out[dstEvery*2:], buf[2][:off])
|
||||
copy(out[dstEvery*3:], buf[3][:off])
|
||||
decoded += int(off) * 4
|
||||
out = out[off:]
|
||||
decoded = ctx.decoded
|
||||
out = out[decoded/4:]
|
||||
}
|
||||
|
||||
// Decode remaining.
|
||||
|
@ -150,7 +118,6 @@ func (d *Decoder) Decompress4X(dst, src []byte) ([]byte, error) {
|
|||
for bitsLeft > 0 {
|
||||
br.fill()
|
||||
if offset >= endsAt {
|
||||
d.bufs.Put(buf)
|
||||
return nil, errors.New("corruption detected: stream overrun 4")
|
||||
}
|
||||
|
||||
|
@ -164,7 +131,6 @@ func (d *Decoder) Decompress4X(dst, src []byte) ([]byte, error) {
|
|||
offset++
|
||||
}
|
||||
if offset != endsAt {
|
||||
d.bufs.Put(buf)
|
||||
return nil, fmt.Errorf("corruption detected: short output block %d, end %d != %d", i, offset, endsAt)
|
||||
}
|
||||
decoded += offset - dstEvery*i
|
||||
|
@ -173,9 +139,88 @@ func (d *Decoder) Decompress4X(dst, src []byte) ([]byte, error) {
|
|||
return nil, err
|
||||
}
|
||||
}
|
||||
d.bufs.Put(buf)
|
||||
if dstSize != decoded {
|
||||
return nil, errors.New("corruption detected: short output block")
|
||||
}
|
||||
return dst, nil
|
||||
}
|
||||
|
||||
// decompress4x_main_loop_x86 is an x86 assembler implementation
|
||||
// of Decompress1X when tablelog > 8.
|
||||
//
|
||||
//go:noescape
|
||||
func decompress1x_main_loop_amd64(ctx *decompress1xContext)
|
||||
|
||||
// decompress4x_main_loop_x86 is an x86 with BMI2 assembler implementation
|
||||
// of Decompress1X when tablelog > 8.
|
||||
//
|
||||
//go:noescape
|
||||
func decompress1x_main_loop_bmi2(ctx *decompress1xContext)
|
||||
|
||||
type decompress1xContext struct {
|
||||
pbr *bitReaderShifted
|
||||
peekBits uint8
|
||||
out *byte
|
||||
outCap int
|
||||
tbl *dEntrySingle
|
||||
decoded int
|
||||
}
|
||||
|
||||
// Error reported by asm implementations
|
||||
const error_max_decoded_size_exeeded = -1
|
||||
|
||||
// Decompress1X will decompress a 1X encoded stream.
|
||||
// The cap of the output buffer will be the maximum decompressed size.
|
||||
// The length of the supplied input must match the end of a block exactly.
|
||||
func (d *Decoder) Decompress1X(dst, src []byte) ([]byte, error) {
|
||||
if len(d.dt.single) == 0 {
|
||||
return nil, errors.New("no table loaded")
|
||||
}
|
||||
var br bitReaderShifted
|
||||
err := br.init(src)
|
||||
if err != nil {
|
||||
return dst, err
|
||||
}
|
||||
maxDecodedSize := cap(dst)
|
||||
dst = dst[:maxDecodedSize]
|
||||
|
||||
const tlSize = 1 << tableLogMax
|
||||
const tlMask = tlSize - 1
|
||||
|
||||
if maxDecodedSize >= 4 {
|
||||
ctx := decompress1xContext{
|
||||
pbr: &br,
|
||||
out: &dst[0],
|
||||
outCap: maxDecodedSize,
|
||||
peekBits: uint8((64 - d.actualTableLog) & 63), // see: bitReaderShifted.peekBitsFast()
|
||||
tbl: &d.dt.single[0],
|
||||
}
|
||||
|
||||
if cpuinfo.HasBMI2() {
|
||||
decompress1x_main_loop_bmi2(&ctx)
|
||||
} else {
|
||||
decompress1x_main_loop_amd64(&ctx)
|
||||
}
|
||||
if ctx.decoded == error_max_decoded_size_exeeded {
|
||||
return nil, ErrMaxDecodedSizeExceeded
|
||||
}
|
||||
|
||||
dst = dst[:ctx.decoded]
|
||||
}
|
||||
|
||||
// br < 8, so uint8 is fine
|
||||
bitsLeft := uint8(br.off)*8 + 64 - br.bitsRead
|
||||
for bitsLeft > 0 {
|
||||
br.fill()
|
||||
if len(dst) >= maxDecodedSize {
|
||||
br.close()
|
||||
return nil, ErrMaxDecodedSizeExceeded
|
||||
}
|
||||
v := d.dt.single[br.peekBitsFast(d.actualTableLog)&tlMask]
|
||||
nBits := uint8(v.entry)
|
||||
br.advance(nBits)
|
||||
bitsLeft -= nBits
|
||||
dst = append(dst, uint8(v.entry>>8))
|
||||
}
|
||||
return dst, br.close()
|
||||
}
|
||||
|
|
File diff suppressed because it is too large
Load Diff
|
@ -1,195 +0,0 @@
|
|||
// +build !appengine
|
||||
// +build gc
|
||||
// +build !noasm
|
||||
|
||||
#include "textflag.h"
|
||||
#include "funcdata.h"
|
||||
#include "go_asm.h"
|
||||
|
||||
#ifdef GOAMD64_v4
|
||||
#ifndef GOAMD64_v3
|
||||
#define GOAMD64_v3
|
||||
#endif
|
||||
#endif
|
||||
|
||||
#define bufoff 256 // see decompress.go, we're using [4][256]byte table
|
||||
|
||||
//func decompress4x_main_loop_x86(pbr0, pbr1, pbr2, pbr3 *bitReaderShifted,
|
||||
// peekBits uint8, buf *byte, tbl *dEntrySingle) (int, bool)
|
||||
TEXT ·decompress4x_main_loop_x86(SB), NOSPLIT, $8
|
||||
#define off R8
|
||||
#define buffer DI
|
||||
#define table SI
|
||||
|
||||
#define br_bits_read R9
|
||||
#define br_value R10
|
||||
#define br_offset R11
|
||||
#define peek_bits R12
|
||||
#define exhausted DX
|
||||
|
||||
#define br0 R13
|
||||
#define br1 R14
|
||||
#define br2 R15
|
||||
#define br3 BP
|
||||
|
||||
MOVQ BP, 0(SP)
|
||||
|
||||
XORQ exhausted, exhausted // exhausted = false
|
||||
XORQ off, off // off = 0
|
||||
|
||||
MOVBQZX peekBits+32(FP), peek_bits
|
||||
MOVQ buf+40(FP), buffer
|
||||
MOVQ tbl+48(FP), table
|
||||
|
||||
MOVQ pbr0+0(FP), br0
|
||||
MOVQ pbr1+8(FP), br1
|
||||
MOVQ pbr2+16(FP), br2
|
||||
MOVQ pbr3+24(FP), br3
|
||||
|
||||
main_loop:
|
||||
{{ define "decode_2_values_x86" }}
|
||||
// const stream = {{ var "id" }}
|
||||
// br{{ var "id"}}.fillFast()
|
||||
MOVBQZX bitReaderShifted_bitsRead(br{{ var "id" }}), br_bits_read
|
||||
MOVQ bitReaderShifted_value(br{{ var "id" }}), br_value
|
||||
MOVQ bitReaderShifted_off(br{{ var "id" }}), br_offset
|
||||
|
||||
// We must have at least 2 * max tablelog left
|
||||
CMPQ br_bits_read, $64-22
|
||||
JBE skip_fill{{ var "id" }}
|
||||
|
||||
SUBQ $32, br_bits_read // b.bitsRead -= 32
|
||||
SUBQ $4, br_offset // b.off -= 4
|
||||
|
||||
// v := b.in[b.off-4 : b.off]
|
||||
// v = v[:4]
|
||||
// low := (uint32(v[0])) | (uint32(v[1]) << 8) | (uint32(v[2]) << 16) | (uint32(v[3]) << 24)
|
||||
MOVQ bitReaderShifted_in(br{{ var "id" }}), AX
|
||||
|
||||
// b.value |= uint64(low) << (b.bitsRead & 63)
|
||||
#ifdef GOAMD64_v3
|
||||
SHLXQ br_bits_read, 0(br_offset)(AX*1), AX // AX = uint32(b.in[b.off:b.off+4]) << (b.bitsRead & 63)
|
||||
#else
|
||||
MOVL 0(br_offset)(AX*1), AX // AX = uint32(b.in[b.off:b.off+4])
|
||||
MOVQ br_bits_read, CX
|
||||
SHLQ CL, AX
|
||||
#endif
|
||||
|
||||
ORQ AX, br_value
|
||||
|
||||
// exhausted = exhausted || (br{{ var "id"}}.off < 4)
|
||||
CMPQ br_offset, $4
|
||||
SETLT DL
|
||||
ORB DL, DH
|
||||
// }
|
||||
skip_fill{{ var "id" }}:
|
||||
|
||||
// val0 := br{{ var "id"}}.peekTopBits(peekBits)
|
||||
#ifdef GOAMD64_v3
|
||||
SHRXQ peek_bits, br_value, AX // AX = (value >> peek_bits) & mask
|
||||
#else
|
||||
MOVQ br_value, AX
|
||||
MOVQ peek_bits, CX
|
||||
SHRQ CL, AX // AX = (value >> peek_bits) & mask
|
||||
#endif
|
||||
|
||||
// v0 := table[val0&mask]
|
||||
MOVW 0(table)(AX*2), AX // AX - v0
|
||||
|
||||
// br{{ var "id"}}.advance(uint8(v0.entry))
|
||||
MOVB AH, BL // BL = uint8(v0.entry >> 8)
|
||||
|
||||
#ifdef GOAMD64_v3
|
||||
MOVBQZX AL, CX
|
||||
SHLXQ AX, br_value, br_value // value <<= n
|
||||
#else
|
||||
MOVBQZX AL, CX
|
||||
SHLQ CL, br_value // value <<= n
|
||||
#endif
|
||||
|
||||
ADDQ CX, br_bits_read // bits_read += n
|
||||
|
||||
|
||||
#ifdef GOAMD64_v3
|
||||
SHRXQ peek_bits, br_value, AX // AX = (value >> peek_bits) & mask
|
||||
#else
|
||||
// val1 := br{{ var "id"}}.peekTopBits(peekBits)
|
||||
MOVQ peek_bits, CX
|
||||
MOVQ br_value, AX
|
||||
SHRQ CL, AX // AX = (value >> peek_bits) & mask
|
||||
#endif
|
||||
|
||||
// v1 := table[val1&mask]
|
||||
MOVW 0(table)(AX*2), AX // AX - v1
|
||||
|
||||
// br{{ var "id"}}.advance(uint8(v1.entry))
|
||||
MOVB AH, BH // BH = uint8(v1.entry >> 8)
|
||||
|
||||
#ifdef GOAMD64_v3
|
||||
MOVBQZX AL, CX
|
||||
SHLXQ AX, br_value, br_value // value <<= n
|
||||
#else
|
||||
MOVBQZX AL, CX
|
||||
SHLQ CL, br_value // value <<= n
|
||||
#endif
|
||||
|
||||
ADDQ CX, br_bits_read // bits_read += n
|
||||
|
||||
|
||||
// these two writes get coalesced
|
||||
// buf[stream][off] = uint8(v0.entry >> 8)
|
||||
// buf[stream][off+1] = uint8(v1.entry >> 8)
|
||||
MOVW BX, {{ var "bufofs" }}(buffer)(off*1)
|
||||
|
||||
// update the bitrader reader structure
|
||||
MOVB br_bits_read, bitReaderShifted_bitsRead(br{{ var "id" }})
|
||||
MOVQ br_value, bitReaderShifted_value(br{{ var "id" }})
|
||||
MOVQ br_offset, bitReaderShifted_off(br{{ var "id" }})
|
||||
{{ end }}
|
||||
|
||||
{{ set "id" "0" }}
|
||||
{{ set "ofs" "0" }}
|
||||
{{ set "bufofs" "0" }} {{/* id * bufoff */}}
|
||||
{{ template "decode_2_values_x86" . }}
|
||||
|
||||
{{ set "id" "1" }}
|
||||
{{ set "ofs" "8" }}
|
||||
{{ set "bufofs" "256" }}
|
||||
{{ template "decode_2_values_x86" . }}
|
||||
|
||||
{{ set "id" "2" }}
|
||||
{{ set "ofs" "16" }}
|
||||
{{ set "bufofs" "512" }}
|
||||
{{ template "decode_2_values_x86" . }}
|
||||
|
||||
{{ set "id" "3" }}
|
||||
{{ set "ofs" "24" }}
|
||||
{{ set "bufofs" "768" }}
|
||||
{{ template "decode_2_values_x86" . }}
|
||||
|
||||
ADDQ $2, off // off += 2
|
||||
|
||||
TESTB DH, DH // any br[i].ofs < 4?
|
||||
JNZ end
|
||||
|
||||
CMPQ off, $bufoff
|
||||
JL main_loop
|
||||
end:
|
||||
MOVQ 0(SP), BP
|
||||
|
||||
MOVB off, ret+56(FP)
|
||||
RET
|
||||
#undef off
|
||||
#undef buffer
|
||||
#undef table
|
||||
|
||||
#undef br_bits_read
|
||||
#undef br_value
|
||||
#undef br_offset
|
||||
#undef peek_bits
|
||||
#undef exhausted
|
||||
|
||||
#undef br0
|
||||
#undef br1
|
||||
#undef br2
|
||||
#undef br3
|
|
@ -122,17 +122,21 @@ func (d *Decoder) Decompress4X(dst, src []byte) ([]byte, error) {
|
|||
d.bufs.Put(buf)
|
||||
return nil, errors.New("corruption detected: stream overrun 1")
|
||||
}
|
||||
copy(out, buf[0][:])
|
||||
copy(out[dstEvery:], buf[1][:])
|
||||
copy(out[dstEvery*2:], buf[2][:])
|
||||
copy(out[dstEvery*3:], buf[3][:])
|
||||
out = out[bufoff:]
|
||||
decoded += bufoff * 4
|
||||
// There must at least be 3 buffers left.
|
||||
if len(out) < dstEvery*3 {
|
||||
if len(out)-bufoff < dstEvery*3 {
|
||||
d.bufs.Put(buf)
|
||||
return nil, errors.New("corruption detected: stream overrun 2")
|
||||
}
|
||||
//copy(out, buf[0][:])
|
||||
//copy(out[dstEvery:], buf[1][:])
|
||||
//copy(out[dstEvery*2:], buf[2][:])
|
||||
//copy(out[dstEvery*3:], buf[3][:])
|
||||
*(*[bufoff]byte)(out) = buf[0]
|
||||
*(*[bufoff]byte)(out[dstEvery:]) = buf[1]
|
||||
*(*[bufoff]byte)(out[dstEvery*2:]) = buf[2]
|
||||
*(*[bufoff]byte)(out[dstEvery*3:]) = buf[3]
|
||||
out = out[bufoff:]
|
||||
decoded += bufoff * 4
|
||||
}
|
||||
}
|
||||
if off > 0 {
|
||||
|
@ -191,3 +195,105 @@ func (d *Decoder) Decompress4X(dst, src []byte) ([]byte, error) {
|
|||
}
|
||||
return dst, nil
|
||||
}
|
||||
|
||||
// Decompress1X will decompress a 1X encoded stream.
|
||||
// The cap of the output buffer will be the maximum decompressed size.
|
||||
// The length of the supplied input must match the end of a block exactly.
|
||||
func (d *Decoder) Decompress1X(dst, src []byte) ([]byte, error) {
|
||||
if len(d.dt.single) == 0 {
|
||||
return nil, errors.New("no table loaded")
|
||||
}
|
||||
if use8BitTables && d.actualTableLog <= 8 {
|
||||
return d.decompress1X8Bit(dst, src)
|
||||
}
|
||||
var br bitReaderShifted
|
||||
err := br.init(src)
|
||||
if err != nil {
|
||||
return dst, err
|
||||
}
|
||||
maxDecodedSize := cap(dst)
|
||||
dst = dst[:0]
|
||||
|
||||
// Avoid bounds check by always having full sized table.
|
||||
const tlSize = 1 << tableLogMax
|
||||
const tlMask = tlSize - 1
|
||||
dt := d.dt.single[:tlSize]
|
||||
|
||||
// Use temp table to avoid bound checks/append penalty.
|
||||
bufs := d.buffer()
|
||||
buf := &bufs[0]
|
||||
var off uint8
|
||||
|
||||
for br.off >= 8 {
|
||||
br.fillFast()
|
||||
v := dt[br.peekBitsFast(d.actualTableLog)&tlMask]
|
||||
br.advance(uint8(v.entry))
|
||||
buf[off+0] = uint8(v.entry >> 8)
|
||||
|
||||
v = dt[br.peekBitsFast(d.actualTableLog)&tlMask]
|
||||
br.advance(uint8(v.entry))
|
||||
buf[off+1] = uint8(v.entry >> 8)
|
||||
|
||||
// Refill
|
||||
br.fillFast()
|
||||
|
||||
v = dt[br.peekBitsFast(d.actualTableLog)&tlMask]
|
||||
br.advance(uint8(v.entry))
|
||||
buf[off+2] = uint8(v.entry >> 8)
|
||||
|
||||
v = dt[br.peekBitsFast(d.actualTableLog)&tlMask]
|
||||
br.advance(uint8(v.entry))
|
||||
buf[off+3] = uint8(v.entry >> 8)
|
||||
|
||||
off += 4
|
||||
if off == 0 {
|
||||
if len(dst)+256 > maxDecodedSize {
|
||||
br.close()
|
||||
d.bufs.Put(bufs)
|
||||
return nil, ErrMaxDecodedSizeExceeded
|
||||
}
|
||||
dst = append(dst, buf[:]...)
|
||||
}
|
||||
}
|
||||
|
||||
if len(dst)+int(off) > maxDecodedSize {
|
||||
d.bufs.Put(bufs)
|
||||
br.close()
|
||||
return nil, ErrMaxDecodedSizeExceeded
|
||||
}
|
||||
dst = append(dst, buf[:off]...)
|
||||
|
||||
// br < 8, so uint8 is fine
|
||||
bitsLeft := uint8(br.off)*8 + 64 - br.bitsRead
|
||||
for bitsLeft > 0 {
|
||||
br.fill()
|
||||
if false && br.bitsRead >= 32 {
|
||||
if br.off >= 4 {
|
||||
v := br.in[br.off-4:]
|
||||
v = v[:4]
|
||||
low := (uint32(v[0])) | (uint32(v[1]) << 8) | (uint32(v[2]) << 16) | (uint32(v[3]) << 24)
|
||||
br.value = (br.value << 32) | uint64(low)
|
||||
br.bitsRead -= 32
|
||||
br.off -= 4
|
||||
} else {
|
||||
for br.off > 0 {
|
||||
br.value = (br.value << 8) | uint64(br.in[br.off-1])
|
||||
br.bitsRead -= 8
|
||||
br.off--
|
||||
}
|
||||
}
|
||||
}
|
||||
if len(dst) >= maxDecodedSize {
|
||||
d.bufs.Put(bufs)
|
||||
br.close()
|
||||
return nil, ErrMaxDecodedSizeExceeded
|
||||
}
|
||||
v := d.dt.single[br.peekBitsFast(d.actualTableLog)&tlMask]
|
||||
nBits := uint8(v.entry)
|
||||
br.advance(nBits)
|
||||
bitsLeft -= nBits
|
||||
dst = append(dst, uint8(v.entry>>8))
|
||||
}
|
||||
d.bufs.Put(bufs)
|
||||
return dst, br.close()
|
||||
}
|
||||
|
|
|
@ -0,0 +1,34 @@
|
|||
// Package cpuinfo gives runtime info about the current CPU.
|
||||
//
|
||||
// This is a very limited module meant for use internally
|
||||
// in this project. For more versatile solution check
|
||||
// https://github.com/klauspost/cpuid.
|
||||
package cpuinfo
|
||||
|
||||
// HasBMI1 checks whether an x86 CPU supports the BMI1 extension.
|
||||
func HasBMI1() bool {
|
||||
return hasBMI1
|
||||
}
|
||||
|
||||
// HasBMI2 checks whether an x86 CPU supports the BMI2 extension.
|
||||
func HasBMI2() bool {
|
||||
return hasBMI2
|
||||
}
|
||||
|
||||
// DisableBMI2 will disable BMI2, for testing purposes.
|
||||
// Call returned function to restore previous state.
|
||||
func DisableBMI2() func() {
|
||||
old := hasBMI2
|
||||
hasBMI2 = false
|
||||
return func() {
|
||||
hasBMI2 = old
|
||||
}
|
||||
}
|
||||
|
||||
// HasBMI checks whether an x86 CPU supports both BMI1 and BMI2 extensions.
|
||||
func HasBMI() bool {
|
||||
return HasBMI1() && HasBMI2()
|
||||
}
|
||||
|
||||
var hasBMI1 bool
|
||||
var hasBMI2 bool
|
11
vendor/github.com/klauspost/compress/internal/cpuinfo/cpuinfo_amd64.go
generated
vendored
Normal file
11
vendor/github.com/klauspost/compress/internal/cpuinfo/cpuinfo_amd64.go
generated
vendored
Normal file
|
@ -0,0 +1,11 @@
|
|||
//go:build amd64 && !appengine && !noasm && gc
|
||||
// +build amd64,!appengine,!noasm,gc
|
||||
|
||||
package cpuinfo
|
||||
|
||||
// go:noescape
|
||||
func x86extensions() (bmi1, bmi2 bool)
|
||||
|
||||
func init() {
|
||||
hasBMI1, hasBMI2 = x86extensions()
|
||||
}
|
36
vendor/github.com/klauspost/compress/internal/cpuinfo/cpuinfo_amd64.s
generated
vendored
Normal file
36
vendor/github.com/klauspost/compress/internal/cpuinfo/cpuinfo_amd64.s
generated
vendored
Normal file
|
@ -0,0 +1,36 @@
|
|||
// +build !appengine
|
||||
// +build gc
|
||||
// +build !noasm
|
||||
|
||||
#include "textflag.h"
|
||||
#include "funcdata.h"
|
||||
#include "go_asm.h"
|
||||
|
||||
TEXT ·x86extensions(SB), NOSPLIT, $0
|
||||
// 1. determine max EAX value
|
||||
XORQ AX, AX
|
||||
CPUID
|
||||
|
||||
CMPQ AX, $7
|
||||
JB unsupported
|
||||
|
||||
// 2. EAX = 7, ECX = 0 --- see Table 3-8 "Information Returned by CPUID Instruction"
|
||||
MOVQ $7, AX
|
||||
MOVQ $0, CX
|
||||
CPUID
|
||||
|
||||
BTQ $3, BX // bit 3 = BMI1
|
||||
SETCS AL
|
||||
|
||||
BTQ $8, BX // bit 8 = BMI2
|
||||
SETCS AH
|
||||
|
||||
MOVB AL, bmi1+0(FP)
|
||||
MOVB AH, bmi2+1(FP)
|
||||
RET
|
||||
|
||||
unsupported:
|
||||
XORQ AX, AX
|
||||
MOVB AL, bmi1+0(FP)
|
||||
MOVB AL, bmi2+1(FP)
|
||||
RET
|
|
@ -18,6 +18,7 @@ func load64(b []byte, i int) uint64 {
|
|||
// emitLiteral writes a literal chunk and returns the number of bytes written.
|
||||
//
|
||||
// It assumes that:
|
||||
//
|
||||
// dst is long enough to hold the encoded bytes
|
||||
// 1 <= len(lit) && len(lit) <= 65536
|
||||
func emitLiteral(dst, lit []byte) int {
|
||||
|
@ -42,6 +43,7 @@ func emitLiteral(dst, lit []byte) int {
|
|||
// emitCopy writes a copy chunk and returns the number of bytes written.
|
||||
//
|
||||
// It assumes that:
|
||||
//
|
||||
// dst is long enough to hold the encoded bytes
|
||||
// 1 <= offset && offset <= 65535
|
||||
// 4 <= length && length <= 65535
|
||||
|
@ -89,6 +91,7 @@ func emitCopy(dst []byte, offset, length int) int {
|
|||
// src[i:i+k-j] and src[j:k] have the same contents.
|
||||
//
|
||||
// It assumes that:
|
||||
//
|
||||
// 0 <= i && i < j && j <= len(src)
|
||||
func extendMatch(src []byte, i, j int) int {
|
||||
for ; j < len(src) && src[i] == src[j]; i, j = i+1, j+1 {
|
||||
|
@ -105,8 +108,9 @@ func hash(u, shift uint32) uint32 {
|
|||
// been written.
|
||||
//
|
||||
// It also assumes that:
|
||||
//
|
||||
// len(dst) >= MaxEncodedLen(len(src)) &&
|
||||
// minNonLiteralBlockSize <= len(src) && len(src) <= maxBlockSize
|
||||
// minNonLiteralBlockSize <= len(src) && len(src) <= maxBlockSize
|
||||
func encodeBlock(dst, src []byte) (d int) {
|
||||
// Initialize the hash table. Its size ranges from 1<<8 to 1<<14 inclusive.
|
||||
// The table element type is uint16, as s < sLimit and sLimit < len(src)
|
||||
|
|
|
@ -12,6 +12,8 @@ The `zstd` package is provided as open source software using a Go standard licen
|
|||
|
||||
Currently the package is heavily optimized for 64 bit processors and will be significantly slower on 32 bit processors.
|
||||
|
||||
For seekable zstd streams, see [this excellent package](https://github.com/SaveTheRbtz/zstd-seekable-format-go).
|
||||
|
||||
## Installation
|
||||
|
||||
Install using `go get -u github.com/klauspost/compress`. The package is located in `github.com/klauspost/compress/zstd`.
|
||||
|
@ -386,47 +388,31 @@ In practice this means that concurrency is often limited to utilizing about 3 co
|
|||
|
||||
### Benchmarks
|
||||
|
||||
These are some examples of performance compared to [datadog cgo library](https://github.com/DataDog/zstd).
|
||||
|
||||
The first two are streaming decodes and the last are smaller inputs.
|
||||
|
||||
|
||||
Running on AMD Ryzen 9 3950X 16-Core Processor. AMD64 assembly used.
|
||||
|
||||
```
|
||||
BenchmarkDecoderSilesia-8 3 385000067 ns/op 550.51 MB/s 5498 B/op 8 allocs/op
|
||||
BenchmarkDecoderSilesiaCgo-8 6 197666567 ns/op 1072.25 MB/s 270672 B/op 8 allocs/op
|
||||
BenchmarkDecoderSilesia-32 5 206878840 ns/op 1024.50 MB/s 49808 B/op 43 allocs/op
|
||||
BenchmarkDecoderEnwik9-32 1 1271809000 ns/op 786.28 MB/s 72048 B/op 52 allocs/op
|
||||
|
||||
BenchmarkDecoderEnwik9-8 1 2027001600 ns/op 493.34 MB/s 10496 B/op 18 allocs/op
|
||||
BenchmarkDecoderEnwik9Cgo-8 2 979499200 ns/op 1020.93 MB/s 270672 B/op 8 allocs/op
|
||||
Concurrent blocks, performance:
|
||||
|
||||
Concurrent performance:
|
||||
|
||||
BenchmarkDecoder_DecodeAllParallel/kppkn.gtb.zst-16 28915 42469 ns/op 4340.07 MB/s 114 B/op 0 allocs/op
|
||||
BenchmarkDecoder_DecodeAllParallel/geo.protodata.zst-16 116505 9965 ns/op 11900.16 MB/s 16 B/op 0 allocs/op
|
||||
BenchmarkDecoder_DecodeAllParallel/plrabn12.txt.zst-16 8952 134272 ns/op 3588.70 MB/s 915 B/op 0 allocs/op
|
||||
BenchmarkDecoder_DecodeAllParallel/lcet10.txt.zst-16 11820 102538 ns/op 4161.90 MB/s 594 B/op 0 allocs/op
|
||||
BenchmarkDecoder_DecodeAllParallel/asyoulik.txt.zst-16 34782 34184 ns/op 3661.88 MB/s 60 B/op 0 allocs/op
|
||||
BenchmarkDecoder_DecodeAllParallel/alice29.txt.zst-16 27712 43447 ns/op 3500.58 MB/s 99 B/op 0 allocs/op
|
||||
BenchmarkDecoder_DecodeAllParallel/html_x_4.zst-16 62826 18750 ns/op 21845.10 MB/s 104 B/op 0 allocs/op
|
||||
BenchmarkDecoder_DecodeAllParallel/paper-100k.pdf.zst-16 631545 1794 ns/op 57078.74 MB/s 2 B/op 0 allocs/op
|
||||
BenchmarkDecoder_DecodeAllParallel/fireworks.jpeg.zst-16 1690140 712 ns/op 172938.13 MB/s 1 B/op 0 allocs/op
|
||||
BenchmarkDecoder_DecodeAllParallel/urls.10K.zst-16 10432 113593 ns/op 6180.73 MB/s 1143 B/op 0 allocs/op
|
||||
BenchmarkDecoder_DecodeAllParallel/html.zst-16 113206 10671 ns/op 9596.27 MB/s 15 B/op 0 allocs/op
|
||||
BenchmarkDecoder_DecodeAllParallel/comp-data.bin.zst-16 1530615 779 ns/op 5229.49 MB/s 0 B/op 0 allocs/op
|
||||
|
||||
BenchmarkDecoder_DecodeAllParallelCgo/kppkn.gtb.zst-16 65217 16192 ns/op 11383.34 MB/s 46 B/op 0 allocs/op
|
||||
BenchmarkDecoder_DecodeAllParallelCgo/geo.protodata.zst-16 292671 4039 ns/op 29363.19 MB/s 6 B/op 0 allocs/op
|
||||
BenchmarkDecoder_DecodeAllParallelCgo/plrabn12.txt.zst-16 26314 46021 ns/op 10470.43 MB/s 293 B/op 0 allocs/op
|
||||
BenchmarkDecoder_DecodeAllParallelCgo/lcet10.txt.zst-16 33897 34900 ns/op 12227.96 MB/s 205 B/op 0 allocs/op
|
||||
BenchmarkDecoder_DecodeAllParallelCgo/asyoulik.txt.zst-16 104348 11433 ns/op 10949.01 MB/s 20 B/op 0 allocs/op
|
||||
BenchmarkDecoder_DecodeAllParallelCgo/alice29.txt.zst-16 75949 15510 ns/op 9805.60 MB/s 32 B/op 0 allocs/op
|
||||
BenchmarkDecoder_DecodeAllParallelCgo/html_x_4.zst-16 173910 6756 ns/op 60624.29 MB/s 37 B/op 0 allocs/op
|
||||
BenchmarkDecoder_DecodeAllParallelCgo/paper-100k.pdf.zst-16 923076 1339 ns/op 76474.87 MB/s 1 B/op 0 allocs/op
|
||||
BenchmarkDecoder_DecodeAllParallelCgo/fireworks.jpeg.zst-16 922920 1351 ns/op 91102.57 MB/s 2 B/op 0 allocs/op
|
||||
BenchmarkDecoder_DecodeAllParallelCgo/urls.10K.zst-16 27649 43618 ns/op 16096.19 MB/s 407 B/op 0 allocs/op
|
||||
BenchmarkDecoder_DecodeAllParallelCgo/html.zst-16 279073 4160 ns/op 24614.18 MB/s 6 B/op 0 allocs/op
|
||||
BenchmarkDecoder_DecodeAllParallelCgo/comp-data.bin.zst-16 749938 1579 ns/op 2581.71 MB/s 0 B/op 0 allocs/op
|
||||
BenchmarkDecoder_DecodeAllParallel/kppkn.gtb.zst-32 67356 17857 ns/op 10321.96 MB/s 22.48 pct 102 B/op 0 allocs/op
|
||||
BenchmarkDecoder_DecodeAllParallel/geo.protodata.zst-32 266656 4421 ns/op 26823.21 MB/s 11.89 pct 19 B/op 0 allocs/op
|
||||
BenchmarkDecoder_DecodeAllParallel/plrabn12.txt.zst-32 20992 56842 ns/op 8477.17 MB/s 39.90 pct 754 B/op 0 allocs/op
|
||||
BenchmarkDecoder_DecodeAllParallel/lcet10.txt.zst-32 27456 43932 ns/op 9714.01 MB/s 33.27 pct 524 B/op 0 allocs/op
|
||||
BenchmarkDecoder_DecodeAllParallel/asyoulik.txt.zst-32 78432 15047 ns/op 8319.15 MB/s 40.34 pct 66 B/op 0 allocs/op
|
||||
BenchmarkDecoder_DecodeAllParallel/alice29.txt.zst-32 65800 18436 ns/op 8249.63 MB/s 37.75 pct 88 B/op 0 allocs/op
|
||||
BenchmarkDecoder_DecodeAllParallel/html_x_4.zst-32 102993 11523 ns/op 35546.09 MB/s 3.637 pct 143 B/op 0 allocs/op
|
||||
BenchmarkDecoder_DecodeAllParallel/paper-100k.pdf.zst-32 1000000 1070 ns/op 95720.98 MB/s 80.53 pct 3 B/op 0 allocs/op
|
||||
BenchmarkDecoder_DecodeAllParallel/fireworks.jpeg.zst-32 749802 1752 ns/op 70272.35 MB/s 100.0 pct 5 B/op 0 allocs/op
|
||||
BenchmarkDecoder_DecodeAllParallel/urls.10K.zst-32 22640 52934 ns/op 13263.37 MB/s 26.25 pct 1014 B/op 0 allocs/op
|
||||
BenchmarkDecoder_DecodeAllParallel/html.zst-32 226412 5232 ns/op 19572.27 MB/s 14.49 pct 20 B/op 0 allocs/op
|
||||
BenchmarkDecoder_DecodeAllParallel/comp-data.bin.zst-32 923041 1276 ns/op 3194.71 MB/s 31.26 pct 0 B/op 0 allocs/op
|
||||
```
|
||||
|
||||
This reflects the performance around May 2020, but this may be out of date.
|
||||
This reflects the performance around May 2022, but this may be out of date.
|
||||
|
||||
## Zstd inside ZIP files
|
||||
|
||||
|
|
|
@ -63,13 +63,6 @@ func (b *bitReader) get32BitsFast(n uint8) uint32 {
|
|||
return v
|
||||
}
|
||||
|
||||
func (b *bitReader) get16BitsFast(n uint8) uint16 {
|
||||
const regMask = 64 - 1
|
||||
v := uint16((b.value << (b.bitsRead & regMask)) >> ((regMask + 1 - n) & regMask))
|
||||
b.bitsRead += n
|
||||
return v
|
||||
}
|
||||
|
||||
// fillFast() will make sure at least 32 bits are available.
|
||||
// There must be at least 4 bytes available.
|
||||
func (b *bitReader) fillFast() {
|
||||
|
|
|
@ -5,8 +5,6 @@
|
|||
|
||||
package zstd
|
||||
|
||||
import "fmt"
|
||||
|
||||
// bitWriter will write bits.
|
||||
// First bit will be LSB of the first byte of output.
|
||||
type bitWriter struct {
|
||||
|
@ -73,80 +71,6 @@ func (b *bitWriter) addBits16Clean(value uint16, bits uint8) {
|
|||
b.nBits += bits
|
||||
}
|
||||
|
||||
// flush will flush all pending full bytes.
|
||||
// There will be at least 56 bits available for writing when this has been called.
|
||||
// Using flush32 is faster, but leaves less space for writing.
|
||||
func (b *bitWriter) flush() {
|
||||
v := b.nBits >> 3
|
||||
switch v {
|
||||
case 0:
|
||||
case 1:
|
||||
b.out = append(b.out,
|
||||
byte(b.bitContainer),
|
||||
)
|
||||
case 2:
|
||||
b.out = append(b.out,
|
||||
byte(b.bitContainer),
|
||||
byte(b.bitContainer>>8),
|
||||
)
|
||||
case 3:
|
||||
b.out = append(b.out,
|
||||
byte(b.bitContainer),
|
||||
byte(b.bitContainer>>8),
|
||||
byte(b.bitContainer>>16),
|
||||
)
|
||||
case 4:
|
||||
b.out = append(b.out,
|
||||
byte(b.bitContainer),
|
||||
byte(b.bitContainer>>8),
|
||||
byte(b.bitContainer>>16),
|
||||
byte(b.bitContainer>>24),
|
||||
)
|
||||
case 5:
|
||||
b.out = append(b.out,
|
||||
byte(b.bitContainer),
|
||||
byte(b.bitContainer>>8),
|
||||
byte(b.bitContainer>>16),
|
||||
byte(b.bitContainer>>24),
|
||||
byte(b.bitContainer>>32),
|
||||
)
|
||||
case 6:
|
||||
b.out = append(b.out,
|
||||
byte(b.bitContainer),
|
||||
byte(b.bitContainer>>8),
|
||||
byte(b.bitContainer>>16),
|
||||
byte(b.bitContainer>>24),
|
||||
byte(b.bitContainer>>32),
|
||||
byte(b.bitContainer>>40),
|
||||
)
|
||||
case 7:
|
||||
b.out = append(b.out,
|
||||
byte(b.bitContainer),
|
||||
byte(b.bitContainer>>8),
|
||||
byte(b.bitContainer>>16),
|
||||
byte(b.bitContainer>>24),
|
||||
byte(b.bitContainer>>32),
|
||||
byte(b.bitContainer>>40),
|
||||
byte(b.bitContainer>>48),
|
||||
)
|
||||
case 8:
|
||||
b.out = append(b.out,
|
||||
byte(b.bitContainer),
|
||||
byte(b.bitContainer>>8),
|
||||
byte(b.bitContainer>>16),
|
||||
byte(b.bitContainer>>24),
|
||||
byte(b.bitContainer>>32),
|
||||
byte(b.bitContainer>>40),
|
||||
byte(b.bitContainer>>48),
|
||||
byte(b.bitContainer>>56),
|
||||
)
|
||||
default:
|
||||
panic(fmt.Errorf("bits (%d) > 64", b.nBits))
|
||||
}
|
||||
b.bitContainer >>= v << 3
|
||||
b.nBits &= 7
|
||||
}
|
||||
|
||||
// flush32 will flush out, so there are at least 32 bits available for writing.
|
||||
func (b *bitWriter) flush32() {
|
||||
if b.nBits < 32 {
|
||||
|
|
|
@ -5,9 +5,13 @@
|
|||
package zstd
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"encoding/binary"
|
||||
"errors"
|
||||
"fmt"
|
||||
"io"
|
||||
"os"
|
||||
"path/filepath"
|
||||
"sync"
|
||||
|
||||
"github.com/klauspost/compress/huff0"
|
||||
|
@ -38,14 +42,14 @@ const (
|
|||
// maxCompressedBlockSize is the biggest allowed compressed block size (128KB)
|
||||
maxCompressedBlockSize = 128 << 10
|
||||
|
||||
compressedBlockOverAlloc = 16
|
||||
maxCompressedBlockSizeAlloc = 128<<10 + compressedBlockOverAlloc
|
||||
|
||||
// Maximum possible block size (all Raw+Uncompressed).
|
||||
maxBlockSize = (1 << 21) - 1
|
||||
|
||||
// https://github.com/facebook/zstd/blob/dev/doc/zstd_compression_format.md#literals_section_header
|
||||
maxCompressedLiteralSize = 1 << 18
|
||||
maxRLELiteralSize = 1 << 20
|
||||
maxMatchLen = 131074
|
||||
maxSequences = 0x7f00 + 0xffff
|
||||
maxMatchLen = 131074
|
||||
maxSequences = 0x7f00 + 0xffff
|
||||
|
||||
// We support slightly less than the reference decoder to be able to
|
||||
// use ints on 32 bit archs.
|
||||
|
@ -97,7 +101,6 @@ type blockDec struct {
|
|||
|
||||
// Block is RLE, this is the size.
|
||||
RLESize uint32
|
||||
tmp [4]byte
|
||||
|
||||
Type blockType
|
||||
|
||||
|
@ -136,7 +139,7 @@ func (b *blockDec) reset(br byteBuffer, windowSize uint64) error {
|
|||
b.Type = blockType((bh >> 1) & 3)
|
||||
// find size.
|
||||
cSize := int(bh >> 3)
|
||||
maxSize := maxBlockSize
|
||||
maxSize := maxCompressedBlockSizeAlloc
|
||||
switch b.Type {
|
||||
case blockTypeReserved:
|
||||
return ErrReservedBlockType
|
||||
|
@ -157,9 +160,9 @@ func (b *blockDec) reset(br byteBuffer, windowSize uint64) error {
|
|||
println("Data size on stream:", cSize)
|
||||
}
|
||||
b.RLESize = 0
|
||||
maxSize = maxCompressedBlockSize
|
||||
maxSize = maxCompressedBlockSizeAlloc
|
||||
if windowSize < maxCompressedBlockSize && b.lowMem {
|
||||
maxSize = int(windowSize)
|
||||
maxSize = int(windowSize) + compressedBlockOverAlloc
|
||||
}
|
||||
if cSize > maxCompressedBlockSize || uint64(cSize) > b.WindowSize {
|
||||
if debugDecoder {
|
||||
|
@ -190,9 +193,9 @@ func (b *blockDec) reset(br byteBuffer, windowSize uint64) error {
|
|||
// Read block data.
|
||||
if cap(b.dataStorage) < cSize {
|
||||
if b.lowMem || cSize > maxCompressedBlockSize {
|
||||
b.dataStorage = make([]byte, 0, cSize)
|
||||
b.dataStorage = make([]byte, 0, cSize+compressedBlockOverAlloc)
|
||||
} else {
|
||||
b.dataStorage = make([]byte, 0, maxCompressedBlockSize)
|
||||
b.dataStorage = make([]byte, 0, maxCompressedBlockSizeAlloc)
|
||||
}
|
||||
}
|
||||
if cap(b.dst) <= maxSize {
|
||||
|
@ -360,14 +363,9 @@ func (b *blockDec) decodeLiterals(in []byte, hist *history) (remain []byte, err
|
|||
}
|
||||
if cap(b.literalBuf) < litRegenSize {
|
||||
if b.lowMem {
|
||||
b.literalBuf = make([]byte, litRegenSize)
|
||||
b.literalBuf = make([]byte, litRegenSize, litRegenSize+compressedBlockOverAlloc)
|
||||
} else {
|
||||
if litRegenSize > maxCompressedLiteralSize {
|
||||
// Exceptional
|
||||
b.literalBuf = make([]byte, litRegenSize)
|
||||
} else {
|
||||
b.literalBuf = make([]byte, litRegenSize, maxCompressedLiteralSize)
|
||||
}
|
||||
b.literalBuf = make([]byte, litRegenSize, maxCompressedBlockSize+compressedBlockOverAlloc)
|
||||
}
|
||||
}
|
||||
literals = b.literalBuf[:litRegenSize]
|
||||
|
@ -397,14 +395,14 @@ func (b *blockDec) decodeLiterals(in []byte, hist *history) (remain []byte, err
|
|||
// Ensure we have space to store it.
|
||||
if cap(b.literalBuf) < litRegenSize {
|
||||
if b.lowMem {
|
||||
b.literalBuf = make([]byte, 0, litRegenSize)
|
||||
b.literalBuf = make([]byte, 0, litRegenSize+compressedBlockOverAlloc)
|
||||
} else {
|
||||
b.literalBuf = make([]byte, 0, maxCompressedLiteralSize)
|
||||
b.literalBuf = make([]byte, 0, maxCompressedBlockSize+compressedBlockOverAlloc)
|
||||
}
|
||||
}
|
||||
var err error
|
||||
// Use our out buffer.
|
||||
huff.MaxDecodedSize = maxCompressedBlockSize
|
||||
huff.MaxDecodedSize = litRegenSize
|
||||
if fourStreams {
|
||||
literals, err = huff.Decoder().Decompress4X(b.literalBuf[:0:litRegenSize], literals)
|
||||
} else {
|
||||
|
@ -429,9 +427,9 @@ func (b *blockDec) decodeLiterals(in []byte, hist *history) (remain []byte, err
|
|||
// Ensure we have space to store it.
|
||||
if cap(b.literalBuf) < litRegenSize {
|
||||
if b.lowMem {
|
||||
b.literalBuf = make([]byte, 0, litRegenSize)
|
||||
b.literalBuf = make([]byte, 0, litRegenSize+compressedBlockOverAlloc)
|
||||
} else {
|
||||
b.literalBuf = make([]byte, 0, maxCompressedBlockSize)
|
||||
b.literalBuf = make([]byte, 0, maxCompressedBlockSize+compressedBlockOverAlloc)
|
||||
}
|
||||
}
|
||||
huff := hist.huffTree
|
||||
|
@ -448,7 +446,7 @@ func (b *blockDec) decodeLiterals(in []byte, hist *history) (remain []byte, err
|
|||
return in, err
|
||||
}
|
||||
hist.huffTree = huff
|
||||
huff.MaxDecodedSize = maxCompressedBlockSize
|
||||
huff.MaxDecodedSize = litRegenSize
|
||||
// Use our out buffer.
|
||||
if fourStreams {
|
||||
literals, err = huff.Decoder().Decompress4X(b.literalBuf[:0:litRegenSize], literals)
|
||||
|
@ -463,6 +461,8 @@ func (b *blockDec) decodeLiterals(in []byte, hist *history) (remain []byte, err
|
|||
if len(literals) != litRegenSize {
|
||||
return in, fmt.Errorf("literal output size mismatch want %d, got %d", litRegenSize, len(literals))
|
||||
}
|
||||
// Re-cap to get extra size.
|
||||
literals = b.literalBuf[:len(literals)]
|
||||
if debugDecoder {
|
||||
printf("Decompressed %d literals into %d bytes\n", litCompSize, litRegenSize)
|
||||
}
|
||||
|
@ -486,10 +486,15 @@ func (b *blockDec) decodeCompressed(hist *history) error {
|
|||
b.dst = append(b.dst, hist.decoders.literals...)
|
||||
return nil
|
||||
}
|
||||
err = hist.decoders.decodeSync(hist)
|
||||
before := len(hist.decoders.out)
|
||||
err = hist.decoders.decodeSync(hist.b[hist.ignoreBuffer:])
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
if hist.decoders.maxSyncLen > 0 {
|
||||
hist.decoders.maxSyncLen += uint64(before)
|
||||
hist.decoders.maxSyncLen -= uint64(len(hist.decoders.out))
|
||||
}
|
||||
b.dst = hist.decoders.out
|
||||
hist.recentOffsets = hist.decoders.prevOffset
|
||||
return nil
|
||||
|
@ -632,6 +637,22 @@ func (b *blockDec) prepareSequences(in []byte, hist *history) (err error) {
|
|||
println("initializing sequences:", err)
|
||||
return err
|
||||
}
|
||||
// Extract blocks...
|
||||
if false && hist.dict == nil {
|
||||
fatalErr := func(err error) {
|
||||
if err != nil {
|
||||
panic(err)
|
||||
}
|
||||
}
|
||||
fn := fmt.Sprintf("n-%d-lits-%d-prev-%d-%d-%d-win-%d.blk", hist.decoders.nSeqs, len(hist.decoders.literals), hist.recentOffsets[0], hist.recentOffsets[1], hist.recentOffsets[2], hist.windowSize)
|
||||
var buf bytes.Buffer
|
||||
fatalErr(binary.Write(&buf, binary.LittleEndian, hist.decoders.litLengths.fse))
|
||||
fatalErr(binary.Write(&buf, binary.LittleEndian, hist.decoders.matchLengths.fse))
|
||||
fatalErr(binary.Write(&buf, binary.LittleEndian, hist.decoders.offsets.fse))
|
||||
buf.Write(in)
|
||||
os.WriteFile(filepath.Join("testdata", "seqs", fn), buf.Bytes(), os.ModePerm)
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
|
@ -650,6 +671,7 @@ func (b *blockDec) decodeSequences(hist *history) error {
|
|||
}
|
||||
hist.decoders.windowSize = hist.windowSize
|
||||
hist.decoders.prevOffset = hist.recentOffsets
|
||||
|
||||
err := hist.decoders.decode(b.sequence)
|
||||
hist.recentOffsets = hist.decoders.prevOffset
|
||||
return err
|
||||
|
|
|
@ -7,7 +7,6 @@ package zstd
|
|||
import (
|
||||
"fmt"
|
||||
"io"
|
||||
"io/ioutil"
|
||||
)
|
||||
|
||||
type byteBuffer interface {
|
||||
|
@ -23,7 +22,7 @@ type byteBuffer interface {
|
|||
readByte() (byte, error)
|
||||
|
||||
// Skip n bytes.
|
||||
skipN(n int) error
|
||||
skipN(n int64) error
|
||||
}
|
||||
|
||||
// in-memory buffer
|
||||
|
@ -52,10 +51,6 @@ func (b *byteBuf) readBig(n int, dst []byte) ([]byte, error) {
|
|||
return r, nil
|
||||
}
|
||||
|
||||
func (b *byteBuf) remain() []byte {
|
||||
return *b
|
||||
}
|
||||
|
||||
func (b *byteBuf) readByte() (byte, error) {
|
||||
bb := *b
|
||||
if len(bb) < 1 {
|
||||
|
@ -66,9 +61,12 @@ func (b *byteBuf) readByte() (byte, error) {
|
|||
return r, nil
|
||||
}
|
||||
|
||||
func (b *byteBuf) skipN(n int) error {
|
||||
func (b *byteBuf) skipN(n int64) error {
|
||||
bb := *b
|
||||
if len(bb) < n {
|
||||
if n < 0 {
|
||||
return fmt.Errorf("negative skip (%d) requested", n)
|
||||
}
|
||||
if int64(len(bb)) < n {
|
||||
return io.ErrUnexpectedEOF
|
||||
}
|
||||
*b = bb[n:]
|
||||
|
@ -124,9 +122,9 @@ func (r *readerWrapper) readByte() (byte, error) {
|
|||
return r.tmp[0], nil
|
||||
}
|
||||
|
||||
func (r *readerWrapper) skipN(n int) error {
|
||||
n2, err := io.CopyN(ioutil.Discard, r.r, int64(n))
|
||||
if n2 != int64(n) {
|
||||
func (r *readerWrapper) skipN(n int64) error {
|
||||
n2, err := io.CopyN(io.Discard, r.r, n)
|
||||
if n2 != n {
|
||||
err = io.ErrUnexpectedEOF
|
||||
}
|
||||
return err
|
||||
|
|
|
@ -13,12 +13,6 @@ type byteReader struct {
|
|||
off int
|
||||
}
|
||||
|
||||
// init will initialize the reader and set the input.
|
||||
func (b *byteReader) init(in []byte) {
|
||||
b.b = in
|
||||
b.off = 0
|
||||
}
|
||||
|
||||
// advance the stream b n bytes.
|
||||
func (b *byteReader) advance(n uint) {
|
||||
b.off += int(n)
|
||||
|
|
|
@ -312,6 +312,7 @@ func (d *Decoder) DecodeAll(input, dst []byte) ([]byte, error) {
|
|||
// Grab a block decoder and frame decoder.
|
||||
block := <-d.decoders
|
||||
frame := block.localFrame
|
||||
initialSize := len(dst)
|
||||
defer func() {
|
||||
if debugDecoder {
|
||||
printf("re-adding decoder: %p", block)
|
||||
|
@ -347,19 +348,33 @@ func (d *Decoder) DecodeAll(input, dst []byte) ([]byte, error) {
|
|||
}
|
||||
frame.history.setDict(&dict)
|
||||
}
|
||||
|
||||
if frame.FrameContentSize != fcsUnknown && frame.FrameContentSize > d.o.maxDecodedSize-uint64(len(dst)) {
|
||||
return dst, ErrDecoderSizeExceeded
|
||||
if frame.WindowSize > d.o.maxWindowSize {
|
||||
if debugDecoder {
|
||||
println("window size exceeded:", frame.WindowSize, ">", d.o.maxWindowSize)
|
||||
}
|
||||
return dst, ErrWindowSizeExceeded
|
||||
}
|
||||
if frame.FrameContentSize < 1<<30 {
|
||||
// Never preallocate more than 1 GB up front.
|
||||
if frame.FrameContentSize != fcsUnknown {
|
||||
if frame.FrameContentSize > d.o.maxDecodedSize-uint64(len(dst)-initialSize) {
|
||||
if debugDecoder {
|
||||
println("decoder size exceeded; fcs:", frame.FrameContentSize, "> mcs:", d.o.maxDecodedSize-uint64(len(dst)-initialSize), "len:", len(dst))
|
||||
}
|
||||
return dst, ErrDecoderSizeExceeded
|
||||
}
|
||||
if d.o.limitToCap && frame.FrameContentSize > uint64(cap(dst)-len(dst)) {
|
||||
if debugDecoder {
|
||||
println("decoder size exceeded; fcs:", frame.FrameContentSize, "> (cap-len)", cap(dst)-len(dst))
|
||||
}
|
||||
return dst, ErrDecoderSizeExceeded
|
||||
}
|
||||
if cap(dst)-len(dst) < int(frame.FrameContentSize) {
|
||||
dst2 := make([]byte, len(dst), len(dst)+int(frame.FrameContentSize))
|
||||
dst2 := make([]byte, len(dst), len(dst)+int(frame.FrameContentSize)+compressedBlockOverAlloc)
|
||||
copy(dst2, dst)
|
||||
dst = dst2
|
||||
}
|
||||
}
|
||||
if cap(dst) == 0 {
|
||||
|
||||
if cap(dst) == 0 && !d.o.limitToCap {
|
||||
// Allocate len(input) * 2 by default if nothing is provided
|
||||
// and we didn't get frame content size.
|
||||
size := len(input) * 2
|
||||
|
@ -377,6 +392,9 @@ func (d *Decoder) DecodeAll(input, dst []byte) ([]byte, error) {
|
|||
if err != nil {
|
||||
return dst, err
|
||||
}
|
||||
if uint64(len(dst)-initialSize) > d.o.maxDecodedSize {
|
||||
return dst, ErrDecoderSizeExceeded
|
||||
}
|
||||
if len(frame.bBuf) == 0 {
|
||||
if debugDecoder {
|
||||
println("frame dbuf empty")
|
||||
|
@ -437,7 +455,7 @@ func (d *Decoder) nextBlock(blocking bool) (ok bool) {
|
|||
println("got", len(d.current.b), "bytes, error:", d.current.err, "data crc:", tmp)
|
||||
}
|
||||
|
||||
if len(next.b) > 0 {
|
||||
if !d.o.ignoreChecksum && len(next.b) > 0 {
|
||||
n, err := d.current.crc.Write(next.b)
|
||||
if err == nil {
|
||||
if n != len(next.b) {
|
||||
|
@ -449,7 +467,7 @@ func (d *Decoder) nextBlock(blocking bool) (ok bool) {
|
|||
got := d.current.crc.Sum64()
|
||||
var tmp [4]byte
|
||||
binary.LittleEndian.PutUint32(tmp[:], uint32(got))
|
||||
if !bytes.Equal(tmp[:], next.d.checkCRC) && !ignoreCRC {
|
||||
if !d.o.ignoreChecksum && !bytes.Equal(tmp[:], next.d.checkCRC) {
|
||||
if debugDecoder {
|
||||
println("CRC Check Failed:", tmp[:], " (got) !=", next.d.checkCRC, "(on stream)")
|
||||
}
|
||||
|
@ -533,9 +551,15 @@ func (d *Decoder) nextBlockSync() (ok bool) {
|
|||
|
||||
// Update/Check CRC
|
||||
if d.frame.HasCheckSum {
|
||||
d.frame.crc.Write(d.current.b)
|
||||
if !d.o.ignoreChecksum {
|
||||
d.frame.crc.Write(d.current.b)
|
||||
}
|
||||
if d.current.d.Last {
|
||||
d.current.err = d.frame.checkCRC()
|
||||
if !d.o.ignoreChecksum {
|
||||
d.current.err = d.frame.checkCRC()
|
||||
} else {
|
||||
d.current.err = d.frame.consumeCRC()
|
||||
}
|
||||
if d.current.err != nil {
|
||||
println("CRC error:", d.current.err)
|
||||
return false
|
||||
|
@ -629,60 +653,18 @@ func (d *Decoder) startSyncDecoder(r io.Reader) error {
|
|||
|
||||
// Create Decoder:
|
||||
// ASYNC:
|
||||
// Spawn 4 go routines.
|
||||
// 0: Read frames and decode blocks.
|
||||
// 1: Decode block and literals. Receives hufftree and seqdecs, returns seqdecs and huff tree.
|
||||
// 2: Wait for recentOffsets if needed. Decode sequences, send recentOffsets.
|
||||
// 3: Wait for stream history, execute sequences, send stream history.
|
||||
// Spawn 3 go routines.
|
||||
// 0: Read frames and decode block literals.
|
||||
// 1: Decode sequences.
|
||||
// 2: Execute sequences, send to output.
|
||||
func (d *Decoder) startStreamDecoder(ctx context.Context, r io.Reader, output chan decodeOutput) {
|
||||
defer d.streamWg.Done()
|
||||
br := readerWrapper{r: r}
|
||||
|
||||
var seqPrepare = make(chan *blockDec, d.o.concurrent)
|
||||
var seqDecode = make(chan *blockDec, d.o.concurrent)
|
||||
var seqExecute = make(chan *blockDec, d.o.concurrent)
|
||||
|
||||
// Async 1: Prepare blocks...
|
||||
go func() {
|
||||
var hist history
|
||||
var hasErr bool
|
||||
for block := range seqPrepare {
|
||||
if hasErr {
|
||||
if block != nil {
|
||||
seqDecode <- block
|
||||
}
|
||||
continue
|
||||
}
|
||||
if block.async.newHist != nil {
|
||||
if debugDecoder {
|
||||
println("Async 1: new history")
|
||||
}
|
||||
hist.reset()
|
||||
if block.async.newHist.dict != nil {
|
||||
hist.setDict(block.async.newHist.dict)
|
||||
}
|
||||
}
|
||||
if block.err != nil || block.Type != blockTypeCompressed {
|
||||
hasErr = block.err != nil
|
||||
seqDecode <- block
|
||||
continue
|
||||
}
|
||||
|
||||
remain, err := block.decodeLiterals(block.data, &hist)
|
||||
block.err = err
|
||||
hasErr = block.err != nil
|
||||
if err == nil {
|
||||
block.async.literals = hist.decoders.literals
|
||||
block.async.seqData = remain
|
||||
} else if debugDecoder {
|
||||
println("decodeLiterals error:", err)
|
||||
}
|
||||
seqDecode <- block
|
||||
}
|
||||
close(seqDecode)
|
||||
}()
|
||||
|
||||
// Async 2: Decode sequences...
|
||||
// Async 1: Decode sequences...
|
||||
go func() {
|
||||
var hist history
|
||||
var hasErr bool
|
||||
|
@ -696,7 +678,7 @@ func (d *Decoder) startStreamDecoder(ctx context.Context, r io.Reader, output ch
|
|||
}
|
||||
if block.async.newHist != nil {
|
||||
if debugDecoder {
|
||||
println("Async 2: new history, recent:", block.async.newHist.recentOffsets)
|
||||
println("Async 1: new history, recent:", block.async.newHist.recentOffsets)
|
||||
}
|
||||
hist.decoders = block.async.newHist.decoders
|
||||
hist.recentOffsets = block.async.newHist.recentOffsets
|
||||
|
@ -750,7 +732,7 @@ func (d *Decoder) startStreamDecoder(ctx context.Context, r io.Reader, output ch
|
|||
}
|
||||
if block.async.newHist != nil {
|
||||
if debugDecoder {
|
||||
println("Async 3: new history")
|
||||
println("Async 2: new history")
|
||||
}
|
||||
hist.windowSize = block.async.newHist.windowSize
|
||||
hist.allocFrameBuffer = block.async.newHist.allocFrameBuffer
|
||||
|
@ -837,6 +819,33 @@ func (d *Decoder) startStreamDecoder(ctx context.Context, r io.Reader, output ch
|
|||
|
||||
decodeStream:
|
||||
for {
|
||||
var hist history
|
||||
var hasErr bool
|
||||
|
||||
decodeBlock := func(block *blockDec) {
|
||||
if hasErr {
|
||||
if block != nil {
|
||||
seqDecode <- block
|
||||
}
|
||||
return
|
||||
}
|
||||
if block.err != nil || block.Type != blockTypeCompressed {
|
||||
hasErr = block.err != nil
|
||||
seqDecode <- block
|
||||
return
|
||||
}
|
||||
|
||||
remain, err := block.decodeLiterals(block.data, &hist)
|
||||
block.err = err
|
||||
hasErr = block.err != nil
|
||||
if err == nil {
|
||||
block.async.literals = hist.decoders.literals
|
||||
block.async.seqData = remain
|
||||
} else if debugDecoder {
|
||||
println("decodeLiterals error:", err)
|
||||
}
|
||||
seqDecode <- block
|
||||
}
|
||||
frame := d.frame
|
||||
if debugDecoder {
|
||||
println("New frame...")
|
||||
|
@ -856,6 +865,10 @@ decodeStream:
|
|||
}
|
||||
}
|
||||
if err == nil && d.frame.WindowSize > d.o.maxWindowSize {
|
||||
if debugDecoder {
|
||||
println("decoder size exceeded, fws:", d.frame.WindowSize, "> mws:", d.o.maxWindowSize)
|
||||
}
|
||||
|
||||
err = ErrDecoderSizeExceeded
|
||||
}
|
||||
if err != nil {
|
||||
|
@ -863,7 +876,7 @@ decodeStream:
|
|||
case <-ctx.Done():
|
||||
case dec := <-d.decoders:
|
||||
dec.sendErr(err)
|
||||
seqPrepare <- dec
|
||||
decodeBlock(dec)
|
||||
}
|
||||
break decodeStream
|
||||
}
|
||||
|
@ -883,6 +896,10 @@ decodeStream:
|
|||
if debugDecoder {
|
||||
println("Alloc History:", h.allocFrameBuffer)
|
||||
}
|
||||
hist.reset()
|
||||
if h.dict != nil {
|
||||
hist.setDict(h.dict)
|
||||
}
|
||||
dec.async.newHist = &h
|
||||
dec.async.fcs = frame.FrameContentSize
|
||||
historySent = true
|
||||
|
@ -909,7 +926,7 @@ decodeStream:
|
|||
}
|
||||
err = dec.err
|
||||
last := dec.Last
|
||||
seqPrepare <- dec
|
||||
decodeBlock(dec)
|
||||
if err != nil {
|
||||
break decodeStream
|
||||
}
|
||||
|
@ -918,7 +935,7 @@ decodeStream:
|
|||
}
|
||||
}
|
||||
}
|
||||
close(seqPrepare)
|
||||
close(seqDecode)
|
||||
wg.Wait()
|
||||
d.frame.history.b = frameHistCache
|
||||
}
|
||||
|
|
|
@ -19,6 +19,8 @@ type decoderOptions struct {
|
|||
maxDecodedSize uint64
|
||||
maxWindowSize uint64
|
||||
dicts []dict
|
||||
ignoreChecksum bool
|
||||
limitToCap bool
|
||||
}
|
||||
|
||||
func (o *decoderOptions) setDefault() {
|
||||
|
@ -31,7 +33,7 @@ func (o *decoderOptions) setDefault() {
|
|||
if o.concurrent > 4 {
|
||||
o.concurrent = 4
|
||||
}
|
||||
o.maxDecodedSize = 1 << 63
|
||||
o.maxDecodedSize = 64 << 30
|
||||
}
|
||||
|
||||
// WithDecoderLowmem will set whether to use a lower amount of memory,
|
||||
|
@ -66,7 +68,7 @@ func WithDecoderConcurrency(n int) DOption {
|
|||
// WithDecoderMaxMemory allows to set a maximum decoded size for in-memory
|
||||
// non-streaming operations or maximum window size for streaming operations.
|
||||
// This can be used to control memory usage of potentially hostile content.
|
||||
// Maximum and default is 1 << 63 bytes.
|
||||
// Maximum is 1 << 63 bytes. Default is 64GiB.
|
||||
func WithDecoderMaxMemory(n uint64) DOption {
|
||||
return func(o *decoderOptions) error {
|
||||
if n == 0 {
|
||||
|
@ -112,3 +114,22 @@ func WithDecoderMaxWindow(size uint64) DOption {
|
|||
return nil
|
||||
}
|
||||
}
|
||||
|
||||
// WithDecodeAllCapLimit will limit DecodeAll to decoding cap(dst)-len(dst) bytes,
|
||||
// or any size set in WithDecoderMaxMemory.
|
||||
// This can be used to limit decoding to a specific maximum output size.
|
||||
// Disabled by default.
|
||||
func WithDecodeAllCapLimit(b bool) DOption {
|
||||
return func(o *decoderOptions) error {
|
||||
o.limitToCap = b
|
||||
return nil
|
||||
}
|
||||
}
|
||||
|
||||
// IgnoreChecksum allows to forcibly ignore checksum checking.
|
||||
func IgnoreChecksum(b bool) DOption {
|
||||
return func(o *decoderOptions) error {
|
||||
o.ignoreChecksum = b
|
||||
return nil
|
||||
}
|
||||
}
|
||||
|
|
|
@ -156,8 +156,8 @@ encodeLoop:
|
|||
panic("offset0 was 0")
|
||||
}
|
||||
|
||||
nextHashS := hashLen(cv, betterShortTableBits, betterShortLen)
|
||||
nextHashL := hashLen(cv, betterLongTableBits, betterLongLen)
|
||||
nextHashS := hashLen(cv, betterShortTableBits, betterShortLen)
|
||||
candidateL := e.longTable[nextHashL]
|
||||
candidateS := e.table[nextHashS]
|
||||
|
||||
|
@ -416,15 +416,23 @@ encodeLoop:
|
|||
|
||||
// Try to find a better match by searching for a long match at the end of the current best match
|
||||
if s+matched < sLimit {
|
||||
// Allow some bytes at the beginning to mismatch.
|
||||
// Sweet spot is around 3 bytes, but depends on input.
|
||||
// The skipped bytes are tested in Extend backwards,
|
||||
// and still picked up as part of the match if they do.
|
||||
const skipBeginning = 3
|
||||
|
||||
nextHashL := hashLen(load6432(src, s+matched), betterLongTableBits, betterLongLen)
|
||||
cv := load3232(src, s)
|
||||
s2 := s + skipBeginning
|
||||
cv := load3232(src, s2)
|
||||
candidateL := e.longTable[nextHashL]
|
||||
coffsetL := candidateL.offset - e.cur - matched
|
||||
if coffsetL >= 0 && coffsetL < s && s-coffsetL < e.maxMatchOff && cv == load3232(src, coffsetL) {
|
||||
coffsetL := candidateL.offset - e.cur - matched + skipBeginning
|
||||
if coffsetL >= 0 && coffsetL < s2 && s2-coffsetL < e.maxMatchOff && cv == load3232(src, coffsetL) {
|
||||
// Found a long match, at least 4 bytes.
|
||||
matchedNext := e.matchlen(s+4, coffsetL+4, src) + 4
|
||||
matchedNext := e.matchlen(s2+4, coffsetL+4, src) + 4
|
||||
if matchedNext > matched {
|
||||
t = coffsetL
|
||||
s = s2
|
||||
matched = matchedNext
|
||||
if debugMatches {
|
||||
println("long match at end-of-match")
|
||||
|
@ -434,12 +442,13 @@ encodeLoop:
|
|||
|
||||
// Check prev long...
|
||||
if true {
|
||||
coffsetL = candidateL.prev - e.cur - matched
|
||||
if coffsetL >= 0 && coffsetL < s && s-coffsetL < e.maxMatchOff && cv == load3232(src, coffsetL) {
|
||||
coffsetL = candidateL.prev - e.cur - matched + skipBeginning
|
||||
if coffsetL >= 0 && coffsetL < s2 && s2-coffsetL < e.maxMatchOff && cv == load3232(src, coffsetL) {
|
||||
// Found a long match, at least 4 bytes.
|
||||
matchedNext := e.matchlen(s+4, coffsetL+4, src) + 4
|
||||
matchedNext := e.matchlen(s2+4, coffsetL+4, src) + 4
|
||||
if matchedNext > matched {
|
||||
t = coffsetL
|
||||
s = s2
|
||||
matched = matchedNext
|
||||
if debugMatches {
|
||||
println("prev long match at end-of-match")
|
||||
|
@ -518,8 +527,8 @@ encodeLoop:
|
|||
}
|
||||
|
||||
// Store this, since we have it.
|
||||
nextHashS := hashLen(cv, betterShortTableBits, betterShortLen)
|
||||
nextHashL := hashLen(cv, betterLongTableBits, betterLongLen)
|
||||
nextHashS := hashLen(cv, betterShortTableBits, betterShortLen)
|
||||
|
||||
// We have at least 4 byte match.
|
||||
// No need to check backwards. We come straight from a match
|
||||
|
@ -674,8 +683,8 @@ encodeLoop:
|
|||
panic("offset0 was 0")
|
||||
}
|
||||
|
||||
nextHashS := hashLen(cv, betterShortTableBits, betterShortLen)
|
||||
nextHashL := hashLen(cv, betterLongTableBits, betterLongLen)
|
||||
nextHashS := hashLen(cv, betterShortTableBits, betterShortLen)
|
||||
candidateL := e.longTable[nextHashL]
|
||||
candidateS := e.table[nextHashS]
|
||||
|
||||
|
@ -1047,8 +1056,8 @@ encodeLoop:
|
|||
}
|
||||
|
||||
// Store this, since we have it.
|
||||
nextHashS := hashLen(cv, betterShortTableBits, betterShortLen)
|
||||
nextHashL := hashLen(cv, betterLongTableBits, betterLongLen)
|
||||
nextHashS := hashLen(cv, betterShortTableBits, betterShortLen)
|
||||
|
||||
// We have at least 4 byte match.
|
||||
// No need to check backwards. We come straight from a match
|
||||
|
|
|
@ -127,8 +127,8 @@ encodeLoop:
|
|||
panic("offset0 was 0")
|
||||
}
|
||||
|
||||
nextHashS := hashLen(cv, dFastShortTableBits, dFastShortLen)
|
||||
nextHashL := hashLen(cv, dFastLongTableBits, dFastLongLen)
|
||||
nextHashS := hashLen(cv, dFastShortTableBits, dFastShortLen)
|
||||
candidateL := e.longTable[nextHashL]
|
||||
candidateS := e.table[nextHashS]
|
||||
|
||||
|
@ -439,8 +439,8 @@ encodeLoop:
|
|||
var t int32
|
||||
for {
|
||||
|
||||
nextHashS := hashLen(cv, dFastShortTableBits, dFastShortLen)
|
||||
nextHashL := hashLen(cv, dFastLongTableBits, dFastLongLen)
|
||||
nextHashS := hashLen(cv, dFastShortTableBits, dFastShortLen)
|
||||
candidateL := e.longTable[nextHashL]
|
||||
candidateS := e.table[nextHashS]
|
||||
|
||||
|
@ -785,8 +785,8 @@ encodeLoop:
|
|||
panic("offset0 was 0")
|
||||
}
|
||||
|
||||
nextHashS := hashLen(cv, dFastShortTableBits, dFastShortLen)
|
||||
nextHashL := hashLen(cv, dFastLongTableBits, dFastLongLen)
|
||||
nextHashS := hashLen(cv, dFastShortTableBits, dFastShortLen)
|
||||
candidateL := e.longTable[nextHashL]
|
||||
candidateS := e.table[nextHashS]
|
||||
|
||||
|
@ -969,7 +969,7 @@ encodeLoop:
|
|||
te0 := tableEntry{offset: index0 + e.cur, val: uint32(cv0)}
|
||||
te1 := tableEntry{offset: index1 + e.cur, val: uint32(cv1)}
|
||||
longHash1 := hashLen(cv0, dFastLongTableBits, dFastLongLen)
|
||||
longHash2 := hashLen(cv0, dFastLongTableBits, dFastLongLen)
|
||||
longHash2 := hashLen(cv1, dFastLongTableBits, dFastLongLen)
|
||||
e.longTable[longHash1] = te0
|
||||
e.longTable[longHash2] = te1
|
||||
e.markLongShardDirty(longHash1)
|
||||
|
@ -1002,8 +1002,8 @@ encodeLoop:
|
|||
}
|
||||
|
||||
// Store this, since we have it.
|
||||
nextHashS := hashLen(cv, dFastShortTableBits, dFastShortLen)
|
||||
nextHashL := hashLen(cv, dFastLongTableBits, dFastLongLen)
|
||||
nextHashS := hashLen(cv, dFastShortTableBits, dFastShortLen)
|
||||
|
||||
// We have at least 4 byte match.
|
||||
// No need to check backwards. We come straight from a match
|
||||
|
@ -1103,7 +1103,8 @@ func (e *doubleFastEncoderDict) Reset(d *dict, singleBlock bool) {
|
|||
}
|
||||
|
||||
if allDirty || dirtyShardCnt > dLongTableShardCnt/2 {
|
||||
copy(e.longTable[:], e.dictLongTable)
|
||||
//copy(e.longTable[:], e.dictLongTable)
|
||||
e.longTable = *(*[dFastLongTableSize]tableEntry)(e.dictLongTable)
|
||||
for i := range e.longTableShardDirty {
|
||||
e.longTableShardDirty[i] = false
|
||||
}
|
||||
|
@ -1114,7 +1115,9 @@ func (e *doubleFastEncoderDict) Reset(d *dict, singleBlock bool) {
|
|||
continue
|
||||
}
|
||||
|
||||
copy(e.longTable[i*dLongTableShardSize:(i+1)*dLongTableShardSize], e.dictLongTable[i*dLongTableShardSize:(i+1)*dLongTableShardSize])
|
||||
// copy(e.longTable[i*dLongTableShardSize:(i+1)*dLongTableShardSize], e.dictLongTable[i*dLongTableShardSize:(i+1)*dLongTableShardSize])
|
||||
*(*[dLongTableShardSize]tableEntry)(e.longTable[i*dLongTableShardSize:]) = *(*[dLongTableShardSize]tableEntry)(e.dictLongTable[i*dLongTableShardSize:])
|
||||
|
||||
e.longTableShardDirty[i] = false
|
||||
}
|
||||
}
|
||||
|
|
Some files were not shown because too many files have changed in this diff Show More
Loading…
Reference in New Issue