161 lines
4.5 KiB
Go
161 lines
4.5 KiB
Go
//go:build amd64 && !purego
|
|
// +build amd64,!purego
|
|
|
|
package base64
|
|
|
|
import (
|
|
"encoding/base64"
|
|
|
|
"github.com/segmentio/asm/cpu"
|
|
"github.com/segmentio/asm/cpu/x86"
|
|
"github.com/segmentio/asm/internal/unsafebytes"
|
|
)
|
|
|
|
// An Encoding is a radix 64 encoding/decoding scheme, defined by a
|
|
// 64-character alphabet.
|
|
type Encoding struct {
|
|
enc func(dst []byte, src []byte, lut *int8) (int, int)
|
|
enclut [32]int8
|
|
|
|
dec func(dst []byte, src []byte, lut *int8) (int, int)
|
|
declut [48]int8
|
|
|
|
base *base64.Encoding
|
|
}
|
|
|
|
const (
|
|
minEncodeLen = 28
|
|
minDecodeLen = 45
|
|
)
|
|
|
|
func newEncoding(encoder string) *Encoding {
|
|
e := &Encoding{base: base64.NewEncoding(encoder)}
|
|
if cpu.X86.Has(x86.AVX2) {
|
|
e.enableEncodeAVX2(encoder)
|
|
e.enableDecodeAVX2(encoder)
|
|
}
|
|
return e
|
|
}
|
|
|
|
func (e *Encoding) enableEncodeAVX2(encoder string) {
|
|
// Translate values 0..63 to the Base64 alphabet. There are five sets:
|
|
//
|
|
// From To Add Index Example
|
|
// [0..25] [65..90] +65 0 ABCDEFGHIJKLMNOPQRSTUVWXYZ
|
|
// [26..51] [97..122] +71 1 abcdefghijklmnopqrstuvwxyz
|
|
// [52..61] [48..57] -4 [2..11] 0123456789
|
|
// [62] [43] -19 12 +
|
|
// [63] [47] -16 13 /
|
|
tab := [32]int8{int8(encoder[0]), int8(encoder[letterRange]) - letterRange}
|
|
for i, ch := range encoder[2*letterRange:] {
|
|
tab[2+i] = int8(ch) - 2*letterRange - int8(i)
|
|
}
|
|
|
|
e.enc = encodeAVX2
|
|
e.enclut = tab
|
|
}
|
|
|
|
func (e *Encoding) enableDecodeAVX2(encoder string) {
|
|
c62, c63 := int8(encoder[62]), int8(encoder[63])
|
|
url := c63 == '_'
|
|
if url {
|
|
c63 = '/'
|
|
}
|
|
|
|
// Translate values from the Base64 alphabet using five sets. Values outside
|
|
// of these ranges are considered invalid:
|
|
//
|
|
// From To Add Index Example
|
|
// [47] [63] +16 1 /
|
|
// [43] [62] +19 2 +
|
|
// [48..57] [52..61] +4 3 0123456789
|
|
// [65..90] [0..25] -65 4,5 ABCDEFGHIJKLMNOPQRSTUVWXYZ
|
|
// [97..122] [26..51] -71 6,7 abcdefghijklmnopqrstuvwxyz
|
|
tab := [48]int8{
|
|
0, 63 - c63, 62 - c62, 4, -65, -65, -71, -71,
|
|
0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00,
|
|
0x15, 0x11, 0x11, 0x11, 0x11, 0x11, 0x11, 0x11,
|
|
0x11, 0x11, 0x13, 0x1B, 0x1B, 0x1B, 0x1B, 0x1B,
|
|
}
|
|
tab[(c62&15)+16] = 0x1A
|
|
tab[(c63&15)+16] = 0x1A
|
|
|
|
if url {
|
|
e.dec = decodeAVX2URI
|
|
} else {
|
|
e.dec = decodeAVX2
|
|
}
|
|
e.declut = tab
|
|
}
|
|
|
|
// WithPadding creates a duplicate Encoding updated with a specified padding
|
|
// character, or NoPadding to disable padding. The padding character must not
|
|
// be contained in the encoding alphabet, must not be '\r' or '\n', and must
|
|
// be no greater than '\xFF'.
|
|
func (enc Encoding) WithPadding(padding rune) *Encoding {
|
|
enc.base = enc.base.WithPadding(padding)
|
|
return &enc
|
|
}
|
|
|
|
// Strict creates a duplicate encoding updated with strict decoding enabled.
|
|
// This requires that trailing padding bits are zero.
|
|
func (enc Encoding) Strict() *Encoding {
|
|
enc.base = enc.base.Strict()
|
|
return &enc
|
|
}
|
|
|
|
// Encode encodes src using the defined encoding alphabet.
|
|
// This will write EncodedLen(len(src)) bytes to dst.
|
|
func (enc *Encoding) Encode(dst, src []byte) {
|
|
if len(src) >= minEncodeLen && enc.enc != nil {
|
|
d, s := enc.enc(dst, src, &enc.enclut[0])
|
|
dst = dst[d:]
|
|
src = src[s:]
|
|
}
|
|
enc.base.Encode(dst, src)
|
|
}
|
|
|
|
// Encode encodes src using the encoding enc, writing
|
|
// EncodedLen(len(src)) bytes to dst.
|
|
func (enc *Encoding) EncodeToString(src []byte) string {
|
|
buf := make([]byte, enc.base.EncodedLen(len(src)))
|
|
enc.Encode(buf, src)
|
|
return string(buf)
|
|
}
|
|
|
|
// EncodedLen calculates the base64-encoded byte length for a message
|
|
// of length n.
|
|
func (enc *Encoding) EncodedLen(n int) int {
|
|
return enc.base.EncodedLen(n)
|
|
}
|
|
|
|
// Decode decodes src using the defined encoding alphabet.
|
|
// This will write DecodedLen(len(src)) bytes to dst and return the number of
|
|
// bytes written.
|
|
func (enc *Encoding) Decode(dst, src []byte) (n int, err error) {
|
|
var d, s int
|
|
if len(src) >= minDecodeLen && enc.dec != nil {
|
|
d, s = enc.dec(dst, src, &enc.declut[0])
|
|
dst = dst[d:]
|
|
src = src[s:]
|
|
}
|
|
n, err = enc.base.Decode(dst, src)
|
|
n += d
|
|
return
|
|
}
|
|
|
|
// DecodeString decodes the base64 encoded string s, returns the decoded
|
|
// value as bytes.
|
|
func (enc *Encoding) DecodeString(s string) ([]byte, error) {
|
|
src := unsafebytes.BytesOf(s)
|
|
dst := make([]byte, enc.base.DecodedLen(len(s)))
|
|
n, err := enc.Decode(dst, src)
|
|
return dst[:n], err
|
|
}
|
|
|
|
// DecodedLen calculates the decoded byte length for a base64-encoded message
|
|
// of length n.
|
|
func (enc *Encoding) DecodedLen(n int) int {
|
|
return enc.base.DecodedLen(n)
|
|
}
|