]> err.no Git - dak/commitdiff
Add classes Validator and ValidatorTestCase.
authorTorsten Werner <twerner@debian.org>
Sun, 23 Jan 2011 22:29:06 +0000 (23:29 +0100)
committerTorsten Werner <twerner@debian.org>
Sun, 23 Jan 2011 22:29:06 +0000 (23:29 +0100)
Signed-off-by: Torsten Werner <twerner@debian.org>
daklib/dbconn.py
tests/dbtest_validation.py [new file with mode: 0755]

index 5d0e054e2c3e37a22c8c03033ff4d2939ad75f51..877c5a39fc4032e670c9f6295e65e1276c50fd7c 100755 (executable)
@@ -54,7 +54,8 @@ from inspect import getargspec
 
 import sqlalchemy
 from sqlalchemy import create_engine, Table, MetaData, Column, Integer
-from sqlalchemy.orm import sessionmaker, mapper, relation, object_session, backref
+from sqlalchemy.orm import sessionmaker, mapper, relation, object_session, \
+    backref, MapperExtension, EXT_CONTINUE
 from sqlalchemy import types as sqltypes
 
 # Don't remove this, we re-export the exceptions to scripts which import us
@@ -65,7 +66,7 @@ from sqlalchemy.orm.exc import NoResultFound
 # in the database
 from config import Config
 from textutils import fix_maintainer
-from dak_exceptions import NoSourceFieldError
+from dak_exceptions import DBUpdateError, NoSourceFieldError
 
 # suppress some deprecation warnings in squeeze related to sqlalchemy
 import warnings
@@ -192,7 +193,10 @@ class ORMObject(object):
         for property in all_properties:
             # check for list or query
             if property[-6:] == '_count':
-                value = getattr(self, property[:-6])
+                real_property = property[:-6]
+                if not hasattr(self, real_property):
+                    continue
+                value = getattr(self, real_property)
                 if hasattr(value, '__len__'):
                     # list
                     value = len(value)
@@ -202,6 +206,8 @@ class ORMObject(object):
                 else:
                     raise KeyError('Do not understand property %s.' % property)
             else:
+                if not hasattr(self, property):
+                    continue
                 # plain object
                 value = getattr(self, property)
                 if value is None:
@@ -239,10 +245,36 @@ class ORMObject(object):
         '''
         return '<%s %s>' % (self.classname(), self.json())
 
+    def validate(self):
+        '''
+        This function should be implemented by derived classes to validate self.
+        It may raise the DBUpdateError exception if needed.
+        '''
+        pass
+
 __all__.append('ORMObject')
 
 ################################################################################
 
+class Validator(MapperExtension):
+    '''
+    This class calls the validate() method for each instance for the
+    'before_update' and 'before_insert' events. A global object validator is
+    used for configuring the individual mappers.
+    '''
+
+    def before_update(self, mapper, connection, instance):
+        instance.validate()
+        return EXT_CONTINUE
+
+    def before_insert(self, mapper, connection, instance):
+        instance.validate()
+        return EXT_CONTINUE
+
+validator = Validator()
+
+################################################################################
+
 class Architecture(ORMObject):
     def __init__(self, arch_string = None, description = None):
         self.arch_string = arch_string
@@ -263,6 +295,12 @@ class Architecture(ORMObject):
     def properties(self):
         return ['arch_string', 'arch_id', 'suites_count']
 
+    def validate(self):
+        if self.arch_string is None or len(self.arch_string) == 0:
+            raise DBUpdateError( \
+                "Validation failed because 'arch_string' must not be empty in object\n%s" % \
+                str(self))
+
 __all__.append('Architecture')
 
 @session_wrapper
@@ -2907,10 +2945,11 @@ class DBConn(object):
 
     def __setupmappers(self):
         mapper(Architecture, self.tbl_architecture,
-           properties = dict(arch_id = self.tbl_architecture.c.id,
+            properties = dict(arch_id = self.tbl_architecture.c.id,
                suites = relation(Suite, secondary=self.tbl_suite_architectures,
                    order_by='suite_name',
-                   backref=backref('architectures', order_by='arch_string'))))
+                   backref=backref('architectures', order_by='arch_string'))),
+            extension = validator)
 
         mapper(Archive, self.tbl_archive,
                properties = dict(archive_id = self.tbl_archive.c.id,
diff --git a/tests/dbtest_validation.py b/tests/dbtest_validation.py
new file mode 100755 (executable)
index 0000000..597097c
--- /dev/null
@@ -0,0 +1,37 @@
+#!/usr/bin/env python
+
+from db_test import DBDakTestCase
+
+from daklib.dbconn import Architecture
+from daklib.dak_exceptions import DBUpdateError
+
+import unittest
+
+class ValidatorTestCase(DBDakTestCase):
+    """
+    The ValidatorTestCase tests the validation mechanism.
+    """
+
+    def must_fail(self):
+        ''''
+        This function must fail with DBUpdateError because arch_string is not
+        set. It rolls back the transaction before re-raising the exception.
+        '''
+        try:
+            architecture = Architecture()
+            self.session.add(architecture)
+            self.session.flush()
+        except:
+            self.session.rollback()
+            raise
+
+    def test_validation(self):
+        'tests validate()'
+        self.assertRaises(DBUpdateError, self.must_fail)
+        # should not fail
+        architecture = Architecture('i386')
+        self.session.add(architecture)
+        self.session.flush()
+
+if __name__ == '__main__':
+    unittest.main()