Code

Ticket #5114: base.py

File base.py, 10.6 KB (added by aaron@…, 7 years ago)

fixed postgres base.py at line 244, properly quotes the m2m_db_table()

Line 
1"""
2PostgreSQL database backend for Django.
3
4Requires psycopg 1: http://initd.org/projects/psycopg1
5"""
6
7from django.db.backends import util
8try:
9    import psycopg as Database
10except ImportError, e:
11    from django.core.exceptions import ImproperlyConfigured
12    raise ImproperlyConfigured, "Error loading psycopg module: %s" % e
13
14DatabaseError = Database.DatabaseError
15IntegrityError = Database.IntegrityError
16
17try:
18    # Only exists in Python 2.4+
19    from threading import local
20except ImportError:
21    # Import copy of _thread_local.py from Python 2.4
22    from django.utils._threading_local import local
23
24def smart_basestring(s, charset):
25    if isinstance(s, unicode):
26        return s.encode(charset)
27    return s
28
29class UnicodeCursorWrapper(object):
30    """
31    A thin wrapper around psycopg cursors that allows them to accept Unicode
32    strings as params.
33
34    This is necessary because psycopg doesn't apply any DB quoting to
35    parameters that are Unicode strings. If a param is Unicode, this will
36    convert it to a bytestring using DEFAULT_CHARSET before passing it to
37    psycopg.
38    """
39    def __init__(self, cursor, charset):
40        self.cursor = cursor
41        self.charset = charset
42
43    def execute(self, sql, params=()):
44        return self.cursor.execute(sql, [smart_basestring(p, self.charset) for p in params])
45
46    def executemany(self, sql, param_list):
47        new_param_list = [tuple([smart_basestring(p, self.charset) for p in params]) for params in param_list]
48        return self.cursor.executemany(sql, new_param_list)
49
50    def __getattr__(self, attr):
51        if attr in self.__dict__:
52            return self.__dict__[attr]
53        else:
54            return getattr(self.cursor, attr)
55
56postgres_version = None
57
58class DatabaseWrapper(local):
59    def __init__(self, **kwargs):
60        self.connection = None
61        self.queries = []
62        self.options = kwargs
63
64    def cursor(self):
65        from django.conf import settings
66        set_tz = False
67        if self.connection is None:
68            set_tz = True
69            if settings.DATABASE_NAME == '':
70                from django.core.exceptions import ImproperlyConfigured
71                raise ImproperlyConfigured, "You need to specify DATABASE_NAME in your Django settings file."
72            conn_string = "dbname=%s" % settings.DATABASE_NAME
73            if settings.DATABASE_USER:
74                conn_string = "user=%s %s" % (settings.DATABASE_USER, conn_string)
75            if settings.DATABASE_PASSWORD:
76                conn_string += " password='%s'" % settings.DATABASE_PASSWORD
77            if settings.DATABASE_HOST:
78                conn_string += " host=%s" % settings.DATABASE_HOST
79            if settings.DATABASE_PORT:
80                conn_string += " port=%s" % settings.DATABASE_PORT
81            self.connection = Database.connect(conn_string, **self.options)
82            self.connection.set_isolation_level(1) # make transactions transparent to all cursors
83        cursor = self.connection.cursor()
84        if set_tz:
85            cursor.execute("SET TIME ZONE %s", [settings.TIME_ZONE])
86        cursor = UnicodeCursorWrapper(cursor, settings.DEFAULT_CHARSET)
87        global postgres_version
88        if not postgres_version:
89            cursor.execute("SELECT version()")
90            postgres_version = [int(val) for val in cursor.fetchone()[0].split()[1].split('.')]       
91        if settings.DEBUG:
92            return util.CursorDebugWrapper(cursor, self)
93        return cursor
94
95    def _commit(self):
96        if self.connection is not None:
97            return self.connection.commit()
98
99    def _rollback(self):
100        if self.connection is not None:
101            return self.connection.rollback()
102
103    def close(self):
104        if self.connection is not None:
105            self.connection.close()
106            self.connection = None
107
108supports_constraints = True
109
110def quote_name(name):
111    if name.startswith('"') and name.endswith('"'):
112        return name # Quoting once is enough.
113    return '"%s"' % name
114
115def dictfetchone(cursor):
116    "Returns a row from the cursor as a dict"
117    return cursor.dictfetchone()
118
119def dictfetchmany(cursor, number):
120    "Returns a certain number of rows from a cursor as a dict"
121    return cursor.dictfetchmany(number)
122
123def dictfetchall(cursor):
124    "Returns all rows from a cursor as a dict"
125    return cursor.dictfetchall()
126
127def get_last_insert_id(cursor, table_name, pk_name):
128    cursor.execute("SELECT CURRVAL('\"%s_%s_seq\"')" % (table_name, pk_name))
129    return cursor.fetchone()[0]
130
131def get_date_extract_sql(lookup_type, table_name):
132    # lookup_type is 'year', 'month', 'day'
133    # http://www.postgresql.org/docs/8.0/static/functions-datetime.html#FUNCTIONS-DATETIME-EXTRACT
134    return "EXTRACT('%s' FROM %s)" % (lookup_type, table_name)
135
136def get_date_trunc_sql(lookup_type, field_name):
137    # lookup_type is 'year', 'month', 'day'
138    # http://www.postgresql.org/docs/8.0/static/functions-datetime.html#FUNCTIONS-DATETIME-TRUNC
139    return "DATE_TRUNC('%s', %s)" % (lookup_type, field_name)
140
141def get_limit_offset_sql(limit, offset=None):
142    sql = "LIMIT %s" % limit
143    if offset and offset != 0:
144        sql += " OFFSET %s" % offset
145    return sql
146
147def get_random_function_sql():
148    return "RANDOM()"
149
150def get_deferrable_sql():
151    return " DEFERRABLE INITIALLY DEFERRED"
152   
153def get_fulltext_search_sql(field_name):
154    raise NotImplementedError
155
156def get_drop_foreignkey_sql():
157    return "DROP CONSTRAINT"
158
159def get_pk_default_value():
160    return "DEFAULT"
161
162def get_sql_flush(style, tables, sequences):
163    """Return a list of SQL statements required to remove all data from
164    all tables in the database (without actually removing the tables
165    themselves) and put the database in an empty 'initial' state
166   
167    """   
168    if tables:
169        if postgres_version[0] >= 8 and postgres_version[1] >= 1:
170            # Postgres 8.1+ can do 'TRUNCATE x, y, z...;'. In fact, it *has to* in order to be able to
171            # truncate tables referenced by a foreign key in any other table. The result is a
172            # single SQL TRUNCATE statement.
173            sql = ['%s %s;' % \
174                (style.SQL_KEYWORD('TRUNCATE'),
175                 style.SQL_FIELD(', '.join([quote_name(table) for table in tables]))
176            )]
177        else:
178            # Older versions of Postgres can't do TRUNCATE in a single call, so they must use
179            # a simple delete.
180            sql = ['%s %s %s;' % \
181                    (style.SQL_KEYWORD('DELETE'),
182                     style.SQL_KEYWORD('FROM'),
183                     style.SQL_FIELD(quote_name(table))
184                     ) for table in tables]
185
186        # 'ALTER SEQUENCE sequence_name RESTART WITH 1;'... style SQL statements
187        # to reset sequence indices
188        for sequence_info in sequences:
189            table_name = sequence_info['table']
190            column_name = sequence_info['column']
191            if column_name and len(column_name)>0:
192                # sequence name in this case will be <table>_<column>_seq
193                sql.append("%s %s %s %s %s %s;" % \
194                    (style.SQL_KEYWORD('ALTER'),
195                    style.SQL_KEYWORD('SEQUENCE'),
196                    style.SQL_FIELD(quote_name('%s_%s_seq' % (table_name, column_name))),
197                    style.SQL_KEYWORD('RESTART'),
198                    style.SQL_KEYWORD('WITH'),
199                    style.SQL_FIELD('1')
200                    )
201                )
202            else:
203                # sequence name in this case will be <table>_id_seq
204                sql.append("%s %s %s %s %s %s;" % \
205                    (style.SQL_KEYWORD('ALTER'),
206                     style.SQL_KEYWORD('SEQUENCE'),
207                     style.SQL_FIELD(quote_name('%s_id_seq' % table_name)),
208                     style.SQL_KEYWORD('RESTART'),
209                     style.SQL_KEYWORD('WITH'),
210                     style.SQL_FIELD('1')
211                     )
212                )
213        return sql
214    else:
215        return []
216
217def get_sql_sequence_reset(style, model_list):
218    "Returns a list of the SQL statements to reset sequences for the given models."
219    from django.db import models
220    output = []
221    for model in model_list:
222        # Use `coalesce` to set the sequence for each model to the max pk value if there are records,
223        # or 1 if there are none. Set the `is_called` property (the third argument to `setval`) to true
224        # if there are records (as the max pk value is already in use), otherwise set it to false.
225        for f in model._meta.fields:
226            if isinstance(f, models.AutoField):
227                output.append("%s setval('%s', coalesce(max(%s), 1), max(%s) %s null) %s %s;" % \
228                    (style.SQL_KEYWORD('SELECT'),
229                    style.SQL_FIELD(quote_name('%s_%s_seq' % (model._meta.db_table, f.column))),
230                    style.SQL_FIELD(quote_name(f.column)),
231                    style.SQL_FIELD(quote_name(f.column)),
232                    style.SQL_KEYWORD('IS NOT'),
233                    style.SQL_KEYWORD('FROM'),
234                    style.SQL_TABLE(quote_name(model._meta.db_table))))
235                break # Only one AutoField is allowed per model, so don't bother continuing.
236        for f in model._meta.many_to_many:
237            output.append("%s setval('%s', coalesce(max(%s), 1), max(%s) %s null) %s %s;" % \
238                (style.SQL_KEYWORD('SELECT'),
239                style.SQL_FIELD(quote_name('%s_id_seq' % f.m2m_db_table())),
240                style.SQL_FIELD(quote_name('id')),
241                style.SQL_FIELD(quote_name('id')),
242                style.SQL_KEYWORD('IS NOT'),
243                style.SQL_KEYWORD('FROM'),
244                style.SQL_TABLE(quote_name(f.m2m_db_table()))))
245    return output
246       
247# Register these custom typecasts, because Django expects dates/times to be
248# in Python's native (standard-library) datetime/time format, whereas psycopg
249# use mx.DateTime by default.
250try:
251    Database.register_type(Database.new_type((1082,), "DATE", util.typecast_date))
252except AttributeError:
253    raise Exception, "You appear to be using psycopg version 2. Set your DATABASE_ENGINE to 'postgresql_psycopg2' instead of 'postgresql'."
254Database.register_type(Database.new_type((1083,1266), "TIME", util.typecast_time))
255Database.register_type(Database.new_type((1114,1184), "TIMESTAMP", util.typecast_timestamp))
256Database.register_type(Database.new_type((16,), "BOOLEAN", util.typecast_boolean))
257Database.register_type(Database.new_type((1700,), "NUMERIC", util.typecast_decimal))
258
259OPERATOR_MAPPING = {
260    'exact': '= %s',
261    'iexact': 'ILIKE %s',
262    'contains': 'LIKE %s',
263    'icontains': 'ILIKE %s',
264    'gt': '> %s',
265    'gte': '>= %s',
266    'lt': '< %s',
267    'lte': '<= %s',
268    'startswith': 'LIKE %s',
269    'endswith': 'LIKE %s',
270    'istartswith': 'ILIKE %s',
271    'iendswith': 'ILIKE %s',
272}