508 lines
11 KiB
Go
Raw Normal View History

2018-06-21 10:49:35 -04:00
// +build !future
/*
Copyright 2018 SAP SE
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
*/
package driver
import (
"context"
"database/sql/driver"
"encoding/binary"
"fmt"
"io"
"regexp"
"sync"
"github.com/SAP/go-hdb/driver/sqltrace"
p "github.com/SAP/go-hdb/internal/protocol"
)
var reBulk = regexp.MustCompile("(?i)^(\\s)*(bulk +)(.*)")
func checkBulkInsert(sql string) (string, bool) {
if reBulk.MatchString(sql) {
return reBulk.ReplaceAllString(sql, "${3}"), true
}
return sql, false
}
var reCall = regexp.MustCompile("(?i)^(\\s)*(call +)(.*)")
func checkCallProcedure(sql string) bool {
return reCall.MatchString(sql)
}
// database connection
func (c *conn) PrepareContext(ctx context.Context, query string) (stmt driver.Stmt, err error) {
if c.session.IsBad() {
return nil, driver.ErrBadConn
}
done := make(chan struct{})
go func() {
prepareQuery, bulkInsert := checkBulkInsert(query)
var (
qt p.QueryType
id uint64
prmFieldSet *p.ParameterFieldSet
resultFieldSet *p.ResultFieldSet
)
qt, id, prmFieldSet, resultFieldSet, err = c.session.Prepare(prepareQuery)
if err != nil {
goto done
}
select {
default:
case <-ctx.Done():
return
}
if bulkInsert {
stmt, err = newBulkInsertStmt(c.session, prepareQuery, id, prmFieldSet)
} else {
stmt, err = newStmt(qt, c.session, prepareQuery, id, prmFieldSet, resultFieldSet)
}
done:
close(done)
}()
select {
case <-ctx.Done():
return nil, ctx.Err()
case <-done:
return stmt, err
}
}
// QueryContext implements the database/sql/driver/QueryerContext interface.
func (c *conn) QueryContext(ctx context.Context, query string, args []driver.NamedValue) (rows driver.Rows, err error) {
if c.session.IsBad() {
return nil, driver.ErrBadConn
}
if len(args) != 0 {
return nil, driver.ErrSkip //fast path not possible (prepare needed)
}
// direct execution of call procedure
// - returns no parameter metadata (sps 82) but only field values
// --> let's take the 'prepare way' for stored procedures
if checkCallProcedure(query) {
return nil, driver.ErrSkip
}
sqltrace.Traceln(query)
id, idx, ok := decodeTableQuery(query)
if ok {
r := procedureCallResultStore.get(id)
if r == nil {
return nil, fmt.Errorf("invalid procedure table query %s", query)
}
return r.tableRows(int(idx))
}
done := make(chan struct{})
go func() {
var (
id uint64
resultFieldSet *p.ResultFieldSet
fieldValues *p.FieldValues
attributes p.PartAttributes
)
id, resultFieldSet, fieldValues, attributes, err = c.session.QueryDirect(query)
if err != nil {
goto done
}
select {
default:
case <-ctx.Done():
return
}
if id == 0 { // non select query
rows = noResult
} else {
rows, err = newQueryResult(c.session, id, resultFieldSet, fieldValues, attributes)
}
done:
close(done)
}()
select {
case <-ctx.Done():
return nil, ctx.Err()
case <-done:
return rows, err
}
}
//statement
func (s *stmt) QueryContext(ctx context.Context, args []driver.NamedValue) (rows driver.Rows, err error) {
if s.session.IsBad() {
return nil, driver.ErrBadConn
}
done := make(chan struct{})
go func() {
switch s.qt {
default:
rows, err = s.defaultQuery(ctx, args)
case p.QtProcedureCall:
rows, err = s.procedureCall(ctx, args)
}
close(done)
}()
select {
case <-ctx.Done():
return nil, ctx.Err()
case <-done:
return rows, err
}
}
func (s *stmt) defaultQuery(ctx context.Context, args []driver.NamedValue) (driver.Rows, error) {
sqltrace.Tracef("%s %v", s.query, args)
rid, values, attributes, err := s.session.Query(s.id, s.prmFieldSet, s.resultFieldSet, args)
if err != nil {
return nil, err
}
select {
default:
case <-ctx.Done():
return nil, ctx.Err()
}
if rid == 0 { // non select query
return noResult, nil
}
return newQueryResult(s.session, rid, s.resultFieldSet, values, attributes)
}
func (s *stmt) procedureCall(ctx context.Context, args []driver.NamedValue) (driver.Rows, error) {
sqltrace.Tracef("%s %v", s.query, args)
fieldValues, tableResults, err := s.session.Call(s.id, s.prmFieldSet, args)
if err != nil {
return nil, err
}
select {
default:
case <-ctx.Done():
return nil, ctx.Err()
}
return newProcedureCallResult(s.session, s.prmFieldSet, fieldValues, tableResults)
}
// bulk insert statement
// check if bulkInsertStmt implements all required interfaces
var (
_ driver.Stmt = (*bulkInsertStmt)(nil)
_ driver.StmtExecContext = (*bulkInsertStmt)(nil)
_ driver.StmtQueryContext = (*bulkInsertStmt)(nil)
_ driver.NamedValueChecker = (*bulkInsertStmt)(nil)
)
type bulkInsertStmt struct {
session *p.Session
query string
id uint64
prmFieldSet *p.ParameterFieldSet
numArg int
args []driver.NamedValue
}
func newBulkInsertStmt(session *p.Session, query string, id uint64, prmFieldSet *p.ParameterFieldSet) (*bulkInsertStmt, error) {
return &bulkInsertStmt{session: session, query: query, id: id, prmFieldSet: prmFieldSet, args: make([]driver.NamedValue, 0)}, nil
}
func (s *bulkInsertStmt) Close() error {
return s.session.DropStatementID(s.id)
}
func (s *bulkInsertStmt) NumInput() int {
return -1
}
func (s *bulkInsertStmt) Exec(args []driver.Value) (driver.Result, error) {
panic("deprecated")
}
func (s *bulkInsertStmt) ExecContext(ctx context.Context, args []driver.NamedValue) (r driver.Result, err error) {
if s.session.IsBad() {
return nil, driver.ErrBadConn
}
sqltrace.Tracef("%s %v", s.query, args)
done := make(chan struct{})
go func() {
if args == nil || len(args) == 0 {
r, err = s.execFlush()
} else {
r, err = s.execBuffer(args)
}
close(done)
}()
select {
case <-ctx.Done():
return nil, ctx.Err()
case <-done:
return r, err
}
}
func (s *bulkInsertStmt) execFlush() (driver.Result, error) {
if s.numArg == 0 {
return driver.ResultNoRows, nil
}
sqltrace.Traceln("execFlush")
result, err := s.session.Exec(s.id, s.prmFieldSet, s.args)
s.args = s.args[:0]
s.numArg = 0
return result, err
}
func (s *bulkInsertStmt) execBuffer(args []driver.NamedValue) (driver.Result, error) {
numField := s.prmFieldSet.NumInputField()
if len(args) != numField {
return nil, fmt.Errorf("invalid number of arguments %d - %d expected", len(args), numField)
}
var result driver.Result = driver.ResultNoRows
var err error
if s.numArg == maxSmallint { // TODO: check why bigArgument count does not work
result, err = s.execFlush()
}
s.args = append(s.args, args...)
s.numArg++
return result, err
}
func (s *bulkInsertStmt) Query(args []driver.Value) (driver.Rows, error) {
panic("deprecated")
}
func (s *bulkInsertStmt) QueryContext(ctx context.Context, args []driver.NamedValue) (driver.Rows, error) {
return nil, fmt.Errorf("query not allowed in context of bulk insert statement %s", s.query)
}
// Deprecated: see NamedValueChecker.
//func (s *bulkInsertStmt) ColumnConverter(idx int) driver.ValueConverter {
//}
// CheckNamedValue implements NamedValueChecker interface.
func (s *bulkInsertStmt) CheckNamedValue(nv *driver.NamedValue) error {
return checkNamedValue(s.prmFieldSet, nv)
}
//call result store
type callResultStore struct {
mu sync.RWMutex
store map[uint64]*procedureCallResult
cnt uint64
free []uint64
}
func (s *callResultStore) get(k uint64) *procedureCallResult {
s.mu.RLock()
defer s.mu.RUnlock()
if r, ok := s.store[k]; ok {
return r
}
return nil
}
func (s *callResultStore) add(v *procedureCallResult) uint64 {
s.mu.Lock()
defer s.mu.Unlock()
var k uint64
if s.free == nil || len(s.free) == 0 {
s.cnt++
k = s.cnt
} else {
size := len(s.free)
k = s.free[size-1]
s.free = s.free[:size-1]
}
if s.store == nil {
s.store = make(map[uint64]*procedureCallResult)
}
s.store[k] = v
return k
}
func (s *callResultStore) del(k uint64) {
s.mu.Lock()
defer s.mu.Unlock()
delete(s.store, k)
if s.free == nil {
s.free = []uint64{k}
} else {
s.free = append(s.free, k)
}
}
var procedureCallResultStore = new(callResultStore)
//procedure call result
// check if procedureCallResult implements all required interfaces
var _ driver.Rows = (*procedureCallResult)(nil)
type procedureCallResult struct {
id uint64
session *p.Session
prmFieldSet *p.ParameterFieldSet
fieldValues *p.FieldValues
_tableRows []driver.Rows
columns []string
eof error
}
func newProcedureCallResult(session *p.Session, prmFieldSet *p.ParameterFieldSet, fieldValues *p.FieldValues, tableResults []*p.TableResult) (driver.Rows, error) {
fieldIdx := prmFieldSet.NumOutputField()
columns := make([]string, fieldIdx+len(tableResults))
for i := 0; i < fieldIdx; i++ {
columns[i] = prmFieldSet.OutputField(i).Name()
}
tableRows := make([]driver.Rows, len(tableResults))
for i, tableResult := range tableResults {
var err error
if tableRows[i], err = newQueryResult(session, tableResult.ID(), tableResult.FieldSet(), tableResult.FieldValues(), tableResult.Attrs()); err != nil {
return nil, err
}
columns[fieldIdx] = fmt.Sprintf("table %d", i)
fieldIdx++
}
result := &procedureCallResult{
session: session,
prmFieldSet: prmFieldSet,
fieldValues: fieldValues,
_tableRows: tableRows,
columns: columns,
}
id := procedureCallResultStore.add(result)
result.id = id
return result, nil
}
func (r *procedureCallResult) Columns() []string {
return r.columns
}
func (r *procedureCallResult) Close() error {
procedureCallResultStore.del(r.id)
return nil
}
func (r *procedureCallResult) Next(dest []driver.Value) error {
if r.session.IsBad() {
return driver.ErrBadConn
}
if r.eof != nil {
return r.eof
}
if r.fieldValues.NumRow() == 0 && len(r._tableRows) == 0 {
r.eof = io.EOF
return r.eof
}
if r.fieldValues.NumRow() != 0 {
r.fieldValues.Row(0, dest)
}
i := r.prmFieldSet.NumOutputField()
for j := range r._tableRows {
dest[i] = encodeTableQuery(r.id, uint64(j))
i++
}
r.eof = io.EOF
return nil
}
func (r *procedureCallResult) tableRows(idx int) (driver.Rows, error) {
if idx >= len(r._tableRows) {
return nil, fmt.Errorf("table row index %d exceeds maximun %d", idx, len(r._tableRows)-1)
}
return r._tableRows[idx], nil
}
// helper
const tableQueryPrefix = "@tq"
func encodeTableQuery(id, idx uint64) string {
start := len(tableQueryPrefix)
b := make([]byte, start+8+8)
copy(b, tableQueryPrefix)
binary.LittleEndian.PutUint64(b[start:start+8], id)
binary.LittleEndian.PutUint64(b[start+8:start+8+8], idx)
return string(b)
}
func decodeTableQuery(query string) (uint64, uint64, bool) {
size := len(query)
start := len(tableQueryPrefix)
if size != start+8+8 {
return 0, 0, false
}
if query[:start] != tableQueryPrefix {
return 0, 0, false
}
id := binary.LittleEndian.Uint64([]byte(query[start : start+8]))
idx := binary.LittleEndian.Uint64([]byte(query[start+8 : start+8+8]))
return id, idx, true
}