2018-06-25 12:26:10 -07:00

378 lines
7.4 KiB
Go

/*
Copyright 2014 SAP SE
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
*/
package driver
import (
"database/sql/driver"
"errors"
"fmt"
"math"
"math/big"
"sync"
)
//bigint word size (*--> src/pkg/math/big/arith.go)
const (
// Compute the size _S of a Word in bytes.
_m = ^big.Word(0)
_logS = _m>>8&1 + _m>>16&1 + _m>>32&1
_S = 1 << _logS
)
const (
// http://en.wikipedia.org/wiki/Decimal128_floating-point_format
dec128Digits = 34
dec128Bias = 6176
dec128MinExp = -6176
dec128MaxExp = 6111
)
const (
decimalSize = 16 //number of bytes
)
var natZero = big.NewInt(0)
var natOne = big.NewInt(1)
var natTen = big.NewInt(10)
var nat = []*big.Int{
natOne, //10^0
natTen, //10^1
big.NewInt(100), //10^2
big.NewInt(1000), //10^3
big.NewInt(10000), //10^4
big.NewInt(100000), //10^5
big.NewInt(1000000), //10^6
big.NewInt(10000000), //10^7
big.NewInt(100000000), //10^8
big.NewInt(1000000000), //10^9
big.NewInt(10000000000), //10^10
}
const lg10 = math.Ln10 / math.Ln2 // ~log2(10)
var maxDecimal = new(big.Int).SetBytes([]byte{0x01, 0xED, 0x09, 0xBE, 0xAD, 0x87, 0xC0, 0x37, 0x8D, 0x8E, 0x63, 0xFF, 0xFF, 0xFF, 0xFF})
type decFlags byte
const (
dfNotExact decFlags = 1 << iota
dfOverflow
dfUnderflow
)
// ErrDecimalOutOfRange means that a big.Rat exceeds the size of hdb decimal fields.
var ErrDecimalOutOfRange = errors.New("decimal out of range error")
// big.Int free list
var bigIntFree = sync.Pool{
New: func() interface{} { return new(big.Int) },
}
// big.Rat free list
var bigRatFree = sync.Pool{
New: func() interface{} { return new(big.Rat) },
}
// A Decimal is the driver representation of a database decimal field value as big.Rat.
type Decimal big.Rat
// Scan implements the database/sql/Scanner interface.
func (d *Decimal) Scan(src interface{}) error {
b, ok := src.([]byte)
if !ok {
return fmt.Errorf("decimal: invalid data type %T", src)
}
if len(b) != decimalSize {
return fmt.Errorf("decimal: invalid size %d of %v - %d expected", len(b), b, decimalSize)
}
if (b[15] & 0x60) == 0x60 {
return fmt.Errorf("decimal: format (infinity, nan, ...) not supported : %v", b)
}
v := (*big.Rat)(d)
p := v.Num()
q := v.Denom()
neg, exp := decodeDecimal(b, p)
switch {
case exp < 0:
q.Set(exp10(exp * -1))
case exp == 0:
q.Set(natOne)
case exp > 0:
p.Mul(p, exp10(exp))
q.Set(natOne)
}
if neg {
v.Neg(v)
}
return nil
}
// Value implements the database/sql/Valuer interface.
func (d Decimal) Value() (driver.Value, error) {
m := bigIntFree.Get().(*big.Int)
neg, exp, df := convertRatToDecimal((*big.Rat)(&d), m, dec128Digits, dec128MinExp, dec128MaxExp)
var v driver.Value
var err error
switch {
default:
v, err = encodeDecimal(m, neg, exp)
case df&dfUnderflow != 0: // set to zero
m.Set(natZero)
v, err = encodeDecimal(m, false, 0)
case df&dfOverflow != 0:
err = ErrDecimalOutOfRange
}
// performance (avoid expensive defer)
bigIntFree.Put(m)
return v, err
}
func convertRatToDecimal(x *big.Rat, m *big.Int, digits, minExp, maxExp int) (bool, int, decFlags) {
neg := x.Sign() < 0 //store sign
if x.Num().Cmp(natZero) == 0 { // zero
m.Set(natZero)
return neg, 0, 0
}
c := bigRatFree.Get().(*big.Rat).Abs(x) // copy && abs
a := c.Num()
b := c.Denom()
exp, shift := 0, 0
if c.IsInt() {
exp = digits10(a) - 1
} else {
shift = digits10(a) - digits10(b)
switch {
case shift < 0:
a.Mul(a, exp10(shift*-1))
case shift > 0:
b.Mul(b, exp10(shift))
}
if a.Cmp(b) == -1 {
exp = shift - 1
} else {
exp = shift
}
}
var df decFlags
switch {
default:
exp = max(exp-digits+1, minExp)
case exp < minExp:
df |= dfUnderflow
exp = exp - digits + 1
}
if exp > maxExp {
df |= dfOverflow
}
shift = exp - shift
switch {
case shift < 0:
a.Mul(a, exp10(shift*-1))
case exp > 0:
b.Mul(b, exp10(shift))
}
m.QuoRem(a, b, a) // reuse a as rest
if a.Cmp(natZero) != 0 {
// round (business >= 0.5 up)
df |= dfNotExact
if a.Add(a, a).Cmp(b) >= 0 {
m.Add(m, natOne)
if m.Cmp(exp10(digits)) == 0 {
shift := min(digits, maxExp-exp)
if shift < 1 { // overflow -> shift one at minimum
df |= dfOverflow
shift = 1
}
m.Set(exp10(digits - shift))
exp += shift
}
}
}
// norm
for exp < maxExp {
a.QuoRem(m, natTen, b) // reuse a, b
if b.Cmp(natZero) != 0 {
break
}
m.Set(a)
exp++
}
// performance (avoid expensive defer)
bigRatFree.Put(c)
return neg, exp, df
}
func min(a, b int) int {
if a < b {
return a
}
return b
}
func max(a, b int) int {
if a > b {
return a
}
return b
}
// performance: tested with reference work variable
// - but int.Set is expensive, so let's live with big.Int creation for n >= len(nat)
func exp10(n int) *big.Int {
if n < len(nat) {
return nat[n]
}
r := big.NewInt(int64(n))
return r.Exp(natTen, r, nil)
}
func digits10(p *big.Int) int {
k := p.BitLen() // 2^k <= p < 2^(k+1) - 1
//i := int(float64(k) / lg10) //minimal digits base 10
//i := int(float64(k) / lg10) //minimal digits base 10
i := k * 100 / 332
if i < 1 {
i = 1
}
for ; ; i++ {
if p.Cmp(exp10(i)) < 0 {
return i
}
}
}
func decodeDecimal(b []byte, m *big.Int) (bool, int) {
neg := (b[15] & 0x80) != 0
exp := int((((uint16(b[15])<<8)|uint16(b[14]))<<1)>>2) - dec128Bias
b14 := b[14] // save b[14]
b[14] &= 0x01 // keep the mantissa bit (rest: sign and exp)
//most significand byte
msb := 14
for msb > 0 {
if b[msb] != 0 {
break
}
msb--
}
//calc number of words
numWords := (msb / _S) + 1
w := make([]big.Word, numWords)
k := numWords - 1
d := big.Word(0)
for i := msb; i >= 0; i-- {
d |= big.Word(b[i])
if k*_S == i {
w[k] = d
k--
d = 0
}
d <<= 8
}
b[14] = b14 // restore b[14]
m.SetBits(w)
return neg, exp
}
func encodeDecimal(m *big.Int, neg bool, exp int) (driver.Value, error) {
b := make([]byte, decimalSize)
// little endian bigint words (significand) -> little endian db decimal format
j := 0
for _, d := range m.Bits() {
for i := 0; i < 8; i++ {
b[j] = byte(d)
d >>= 8
j++
}
}
exp += dec128Bias
b[14] |= (byte(exp) << 1)
b[15] = byte(uint16(exp) >> 7)
if neg {
b[15] |= 0x80
}
return b, nil
}
// NullDecimal represents an Decimal that may be null.
// NullDecimal implements the Scanner interface so
// it can be used as a scan destination, similar to NullString.
type NullDecimal struct {
Decimal *Decimal
Valid bool // Valid is true if Decimal is not NULL
}
// Scan implements the Scanner interface.
func (n *NullDecimal) Scan(value interface{}) error {
var b []byte
b, n.Valid = value.([]byte)
if !n.Valid {
return nil
}
if n.Decimal == nil {
return fmt.Errorf("invalid decimal value %v", n.Decimal)
}
return n.Decimal.Scan(b)
}
// Value implements the driver Valuer interface.
func (n NullDecimal) Value() (driver.Value, error) {
if !n.Valid {
return nil, nil
}
if n.Decimal == nil {
return nil, fmt.Errorf("invalid decimal value %v", n.Decimal)
}
return n.Decimal.Value()
}