Blame view

server/database/db_adapters.py 4.51 KB
胡边 committed
1 2
# -*- coding: utf-8 -*-

3 4
from server.log import log

胡边 committed
5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48

class SQLAlchemyAdapterMetaClass(type):
    @staticmethod
    def wrap(func):
        """Return a wrapped instance method"""

        def auto_commit(self, *args, **kwargs):
            try:
                # todo a trick for DB transaction issue
                # self.commit()
                return_value = func(self, *args, **kwargs)
                self.commit()
                return return_value
            except:
                self.rollback()
                raise

        return auto_commit

    def __new__(cls, name, bases, attrs):
        """If the method in this list, DON'T wrap it"""
        no_wrap = ["commit", "merge", "rollback", "remove", "session"]

        def wrap(method):
            """private methods are not wrapped"""
            if method not in no_wrap and not method.startswith("__"):
                attrs[method] = cls.wrap(attrs[method])

        map(lambda m: wrap(m), attrs.keys())
        return super(SQLAlchemyAdapterMetaClass, cls).__new__(cls, name, bases, attrs)


class DBAdapter(object):
    def __init__(self, db_session):
        self.db_session = db_session


class SQLAlchemyAdapter(DBAdapter):
    """Use MetaClass to make this class"""
    __metaclass__ = SQLAlchemyAdapterMetaClass

    def __init__(self, db_session):
        super(SQLAlchemyAdapter, self).__init__(db_session)

49
    # ------------------ ORM  basic functions -------------------- #
胡边 committed
50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65

    def commit(self):
        self.db_session.commit()

    def remove(self):
        self.db_session.remove()

    def merge(self, obj):
        self.db_session.merge(obj)

    def rollback(self):
        self.db_session.rollback()

    def session(self):
        return self.db_session

66
    # -------------------- public usage functions ------------------ #
胡边 committed
67 68 69 70 71

    def get_object(self, ObjectClass, id):
        """ Retrieve one object specified by the primary key 'pk' """
        return ObjectClass.query.get(id)

72
    def get_all_objects(self, ObjectClass, *criterion):
胡边 committed
73 74
        return ObjectClass.query.filter(*criterion).all()

75
    def get_all_objects_by(self, ObjectClass, **kwargs):
胡边 committed
76 77
        return ObjectClass.query.filter_by(**kwargs).all()

78
    def get_all_objects_order_by(self, ObjectClass, limit=None, *order_by, **kwargs):
胡边 committed
79 80 81 82 83
        if limit is not None:
            return ObjectClass.query.filter_by(**kwargs).order_by(*order_by).limit(limit)
        else:
            return ObjectClass.query.filter_by(**kwargs).order_by(*order_by).all()

84 85 86
    def get_first_object_by(self, ObjectClass, **kwargs):
        return ObjectClass.query.filter_by(**kwargs).first()

胡边 committed
87 88 89 90 91 92
    def count(self, ObjectClass, *criterion):
        return ObjectClass.query.filter(*criterion).count()

    def count_by(self, ObjectClass, **kwargs):
        return ObjectClass.query.filter_by(**kwargs).count()

93
    def get_first_object(self, ObjectClass, *criterion):
胡边 committed
94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129
        return ObjectClass.query.filter(*criterion).first()

    def add_object(self, inst):
        self.db_session.add(inst)

    def add_object_kwargs(self, ObjectClass, **kwargs):
        """ Add an object of class 'ObjectClass' with fields and values specified in '**kwargs'. """
        object = ObjectClass(**kwargs)
        self.db_session.add(object)
        return object

    def update_object(self, object, **kwargs):
        """ Update object 'object' with the fields and values specified in '**kwargs'. """
        for key, value in kwargs.items():
            if hasattr(object, key):
                setattr(object, key, value)
            else:
                raise KeyError("Object '%s' has no field '%s'." % (type(object), key))

    def delete_object(self, instance):
        """ Delete object 'object'. """
        self.db_session.delete(instance)

    def delete_all_objects(self, ObjectClass, *criterion):
        ObjectClass.query.filter(*criterion).delete(synchronize_session=False)

    def delete_all_objects_by(self, ObjectClass, **kwargs):
        """ Delete all objects matching the case sensitive filters in 'kwargs'. """

        # Convert each name/value pair in 'kwargs' into a filter
        query = ObjectClass.query.filter_by(**kwargs)

        # query filter by in_ do not support none args, use synchronize_session=False instead
        return query.delete(synchronize_session=False)

        # ------------------------------ auto wrapped 'public' methods  --- end ------------------------------
130 131 132 133 134 135 136

    def exec_sql(self, sql_str):
        try:
            self.session().execte(sql_str)
        except Exception as ex:
            log.error(ex)
            return None