704 lines
20 KiB
Python
704 lines
20 KiB
Python
|
"""
|
||
|
Database API
|
||
|
(part of web.py)
|
||
|
"""
|
||
|
|
||
|
# todo:
|
||
|
# - test with sqlite
|
||
|
# - a store function?
|
||
|
|
||
|
__all__ = [
|
||
|
"UnknownParamstyle", "UnknownDB",
|
||
|
"sqllist", "sqlors", "aparam", "reparam",
|
||
|
"SQLQuery", "sqlquote",
|
||
|
"SQLLiteral", "sqlliteral",
|
||
|
"connect",
|
||
|
"TransactionError", "transaction", "transact", "commit", "rollback",
|
||
|
"query",
|
||
|
"select", "insert", "update", "delete"
|
||
|
]
|
||
|
|
||
|
import time
|
||
|
try: import datetime
|
||
|
except ImportError: datetime = None
|
||
|
|
||
|
from utils import storage, iters, iterbetter
|
||
|
import webapi as web
|
||
|
|
||
|
try:
|
||
|
from DBUtils import PooledDB
|
||
|
web.config._hasPooling = True
|
||
|
except ImportError:
|
||
|
web.config._hasPooling = False
|
||
|
|
||
|
class _ItplError(ValueError):
|
||
|
def __init__(self, text, pos):
|
||
|
ValueError.__init__(self)
|
||
|
self.text = text
|
||
|
self.pos = pos
|
||
|
def __str__(self):
|
||
|
return "unfinished expression in %s at char %d" % (
|
||
|
repr(self.text), self.pos)
|
||
|
|
||
|
def _interpolate(format):
|
||
|
"""
|
||
|
Takes a format string and returns a list of 2-tuples of the form
|
||
|
(boolean, string) where boolean says whether string should be evaled
|
||
|
or not.
|
||
|
|
||
|
from <http://lfw.org/python/Itpl.py> (public domain, Ka-Ping Yee)
|
||
|
"""
|
||
|
from tokenize import tokenprog
|
||
|
|
||
|
def matchorfail(text, pos):
|
||
|
match = tokenprog.match(text, pos)
|
||
|
if match is None:
|
||
|
raise _ItplError(text, pos)
|
||
|
return match, match.end()
|
||
|
|
||
|
namechars = "abcdefghijklmnopqrstuvwxyz" \
|
||
|
"ABCDEFGHIJKLMNOPQRSTUVWXYZ0123456789_";
|
||
|
chunks = []
|
||
|
pos = 0
|
||
|
|
||
|
while 1:
|
||
|
dollar = format.find("$", pos)
|
||
|
if dollar < 0:
|
||
|
break
|
||
|
nextchar = format[dollar + 1]
|
||
|
|
||
|
if nextchar == "{":
|
||
|
chunks.append((0, format[pos:dollar]))
|
||
|
pos, level = dollar + 2, 1
|
||
|
while level:
|
||
|
match, pos = matchorfail(format, pos)
|
||
|
tstart, tend = match.regs[3]
|
||
|
token = format[tstart:tend]
|
||
|
if token == "{":
|
||
|
level = level + 1
|
||
|
elif token == "}":
|
||
|
level = level - 1
|
||
|
chunks.append((1, format[dollar + 2:pos - 1]))
|
||
|
|
||
|
elif nextchar in namechars:
|
||
|
chunks.append((0, format[pos:dollar]))
|
||
|
match, pos = matchorfail(format, dollar + 1)
|
||
|
while pos < len(format):
|
||
|
if format[pos] == "." and \
|
||
|
pos + 1 < len(format) and format[pos + 1] in namechars:
|
||
|
match, pos = matchorfail(format, pos + 1)
|
||
|
elif format[pos] in "([":
|
||
|
pos, level = pos + 1, 1
|
||
|
while level:
|
||
|
match, pos = matchorfail(format, pos)
|
||
|
tstart, tend = match.regs[3]
|
||
|
token = format[tstart:tend]
|
||
|
if token[0] in "([":
|
||
|
level = level + 1
|
||
|
elif token[0] in ")]":
|
||
|
level = level - 1
|
||
|
else:
|
||
|
break
|
||
|
chunks.append((1, format[dollar + 1:pos]))
|
||
|
|
||
|
else:
|
||
|
chunks.append((0, format[pos:dollar + 1]))
|
||
|
pos = dollar + 1 + (nextchar == "$")
|
||
|
|
||
|
if pos < len(format):
|
||
|
chunks.append((0, format[pos:]))
|
||
|
return chunks
|
||
|
|
||
|
class UnknownParamstyle(Exception):
|
||
|
"""
|
||
|
raised for unsupported db paramstyles
|
||
|
|
||
|
(currently supported: qmark, numeric, format, pyformat)
|
||
|
"""
|
||
|
pass
|
||
|
|
||
|
def aparam():
|
||
|
"""
|
||
|
Returns the appropriate string to be used to interpolate
|
||
|
a value with the current `web.ctx.db_module` or simply %s
|
||
|
if there isn't one.
|
||
|
|
||
|
>>> aparam()
|
||
|
'%s'
|
||
|
"""
|
||
|
if hasattr(web.ctx, 'db_module'):
|
||
|
style = web.ctx.db_module.paramstyle
|
||
|
else:
|
||
|
style = 'pyformat'
|
||
|
|
||
|
if style == 'qmark':
|
||
|
return '?'
|
||
|
elif style == 'numeric':
|
||
|
return ':1'
|
||
|
elif style in ['format', 'pyformat']:
|
||
|
return '%s'
|
||
|
raise UnknownParamstyle, style
|
||
|
|
||
|
def reparam(string_, dictionary):
|
||
|
"""
|
||
|
Takes a string and a dictionary and interpolates the string
|
||
|
using values from the dictionary. Returns an `SQLQuery` for the result.
|
||
|
|
||
|
>>> reparam("s = $s", dict(s=True))
|
||
|
<sql: "s = 't'">
|
||
|
"""
|
||
|
vals = []
|
||
|
result = []
|
||
|
for live, chunk in _interpolate(string_):
|
||
|
if live:
|
||
|
result.append(aparam())
|
||
|
vals.append(eval(chunk, dictionary))
|
||
|
else: result.append(chunk)
|
||
|
return SQLQuery(''.join(result), vals)
|
||
|
|
||
|
def sqlify(obj):
|
||
|
"""
|
||
|
converts `obj` to its proper SQL version
|
||
|
|
||
|
>>> sqlify(None)
|
||
|
'NULL'
|
||
|
>>> sqlify(True)
|
||
|
"'t'"
|
||
|
>>> sqlify(3)
|
||
|
'3'
|
||
|
"""
|
||
|
|
||
|
# because `1 == True and hash(1) == hash(True)`
|
||
|
# we have to do this the hard way...
|
||
|
|
||
|
if obj is None:
|
||
|
return 'NULL'
|
||
|
elif obj is True:
|
||
|
return "'t'"
|
||
|
elif obj is False:
|
||
|
return "'f'"
|
||
|
elif datetime and isinstance(obj, datetime.datetime):
|
||
|
return repr(obj.isoformat())
|
||
|
else:
|
||
|
return repr(obj)
|
||
|
|
||
|
class SQLQuery:
|
||
|
"""
|
||
|
You can pass this sort of thing as a clause in any db function.
|
||
|
Otherwise, you can pass a dictionary to the keyword argument `vars`
|
||
|
and the function will call reparam for you.
|
||
|
"""
|
||
|
# tested in sqlquote's docstring
|
||
|
def __init__(self, s='', v=()):
|
||
|
self.s, self.v = str(s), tuple(v)
|
||
|
|
||
|
def __getitem__(self, key): # for backwards-compatibility
|
||
|
return [self.s, self.v][key]
|
||
|
|
||
|
def __add__(self, other):
|
||
|
if isinstance(other, str):
|
||
|
self.s += other
|
||
|
elif isinstance(other, SQLQuery):
|
||
|
self.s += other.s
|
||
|
self.v += other.v
|
||
|
return self
|
||
|
|
||
|
def __radd__(self, other):
|
||
|
if isinstance(other, str):
|
||
|
self.s = other + self.s
|
||
|
return self
|
||
|
else:
|
||
|
return NotImplemented
|
||
|
|
||
|
def __str__(self):
|
||
|
try:
|
||
|
return self.s % tuple([sqlify(x) for x in self.v])
|
||
|
except (ValueError, TypeError):
|
||
|
return self.s
|
||
|
|
||
|
def __repr__(self):
|
||
|
return '<sql: %s>' % repr(str(self))
|
||
|
|
||
|
class SQLLiteral:
|
||
|
"""
|
||
|
Protects a string from `sqlquote`.
|
||
|
|
||
|
>>> insert('foo', time=SQLLiteral('NOW()'), _test=True)
|
||
|
<sql: 'INSERT INTO foo (time) VALUES (NOW())'>
|
||
|
"""
|
||
|
def __init__(self, v):
|
||
|
self.v = v
|
||
|
|
||
|
def __repr__(self):
|
||
|
return self.v
|
||
|
|
||
|
sqlliteral = SQLLiteral
|
||
|
|
||
|
def sqlquote(a):
|
||
|
"""
|
||
|
Ensures `a` is quoted properly for use in a SQL query.
|
||
|
|
||
|
>>> 'WHERE x = ' + sqlquote(True) + ' AND y = ' + sqlquote(3)
|
||
|
<sql: "WHERE x = 't' AND y = 3">
|
||
|
"""
|
||
|
return SQLQuery(aparam(), (a,))
|
||
|
|
||
|
class UnknownDB(Exception):
|
||
|
"""raised for unsupported dbms"""
|
||
|
pass
|
||
|
|
||
|
def connect(dbn, **keywords):
|
||
|
"""
|
||
|
Connects to the specified database.
|
||
|
|
||
|
`dbn` currently must be "postgres", "mysql", or "sqlite".
|
||
|
|
||
|
If DBUtils is installed, connection pooling will be used.
|
||
|
"""
|
||
|
if dbn == "postgres":
|
||
|
try:
|
||
|
import psycopg2 as db
|
||
|
except ImportError:
|
||
|
try:
|
||
|
import psycopg as db
|
||
|
except ImportError:
|
||
|
import pgdb as db
|
||
|
if 'pw' in keywords:
|
||
|
keywords['password'] = keywords['pw']
|
||
|
del keywords['pw']
|
||
|
keywords['database'] = keywords['db']
|
||
|
del keywords['db']
|
||
|
|
||
|
elif dbn == "mysql":
|
||
|
import MySQLdb as db
|
||
|
if 'pw' in keywords:
|
||
|
keywords['passwd'] = keywords['pw']
|
||
|
del keywords['pw']
|
||
|
db.paramstyle = 'pyformat' # it's both, like psycopg
|
||
|
|
||
|
elif dbn == "sqlite":
|
||
|
try:
|
||
|
import sqlite3 as db
|
||
|
db.paramstyle = 'qmark'
|
||
|
except ImportError:
|
||
|
try:
|
||
|
from pysqlite2 import dbapi2 as db
|
||
|
db.paramstyle = 'qmark'
|
||
|
except ImportError:
|
||
|
import sqlite as db
|
||
|
web.config._hasPooling = False
|
||
|
keywords['database'] = keywords['db']
|
||
|
del keywords['db']
|
||
|
|
||
|
elif dbn == "firebird":
|
||
|
import kinterbasdb as db
|
||
|
if 'pw' in keywords:
|
||
|
keywords['passwd'] = keywords['pw']
|
||
|
del keywords['pw']
|
||
|
keywords['database'] = keywords['db']
|
||
|
del keywords['db']
|
||
|
|
||
|
else:
|
||
|
raise UnknownDB, dbn
|
||
|
|
||
|
web.ctx.db_name = dbn
|
||
|
web.ctx.db_module = db
|
||
|
web.ctx.db_transaction = 0
|
||
|
web.ctx.db = keywords
|
||
|
|
||
|
def _PooledDB(db, keywords):
|
||
|
# In DBUtils 0.9.3, `dbapi` argument is renamed as `creator`
|
||
|
# see Bug#122112
|
||
|
if PooledDB.__version__.split('.') < '0.9.3'.split('.'):
|
||
|
return PooledDB.PooledDB(dbapi=db, **keywords)
|
||
|
else:
|
||
|
return PooledDB.PooledDB(creator=db, **keywords)
|
||
|
|
||
|
def db_cursor():
|
||
|
if isinstance(web.ctx.db, dict):
|
||
|
keywords = web.ctx.db
|
||
|
if web.config._hasPooling:
|
||
|
if 'db' not in globals():
|
||
|
globals()['db'] = _PooledDB(db, keywords)
|
||
|
web.ctx.db = globals()['db'].connection()
|
||
|
else:
|
||
|
web.ctx.db = db.connect(**keywords)
|
||
|
return web.ctx.db.cursor()
|
||
|
web.ctx.db_cursor = db_cursor
|
||
|
|
||
|
web.ctx.dbq_count = 0
|
||
|
|
||
|
def db_execute(cur, sql_query, dorollback=True):
|
||
|
"""executes an sql query"""
|
||
|
|
||
|
web.ctx.dbq_count += 1
|
||
|
|
||
|
try:
|
||
|
a = time.time()
|
||
|
out = cur.execute(sql_query.s, sql_query.v)
|
||
|
b = time.time()
|
||
|
except:
|
||
|
if web.config.get('db_printing'):
|
||
|
print >> web.debug, 'ERR:', str(sql_query)
|
||
|
if dorollback: rollback(care=False)
|
||
|
raise
|
||
|
|
||
|
if web.config.get('db_printing'):
|
||
|
print >> web.debug, '%s (%s): %s' % (round(b-a, 2), web.ctx.dbq_count, str(sql_query))
|
||
|
|
||
|
return out
|
||
|
web.ctx.db_execute = db_execute
|
||
|
return web.ctx.db
|
||
|
|
||
|
class TransactionError(Exception): pass
|
||
|
|
||
|
class transaction:
|
||
|
"""
|
||
|
A context that can be used in conjunction with "with" statements
|
||
|
to implement SQL transactions. Starts a transaction on enter,
|
||
|
rolls it back if there's an error; otherwise it commits it at the
|
||
|
end.
|
||
|
"""
|
||
|
def __enter__(self):
|
||
|
transact()
|
||
|
|
||
|
def __exit__(self, exctype, excvalue, traceback):
|
||
|
if exctype is not None:
|
||
|
rollback()
|
||
|
else:
|
||
|
commit()
|
||
|
|
||
|
def transact():
|
||
|
"""Start a transaction."""
|
||
|
if not web.ctx.db_transaction:
|
||
|
# commit everything up to now, so we don't rollback it later
|
||
|
if hasattr(web.ctx.db, 'commit'):
|
||
|
web.ctx.db.commit()
|
||
|
else:
|
||
|
db_cursor = web.ctx.db_cursor()
|
||
|
web.ctx.db_execute(db_cursor,
|
||
|
SQLQuery("SAVEPOINT webpy_sp_%s" % web.ctx.db_transaction))
|
||
|
web.ctx.db_transaction += 1
|
||
|
|
||
|
def commit():
|
||
|
"""Commits a transaction."""
|
||
|
web.ctx.db_transaction -= 1
|
||
|
if web.ctx.db_transaction < 0:
|
||
|
raise TransactionError, "not in a transaction"
|
||
|
|
||
|
if not web.ctx.db_transaction:
|
||
|
if hasattr(web.ctx.db, 'commit'):
|
||
|
web.ctx.db.commit()
|
||
|
else:
|
||
|
db_cursor = web.ctx.db_cursor()
|
||
|
web.ctx.db_execute(db_cursor,
|
||
|
SQLQuery("RELEASE SAVEPOINT webpy_sp_%s" % web.ctx.db_transaction))
|
||
|
|
||
|
def rollback(care=True):
|
||
|
"""Rolls back a transaction."""
|
||
|
web.ctx.db_transaction -= 1
|
||
|
if web.ctx.db_transaction < 0:
|
||
|
web.db_transaction = 0
|
||
|
if care:
|
||
|
raise TransactionError, "not in a transaction"
|
||
|
else:
|
||
|
return
|
||
|
|
||
|
if not web.ctx.db_transaction:
|
||
|
if hasattr(web.ctx.db, 'rollback'):
|
||
|
web.ctx.db.rollback()
|
||
|
else:
|
||
|
db_cursor = web.ctx.db_cursor()
|
||
|
web.ctx.db_execute(db_cursor,
|
||
|
SQLQuery("ROLLBACK TO SAVEPOINT webpy_sp_%s" % web.ctx.db_transaction),
|
||
|
dorollback=False)
|
||
|
|
||
|
def query(sql_query, vars=None, processed=False, _test=False):
|
||
|
"""
|
||
|
Execute SQL query `sql_query` using dictionary `vars` to interpolate it.
|
||
|
If `processed=True`, `vars` is a `reparam`-style list to use
|
||
|
instead of interpolating.
|
||
|
|
||
|
>>> query("SELECT * FROM foo", _test=True)
|
||
|
<sql: 'SELECT * FROM foo'>
|
||
|
>>> query("SELECT * FROM foo WHERE x = $x", vars=dict(x='f'), _test=True)
|
||
|
<sql: "SELECT * FROM foo WHERE x = 'f'">
|
||
|
>>> query("SELECT * FROM foo WHERE x = " + sqlquote('f'), _test=True)
|
||
|
<sql: "SELECT * FROM foo WHERE x = 'f'">
|
||
|
"""
|
||
|
if vars is None: vars = {}
|
||
|
|
||
|
if not processed and not isinstance(sql_query, SQLQuery):
|
||
|
sql_query = reparam(sql_query, vars)
|
||
|
|
||
|
if _test: return sql_query
|
||
|
|
||
|
db_cursor = web.ctx.db_cursor()
|
||
|
web.ctx.db_execute(db_cursor, sql_query)
|
||
|
|
||
|
if db_cursor.description:
|
||
|
names = [x[0] for x in db_cursor.description]
|
||
|
def iterwrapper():
|
||
|
row = db_cursor.fetchone()
|
||
|
while row:
|
||
|
yield storage(dict(zip(names, row)))
|
||
|
row = db_cursor.fetchone()
|
||
|
out = iterbetter(iterwrapper())
|
||
|
if web.ctx.db_name != "sqlite":
|
||
|
out.__len__ = lambda: int(db_cursor.rowcount)
|
||
|
out.list = lambda: [storage(dict(zip(names, x))) \
|
||
|
for x in db_cursor.fetchall()]
|
||
|
else:
|
||
|
out = db_cursor.rowcount
|
||
|
|
||
|
if not web.ctx.db_transaction: web.ctx.db.commit()
|
||
|
return out
|
||
|
|
||
|
def sqllist(lst):
|
||
|
"""
|
||
|
Converts the arguments for use in something like a WHERE clause.
|
||
|
|
||
|
>>> sqllist(['a', 'b'])
|
||
|
'a, b'
|
||
|
>>> sqllist('a')
|
||
|
'a'
|
||
|
|
||
|
"""
|
||
|
if isinstance(lst, str):
|
||
|
return lst
|
||
|
else:
|
||
|
return ', '.join(lst)
|
||
|
|
||
|
def sqlors(left, lst):
|
||
|
"""
|
||
|
`left is a SQL clause like `tablename.arg = `
|
||
|
and `lst` is a list of values. Returns a reparam-style
|
||
|
pair featuring the SQL that ORs together the clause
|
||
|
for each item in the lst.
|
||
|
|
||
|
>>> sqlors('foo = ', [])
|
||
|
<sql: '2+2=5'>
|
||
|
>>> sqlors('foo = ', [1])
|
||
|
<sql: 'foo = 1'>
|
||
|
>>> sqlors('foo = ', 1)
|
||
|
<sql: 'foo = 1'>
|
||
|
>>> sqlors('foo = ', [1,2,3])
|
||
|
<sql: '(foo = 1 OR foo = 2 OR foo = 3)'>
|
||
|
"""
|
||
|
if isinstance(lst, iters):
|
||
|
lst = list(lst)
|
||
|
ln = len(lst)
|
||
|
if ln == 0:
|
||
|
return SQLQuery("2+2=5", [])
|
||
|
if ln == 1:
|
||
|
lst = lst[0]
|
||
|
|
||
|
if isinstance(lst, iters):
|
||
|
return SQLQuery('(' + left +
|
||
|
(' OR ' + left).join([aparam() for param in lst]) + ")", lst)
|
||
|
else:
|
||
|
return SQLQuery(left + aparam(), [lst])
|
||
|
|
||
|
def sqlwhere(dictionary, grouping=' AND '):
|
||
|
"""
|
||
|
Converts a `dictionary` to an SQL WHERE clause `SQLQuery`.
|
||
|
|
||
|
>>> sqlwhere({'cust_id': 2, 'order_id':3})
|
||
|
<sql: 'order_id = 3 AND cust_id = 2'>
|
||
|
>>> sqlwhere({'cust_id': 2, 'order_id':3}, grouping=', ')
|
||
|
<sql: 'order_id = 3, cust_id = 2'>
|
||
|
"""
|
||
|
|
||
|
return SQLQuery(grouping.join([
|
||
|
'%s = %s' % (k, aparam()) for k in dictionary.keys()
|
||
|
]), dictionary.values())
|
||
|
|
||
|
def select(tables, vars=None, what='*', where=None, order=None, group=None,
|
||
|
limit=None, offset=None, _test=False):
|
||
|
"""
|
||
|
Selects `what` from `tables` with clauses `where`, `order`,
|
||
|
`group`, `limit`, and `offset`. Uses vars to interpolate.
|
||
|
Otherwise, each clause can be a SQLQuery.
|
||
|
|
||
|
>>> select('foo', _test=True)
|
||
|
<sql: 'SELECT * FROM foo'>
|
||
|
>>> select(['foo', 'bar'], where="foo.bar_id = bar.id", limit=5, _test=True)
|
||
|
<sql: 'SELECT * FROM foo, bar WHERE foo.bar_id = bar.id LIMIT 5'>
|
||
|
"""
|
||
|
if vars is None: vars = {}
|
||
|
qout = ""
|
||
|
|
||
|
def gen_clause(sql, val):
|
||
|
if isinstance(val, (int, long)):
|
||
|
if sql == 'WHERE':
|
||
|
nout = 'id = ' + sqlquote(val)
|
||
|
else:
|
||
|
nout = SQLQuery(val)
|
||
|
elif isinstance(val, (list, tuple)) and len(val) == 2:
|
||
|
nout = SQLQuery(val[0], val[1]) # backwards-compatibility
|
||
|
elif isinstance(val, SQLQuery):
|
||
|
nout = val
|
||
|
elif val:
|
||
|
nout = reparam(val, vars)
|
||
|
else:
|
||
|
return ""
|
||
|
|
||
|
out = ""
|
||
|
if qout: out += " "
|
||
|
out += sql + " " + nout
|
||
|
return out
|
||
|
|
||
|
if web.ctx.get('db_name') == "firebird":
|
||
|
for (sql, val) in (
|
||
|
('FIRST', limit),
|
||
|
('SKIP', offset)
|
||
|
):
|
||
|
qout += gen_clause(sql, val)
|
||
|
if qout:
|
||
|
SELECT = 'SELECT ' + qout
|
||
|
else:
|
||
|
SELECT = 'SELECT'
|
||
|
qout = ""
|
||
|
sql_clauses = (
|
||
|
(SELECT, what),
|
||
|
('FROM', sqllist(tables)),
|
||
|
('WHERE', where),
|
||
|
('GROUP BY', group),
|
||
|
('ORDER BY', order)
|
||
|
)
|
||
|
else:
|
||
|
sql_clauses = (
|
||
|
('SELECT', what),
|
||
|
('FROM', sqllist(tables)),
|
||
|
('WHERE', where),
|
||
|
('GROUP BY', group),
|
||
|
('ORDER BY', order),
|
||
|
('LIMIT', limit),
|
||
|
('OFFSET', offset)
|
||
|
)
|
||
|
|
||
|
for (sql, val) in sql_clauses:
|
||
|
qout += gen_clause(sql, val)
|
||
|
|
||
|
if _test: return qout
|
||
|
return query(qout, processed=True)
|
||
|
|
||
|
def insert(tablename, seqname=None, _test=False, **values):
|
||
|
"""
|
||
|
Inserts `values` into `tablename`. Returns current sequence ID.
|
||
|
Set `seqname` to the ID if it's not the default, or to `False`
|
||
|
if there isn't one.
|
||
|
|
||
|
>>> insert('foo', joe='bob', a=2, _test=True)
|
||
|
<sql: "INSERT INTO foo (a, joe) VALUES (2, 'bob')">
|
||
|
"""
|
||
|
|
||
|
if values:
|
||
|
sql_query = SQLQuery("INSERT INTO %s (%s) VALUES (%s)" % (
|
||
|
tablename,
|
||
|
", ".join(values.keys()),
|
||
|
', '.join([aparam() for x in values])
|
||
|
), values.values())
|
||
|
else:
|
||
|
sql_query = SQLQuery("INSERT INTO %s DEFAULT VALUES" % tablename)
|
||
|
|
||
|
if _test: return sql_query
|
||
|
|
||
|
db_cursor = web.ctx.db_cursor()
|
||
|
if seqname is False:
|
||
|
pass
|
||
|
elif web.ctx.db_name == "postgres":
|
||
|
if seqname is None:
|
||
|
seqname = tablename + "_id_seq"
|
||
|
sql_query += "; SELECT currval('%s')" % seqname
|
||
|
elif web.ctx.db_name == "mysql":
|
||
|
web.ctx.db_execute(db_cursor, sql_query)
|
||
|
sql_query = SQLQuery("SELECT last_insert_id()")
|
||
|
elif web.ctx.db_name == "sqlite":
|
||
|
web.ctx.db_execute(db_cursor, sql_query)
|
||
|
# not really the same...
|
||
|
sql_query = SQLQuery("SELECT last_insert_rowid()")
|
||
|
|
||
|
web.ctx.db_execute(db_cursor, sql_query)
|
||
|
try:
|
||
|
out = db_cursor.fetchone()[0]
|
||
|
except Exception:
|
||
|
out = None
|
||
|
|
||
|
if not web.ctx.db_transaction: web.ctx.db.commit()
|
||
|
|
||
|
return out
|
||
|
|
||
|
def update(tables, where, vars=None, _test=False, **values):
|
||
|
"""
|
||
|
Update `tables` with clause `where` (interpolated using `vars`)
|
||
|
and setting `values`.
|
||
|
|
||
|
>>> joe = 'Joseph'
|
||
|
>>> update('foo', where='name = $joe', name='bob', age=5,
|
||
|
... vars=locals(), _test=True)
|
||
|
<sql: "UPDATE foo SET age = 5, name = 'bob' WHERE name = 'Joseph'">
|
||
|
"""
|
||
|
if vars is None: vars = {}
|
||
|
|
||
|
if isinstance(where, (int, long)):
|
||
|
where = "id = " + sqlquote(where)
|
||
|
elif isinstance(where, (list, tuple)) and len(where) == 2:
|
||
|
where = SQLQuery(where[0], where[1])
|
||
|
elif isinstance(where, SQLQuery):
|
||
|
pass
|
||
|
else:
|
||
|
where = reparam(where, vars)
|
||
|
|
||
|
query = (
|
||
|
"UPDATE " + sqllist(tables) +
|
||
|
" SET " + sqlwhere(values, ', ') +
|
||
|
" WHERE " + where)
|
||
|
|
||
|
if _test: return query
|
||
|
|
||
|
db_cursor = web.ctx.db_cursor()
|
||
|
web.ctx.db_execute(db_cursor, query)
|
||
|
|
||
|
if not web.ctx.db_transaction: web.ctx.db.commit()
|
||
|
return db_cursor.rowcount
|
||
|
|
||
|
def delete(table, where=None, using=None, vars=None, _test=False):
|
||
|
"""
|
||
|
Deletes from `table` with clauses `where` and `using`.
|
||
|
|
||
|
>>> name = 'Joe'
|
||
|
>>> delete('foo', where='name = $name', vars=locals(), _test=True)
|
||
|
<sql: "DELETE FROM foo WHERE name = 'Joe'">
|
||
|
"""
|
||
|
if vars is None: vars = {}
|
||
|
|
||
|
if isinstance(where, (int, long)):
|
||
|
where = "id = " + sqlquote(where)
|
||
|
elif isinstance(where, (list, tuple)) and len(where) == 2:
|
||
|
where = SQLQuery(where[0], where[1])
|
||
|
elif isinstance(where, SQLQuery):
|
||
|
pass
|
||
|
elif where is None:
|
||
|
pass
|
||
|
else:
|
||
|
where = reparam(where, vars)
|
||
|
|
||
|
q = 'DELETE FROM ' + table
|
||
|
if where:
|
||
|
q += ' WHERE ' + where
|
||
|
if using and web.ctx.get('db_name') != "firebird":
|
||
|
q += ' USING ' + sqllist(using)
|
||
|
|
||
|
if _test: return q
|
||
|
|
||
|
db_cursor = web.ctx.db_cursor()
|
||
|
web.ctx.db_execute(db_cursor, q)
|
||
|
|
||
|
if not web.ctx.db_transaction: web.ctx.db.commit()
|
||
|
return db_cursor.rowcount
|
||
|
|
||
|
if __name__ == "__main__":
|
||
|
import doctest
|
||
|
doctest.testmod()
|