db_adapters.py
4.51 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
# -*- coding: utf-8 -*-
from server.log import log
class SQLAlchemyAdapterMetaClass(type):
@staticmethod
def wrap(func):
"""Return a wrapped instance method"""
def auto_commit(self, *args, **kwargs):
try:
# todo a trick for DB transaction issue
# self.commit()
return_value = func(self, *args, **kwargs)
self.commit()
return return_value
except:
self.rollback()
raise
return auto_commit
def __new__(cls, name, bases, attrs):
"""If the method in this list, DON'T wrap it"""
no_wrap = ["commit", "merge", "rollback", "remove", "session"]
def wrap(method):
"""private methods are not wrapped"""
if method not in no_wrap and not method.startswith("__"):
attrs[method] = cls.wrap(attrs[method])
map(lambda m: wrap(m), attrs.keys())
return super(SQLAlchemyAdapterMetaClass, cls).__new__(cls, name, bases, attrs)
class DBAdapter(object):
def __init__(self, db_session):
self.db_session = db_session
class SQLAlchemyAdapter(DBAdapter):
"""Use MetaClass to make this class"""
__metaclass__ = SQLAlchemyAdapterMetaClass
def __init__(self, db_session):
super(SQLAlchemyAdapter, self).__init__(db_session)
# ------------------ ORM basic functions -------------------- #
def commit(self):
self.db_session.commit()
def remove(self):
self.db_session.remove()
def merge(self, obj):
self.db_session.merge(obj)
def rollback(self):
self.db_session.rollback()
def session(self):
return self.db_session
# -------------------- public usage functions ------------------ #
def get_object(self, ObjectClass, id):
""" Retrieve one object specified by the primary key 'pk' """
return ObjectClass.query.get(id)
def get_all_objects(self, ObjectClass, *criterion):
return ObjectClass.query.filter(*criterion).all()
def get_all_objects_by(self, ObjectClass, **kwargs):
return ObjectClass.query.filter_by(**kwargs).all()
def get_all_objects_order_by(self, ObjectClass, limit=None, *order_by, **kwargs):
if limit is not None:
return ObjectClass.query.filter_by(**kwargs).order_by(*order_by).limit(limit)
else:
return ObjectClass.query.filter_by(**kwargs).order_by(*order_by).all()
def get_first_object_by(self, ObjectClass, **kwargs):
return ObjectClass.query.filter_by(**kwargs).first()
def count(self, ObjectClass, *criterion):
return ObjectClass.query.filter(*criterion).count()
def count_by(self, ObjectClass, **kwargs):
return ObjectClass.query.filter_by(**kwargs).count()
def get_first_object(self, ObjectClass, *criterion):
return ObjectClass.query.filter(*criterion).first()
def add_object(self, inst):
self.db_session.add(inst)
def add_object_kwargs(self, ObjectClass, **kwargs):
""" Add an object of class 'ObjectClass' with fields and values specified in '**kwargs'. """
object = ObjectClass(**kwargs)
self.db_session.add(object)
return object
def update_object(self, object, **kwargs):
""" Update object 'object' with the fields and values specified in '**kwargs'. """
for key, value in kwargs.items():
if hasattr(object, key):
setattr(object, key, value)
else:
raise KeyError("Object '%s' has no field '%s'." % (type(object), key))
def delete_object(self, instance):
""" Delete object 'object'. """
self.db_session.delete(instance)
def delete_all_objects(self, ObjectClass, *criterion):
ObjectClass.query.filter(*criterion).delete(synchronize_session=False)
def delete_all_objects_by(self, ObjectClass, **kwargs):
""" Delete all objects matching the case sensitive filters in 'kwargs'. """
# Convert each name/value pair in 'kwargs' into a filter
query = ObjectClass.query.filter_by(**kwargs)
# query filter by in_ do not support none args, use synchronize_session=False instead
return query.delete(synchronize_session=False)
# ------------------------------ auto wrapped 'public' methods --- end ------------------------------
def exec_sql(self, sql_str):
try:
self.session().execte(sql_str)
except Exception as ex:
log.error(ex)
return None