""" Database API (part of web.py) """ # todo: # - test with sqlite # - a store function? __all__ = [ "UnknownParamstyle", "UnknownDB", "sqllist", "sqlors", "aparam", "reparam", "SQLQuery", "sqlquote", "SQLLiteral", "sqlliteral", "connect", "TransactionError", "transaction", "transact", "commit", "rollback", "query", "select", "insert", "update", "delete" ] import time try: import datetime except ImportError: datetime = None from utils import storage, iters, iterbetter import webapi as web try: from DBUtils import PooledDB web.config._hasPooling = True except ImportError: web.config._hasPooling = False class _ItplError(ValueError): def __init__(self, text, pos): ValueError.__init__(self) self.text = text self.pos = pos def __str__(self): return "unfinished expression in %s at char %d" % ( repr(self.text), self.pos) def _interpolate(format): """ Takes a format string and returns a list of 2-tuples of the form (boolean, string) where boolean says whether string should be evaled or not. from (public domain, Ka-Ping Yee) """ from tokenize import tokenprog def matchorfail(text, pos): match = tokenprog.match(text, pos) if match is None: raise _ItplError(text, pos) return match, match.end() namechars = "abcdefghijklmnopqrstuvwxyz" \ "ABCDEFGHIJKLMNOPQRSTUVWXYZ0123456789_"; chunks = [] pos = 0 while 1: dollar = format.find("$", pos) if dollar < 0: break nextchar = format[dollar + 1] if nextchar == "{": chunks.append((0, format[pos:dollar])) pos, level = dollar + 2, 1 while level: match, pos = matchorfail(format, pos) tstart, tend = match.regs[3] token = format[tstart:tend] if token == "{": level = level + 1 elif token == "}": level = level - 1 chunks.append((1, format[dollar + 2:pos - 1])) elif nextchar in namechars: chunks.append((0, format[pos:dollar])) match, pos = matchorfail(format, dollar + 1) while pos < len(format): if format[pos] == "." and \ pos + 1 < len(format) and format[pos + 1] in namechars: match, pos = matchorfail(format, pos + 1) elif format[pos] in "([": pos, level = pos + 1, 1 while level: match, pos = matchorfail(format, pos) tstart, tend = match.regs[3] token = format[tstart:tend] if token[0] in "([": level = level + 1 elif token[0] in ")]": level = level - 1 else: break chunks.append((1, format[dollar + 1:pos])) else: chunks.append((0, format[pos:dollar + 1])) pos = dollar + 1 + (nextchar == "$") if pos < len(format): chunks.append((0, format[pos:])) return chunks class UnknownParamstyle(Exception): """ raised for unsupported db paramstyles (currently supported: qmark, numeric, format, pyformat) """ pass def aparam(): """ Returns the appropriate string to be used to interpolate a value with the current `web.ctx.db_module` or simply %s if there isn't one. >>> aparam() '%s' """ if hasattr(web.ctx, 'db_module'): style = web.ctx.db_module.paramstyle else: style = 'pyformat' if style == 'qmark': return '?' elif style == 'numeric': return ':1' elif style in ['format', 'pyformat']: return '%s' raise UnknownParamstyle, style def reparam(string_, dictionary): """ Takes a string and a dictionary and interpolates the string using values from the dictionary. Returns an `SQLQuery` for the result. >>> reparam("s = $s", dict(s=True)) """ vals = [] result = [] for live, chunk in _interpolate(string_): if live: result.append(aparam()) vals.append(eval(chunk, dictionary)) else: result.append(chunk) return SQLQuery(''.join(result), vals) def sqlify(obj): """ converts `obj` to its proper SQL version >>> sqlify(None) 'NULL' >>> sqlify(True) "'t'" >>> sqlify(3) '3' """ # because `1 == True and hash(1) == hash(True)` # we have to do this the hard way... if obj is None: return 'NULL' elif obj is True: return "'t'" elif obj is False: return "'f'" elif datetime and isinstance(obj, datetime.datetime): return repr(obj.isoformat()) else: return repr(obj) class SQLQuery: """ You can pass this sort of thing as a clause in any db function. Otherwise, you can pass a dictionary to the keyword argument `vars` and the function will call reparam for you. """ # tested in sqlquote's docstring def __init__(self, s='', v=()): self.s, self.v = str(s), tuple(v) def __getitem__(self, key): # for backwards-compatibility return [self.s, self.v][key] def __add__(self, other): if isinstance(other, str): self.s += other elif isinstance(other, SQLQuery): self.s += other.s self.v += other.v return self def __radd__(self, other): if isinstance(other, str): self.s = other + self.s return self else: return NotImplemented def __str__(self): try: return self.s % tuple([sqlify(x) for x in self.v]) except (ValueError, TypeError): return self.s def __repr__(self): return '' % repr(str(self)) class SQLLiteral: """ Protects a string from `sqlquote`. >>> insert('foo', time=SQLLiteral('NOW()'), _test=True) """ def __init__(self, v): self.v = v def __repr__(self): return self.v sqlliteral = SQLLiteral def sqlquote(a): """ Ensures `a` is quoted properly for use in a SQL query. >>> 'WHERE x = ' + sqlquote(True) + ' AND y = ' + sqlquote(3) """ return SQLQuery(aparam(), (a,)) class UnknownDB(Exception): """raised for unsupported dbms""" pass def connect(dbn, **keywords): """ Connects to the specified database. `dbn` currently must be "postgres", "mysql", or "sqlite". If DBUtils is installed, connection pooling will be used. """ if dbn == "postgres": try: import psycopg2 as db except ImportError: try: import psycopg as db except ImportError: import pgdb as db if 'pw' in keywords: keywords['password'] = keywords['pw'] del keywords['pw'] keywords['database'] = keywords['db'] del keywords['db'] elif dbn == "mysql": import MySQLdb as db if 'pw' in keywords: keywords['passwd'] = keywords['pw'] del keywords['pw'] db.paramstyle = 'pyformat' # it's both, like psycopg elif dbn == "sqlite": try: import sqlite3 as db db.paramstyle = 'qmark' except ImportError: try: from pysqlite2 import dbapi2 as db db.paramstyle = 'qmark' except ImportError: import sqlite as db web.config._hasPooling = False keywords['database'] = keywords['db'] del keywords['db'] elif dbn == "firebird": import kinterbasdb as db if 'pw' in keywords: keywords['passwd'] = keywords['pw'] del keywords['pw'] keywords['database'] = keywords['db'] del keywords['db'] else: raise UnknownDB, dbn web.ctx.db_name = dbn web.ctx.db_module = db web.ctx.db_transaction = 0 web.ctx.db = keywords def _PooledDB(db, keywords): # In DBUtils 0.9.3, `dbapi` argument is renamed as `creator` # see Bug#122112 if PooledDB.__version__.split('.') < '0.9.3'.split('.'): return PooledDB.PooledDB(dbapi=db, **keywords) else: return PooledDB.PooledDB(creator=db, **keywords) def db_cursor(): if isinstance(web.ctx.db, dict): keywords = web.ctx.db if web.config._hasPooling: if 'db' not in globals(): globals()['db'] = _PooledDB(db, keywords) web.ctx.db = globals()['db'].connection() else: web.ctx.db = db.connect(**keywords) return web.ctx.db.cursor() web.ctx.db_cursor = db_cursor web.ctx.dbq_count = 0 def db_execute(cur, sql_query, dorollback=True): """executes an sql query""" web.ctx.dbq_count += 1 try: a = time.time() out = cur.execute(sql_query.s, sql_query.v) b = time.time() except: if web.config.get('db_printing'): print >> web.debug, 'ERR:', str(sql_query) if dorollback: rollback(care=False) raise if web.config.get('db_printing'): print >> web.debug, '%s (%s): %s' % (round(b-a, 2), web.ctx.dbq_count, str(sql_query)) return out web.ctx.db_execute = db_execute return web.ctx.db class TransactionError(Exception): pass class transaction: """ A context that can be used in conjunction with "with" statements to implement SQL transactions. Starts a transaction on enter, rolls it back if there's an error; otherwise it commits it at the end. """ def __enter__(self): transact() def __exit__(self, exctype, excvalue, traceback): if exctype is not None: rollback() else: commit() def transact(): """Start a transaction.""" if not web.ctx.db_transaction: # commit everything up to now, so we don't rollback it later if hasattr(web.ctx.db, 'commit'): web.ctx.db.commit() else: db_cursor = web.ctx.db_cursor() web.ctx.db_execute(db_cursor, SQLQuery("SAVEPOINT webpy_sp_%s" % web.ctx.db_transaction)) web.ctx.db_transaction += 1 def commit(): """Commits a transaction.""" web.ctx.db_transaction -= 1 if web.ctx.db_transaction < 0: raise TransactionError, "not in a transaction" if not web.ctx.db_transaction: if hasattr(web.ctx.db, 'commit'): web.ctx.db.commit() else: db_cursor = web.ctx.db_cursor() web.ctx.db_execute(db_cursor, SQLQuery("RELEASE SAVEPOINT webpy_sp_%s" % web.ctx.db_transaction)) def rollback(care=True): """Rolls back a transaction.""" web.ctx.db_transaction -= 1 if web.ctx.db_transaction < 0: web.db_transaction = 0 if care: raise TransactionError, "not in a transaction" else: return if not web.ctx.db_transaction: if hasattr(web.ctx.db, 'rollback'): web.ctx.db.rollback() else: db_cursor = web.ctx.db_cursor() web.ctx.db_execute(db_cursor, SQLQuery("ROLLBACK TO SAVEPOINT webpy_sp_%s" % web.ctx.db_transaction), dorollback=False) def query(sql_query, vars=None, processed=False, _test=False): """ Execute SQL query `sql_query` using dictionary `vars` to interpolate it. If `processed=True`, `vars` is a `reparam`-style list to use instead of interpolating. >>> query("SELECT * FROM foo", _test=True) >>> query("SELECT * FROM foo WHERE x = $x", vars=dict(x='f'), _test=True) >>> query("SELECT * FROM foo WHERE x = " + sqlquote('f'), _test=True) """ if vars is None: vars = {} if not processed and not isinstance(sql_query, SQLQuery): sql_query = reparam(sql_query, vars) if _test: return sql_query db_cursor = web.ctx.db_cursor() web.ctx.db_execute(db_cursor, sql_query) if db_cursor.description: names = [x[0] for x in db_cursor.description] def iterwrapper(): row = db_cursor.fetchone() while row: yield storage(dict(zip(names, row))) row = db_cursor.fetchone() out = iterbetter(iterwrapper()) if web.ctx.db_name != "sqlite": out.__len__ = lambda: int(db_cursor.rowcount) out.list = lambda: [storage(dict(zip(names, x))) \ for x in db_cursor.fetchall()] else: out = db_cursor.rowcount if not web.ctx.db_transaction: web.ctx.db.commit() return out def sqllist(lst): """ Converts the arguments for use in something like a WHERE clause. >>> sqllist(['a', 'b']) 'a, b' >>> sqllist('a') 'a' """ if isinstance(lst, str): return lst else: return ', '.join(lst) def sqlors(left, lst): """ `left is a SQL clause like `tablename.arg = ` and `lst` is a list of values. Returns a reparam-style pair featuring the SQL that ORs together the clause for each item in the lst. >>> sqlors('foo = ', []) >>> sqlors('foo = ', [1]) >>> sqlors('foo = ', 1) >>> sqlors('foo = ', [1,2,3]) """ if isinstance(lst, iters): lst = list(lst) ln = len(lst) if ln == 0: return SQLQuery("2+2=5", []) if ln == 1: lst = lst[0] if isinstance(lst, iters): return SQLQuery('(' + left + (' OR ' + left).join([aparam() for param in lst]) + ")", lst) else: return SQLQuery(left + aparam(), [lst]) def sqlwhere(dictionary, grouping=' AND '): """ Converts a `dictionary` to an SQL WHERE clause `SQLQuery`. >>> sqlwhere({'cust_id': 2, 'order_id':3}) >>> sqlwhere({'cust_id': 2, 'order_id':3}, grouping=', ') """ return SQLQuery(grouping.join([ '%s = %s' % (k, aparam()) for k in dictionary.keys() ]), dictionary.values()) def select(tables, vars=None, what='*', where=None, order=None, group=None, limit=None, offset=None, _test=False): """ Selects `what` from `tables` with clauses `where`, `order`, `group`, `limit`, and `offset`. Uses vars to interpolate. Otherwise, each clause can be a SQLQuery. >>> select('foo', _test=True) >>> select(['foo', 'bar'], where="foo.bar_id = bar.id", limit=5, _test=True) """ if vars is None: vars = {} qout = "" def gen_clause(sql, val): if isinstance(val, (int, long)): if sql == 'WHERE': nout = 'id = ' + sqlquote(val) else: nout = SQLQuery(val) elif isinstance(val, (list, tuple)) and len(val) == 2: nout = SQLQuery(val[0], val[1]) # backwards-compatibility elif isinstance(val, SQLQuery): nout = val elif val: nout = reparam(val, vars) else: return "" out = "" if qout: out += " " out += sql + " " + nout return out if web.ctx.get('db_name') == "firebird": for (sql, val) in ( ('FIRST', limit), ('SKIP', offset) ): qout += gen_clause(sql, val) if qout: SELECT = 'SELECT ' + qout else: SELECT = 'SELECT' qout = "" sql_clauses = ( (SELECT, what), ('FROM', sqllist(tables)), ('WHERE', where), ('GROUP BY', group), ('ORDER BY', order) ) else: sql_clauses = ( ('SELECT', what), ('FROM', sqllist(tables)), ('WHERE', where), ('GROUP BY', group), ('ORDER BY', order), ('LIMIT', limit), ('OFFSET', offset) ) for (sql, val) in sql_clauses: qout += gen_clause(sql, val) if _test: return qout return query(qout, processed=True) def insert(tablename, seqname=None, _test=False, **values): """ Inserts `values` into `tablename`. Returns current sequence ID. Set `seqname` to the ID if it's not the default, or to `False` if there isn't one. >>> insert('foo', joe='bob', a=2, _test=True) """ if values: sql_query = SQLQuery("INSERT INTO %s (%s) VALUES (%s)" % ( tablename, ", ".join(values.keys()), ', '.join([aparam() for x in values]) ), values.values()) else: sql_query = SQLQuery("INSERT INTO %s DEFAULT VALUES" % tablename) if _test: return sql_query db_cursor = web.ctx.db_cursor() if seqname is False: pass elif web.ctx.db_name == "postgres": if seqname is None: seqname = tablename + "_id_seq" sql_query += "; SELECT currval('%s')" % seqname elif web.ctx.db_name == "mysql": web.ctx.db_execute(db_cursor, sql_query) sql_query = SQLQuery("SELECT last_insert_id()") elif web.ctx.db_name == "sqlite": web.ctx.db_execute(db_cursor, sql_query) # not really the same... sql_query = SQLQuery("SELECT last_insert_rowid()") web.ctx.db_execute(db_cursor, sql_query) try: out = db_cursor.fetchone()[0] except Exception: out = None if not web.ctx.db_transaction: web.ctx.db.commit() return out def update(tables, where, vars=None, _test=False, **values): """ Update `tables` with clause `where` (interpolated using `vars`) and setting `values`. >>> joe = 'Joseph' >>> update('foo', where='name = $joe', name='bob', age=5, ... vars=locals(), _test=True) """ if vars is None: vars = {} if isinstance(where, (int, long)): where = "id = " + sqlquote(where) elif isinstance(where, (list, tuple)) and len(where) == 2: where = SQLQuery(where[0], where[1]) elif isinstance(where, SQLQuery): pass else: where = reparam(where, vars) query = ( "UPDATE " + sqllist(tables) + " SET " + sqlwhere(values, ', ') + " WHERE " + where) if _test: return query db_cursor = web.ctx.db_cursor() web.ctx.db_execute(db_cursor, query) if not web.ctx.db_transaction: web.ctx.db.commit() return db_cursor.rowcount def delete(table, where=None, using=None, vars=None, _test=False): """ Deletes from `table` with clauses `where` and `using`. >>> name = 'Joe' >>> delete('foo', where='name = $name', vars=locals(), _test=True) """ if vars is None: vars = {} if isinstance(where, (int, long)): where = "id = " + sqlquote(where) elif isinstance(where, (list, tuple)) and len(where) == 2: where = SQLQuery(where[0], where[1]) elif isinstance(where, SQLQuery): pass elif where is None: pass else: where = reparam(where, vars) q = 'DELETE FROM ' + table if where: q += ' WHERE ' + where if using and web.ctx.get('db_name') != "firebird": q += ' USING ' + sqllist(using) if _test: return q db_cursor = web.ctx.db_cursor() web.ctx.db_execute(db_cursor, q) if not web.ctx.db_transaction: web.ctx.db.commit() return db_cursor.rowcount if __name__ == "__main__": import doctest doctest.testmod()