2018-06-25 12:26:10 -07:00

364 lines
7.4 KiB
Go

/*
Copyright 2014 SAP SE
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
*/
package driver
import (
"database/sql/driver"
"errors"
"fmt"
"math"
"reflect"
"time"
p "github.com/SAP/go-hdb/internal/protocol"
)
const (
minTinyint = 0
maxTinyint = math.MaxUint8
minSmallint = math.MinInt16
maxSmallint = math.MaxInt16
minInteger = math.MinInt32
maxInteger = math.MaxInt32
minBigint = math.MinInt64
maxBigint = math.MaxInt64
maxReal = math.MaxFloat32
maxDouble = math.MaxFloat64
)
// ErrIntegerOutOfRange means that an integer exceeds the size of the hdb integer field.
var ErrIntegerOutOfRange = errors.New("integer out of range error")
// ErrFloatOutOfRange means that a float exceeds the size of the hdb float field.
var ErrFloatOutOfRange = errors.New("float out of range error")
var typeOfTime = reflect.TypeOf((*time.Time)(nil)).Elem()
var typeOfBytes = reflect.TypeOf((*[]byte)(nil)).Elem()
func checkNamedValue(prmFieldSet *p.ParameterFieldSet, nv *driver.NamedValue) error {
idx := nv.Ordinal - 1
if idx >= prmFieldSet.NumInputField() {
return nil
}
f := prmFieldSet.Field(idx)
dt := f.TypeCode().DataType()
value, err := convertNamedValue(idx, f, dt, nv.Value)
if err != nil {
return err
}
nv.Value = value
return nil
}
func convertNamedValue(idx int, f *p.ParameterField, dt p.DataType, v driver.Value) (driver.Value, error) {
var err error
// let fields with own Value converter convert themselves first (e.g. NullInt64, ...)
if _, ok := v.(driver.Valuer); ok {
if v, err = driver.DefaultParameterConverter.ConvertValue(v); err != nil {
return nil, err
}
}
switch dt {
default:
return nil, fmt.Errorf("convert named value datatype error: %[1]d - %[1]s", dt)
case p.DtTinyint:
return convertNvInteger(v, minTinyint, maxTinyint)
case p.DtSmallint:
return convertNvInteger(v, minSmallint, maxSmallint)
case p.DtInteger:
return convertNvInteger(v, minInteger, maxInteger)
case p.DtBigint:
return convertNvInteger(v, minBigint, maxBigint)
case p.DtReal:
return convertNvFloat(v, maxReal)
case p.DtDouble:
return convertNvFloat(v, maxDouble)
case p.DtTime:
return convertNvTime(v)
case p.DtDecimal:
return convertNvDecimal(v)
case p.DtString:
return convertNvString(v)
case p.DtBytes:
return convertNvBytes(v)
case p.DtLob:
return convertNvLob(idx, f, v)
}
}
// integer types
func convertNvInteger(v interface{}, min, max int64) (driver.Value, error) {
if v == nil {
return v, nil
}
rv := reflect.ValueOf(v)
switch rv.Kind() {
// bool is represented in HDB as tinyint
case reflect.Bool:
return rv.Bool(), nil
case reflect.Int, reflect.Int8, reflect.Int16, reflect.Int32, reflect.Int64:
i64 := rv.Int()
if i64 > max || i64 < min {
return nil, ErrIntegerOutOfRange
}
return i64, nil
case reflect.Uint, reflect.Uint8, reflect.Uint16, reflect.Uint32, reflect.Uint64:
u64 := rv.Uint()
if u64 > uint64(max) {
return nil, ErrIntegerOutOfRange
}
return int64(u64), nil
case reflect.Ptr:
// indirect pointers
if rv.IsNil() {
return nil, nil
}
return convertNvInteger(rv.Elem().Interface(), min, max)
}
return nil, fmt.Errorf("unsupported integer conversion type error %[1]T %[1]v", v)
}
// float types
func convertNvFloat(v interface{}, max float64) (driver.Value, error) {
if v == nil {
return v, nil
}
rv := reflect.ValueOf(v)
switch rv.Kind() {
case reflect.Float32, reflect.Float64:
f64 := rv.Float()
if math.Abs(f64) > max {
return nil, ErrFloatOutOfRange
}
return f64, nil
case reflect.Ptr:
// indirect pointers
if rv.IsNil() {
return nil, nil
}
return convertNvFloat(rv.Elem().Interface(), max)
}
return nil, fmt.Errorf("unsupported float conversion type error %[1]T %[1]v", v)
}
// time
func convertNvTime(v interface{}) (driver.Value, error) {
if v == nil {
return nil, nil
}
switch v := v.(type) {
case time.Time:
return v, nil
}
rv := reflect.ValueOf(v)
switch rv.Kind() {
case reflect.Ptr:
// indirect pointers
if rv.IsNil() {
return nil, nil
}
return convertNvTime(rv.Elem().Interface())
}
if rv.Type().ConvertibleTo(typeOfTime) {
tv := rv.Convert(typeOfTime)
return tv.Interface().(time.Time), nil
}
return nil, fmt.Errorf("unsupported time conversion type error %[1]T %[1]v", v)
}
// decimal
func convertNvDecimal(v interface{}) (driver.Value, error) {
if v == nil {
return nil, nil
}
if v, ok := v.([]byte); ok {
return v, nil
}
return nil, fmt.Errorf("unsupported decimal conversion type error %[1]T %[1]v", v)
}
// string
func convertNvString(v interface{}) (driver.Value, error) {
if v == nil {
return v, nil
}
switch v := v.(type) {
case string, []byte:
return v, nil
}
rv := reflect.ValueOf(v)
switch rv.Kind() {
case reflect.String:
return rv.String(), nil
case reflect.Slice:
if rv.Type() == typeOfBytes {
return rv.Bytes(), nil
}
case reflect.Ptr:
// indirect pointers
if rv.IsNil() {
return nil, nil
}
return convertNvString(rv.Elem().Interface())
}
if rv.Type().ConvertibleTo(typeOfBytes) {
bv := rv.Convert(typeOfBytes)
return bv.Interface().([]byte), nil
}
return nil, fmt.Errorf("unsupported character conversion type error %[1]T %[1]v", v)
}
// bytes
func convertNvBytes(v interface{}) (driver.Value, error) {
if v == nil {
return v, nil
}
if v, ok := v.([]byte); ok {
return v, nil
}
rv := reflect.ValueOf(v)
switch rv.Kind() {
case reflect.Slice:
if rv.Type() == typeOfBytes {
return rv.Bytes(), nil
}
case reflect.Ptr:
// indirect pointers
if rv.IsNil() {
return nil, nil
}
return convertNvBytes(rv.Elem().Interface())
}
if rv.Type().ConvertibleTo(typeOfBytes) {
bv := rv.Convert(typeOfBytes)
return bv.Interface().([]byte), nil
}
return nil, fmt.Errorf("unsupported bytes conversion type error %[1]T %[1]v", v)
}
// Lob
func convertNvLob(idx int, f *p.ParameterField, v interface{}) (driver.Value, error) {
if v == nil {
return v, nil
}
switch v := v.(type) {
case Lob:
if v.rd == nil {
return nil, fmt.Errorf("lob error: initial reader %[1]T %[1]v", v)
}
f.SetLobReader(v.rd)
return fmt.Sprintf("<lob %d", idx), nil
case *Lob:
if v.rd == nil {
return nil, fmt.Errorf("lob error: initial reader %[1]T %[1]v", v)
}
f.SetLobReader(v.rd)
return fmt.Sprintf("<lob %d", idx), nil
case NullLob:
if !v.Valid {
return nil, nil
}
if v.Lob.rd == nil {
return nil, fmt.Errorf("lob error: initial reader %[1]T %[1]v", v)
}
f.SetLobReader(v.Lob.rd)
return fmt.Sprintf("<lob %d", idx), nil
case *NullLob:
if !v.Valid {
return nil, nil
}
if v.Lob.rd == nil {
return nil, fmt.Errorf("lob error: initial reader %[1]T %[1]v", v)
}
f.SetLobReader(v.Lob.rd)
return fmt.Sprintf("<lob %d", idx), nil
}
rv := reflect.ValueOf(v)
switch rv.Kind() {
case reflect.Ptr:
// indirect pointers
if rv.IsNil() {
return nil, nil
}
return convertNvLob(idx, f, rv.Elem().Interface())
}
return nil, fmt.Errorf("unsupported lob conversion type error %[1]T %[1]v", v)
}