Ticket #11487: base.py

File base.py, 21.6 KB (added by Marcos Daniel Petry, 15 years ago)
Line 
1"""
2Oracle database backend for Django.
3
4Requires cx_Oracle: http://cx-oracle.sourceforge.net/
5"""
6
7import os
8import datetime
9import time
10try:
11 from decimal import Decimal
12except ImportError:
13 from django.utils._decimal import Decimal
14
15# Oracle takes client-side character set encoding from the environment.
16os.environ['NLS_LANG'] = '.UTF8'
17# This prevents unicode from getting mangled by getting encoded into the
18# potentially non-unicode database character set.
19os.environ['ORA_NCHAR_LITERAL_REPLACE'] = 'TRUE'
20
21try:
22 import cx_Oracle as Database
23except ImportError, e:
24 from django.core.exceptions import ImproperlyConfigured
25 raise ImproperlyConfigured("Error loading cx_Oracle module: %s" % e)
26
27from django.db.backends import *
28from django.db.backends.signals import connection_created
29from django.db.backends.oracle import query
30from django.db.backends.oracle.client import DatabaseClient
31from django.db.backends.oracle.creation import DatabaseCreation
32from django.db.backends.oracle.introspection import DatabaseIntrospection
33from django.utils.encoding import smart_str, force_unicode
34
35DatabaseError = Database.DatabaseError
36IntegrityError = Database.IntegrityError
37
38
39class DatabaseFeatures(BaseDatabaseFeatures):
40 empty_fetchmany_value = ()
41 needs_datetime_string_cast = False
42 uses_custom_query_class = True
43 interprets_empty_strings_as_nulls = True
44 uses_savepoints = True
45 can_return_id_from_insert = True
46
47
48class DatabaseOperations(BaseDatabaseOperations):
49
50 def autoinc_sql(self, table, column):
51 # To simulate auto-incrementing primary keys in Oracle, we have to
52 # create a sequence and a trigger.
53 sq_name = get_sequence_name(table)
54 tr_name = get_trigger_name(table)
55 tbl_name = self.quote_name(table)
56 col_name = self.quote_name(column)
57 sequence_sql = """
58DECLARE
59 i INTEGER;
60BEGIN
61 SELECT COUNT(*) INTO i FROM USER_CATALOG
62 WHERE TABLE_NAME = '%(sq_name)s' AND TABLE_TYPE = 'SEQUENCE';
63 IF i = 0 THEN
64 EXECUTE IMMEDIATE 'CREATE SEQUENCE "%(sq_name)s"';
65 END IF;
66END;
67/""" % locals()
68 trigger_sql = """
69CREATE OR REPLACE TRIGGER "%(tr_name)s"
70BEFORE INSERT ON %(tbl_name)s
71FOR EACH ROW
72WHEN (new.%(col_name)s IS NULL)
73 BEGIN
74 SELECT "%(sq_name)s".nextval
75 INTO :new.%(col_name)s FROM dual;
76 END;
77/""" % locals()
78 return sequence_sql, trigger_sql
79
80 def date_extract_sql(self, lookup_type, field_name):
81 # http://download-east.oracle.com/docs/cd/B10501_01/server.920/a96540/functions42a.htm#1017163
82 if lookup_type == 'week_day':
83 # TO_CHAR(field, 'D') returns an integer from 1-7, where 1=Sunday.
84 return "TO_CHAR(%s, 'D')" % field_name
85 else:
86 return "EXTRACT(%s FROM %s)" % (lookup_type, field_name)
87
88 def date_trunc_sql(self, lookup_type, field_name):
89 # Oracle uses TRUNC() for both dates and numbers.
90 # http://download-east.oracle.com/docs/cd/B10501_01/server.920/a96540/functions155a.htm#SQLRF06151
91 if lookup_type == 'day':
92 sql = 'TRUNC(%s)' % field_name
93 else:
94 sql = "TRUNC(%s, '%s')" % (field_name, lookup_type)
95 return sql
96
97 def datetime_cast_sql(self):
98 return "TO_TIMESTAMP(%s, 'YYYY-MM-DD HH24:MI:SS.FF')"
99
100 def deferrable_sql(self):
101 return " DEFERRABLE INITIALLY DEFERRED"
102
103 def drop_sequence_sql(self, table):
104 return "DROP SEQUENCE %s;" % self.quote_name(get_sequence_name(table))
105
106 def fetch_returned_insert_id(self, cursor):
107 return long(cursor._insert_id_var.getvalue())
108
109 def field_cast_sql(self, db_type):
110 if db_type and db_type.endswith('LOB'):
111 return "DBMS_LOB.SUBSTR(%s)"
112 else:
113 return "%s"
114
115 def last_insert_id(self, cursor, table_name, pk_name):
116 sq_name = get_sequence_name(table_name)
117 cursor.execute('SELECT "%s".currval FROM dual' % sq_name)
118 return cursor.fetchone()[0]
119
120 def lookup_cast(self, lookup_type):
121 if lookup_type in ('iexact', 'icontains', 'istartswith', 'iendswith'):
122 return "UPPER(%s)"
123 return "%s"
124
125 def max_name_length(self):
126 return 30
127
128 def prep_for_iexact_query(self, x):
129 return x
130
131 def process_clob(self, value):
132 if value is None:
133 return u''
134 return force_unicode(value.read())
135
136 def query_class(self, DefaultQueryClass):
137 return query.query_class(DefaultQueryClass, Database)
138
139 def quote_name(self, name):
140 # SQL92 requires delimited (quoted) names to be case-sensitive. When
141 # not quoted, Oracle has case-insensitive behavior for identifiers, but
142 # always defaults to uppercase.
143 # We simplify things by making Oracle identifiers always uppercase.
144 if not name.startswith('"') and not name.endswith('"'):
145 name = '"%s"' % util.truncate_name(name.upper(),
146 self.max_name_length())
147 return name.upper()
148
149 def random_function_sql(self):
150 return "DBMS_RANDOM.RANDOM"
151
152 def regex_lookup_9(self, lookup_type):
153 raise NotImplementedError("Regexes are not supported in Oracle before version 10g.")
154
155 def regex_lookup_10(self, lookup_type):
156 if lookup_type == 'regex':
157 match_option = "'c'"
158 else:
159 match_option = "'i'"
160 return 'REGEXP_LIKE(%%s, %%s, %s)' % match_option
161
162 def regex_lookup(self, lookup_type):
163 # If regex_lookup is called before it's been initialized, then create
164 # a cursor to initialize it and recur.
165 from django.db import connection
166 connection.cursor()
167 return connection.ops.regex_lookup(lookup_type)
168
169 def return_insert_id(self):
170 return "RETURNING %s INTO %%s", (InsertIdVar(),)
171
172 def savepoint_create_sql(self, sid):
173 return "SAVEPOINT " + self.quote_name(sid)
174
175 def savepoint_rollback_sql(self, sid):
176 return "ROLLBACK TO SAVEPOINT " + self.quote_name(sid)
177
178 def sql_flush(self, style, tables, sequences):
179 # Return a list of 'TRUNCATE x;', 'TRUNCATE y;',
180 # 'TRUNCATE z;'... style SQL statements
181 if tables:
182 # Oracle does support TRUNCATE, but it seems to get us into
183 # FK referential trouble, whereas DELETE FROM table works.
184 sql = ['%s %s %s;' % \
185 (style.SQL_KEYWORD('DELETE'),
186 style.SQL_KEYWORD('FROM'),
187 style.SQL_FIELD(self.quote_name(table)))
188 for table in tables]
189 # Since we've just deleted all the rows, running our sequence
190 # ALTER code will reset the sequence to 0.
191 for sequence_info in sequences:
192 sequence_name = get_sequence_name(sequence_info['table'])
193 table_name = self.quote_name(sequence_info['table'])
194 column_name = self.quote_name(sequence_info['column'] or 'id')
195 query = _get_sequence_reset_sql() % {'sequence': sequence_name,
196 'table': table_name,
197 'column': column_name}
198 sql.append(query)
199 return sql
200 else:
201 return []
202
203 def sequence_reset_sql(self, style, model_list):
204 from django.db import models
205 output = []
206 query = _get_sequence_reset_sql()
207 for model in model_list:
208 for f in model._meta.local_fields:
209 if isinstance(f, models.AutoField):
210 table_name = self.quote_name(model._meta.db_table)
211 sequence_name = get_sequence_name(model._meta.db_table)
212 column_name = self.quote_name(f.column)
213 output.append(query % {'sequence': sequence_name,
214 'table': table_name,
215 'column': column_name})
216 # Only one AutoField is allowed per model, so don't
217 # continue to loop
218 break
219 for f in model._meta.many_to_many:
220 if not f.rel.through:
221 table_name = self.quote_name(f.m2m_db_table())
222 sequence_name = get_sequence_name(f.m2m_db_table())
223 column_name = self.quote_name('id')
224 output.append(query % {'sequence': sequence_name,
225 'table': table_name,
226 'column': column_name})
227 return output
228
229 def start_transaction_sql(self):
230 return ''
231
232 def tablespace_sql(self, tablespace, inline=False):
233 return "%sTABLESPACE %s" % ((inline and "USING INDEX " or ""),
234 self.quote_name(tablespace))
235
236 def value_to_db_time(self, value):
237 if value is None:
238 return None
239 if isinstance(value, basestring):
240 return datetime.datetime(*(time.strptime(value, '%H:%M:%S')[:6]))
241 return datetime.datetime(1900, 1, 1, value.hour, value.minute,
242 value.second, value.microsecond)
243
244 def year_lookup_bounds_for_date_field(self, value):
245 first = '%s-01-01'
246 second = '%s-12-31'
247 return [first % value, second % value]
248
249 def combine_expression(self, connector, sub_expressions):
250 "Oracle requires special cases for %% and & operators in query expressions"
251 if connector == '%%':
252 return 'MOD(%s)' % ','.join(sub_expressions)
253 elif connector == '&':
254 return 'BITAND(%s)' % ','.join(sub_expressions)
255 elif connector == '|':
256 raise NotImplementedError("Bit-wise or is not supported in Oracle.")
257 return super(DatabaseOperations, self).combine_expression(connector, sub_expressions)
258
259
260class DatabaseWrapper(BaseDatabaseWrapper):
261
262 operators = {
263 'exact': '= %s',
264 'iexact': '= UPPER(%s)',
265 'contains': "LIKEC %s ESCAPE '\\'",
266 'icontains': "LIKEC UPPER(%s) ESCAPE '\\'",
267 'gt': '> %s',
268 'gte': '>= %s',
269 'lt': '< %s',
270 'lte': '<= %s',
271 'startswith': "LIKEC %s ESCAPE '\\'",
272 'endswith': "LIKEC %s ESCAPE '\\'",
273 'istartswith': "LIKEC UPPER(%s) ESCAPE '\\'",
274 'iendswith': "LIKEC UPPER(%s) ESCAPE '\\'",
275 }
276 oracle_version = None
277
278 def __init__(self, *args, **kwargs):
279 super(DatabaseWrapper, self).__init__(*args, **kwargs)
280
281 self.features = DatabaseFeatures()
282 self.ops = DatabaseOperations()
283 self.client = DatabaseClient(self)
284 self.creation = DatabaseCreation(self)
285 self.introspection = DatabaseIntrospection(self)
286 self.validation = BaseDatabaseValidation()
287
288 def _valid_connection(self):
289 return self.connection is not None
290
291 def _connect_string(self):
292 settings_dict = self.settings_dict
293 if len(settings_dict['DATABASE_HOST'].strip()) == 0:
294 settings_dict['DATABASE_HOST'] = 'localhost'
295 if len(settings_dict['DATABASE_PORT'].strip()) != 0:
296 dsn = Database.makedsn(settings_dict['DATABASE_HOST'],
297 int(settings_dict['DATABASE_PORT']),
298 settings_dict['DATABASE_NAME'])
299 else:
300 dsn = settings_dict['DATABASE_NAME']
301 return "%s/%s@%s" % (settings_dict['DATABASE_USER'],
302 settings_dict['DATABASE_PASSWORD'], dsn)
303
304 def _cursor(self):
305 cursor = None
306 if not self._valid_connection():
307 conn_string = self._connect_string()
308 self.connection = Database.connect(conn_string, **self.settings_dict['DATABASE_OPTIONS'])
309 cursor = FormatStylePlaceholderCursor(self.connection)
310 # Set oracle date to ansi date format. This only needs to execute
311 # once when we create a new connection. We also set the Territory
312 # to 'AMERICA' which forces Sunday to evaluate to a '1' in TO_CHAR().
313 cursor.execute("ALTER SESSION SET NLS_DATE_FORMAT = 'YYYY-MM-DD HH24:MI:SS' "
314 "NLS_TIMESTAMP_FORMAT = 'YYYY-MM-DD HH24:MI:SS.FF' "
315 "NLS_TERRITORY = 'AMERICA'")
316 try:
317 self.oracle_version = int(self.connection.version.split('.')[0])
318 # There's no way for the DatabaseOperations class to know the
319 # currently active Oracle version, so we do some setups here.
320 # TODO: Multi-db support will need a better solution (a way to
321 # communicate the current version).
322 if self.oracle_version <= 9:
323 self.ops.regex_lookup = self.ops.regex_lookup_9
324 else:
325 self.ops.regex_lookup = self.ops.regex_lookup_10
326 except ValueError:
327 pass
328 try:
329 self.connection.stmtcachesize = 20
330 except:
331 # Django docs specify cx_Oracle version 4.3.1 or higher, but
332 # stmtcachesize is available only in 4.3.2 and up.
333 pass
334 connection_created.send(sender=self.__class__)
335 if not cursor:
336 cursor = FormatStylePlaceholderCursor(self.connection)
337 return cursor
338
339 # Oracle doesn't support savepoint commits. Ignore them.
340 def _savepoint_commit(self, sid):
341 pass
342
343
344class OracleParam(object):
345 """
346 Wrapper object for formatting parameters for Oracle. If the string
347 representation of the value is large enough (greater than 4000 characters)
348 the input size needs to be set as NCLOB. Alternatively, if the parameter
349 has an `input_size` attribute, then the value of the `input_size` attribute
350 will be used instead. Otherwise, no input size will be set for the
351 parameter when executing the query.
352 """
353
354 def __init__(self, param, cursor, strings_only=False):
355 if hasattr(param, 'bind_parameter'):
356 self.smart_str = param.bind_parameter(cursor)
357 else:
358 self.smart_str = smart_str(param, cursor.charset, strings_only)
359 if hasattr(param, 'input_size'):
360 # If parameter has `input_size` attribute, use that.
361 self.input_size = param.input_size
362 elif isinstance(param, basestring) and len(param) > 4000:
363 # Mark any string param greater than 4000 characters as an NCLOB.
364 self.input_size = Database.LONG_STRING
365 else:
366 self.input_size = None
367
368
369class InsertIdVar(object):
370 """
371 A late-binding cursor variable that can be passed to Cursor.execute
372 as a parameter, in order to receive the id of the row created by an
373 insert statement.
374 """
375
376 def bind_parameter(self, cursor):
377 param = cursor.var(Database.NUMBER)
378 cursor._insert_id_var = param
379 return param
380
381
382class FormatStylePlaceholderCursor(object):
383 """
384 Django uses "format" (e.g. '%s') style placeholders, but Oracle uses ":var"
385 style. This fixes it -- but note that if you want to use a literal "%s" in
386 a query, you'll need to use "%%s".
387
388 We also do automatic conversion between Unicode on the Python side and
389 UTF-8 -- for talking to Oracle -- in here.
390 """
391 charset = 'utf-8'
392
393 def __init__(self, connection):
394 self.cursor = connection.cursor()
395 # Necessary to retrieve decimal values without rounding error.
396 self.cursor.numbersAsStrings = True
397 # Default arraysize of 1 is highly sub-optimal.
398 self.cursor.arraysize = 100
399
400 def _format_params(self, params):
401 return tuple([OracleParam(p, self, True) for p in params])
402
403 def _guess_input_sizes(self, params_list):
404 sizes = [None] * len(params_list[0])
405 for params in params_list:
406 for i, value in enumerate(params):
407 if value.input_size:
408 sizes[i] = value.input_size
409 self.setinputsizes(*sizes)
410
411 def _param_generator(self, params):
412 return [p.smart_str for p in params]
413
414 def execute(self, query, params=None):
415 if params is None:
416 params = []
417 else:
418 params = self._format_params(params)
419 args = [(':arg%d' % i) for i in range(len(params))]
420 # cx_Oracle wants no trailing ';' for SQL statements. For PL/SQL, it
421 # it does want a trailing ';' but not a trailing '/'. However, these
422 # characters must be included in the original query in case the query
423 # is being passed to SQL*Plus.
424 if query.endswith(';') or query.endswith('/'):
425 query = query[:-1]
426 query = smart_str(query, self.charset) % tuple(args)
427 self._guess_input_sizes([params])
428 try:
429 return self.cursor.execute(query, self._param_generator(params))
430 except DatabaseError, e:
431 # cx_Oracle <= 4.4.0 wrongly raises a DatabaseError for ORA-01400.
432 if e.args[0].code == 1400 and not isinstance(e, IntegrityError):
433 e = IntegrityError(e.args[0])
434 raise e
435
436 def executemany(self, query, params=None):
437 try:
438 args = [(':arg%d' % i) for i in range(len(params[0]))]
439 except (IndexError, TypeError):
440 # No params given, nothing to do
441 return None
442 # cx_Oracle wants no trailing ';' for SQL statements. For PL/SQL, it
443 # it does want a trailing ';' but not a trailing '/'. However, these
444 # characters must be included in the original query in case the query
445 # is being passed to SQL*Plus.
446 if query.endswith(';') or query.endswith('/'):
447 query = query[:-1]
448 query = smart_str(query, self.charset) % tuple(args)
449 formatted = [self._format_params(i) for i in params]
450 self._guess_input_sizes(formatted)
451 try:
452 return self.cursor.executemany(query,
453 [self._param_generator(p) for p in formatted])
454 except DatabaseError, e:
455 # cx_Oracle <= 4.4.0 wrongly raises a DatabaseError for ORA-01400.
456 if e.args[0].code == 1400 and not isinstance(e, IntegrityError):
457 e = IntegrityError(e.args[0])
458 raise e
459
460 def fetchone(self):
461 row = self.cursor.fetchone()
462 if row is None:
463 return row
464 return self._rowfactory(row)
465
466 def fetchmany(self, size=None):
467 if size is None:
468 size = self.arraysize
469 return tuple([self._rowfactory(r)
470 for r in self.cursor.fetchmany(size)])
471
472 def fetchall(self):
473 return tuple([self._rowfactory(r)
474 for r in self.cursor.fetchall()])
475
476 def _rowfactory(self, row):
477 # Cast numeric values as the appropriate Python type based upon the
478 # cursor description, and convert strings to unicode.
479 casted = []
480 for value, desc in zip(row, self.cursor.description):
481 if value is not None and desc[1] is Database.NUMBER:
482 precision, scale = desc[4:6]
483 if scale == -127:
484 if precision == 0:
485 # NUMBER column: decimal-precision floating point
486 # This will normally be an integer from a sequence,
487 # but it could be a decimal value.
488 if '.' in value:
489 value = Decimal(value)
490 else:
491 value = int(value)
492 else:
493 # FLOAT column: binary-precision floating point.
494 # This comes from FloatField columns.
495 value = float(value)
496 elif precision > 0:
497 # NUMBER(p,s) column: decimal-precision fixed point.
498 # This comes from IntField and DecimalField columns.
499 if scale == 0:
500 value = int(value)
501 else:
502 value = Decimal(value)
503 elif '.' in value:
504 # No type information. This normally comes from a
505 # mathematical expression in the SELECT list. Guess int
506 # or Decimal based on whether it has a decimal point.
507 value = Decimal(value)
508 else:
509 value = int(value)
510 elif desc[1] in (Database.STRING, Database.FIXED_CHAR,
511 Database.LONG_STRING):
512 value = to_unicode(value)
513 casted.append(value)
514 return tuple(casted)
515
516 def __getattr__(self, attr):
517 if attr in self.__dict__:
518 return self.__dict__[attr]
519 else:
520 return getattr(self.cursor, attr)
521
522 def __iter__(self):
523 return iter(self.cursor)
524
525
526def to_unicode(s):
527 """
528 Convert strings to Unicode objects (and return all other data types
529 unchanged).
530 """
531 if isinstance(s, basestring):
532 return force_unicode(s)
533 return s
534
535
536def _get_sequence_reset_sql():
537 # TODO: colorize this SQL code with style.SQL_KEYWORD(), etc.
538 return """
539DECLARE
540 startvalue integer;
541 cval integer;
542BEGIN
543 LOCK TABLE %(table)s IN SHARE MODE;
544 SELECT NVL(MAX(%(column)s), 0) INTO startvalue FROM %(table)s;
545 SELECT "%(sequence)s".nextval INTO cval FROM dual;
546 cval := startvalue - cval;
547 IF cval != 0 THEN
548 EXECUTE IMMEDIATE 'ALTER SEQUENCE "%(sequence)s" MINVALUE 0 INCREMENT BY '||cval;
549 SELECT "%(sequence)s".nextval INTO cval FROM dual;
550 EXECUTE IMMEDIATE 'ALTER SEQUENCE "%(sequence)s" INCREMENT BY 1';
551 END IF;
552 COMMIT;
553END;
554/"""
555
556
557def get_sequence_name(table):
558 name_length = DatabaseOperations().max_name_length() - 3
559 return '%s_SQ' % util.truncate_name(table, name_length).upper()
560
561
562def get_trigger_name(table):
563 name_length = DatabaseOperations().max_name_length() - 3
564 return '%s_TR' % util.truncate_name(table, name_length).upper()
Back to Top