From: Chris Lamb Date: Mon, 26 Oct 2009 19:25:33 +0000 (+0000) Subject: Factor out most common session handling into decorator. X-Git-Url: https://err.no/cgi-bin/gitweb.cgi?a=commitdiff_plain;h=c02076be73a969f103f12c91ef67250d6e4c08d2;p=dak Factor out most common session handling into decorator. Signed-off-by: Chris Lamb --- diff --git a/daklib/dbconn.py b/daklib/dbconn.py index c0b7d0e8..651b790b 100755 --- a/daklib/dbconn.py +++ b/daklib/dbconn.py @@ -37,6 +37,8 @@ import os import psycopg2 import traceback +from inspect import getargspec + from sqlalchemy import create_engine, Table, MetaData, select from sqlalchemy.orm import sessionmaker, mapper, relation @@ -55,6 +57,27 @@ __all__ = ['IntegrityError', 'SQLAlchemyError'] ################################################################################ +def session_wrapper(fn): + def wrapped(*args, **kwargs): + private_transaction = False + session = kwargs.get('session') + + # No session specified as last argument or in kwargs, create one. + if session is None and len(args) == len(getargspec(fn)[0]) - 1: + private_transaction = True + kwargs['session'] = DBConn().session() + + try: + return fn(*args, **kwargs) + finally: + if private_transaction: + # We created a session; close it. + kwargs['session'].close() + + return wrapped + +################################################################################ + class Architecture(object): def __init__(self, *args, **kwargs): pass @@ -76,6 +99,7 @@ class Architecture(object): __all__.append('Architecture') +@session_wrapper def get_architecture(architecture, session=None): """ Returns database id for given C{architecture}. @@ -89,13 +113,7 @@ def get_architecture(architecture, session=None): @rtype: Architecture @return: Architecture object for the given arch (None if not present) - """ - privatetrans = False - - if session is None: - session = DBConn().session() - privatetrans = True q = session.query(Architecture).filter_by(arch_string=architecture) @@ -104,13 +122,11 @@ def get_architecture(architecture, session=None): else: ret = q.one() - if privatetrans: - session.close() - return ret __all__.append('get_architecture') +@session_wrapper def get_architecture_suites(architecture, session=None): """ Returns list of Suite objects for given C{architecture} name @@ -125,11 +141,6 @@ def get_architecture_suites(architecture, session=None): @rtype: list @return: list of Suite objects for the given name (may be empty) """ - privatetrans = False - - if session is None: - session = DBConn().session() - privatetrans = True q = session.query(Suite) q = q.join(SuiteArchitecture) @@ -137,9 +148,6 @@ def get_architecture_suites(architecture, session=None): ret = q.all() - if privatetrans: - session.close() - return ret __all__.append('get_architecture_suites') @@ -155,6 +163,7 @@ class Archive(object): __all__.append('Archive') +@session_wrapper def get_archive(archive, session=None): """ returns database id for given c{archive}. @@ -172,11 +181,6 @@ def get_archive(archive, session=None): """ archive = archive.lower() - privatetrans = False - if session is None: - session = DBConn().session() - privatetrans = True - q = session.query(Archive).filter_by(archive_name=archive) if q.count() == 0: @@ -184,9 +188,6 @@ def get_archive(archive, session=None): else: ret = q.one() - if privatetrans: - session.close() - return ret __all__.append('get_archive') @@ -213,6 +214,7 @@ class DBBinary(object): __all__.append('DBBinary') +@session_wrapper def get_suites_binary_in(package, session=None): """ Returns list of Suite objects which given C{package} name is in @@ -224,19 +226,13 @@ def get_suites_binary_in(package, session=None): @return: list of Suite objects for the given package """ - privatetrans = False - if session is None: - session = DBConn().session() - privatetrans = True - ret = session.query(Suite).join(BinAssociation).join(DBBinary).filter_by(package=package).all() - session.close() - return ret __all__.append('get_suites_binary_in') +@session_wrapper def get_binary_from_id(id, session=None): """ Returns DBBinary object for given C{id} @@ -251,10 +247,6 @@ def get_binary_from_id(id, session=None): @rtype: DBBinary @return: DBBinary object for the given binary (None if not present) """ - privatetrans = False - if session is None: - session = DBConn().session() - privatetrans = True q = session.query(DBBinary).filter_by(binary_id=id) @@ -263,13 +255,11 @@ def get_binary_from_id(id, session=None): else: ret = q.one() - if privatetrans: - session.close() - return ret __all__.append('get_binary_from_id') +@session_wrapper def get_binaries_from_name(package, version=None, architecture=None, session=None): """ Returns list of DBBinary objects for given C{package} name @@ -290,10 +280,6 @@ def get_binaries_from_name(package, version=None, architecture=None, session=Non @rtype: list @return: list of DBBinary objects for the given name (may be empty) """ - privatetrans = False - if session is None: - session = DBConn().session() - privatetrans = True q = session.query(DBBinary).filter_by(package=package) @@ -307,13 +293,11 @@ def get_binaries_from_name(package, version=None, architecture=None, session=Non ret = q.all() - if privatetrans: - session.close() - return ret __all__.append('get_binaries_from_name') +@session_wrapper def get_binaries_from_source_id(source_id, session=None): """ Returns list of DBBinary objects for given C{source_id} @@ -328,29 +312,17 @@ def get_binaries_from_source_id(source_id, session=None): @rtype: list @return: list of DBBinary objects for the given name (may be empty) """ - privatetrans = False - if session is None: - session = DBConn().session() - privatetrans = True ret = session.query(DBBinary).filter_by(source_id=source_id).all() - if privatetrans: - session.close() - return ret - __all__.append('get_binaries_from_source_id') - +@session_wrapper def get_binary_from_name_suite(package, suitename, session=None): ### For dak examine-package ### XXX: Doesn't use object API yet - privatetrans = False - if session is None: - session = DBConn().session() - privatetrans = True sql = """SELECT DISTINCT(b.package), b.version, c.name, su.suite_name FROM binaries b, files fi, location l, component c, bin_associations ba, suite su @@ -365,13 +337,11 @@ def get_binary_from_name_suite(package, suitename, session=None): ret = session.execute(sql, {'package': package, 'suitename': suitename}) - if privatetrans: - session.close() - return ret __all__.append('get_binary_from_name_suite') +@session_wrapper def get_binary_components(package, suitename, arch, session=None): # Check for packages that have moved from one component to another query = """SELECT c.name FROM binaries b, bin_associations ba, suite s, location l, component c, architecture a, files f @@ -384,16 +354,8 @@ def get_binary_components(package, suitename, arch, session=None): vals = {'package': package, 'suitename': suitename, 'arch': arch} - privatetrans = False - if session is None: - session = DBConn().session() - privatetrans = True - ret = session.execute(query, vals) - if privatetrans: - session.close() - return ret __all__.append('get_binary_components') @@ -422,6 +384,7 @@ class Component(object): __all__.append('Component') +@session_wrapper def get_component(component, session=None): """ Returns database id for given C{component}. @@ -435,11 +398,6 @@ def get_component(component, session=None): """ component = component.lower() - privatetrans = False - if session is None: - session = DBConn().session() - privatetrans = True - q = session.query(Component).filter_by(component_name=component) if q.count() == 0: @@ -447,9 +405,6 @@ def get_component(component, session=None): else: ret = q.one() - if privatetrans: - session.close() - return ret __all__.append('get_component') @@ -517,6 +472,7 @@ def get_or_set_contents_file_id(filename, session=None): __all__.append('get_or_set_contents_file_id') +@session_wrapper def get_contents(suite, overridetype, section=None, session=None): """ Returns contents for a suite / overridetype combination, limiting @@ -540,11 +496,6 @@ def get_contents(suite, overridetype, section=None, session=None): package, arch_id) """ - privatetrans = False - if session is None: - session = DBConn().session() - privatetrans = True - # find me all of the contents for a given suite contents_q = """SELECT (p.path||'/'||n.file) AS fn, s.section, @@ -570,9 +521,6 @@ def get_contents(suite, overridetype, section=None, session=None): ret = session.execute(contents_q, vals) - if privatetrans: - session.close() - return ret __all__.append('get_contents') @@ -714,6 +662,7 @@ class DSCFile(object): __all__.append('DSCFile') +@session_wrapper def get_dscfiles(dscfile_id=None, source_id=None, poolfile_id=None, session=None): """ Returns a list of DSCFiles which may be empty @@ -731,11 +680,6 @@ def get_dscfiles(dscfile_id=None, source_id=None, poolfile_id=None, session=None @return: Possibly empty list of DSCFiles """ - privatetrans = False - if session is None: - session = DBConn().session() - privatetrans = True - q = session.query(DSCFile) if dscfile_id is not None: @@ -749,9 +693,6 @@ def get_dscfiles(dscfile_id=None, source_id=None, poolfile_id=None, session=None ret = q.all() - if privatetrans: - session.close() - return ret __all__.append('get_dscfiles') @@ -767,6 +708,7 @@ class PoolFile(object): __all__.append('PoolFile') +@session_wrapper def check_poolfile(filename, filesize, md5sum, location_id, session=None): """ Returns a tuple: @@ -794,11 +736,6 @@ def check_poolfile(filename, filesize, md5sum, location_id, session=None): (False, PoolFile object) if file found with size/md5sum mismatch """ - privatetrans = False - if session is None: - session = DBConn().session() - privatetrans = True - q = session.query(PoolFile).filter_by(filename=filename) q = q.join(Location).filter_by(location_id=location_id) @@ -816,13 +753,11 @@ def check_poolfile(filename, filesize, md5sum, location_id, session=None): if ret is None: ret = (True, obj) - if privatetrans: - session.close() - return ret __all__.append('check_poolfile') +@session_wrapper def get_poolfile_by_id(file_id, session=None): """ Returns a PoolFile objects or None for the given id @@ -834,11 +769,6 @@ def get_poolfile_by_id(file_id, session=None): @return: either the PoolFile object or None """ - privatetrans = False - if session is None: - session = DBConn().session() - privatetrans = True - q = session.query(PoolFile).filter_by(file_id=file_id) if q.count() > 0: @@ -846,14 +776,12 @@ def get_poolfile_by_id(file_id, session=None): else: ret = None - if privatetrans: - session.close() - return ret __all__.append('get_poolfile_by_id') +@session_wrapper def get_poolfile_by_name(filename, location_id=None, session=None): """ Returns an array of PoolFile objects for the given filename and @@ -869,11 +797,6 @@ def get_poolfile_by_name(filename, location_id=None, session=None): @return: array of PoolFile objects """ - privatetrans = False - if session is None: - session = DBConn().session() - privatetrans = True - q = session.query(PoolFile).filter_by(filename=filename) if location_id is not None: @@ -881,13 +804,11 @@ def get_poolfile_by_name(filename, location_id=None, session=None): ret = q.all() - if privatetrans: - session.close() - return ret __all__.append('get_poolfile_by_name') +@session_wrapper def get_poolfile_like_name(filename, session=None): """ Returns an array of PoolFile objects which are like the given name @@ -899,19 +820,11 @@ def get_poolfile_like_name(filename, session=None): @return: array of PoolFile objects """ - privatetrans = False - if session is None: - session = DBConn().session() - privatetrans = True - # TODO: There must be a way of properly using bind parameters with %FOO% q = session.query(PoolFile).filter(PoolFile.filename.like('%%%s%%' % filename)) ret = q.all() - if privatetrans: - session.close() - return ret __all__.append('get_poolfile_like_name') @@ -1028,6 +941,7 @@ class Location(object): __all__.append('Location') +@session_wrapper def get_location(location, component=None, archive=None, session=None): """ Returns Location object for the given combination of location, component @@ -1046,11 +960,6 @@ def get_location(location, component=None, archive=None, session=None): @return: Either a Location object or None if one can't be found """ - privatetrans = False - if session is None: - session = DBConn().session() - privatetrans = True - q = session.query(Location).filter_by(path=location) if archive is not None: @@ -1064,9 +973,6 @@ def get_location(location, component=None, archive=None, session=None): else: ret = q.one() - if privatetrans: - session.close() - return ret __all__.append('get_location') @@ -1167,6 +1073,7 @@ class NewComment(object): __all__.append('NewComment') +@session_wrapper def has_new_comment(package, version, session=None): """ Returns true if the given combination of C{package}, C{version} has a comment. @@ -1185,24 +1092,17 @@ def has_new_comment(package, version, session=None): @return: true/false """ - privatetrans = False - if session is None: - session = DBConn().session() - privatetrans = True - q = session.query(NewComment) q = q.filter_by(package=package) q = q.filter_by(version=version) ret = q.count() > 0 - if privatetrans: - session.close() - return ret __all__.append('has_new_comment') +@session_wrapper def get_new_comments(package=None, version=None, comment_id=None, session=None): """ Returns (possibly empty) list of NewComment objects for the given @@ -1223,14 +1123,8 @@ def get_new_comments(package=None, version=None, comment_id=None, session=None): @rtype: list @return: A (possibly empty) list of NewComment objects will be returned - """ - privatetrans = False - if session is None: - session = DBConn().session() - privatetrans = True - q = session.query(NewComment) if package is not None: q = q.filter_by(package=package) if version is not None: q = q.filter_by(version=version) @@ -1238,9 +1132,6 @@ def get_new_comments(package=None, version=None, comment_id=None, session=None): ret = q.all() - if privatetrans: - session.close() - return ret __all__.append('get_new_comments') @@ -1256,6 +1147,7 @@ class Override(object): __all__.append('Override') +@session_wrapper def get_override(package, suite=None, component=None, overridetype=None, session=None): """ Returns Override object for the given parameters @@ -1281,12 +1173,7 @@ def get_override(package, suite=None, component=None, overridetype=None, session @rtype: list @return: A (possibly empty) list of Override objects will be returned - """ - privatetrans = False - if session is None: - session = DBConn().session() - privatetrans = True q = session.query(Override) q = q.filter_by(package=package) @@ -1305,9 +1192,6 @@ def get_override(package, suite=None, component=None, overridetype=None, session ret = q.all() - if privatetrans: - session.close() - return ret __all__.append('get_override') @@ -1324,6 +1208,7 @@ class OverrideType(object): __all__.append('OverrideType') +@session_wrapper def get_override_type(override_type, session=None): """ Returns OverrideType object for given C{override type}. @@ -1337,12 +1222,7 @@ def get_override_type(override_type, session=None): @rtype: int @return: the database id for the given override type - """ - privatetrans = False - if session is None: - session = DBConn().session() - privatetrans = True q = session.query(OverrideType).filter_by(overridetype=override_type) @@ -1351,9 +1231,6 @@ def get_override_type(override_type, session=None): else: ret = q.one() - if privatetrans: - session.close() - return ret __all__.append('get_override_type') @@ -1469,6 +1346,7 @@ class Priority(object): __all__.append('Priority') +@session_wrapper def get_priority(priority, session=None): """ Returns Priority object for given C{priority name}. @@ -1482,12 +1360,7 @@ def get_priority(priority, session=None): @rtype: Priority @return: Priority object for the given priority - """ - privatetrans = False - if session is None: - session = DBConn().session() - privatetrans = True q = session.query(Priority).filter_by(priority=priority) @@ -1496,13 +1369,11 @@ def get_priority(priority, session=None): else: ret = q.one() - if privatetrans: - session.close() - return ret __all__.append('get_priority') +@session_wrapper def get_priorities(session=None): """ Returns dictionary of priority names -> id mappings @@ -1514,19 +1385,12 @@ def get_priorities(session=None): @rtype: dictionary @return: dictionary of priority names -> id mappings """ - privatetrans = False - if session is None: - session = DBConn().session() - privatetrans = True ret = {} q = session.query(Priority) for x in q.all(): ret[x.priority] = x.priority_id - if privatetrans: - session.close() - return ret __all__.append('get_priorities') @@ -1654,6 +1518,7 @@ class Queue(object): __all__.append('Queue') +@session_wrapper def get_queue(queuename, session=None): """ Returns Queue object for given C{queue name}. @@ -1667,12 +1532,7 @@ def get_queue(queuename, session=None): @rtype: Queue @return: Queue object for the given queue - """ - privatetrans = False - if session is None: - session = DBConn().session() - privatetrans = True q = session.query(Queue).filter_by(queue_name=queuename) if q.count() == 0: @@ -1680,9 +1540,6 @@ def get_queue(queuename, session=None): else: ret = q.one() - if privatetrans: - session.close() - return ret __all__.append('get_queue') @@ -1698,6 +1555,7 @@ class QueueBuild(object): __all__.append('QueueBuild') +@session_wrapper def get_queue_build(filename, suite, session=None): """ Returns QueueBuild object for given C{filename} and C{suite}. @@ -1714,12 +1572,7 @@ def get_queue_build(filename, suite, session=None): @rtype: Queue @return: Queue object for the given queue - """ - privatetrans = False - if session is None: - session = DBConn().session() - privatetrans = True if isinstance(suite, int): q = session.query(QueueBuild).filter_by(filename=filename).filter_by(suite_id=suite) @@ -1732,9 +1585,6 @@ def get_queue_build(filename, suite, session=None): else: ret = q.one() - if privatetrans: - session.close() - return ret __all__.append('get_queue_build') @@ -1762,6 +1612,7 @@ class Section(object): __all__.append('Section') +@session_wrapper def get_section(section, session=None): """ Returns Section object for given C{section name}. @@ -1775,12 +1626,7 @@ def get_section(section, session=None): @rtype: Section @return: Section object for the given section name - """ - privatetrans = False - if session is None: - session = DBConn().session() - privatetrans = True q = session.query(Section).filter_by(section=section) if q.count() == 0: @@ -1788,13 +1634,11 @@ def get_section(section, session=None): else: ret = q.one() - if privatetrans: - session.close() - return ret __all__.append('get_section') +@session_wrapper def get_sections(session=None): """ Returns dictionary of section names -> id mappings @@ -1806,19 +1650,12 @@ def get_sections(session=None): @rtype: dictionary @return: dictionary of section names -> id mappings """ - privatetrans = False - if session is None: - session = DBConn().session() - privatetrans = True ret = {} q = session.query(Section) for x in q.all(): ret[x.section] = x.section_id - if privatetrans: - session.close() - return ret __all__.append('get_sections') @@ -1834,6 +1671,7 @@ class DBSource(object): __all__.append('DBSource') +@session_wrapper def source_exists(source, source_version, suites = ["any"], session=None): """ Ensure that source exists somewhere in the archive for the binary @@ -1859,11 +1697,6 @@ def source_exists(source, source_version, suites = ["any"], session=None): """ - privatetrans = False - if session is None: - session = DBConn().session() - privatetrans = True - cnf = Config() ret = 1 @@ -1902,13 +1735,11 @@ def source_exists(source, source_version, suites = ["any"], session=None): # No source found so return not ok ret = 0 - if privatetrans: - session.close() - return ret __all__.append('source_exists') +@session_wrapper def get_suites_source_in(source, session=None): """ Returns list of Suite objects which given C{source} name is in @@ -1920,20 +1751,13 @@ def get_suites_source_in(source, session=None): @return: list of Suite objects for the given source """ - privatetrans = False - if session is None: - session = DBConn().session() - privatetrans = True - ret = session.query(Suite).join(SrcAssociation).join(DBSource).filter_by(source=source).all() - if privatetrans: - session.close() - return ret __all__.append('get_suites_source_in') +@session_wrapper def get_sources_from_name(source, version=None, dm_upload_allowed=None, session=None): """ Returns list of DBSource objects for given C{source} name and other parameters @@ -1955,10 +1779,6 @@ def get_sources_from_name(source, version=None, dm_upload_allowed=None, session= @rtype: list @return: list of DBSource objects for the given name (may be empty) """ - privatetrans = False - if session is None: - session = DBConn().session() - privatetrans = True q = session.query(DBSource).filter_by(source=source) @@ -1970,13 +1790,11 @@ def get_sources_from_name(source, version=None, dm_upload_allowed=None, session= ret = q.all() - if privatetrans: - session.close() - return ret __all__.append('get_sources_from_name') +@session_wrapper def get_source_in_suite(source, suite, session=None): """ Returns list of DBSource objects for a combination of C{source} and C{suite}. @@ -1994,10 +1812,6 @@ def get_source_in_suite(source, suite, session=None): @return: the version for I{source} in I{suite} """ - privatetrans = False - if session is None: - session = DBConn().session() - privatetrans = True q = session.query(SrcAssociation) q = q.join('source').filter_by(source=source) @@ -2009,9 +1823,6 @@ def get_source_in_suite(source, suite, session=None): # ???: Maybe we should just return the SrcAssociation object instead ret = q.one().source - if privatetrans: - session.close() - return ret __all__.append('get_source_in_suite') @@ -2090,6 +1901,7 @@ class Suite(object): __all__.append('Suite') +@session_wrapper def get_suite_architecture(suite, architecture, session=None): """ Returns a SuiteArchitecture object given C{suite} and ${arch} or None if it @@ -2109,11 +1921,6 @@ def get_suite_architecture(suite, architecture, session=None): @return: the SuiteArchitecture object or None """ - privatetrans = False - if session is None: - session = DBConn().session() - privatetrans = True - q = session.query(SuiteArchitecture) q = q.join(Architecture).filter_by(arch_string=architecture) q = q.join(Suite).filter_by(suite_name=suite) @@ -2123,13 +1930,11 @@ def get_suite_architecture(suite, architecture, session=None): else: ret = q.one() - if privatetrans: - session.close() - return ret __all__.append('get_suite_architecture') +@session_wrapper def get_suite(suite, session=None): """ Returns Suite object for given C{suite name}. @@ -2143,12 +1948,7 @@ def get_suite(suite, session=None): @rtype: Suite @return: Suite object for the requested suite name (None if not presenT) - """ - privatetrans = False - if session is None: - session = DBConn().session() - privatetrans = True q = session.query(Suite).filter_by(suite_name=suite) @@ -2157,9 +1957,6 @@ def get_suite(suite, session=None): else: ret = q.one() - if privatetrans: - session.close() - return ret __all__.append('get_suite') @@ -2175,6 +1972,7 @@ class SuiteArchitecture(object): __all__.append('SuiteArchitecture') +@session_wrapper def get_suite_architectures(suite, skipsrc=False, skipall=False, session=None): """ Returns list of Architecture objects for given C{suite} name @@ -2198,11 +1996,6 @@ def get_suite_architectures(suite, skipsrc=False, skipall=False, session=None): @return: list of Architecture objects for the given name (may be empty) """ - privatetrans = False - if session is None: - session = DBConn().session() - privatetrans = True - q = session.query(Architecture) q = q.join(SuiteArchitecture) q = q.join(Suite).filter_by(suite_name=suite) @@ -2217,9 +2010,6 @@ def get_suite_architectures(suite, skipsrc=False, skipall=False, session=None): ret = q.all() - if privatetrans: - session.close() - return ret __all__.append('get_suite_architectures') @@ -2320,13 +2110,8 @@ def get_or_set_uid(uidname, session=None): __all__.append('get_or_set_uid') - +@session_wrapper def get_uid_from_fingerprint(fpr, session=None): - privatetrans = False - if session is None: - session = DBConn().session() - privatetrans = True - q = session.query(Uid) q = q.join(Fingerprint).filter_by(fingerprint=fpr) @@ -2335,9 +2120,6 @@ def get_uid_from_fingerprint(fpr, session=None): else: ret = q.one() - if privatetrans: - session.close() - return ret __all__.append('get_uid_from_fingerprint')