Ticket #5114: base.py

File base.py, 10.6 KB (added by aaron@…, 17 years ago)

fixed postgres base.py at line 244, properly quotes the m2m_db_table()

Line 
1"""
2PostgreSQL database backend for Django.
3
4Requires psycopg 1: http://initd.org/projects/psycopg1
5"""
6
7from django.db.backends import util
8try:
9 import psycopg as Database
10except ImportError, e:
11 from django.core.exceptions import ImproperlyConfigured
12 raise ImproperlyConfigured, "Error loading psycopg module: %s" % e
13
14DatabaseError = Database.DatabaseError
15IntegrityError = Database.IntegrityError
16
17try:
18 # Only exists in Python 2.4+
19 from threading import local
20except ImportError:
21 # Import copy of _thread_local.py from Python 2.4
22 from django.utils._threading_local import local
23
24def smart_basestring(s, charset):
25 if isinstance(s, unicode):
26 return s.encode(charset)
27 return s
28
29class UnicodeCursorWrapper(object):
30 """
31 A thin wrapper around psycopg cursors that allows them to accept Unicode
32 strings as params.
33
34 This is necessary because psycopg doesn't apply any DB quoting to
35 parameters that are Unicode strings. If a param is Unicode, this will
36 convert it to a bytestring using DEFAULT_CHARSET before passing it to
37 psycopg.
38 """
39 def __init__(self, cursor, charset):
40 self.cursor = cursor
41 self.charset = charset
42
43 def execute(self, sql, params=()):
44 return self.cursor.execute(sql, [smart_basestring(p, self.charset) for p in params])
45
46 def executemany(self, sql, param_list):
47 new_param_list = [tuple([smart_basestring(p, self.charset) for p in params]) for params in param_list]
48 return self.cursor.executemany(sql, new_param_list)
49
50 def __getattr__(self, attr):
51 if attr in self.__dict__:
52 return self.__dict__[attr]
53 else:
54 return getattr(self.cursor, attr)
55
56postgres_version = None
57
58class DatabaseWrapper(local):
59 def __init__(self, **kwargs):
60 self.connection = None
61 self.queries = []
62 self.options = kwargs
63
64 def cursor(self):
65 from django.conf import settings
66 set_tz = False
67 if self.connection is None:
68 set_tz = True
69 if settings.DATABASE_NAME == '':
70 from django.core.exceptions import ImproperlyConfigured
71 raise ImproperlyConfigured, "You need to specify DATABASE_NAME in your Django settings file."
72 conn_string = "dbname=%s" % settings.DATABASE_NAME
73 if settings.DATABASE_USER:
74 conn_string = "user=%s %s" % (settings.DATABASE_USER, conn_string)
75 if settings.DATABASE_PASSWORD:
76 conn_string += " password='%s'" % settings.DATABASE_PASSWORD
77 if settings.DATABASE_HOST:
78 conn_string += " host=%s" % settings.DATABASE_HOST
79 if settings.DATABASE_PORT:
80 conn_string += " port=%s" % settings.DATABASE_PORT
81 self.connection = Database.connect(conn_string, **self.options)
82 self.connection.set_isolation_level(1) # make transactions transparent to all cursors
83 cursor = self.connection.cursor()
84 if set_tz:
85 cursor.execute("SET TIME ZONE %s", [settings.TIME_ZONE])
86 cursor = UnicodeCursorWrapper(cursor, settings.DEFAULT_CHARSET)
87 global postgres_version
88 if not postgres_version:
89 cursor.execute("SELECT version()")
90 postgres_version = [int(val) for val in cursor.fetchone()[0].split()[1].split('.')]
91 if settings.DEBUG:
92 return util.CursorDebugWrapper(cursor, self)
93 return cursor
94
95 def _commit(self):
96 if self.connection is not None:
97 return self.connection.commit()
98
99 def _rollback(self):
100 if self.connection is not None:
101 return self.connection.rollback()
102
103 def close(self):
104 if self.connection is not None:
105 self.connection.close()
106 self.connection = None
107
108supports_constraints = True
109
110def quote_name(name):
111 if name.startswith('"') and name.endswith('"'):
112 return name # Quoting once is enough.
113 return '"%s"' % name
114
115def dictfetchone(cursor):
116 "Returns a row from the cursor as a dict"
117 return cursor.dictfetchone()
118
119def dictfetchmany(cursor, number):
120 "Returns a certain number of rows from a cursor as a dict"
121 return cursor.dictfetchmany(number)
122
123def dictfetchall(cursor):
124 "Returns all rows from a cursor as a dict"
125 return cursor.dictfetchall()
126
127def get_last_insert_id(cursor, table_name, pk_name):
128 cursor.execute("SELECT CURRVAL('\"%s_%s_seq\"')" % (table_name, pk_name))
129 return cursor.fetchone()[0]
130
131def get_date_extract_sql(lookup_type, table_name):
132 # lookup_type is 'year', 'month', 'day'
133 # http://www.postgresql.org/docs/8.0/static/functions-datetime.html#FUNCTIONS-DATETIME-EXTRACT
134 return "EXTRACT('%s' FROM %s)" % (lookup_type, table_name)
135
136def get_date_trunc_sql(lookup_type, field_name):
137 # lookup_type is 'year', 'month', 'day'
138 # http://www.postgresql.org/docs/8.0/static/functions-datetime.html#FUNCTIONS-DATETIME-TRUNC
139 return "DATE_TRUNC('%s', %s)" % (lookup_type, field_name)
140
141def get_limit_offset_sql(limit, offset=None):
142 sql = "LIMIT %s" % limit
143 if offset and offset != 0:
144 sql += " OFFSET %s" % offset
145 return sql
146
147def get_random_function_sql():
148 return "RANDOM()"
149
150def get_deferrable_sql():
151 return " DEFERRABLE INITIALLY DEFERRED"
152
153def get_fulltext_search_sql(field_name):
154 raise NotImplementedError
155
156def get_drop_foreignkey_sql():
157 return "DROP CONSTRAINT"
158
159def get_pk_default_value():
160 return "DEFAULT"
161
162def get_sql_flush(style, tables, sequences):
163 """Return a list of SQL statements required to remove all data from
164 all tables in the database (without actually removing the tables
165 themselves) and put the database in an empty 'initial' state
166
167 """
168 if tables:
169 if postgres_version[0] >= 8 and postgres_version[1] >= 1:
170 # Postgres 8.1+ can do 'TRUNCATE x, y, z...;'. In fact, it *has to* in order to be able to
171 # truncate tables referenced by a foreign key in any other table. The result is a
172 # single SQL TRUNCATE statement.
173 sql = ['%s %s;' % \
174 (style.SQL_KEYWORD('TRUNCATE'),
175 style.SQL_FIELD(', '.join([quote_name(table) for table in tables]))
176 )]
177 else:
178 # Older versions of Postgres can't do TRUNCATE in a single call, so they must use
179 # a simple delete.
180 sql = ['%s %s %s;' % \
181 (style.SQL_KEYWORD('DELETE'),
182 style.SQL_KEYWORD('FROM'),
183 style.SQL_FIELD(quote_name(table))
184 ) for table in tables]
185
186 # 'ALTER SEQUENCE sequence_name RESTART WITH 1;'... style SQL statements
187 # to reset sequence indices
188 for sequence_info in sequences:
189 table_name = sequence_info['table']
190 column_name = sequence_info['column']
191 if column_name and len(column_name)>0:
192 # sequence name in this case will be <table>_<column>_seq
193 sql.append("%s %s %s %s %s %s;" % \
194 (style.SQL_KEYWORD('ALTER'),
195 style.SQL_KEYWORD('SEQUENCE'),
196 style.SQL_FIELD(quote_name('%s_%s_seq' % (table_name, column_name))),
197 style.SQL_KEYWORD('RESTART'),
198 style.SQL_KEYWORD('WITH'),
199 style.SQL_FIELD('1')
200 )
201 )
202 else:
203 # sequence name in this case will be <table>_id_seq
204 sql.append("%s %s %s %s %s %s;" % \
205 (style.SQL_KEYWORD('ALTER'),
206 style.SQL_KEYWORD('SEQUENCE'),
207 style.SQL_FIELD(quote_name('%s_id_seq' % table_name)),
208 style.SQL_KEYWORD('RESTART'),
209 style.SQL_KEYWORD('WITH'),
210 style.SQL_FIELD('1')
211 )
212 )
213 return sql
214 else:
215 return []
216
217def get_sql_sequence_reset(style, model_list):
218 "Returns a list of the SQL statements to reset sequences for the given models."
219 from django.db import models
220 output = []
221 for model in model_list:
222 # Use `coalesce` to set the sequence for each model to the max pk value if there are records,
223 # or 1 if there are none. Set the `is_called` property (the third argument to `setval`) to true
224 # if there are records (as the max pk value is already in use), otherwise set it to false.
225 for f in model._meta.fields:
226 if isinstance(f, models.AutoField):
227 output.append("%s setval('%s', coalesce(max(%s), 1), max(%s) %s null) %s %s;" % \
228 (style.SQL_KEYWORD('SELECT'),
229 style.SQL_FIELD(quote_name('%s_%s_seq' % (model._meta.db_table, f.column))),
230 style.SQL_FIELD(quote_name(f.column)),
231 style.SQL_FIELD(quote_name(f.column)),
232 style.SQL_KEYWORD('IS NOT'),
233 style.SQL_KEYWORD('FROM'),
234 style.SQL_TABLE(quote_name(model._meta.db_table))))
235 break # Only one AutoField is allowed per model, so don't bother continuing.
236 for f in model._meta.many_to_many:
237 output.append("%s setval('%s', coalesce(max(%s), 1), max(%s) %s null) %s %s;" % \
238 (style.SQL_KEYWORD('SELECT'),
239 style.SQL_FIELD(quote_name('%s_id_seq' % f.m2m_db_table())),
240 style.SQL_FIELD(quote_name('id')),
241 style.SQL_FIELD(quote_name('id')),
242 style.SQL_KEYWORD('IS NOT'),
243 style.SQL_KEYWORD('FROM'),
244 style.SQL_TABLE(quote_name(f.m2m_db_table()))))
245 return output
246
247# Register these custom typecasts, because Django expects dates/times to be
248# in Python's native (standard-library) datetime/time format, whereas psycopg
249# use mx.DateTime by default.
250try:
251 Database.register_type(Database.new_type((1082,), "DATE", util.typecast_date))
252except AttributeError:
253 raise Exception, "You appear to be using psycopg version 2. Set your DATABASE_ENGINE to 'postgresql_psycopg2' instead of 'postgresql'."
254Database.register_type(Database.new_type((1083,1266), "TIME", util.typecast_time))
255Database.register_type(Database.new_type((1114,1184), "TIMESTAMP", util.typecast_timestamp))
256Database.register_type(Database.new_type((16,), "BOOLEAN", util.typecast_boolean))
257Database.register_type(Database.new_type((1700,), "NUMERIC", util.typecast_decimal))
258
259OPERATOR_MAPPING = {
260 'exact': '= %s',
261 'iexact': 'ILIKE %s',
262 'contains': 'LIKE %s',
263 'icontains': 'ILIKE %s',
264 'gt': '> %s',
265 'gte': '>= %s',
266 'lt': '< %s',
267 'lte': '<= %s',
268 'startswith': 'LIKE %s',
269 'endswith': 'LIKE %s',
270 'istartswith': 'ILIKE %s',
271 'iendswith': 'ILIKE %s',
272}
Back to Top