Code

Ticket #11487: base.2.py

File base.2.py, 21.6 KB (added by mdpetry, 5 years ago)

changed variable to Database.CLOB

Line 
1"""
2Oracle database backend for Django.
3
4Requires cx_Oracle: http://cx-oracle.sourceforge.net/
5"""
6
7import os
8import datetime
9import time
10try:
11    from decimal import Decimal
12except ImportError:
13    from django.utils._decimal import Decimal
14
15# Oracle takes client-side character set encoding from the environment.
16os.environ['NLS_LANG'] = '.UTF8'
17# This prevents unicode from getting mangled by getting encoded into the
18# potentially non-unicode database character set.
19os.environ['ORA_NCHAR_LITERAL_REPLACE'] = 'TRUE'
20
21try:
22    import cx_Oracle as Database
23except ImportError, e:
24    from django.core.exceptions import ImproperlyConfigured
25    raise ImproperlyConfigured("Error loading cx_Oracle module: %s" % e)
26
27from django.db.backends import *
28from django.db.backends.signals import connection_created
29from django.db.backends.oracle import query
30from django.db.backends.oracle.client import DatabaseClient
31from django.db.backends.oracle.creation import DatabaseCreation
32from django.db.backends.oracle.introspection import DatabaseIntrospection
33from django.utils.encoding import smart_str, force_unicode
34
35DatabaseError = Database.DatabaseError
36IntegrityError = Database.IntegrityError
37
38
39class DatabaseFeatures(BaseDatabaseFeatures):
40    empty_fetchmany_value = ()
41    needs_datetime_string_cast = False
42    uses_custom_query_class = True
43    interprets_empty_strings_as_nulls = True
44    uses_savepoints = True
45    can_return_id_from_insert = True
46
47
48class DatabaseOperations(BaseDatabaseOperations):
49
50    def autoinc_sql(self, table, column):
51        # To simulate auto-incrementing primary keys in Oracle, we have to
52        # create a sequence and a trigger.
53        sq_name = get_sequence_name(table)
54        tr_name = get_trigger_name(table)
55        tbl_name = self.quote_name(table)
56        col_name = self.quote_name(column)
57        sequence_sql = """
58DECLARE
59    i INTEGER;
60BEGIN
61    SELECT COUNT(*) INTO i FROM USER_CATALOG
62        WHERE TABLE_NAME = '%(sq_name)s' AND TABLE_TYPE = 'SEQUENCE';
63    IF i = 0 THEN
64        EXECUTE IMMEDIATE 'CREATE SEQUENCE "%(sq_name)s"';
65    END IF;
66END;
67/""" % locals()
68        trigger_sql = """
69CREATE OR REPLACE TRIGGER "%(tr_name)s"
70BEFORE INSERT ON %(tbl_name)s
71FOR EACH ROW
72WHEN (new.%(col_name)s IS NULL)
73    BEGIN
74        SELECT "%(sq_name)s".nextval
75        INTO :new.%(col_name)s FROM dual;
76    END;
77/""" % locals()
78        return sequence_sql, trigger_sql
79
80    def date_extract_sql(self, lookup_type, field_name):
81        # http://download-east.oracle.com/docs/cd/B10501_01/server.920/a96540/functions42a.htm#1017163
82        if lookup_type == 'week_day':
83            # TO_CHAR(field, 'D') returns an integer from 1-7, where 1=Sunday.
84            return "TO_CHAR(%s, 'D')" % field_name
85        else:
86            return "EXTRACT(%s FROM %s)" % (lookup_type, field_name)
87
88    def date_trunc_sql(self, lookup_type, field_name):
89        # Oracle uses TRUNC() for both dates and numbers.
90        # http://download-east.oracle.com/docs/cd/B10501_01/server.920/a96540/functions155a.htm#SQLRF06151
91        if lookup_type == 'day':
92            sql = 'TRUNC(%s)' % field_name
93        else:
94            sql = "TRUNC(%s, '%s')" % (field_name, lookup_type)
95        return sql
96
97    def datetime_cast_sql(self):
98        return "TO_TIMESTAMP(%s, 'YYYY-MM-DD HH24:MI:SS.FF')"
99
100    def deferrable_sql(self):
101        return " DEFERRABLE INITIALLY DEFERRED"
102
103    def drop_sequence_sql(self, table):
104        return "DROP SEQUENCE %s;" % self.quote_name(get_sequence_name(table))
105
106    def fetch_returned_insert_id(self, cursor):
107        return long(cursor._insert_id_var.getvalue())
108
109    def field_cast_sql(self, db_type):
110        if db_type and db_type.endswith('LOB'):
111            return "DBMS_LOB.SUBSTR(%s)"
112        else:
113            return "%s"
114
115    def last_insert_id(self, cursor, table_name, pk_name):
116        sq_name = get_sequence_name(table_name)
117        cursor.execute('SELECT "%s".currval FROM dual' % sq_name)
118        return cursor.fetchone()[0]
119
120    def lookup_cast(self, lookup_type):
121        if lookup_type in ('iexact', 'icontains', 'istartswith', 'iendswith'):
122            return "UPPER(%s)"
123        return "%s"
124
125    def max_name_length(self):
126        return 30
127
128    def prep_for_iexact_query(self, x):
129        return x
130
131    def process_clob(self, value):
132        if value is None:
133            return u''
134        return force_unicode(value.read())
135
136    def query_class(self, DefaultQueryClass):
137        return query.query_class(DefaultQueryClass, Database)
138
139    def quote_name(self, name):
140        # SQL92 requires delimited (quoted) names to be case-sensitive.  When
141        # not quoted, Oracle has case-insensitive behavior for identifiers, but
142        # always defaults to uppercase.
143        # We simplify things by making Oracle identifiers always uppercase.
144        if not name.startswith('"') and not name.endswith('"'):
145            name = '"%s"' % util.truncate_name(name.upper(),
146                                               self.max_name_length())
147        return name.upper()
148
149    def random_function_sql(self):
150        return "DBMS_RANDOM.RANDOM"
151
152    def regex_lookup_9(self, lookup_type):
153        raise NotImplementedError("Regexes are not supported in Oracle before version 10g.")
154
155    def regex_lookup_10(self, lookup_type):
156        if lookup_type == 'regex':
157            match_option = "'c'"
158        else:
159            match_option = "'i'"
160        return 'REGEXP_LIKE(%%s, %%s, %s)' % match_option
161
162    def regex_lookup(self, lookup_type):
163        # If regex_lookup is called before it's been initialized, then create
164        # a cursor to initialize it and recur.
165        from django.db import connection
166        connection.cursor()
167        return connection.ops.regex_lookup(lookup_type)
168
169    def return_insert_id(self):
170        return "RETURNING %s INTO %%s", (InsertIdVar(),)
171
172    def savepoint_create_sql(self, sid):
173        return "SAVEPOINT " + self.quote_name(sid)
174
175    def savepoint_rollback_sql(self, sid):
176        return "ROLLBACK TO SAVEPOINT " + self.quote_name(sid)
177
178    def sql_flush(self, style, tables, sequences):
179        # Return a list of 'TRUNCATE x;', 'TRUNCATE y;',
180        # 'TRUNCATE z;'... style SQL statements
181        if tables:
182            # Oracle does support TRUNCATE, but it seems to get us into
183            # FK referential trouble, whereas DELETE FROM table works.
184            sql = ['%s %s %s;' % \
185                    (style.SQL_KEYWORD('DELETE'),
186                     style.SQL_KEYWORD('FROM'),
187                     style.SQL_FIELD(self.quote_name(table)))
188                    for table in tables]
189            # Since we've just deleted all the rows, running our sequence
190            # ALTER code will reset the sequence to 0.
191            for sequence_info in sequences:
192                sequence_name = get_sequence_name(sequence_info['table'])
193                table_name = self.quote_name(sequence_info['table'])
194                column_name = self.quote_name(sequence_info['column'] or 'id')
195                query = _get_sequence_reset_sql() % {'sequence': sequence_name,
196                                                     'table': table_name,
197                                                     'column': column_name}
198                sql.append(query)
199            return sql
200        else:
201            return []
202
203    def sequence_reset_sql(self, style, model_list):
204        from django.db import models
205        output = []
206        query = _get_sequence_reset_sql()
207        for model in model_list:
208            for f in model._meta.local_fields:
209                if isinstance(f, models.AutoField):
210                    table_name = self.quote_name(model._meta.db_table)
211                    sequence_name = get_sequence_name(model._meta.db_table)
212                    column_name = self.quote_name(f.column)
213                    output.append(query % {'sequence': sequence_name,
214                                           'table': table_name,
215                                           'column': column_name})
216                    # Only one AutoField is allowed per model, so don't
217                    # continue to loop
218                    break
219            for f in model._meta.many_to_many:
220                if not f.rel.through:
221                    table_name = self.quote_name(f.m2m_db_table())
222                    sequence_name = get_sequence_name(f.m2m_db_table())
223                    column_name = self.quote_name('id')
224                    output.append(query % {'sequence': sequence_name,
225                                           'table': table_name,
226                                           'column': column_name})
227        return output
228
229    def start_transaction_sql(self):
230        return ''
231
232    def tablespace_sql(self, tablespace, inline=False):
233        return "%sTABLESPACE %s" % ((inline and "USING INDEX " or ""),
234            self.quote_name(tablespace))
235
236    def value_to_db_time(self, value):
237        if value is None:
238            return None
239        if isinstance(value, basestring):
240            return datetime.datetime(*(time.strptime(value, '%H:%M:%S')[:6]))
241        return datetime.datetime(1900, 1, 1, value.hour, value.minute,
242                                 value.second, value.microsecond)
243
244    def year_lookup_bounds_for_date_field(self, value):
245        first = '%s-01-01'
246        second = '%s-12-31'
247        return [first % value, second % value]
248
249    def combine_expression(self, connector, sub_expressions):
250        "Oracle requires special cases for %% and & operators in query expressions"
251        if connector == '%%':
252            return 'MOD(%s)' % ','.join(sub_expressions)
253        elif connector == '&':
254            return 'BITAND(%s)' % ','.join(sub_expressions)
255        elif connector == '|':
256            raise NotImplementedError("Bit-wise or is not supported in Oracle.")
257        return super(DatabaseOperations, self).combine_expression(connector, sub_expressions)
258
259
260class DatabaseWrapper(BaseDatabaseWrapper):
261
262    operators = {
263        'exact': '= %s',
264        'iexact': '= UPPER(%s)',
265        'contains': "LIKEC %s ESCAPE '\\'",
266        'icontains': "LIKEC UPPER(%s) ESCAPE '\\'",
267        'gt': '> %s',
268        'gte': '>= %s',
269        'lt': '< %s',
270        'lte': '<= %s',
271        'startswith': "LIKEC %s ESCAPE '\\'",
272        'endswith': "LIKEC %s ESCAPE '\\'",
273        'istartswith': "LIKEC UPPER(%s) ESCAPE '\\'",
274        'iendswith': "LIKEC UPPER(%s) ESCAPE '\\'",
275    }
276    oracle_version = None
277
278    def __init__(self, *args, **kwargs):
279        super(DatabaseWrapper, self).__init__(*args, **kwargs)
280
281        self.features = DatabaseFeatures()
282        self.ops = DatabaseOperations()
283        self.client = DatabaseClient(self)
284        self.creation = DatabaseCreation(self)
285        self.introspection = DatabaseIntrospection(self)
286        self.validation = BaseDatabaseValidation()
287
288    def _valid_connection(self):
289        return self.connection is not None
290
291    def _connect_string(self):
292        settings_dict = self.settings_dict
293        if len(settings_dict['DATABASE_HOST'].strip()) == 0:
294            settings_dict['DATABASE_HOST'] = 'localhost'
295        if len(settings_dict['DATABASE_PORT'].strip()) != 0:
296            dsn = Database.makedsn(settings_dict['DATABASE_HOST'],
297                                   int(settings_dict['DATABASE_PORT']),
298                                   settings_dict['DATABASE_NAME'])
299        else:
300            dsn = settings_dict['DATABASE_NAME']
301        return "%s/%s@%s" % (settings_dict['DATABASE_USER'],
302                             settings_dict['DATABASE_PASSWORD'], dsn)
303
304    def _cursor(self):
305        cursor = None
306        if not self._valid_connection():
307            conn_string = self._connect_string()
308            self.connection = Database.connect(conn_string, **self.settings_dict['DATABASE_OPTIONS'])
309            cursor = FormatStylePlaceholderCursor(self.connection)
310            # Set oracle date to ansi date format.  This only needs to execute
311            # once when we create a new connection. We also set the Territory
312            # to 'AMERICA' which forces Sunday to evaluate to a '1' in TO_CHAR().
313            cursor.execute("ALTER SESSION SET NLS_DATE_FORMAT = 'YYYY-MM-DD HH24:MI:SS' "
314                           "NLS_TIMESTAMP_FORMAT = 'YYYY-MM-DD HH24:MI:SS.FF' "
315                           "NLS_TERRITORY = 'AMERICA'")
316            try:
317                self.oracle_version = int(self.connection.version.split('.')[0])
318                # There's no way for the DatabaseOperations class to know the
319                # currently active Oracle version, so we do some setups here.
320                # TODO: Multi-db support will need a better solution (a way to
321                # communicate the current version).
322                if self.oracle_version <= 9:
323                    self.ops.regex_lookup = self.ops.regex_lookup_9
324                else:
325                    self.ops.regex_lookup = self.ops.regex_lookup_10
326            except ValueError:
327                pass
328            try:
329                self.connection.stmtcachesize = 20
330            except:
331                # Django docs specify cx_Oracle version 4.3.1 or higher, but
332                # stmtcachesize is available only in 4.3.2 and up.
333                pass
334            connection_created.send(sender=self.__class__)
335        if not cursor:
336            cursor = FormatStylePlaceholderCursor(self.connection)
337        return cursor
338
339    # Oracle doesn't support savepoint commits.  Ignore them.
340    def _savepoint_commit(self, sid):
341        pass
342
343
344class OracleParam(object):
345    """
346    Wrapper object for formatting parameters for Oracle. If the string
347    representation of the value is large enough (greater than 4000 characters)
348    the input size needs to be set as NCLOB. Alternatively, if the parameter
349    has an `input_size` attribute, then the value of the `input_size` attribute
350    will be used instead. Otherwise, no input size will be set for the
351    parameter when executing the query.
352    """
353
354    def __init__(self, param, cursor, strings_only=False):
355        if hasattr(param, 'bind_parameter'):
356            self.smart_str = param.bind_parameter(cursor)
357        else:
358            self.smart_str = smart_str(param, cursor.charset, strings_only)
359        if hasattr(param, 'input_size'):
360            # If parameter has `input_size` attribute, use that.
361            self.input_size = param.input_size
362        elif isinstance(param, basestring) and len(param) > 4000:
363            # Mark any string param greater than 4000 characters as an CLOB.
364            self.input_size = Database.CLOB
365        else:
366            self.input_size = None
367
368
369class InsertIdVar(object):
370    """
371    A late-binding cursor variable that can be passed to Cursor.execute
372    as a parameter, in order to receive the id of the row created by an
373    insert statement.
374    """
375
376    def bind_parameter(self, cursor):
377        param = cursor.var(Database.NUMBER)
378        cursor._insert_id_var = param
379        return param
380
381
382class FormatStylePlaceholderCursor(object):
383    """
384    Django uses "format" (e.g. '%s') style placeholders, but Oracle uses ":var"
385    style. This fixes it -- but note that if you want to use a literal "%s" in
386    a query, you'll need to use "%%s".
387
388    We also do automatic conversion between Unicode on the Python side and
389    UTF-8 -- for talking to Oracle -- in here.
390    """
391    charset = 'utf-8'
392
393    def __init__(self, connection):
394        self.cursor = connection.cursor()
395        # Necessary to retrieve decimal values without rounding error.
396        self.cursor.numbersAsStrings = True
397        # Default arraysize of 1 is highly sub-optimal.
398        self.cursor.arraysize = 100
399
400    def _format_params(self, params):
401        return tuple([OracleParam(p, self, True) for p in params])
402
403    def _guess_input_sizes(self, params_list):
404        sizes = [None] * len(params_list[0])
405        for params in params_list:
406            for i, value in enumerate(params):
407                if value.input_size:
408                    sizes[i] = value.input_size
409        self.setinputsizes(*sizes)
410
411    def _param_generator(self, params):
412        return [p.smart_str for p in params]
413
414    def execute(self, query, params=None):
415        if params is None:
416            params = []
417        else:
418            params = self._format_params(params)
419        args = [(':arg%d' % i) for i in range(len(params))]
420        # cx_Oracle wants no trailing ';' for SQL statements.  For PL/SQL, it
421        # it does want a trailing ';' but not a trailing '/'.  However, these
422        # characters must be included in the original query in case the query
423        # is being passed to SQL*Plus.
424        if query.endswith(';') or query.endswith('/'):
425            query = query[:-1]
426        query = smart_str(query, self.charset) % tuple(args)
427        self._guess_input_sizes([params])
428        try:
429            return self.cursor.execute(query, self._param_generator(params))
430        except DatabaseError, e:
431            # cx_Oracle <= 4.4.0 wrongly raises a DatabaseError for ORA-01400.
432            if e.args[0].code == 1400 and not isinstance(e, IntegrityError):
433                e = IntegrityError(e.args[0])
434            raise e
435
436    def executemany(self, query, params=None):
437        try:
438            args = [(':arg%d' % i) for i in range(len(params[0]))]
439        except (IndexError, TypeError):
440            # No params given, nothing to do
441            return None
442        # cx_Oracle wants no trailing ';' for SQL statements.  For PL/SQL, it
443        # it does want a trailing ';' but not a trailing '/'.  However, these
444        # characters must be included in the original query in case the query
445        # is being passed to SQL*Plus.
446        if query.endswith(';') or query.endswith('/'):
447            query = query[:-1]
448        query = smart_str(query, self.charset) % tuple(args)
449        formatted = [self._format_params(i) for i in params]
450        self._guess_input_sizes(formatted)
451        try:
452            return self.cursor.executemany(query,
453                                [self._param_generator(p) for p in formatted])
454        except DatabaseError, e:
455            # cx_Oracle <= 4.4.0 wrongly raises a DatabaseError for ORA-01400.
456            if e.args[0].code == 1400 and not isinstance(e, IntegrityError):
457                e = IntegrityError(e.args[0])
458            raise e
459
460    def fetchone(self):
461        row = self.cursor.fetchone()
462        if row is None:
463            return row
464        return self._rowfactory(row)
465
466    def fetchmany(self, size=None):
467        if size is None:
468            size = self.arraysize
469        return tuple([self._rowfactory(r)
470                      for r in self.cursor.fetchmany(size)])
471
472    def fetchall(self):
473        return tuple([self._rowfactory(r)
474                      for r in self.cursor.fetchall()])
475
476    def _rowfactory(self, row):
477        # Cast numeric values as the appropriate Python type based upon the
478        # cursor description, and convert strings to unicode.
479        casted = []
480        for value, desc in zip(row, self.cursor.description):
481            if value is not None and desc[1] is Database.NUMBER:
482                precision, scale = desc[4:6]
483                if scale == -127:
484                    if precision == 0:
485                        # NUMBER column: decimal-precision floating point
486                        # This will normally be an integer from a sequence,
487                        # but it could be a decimal value.
488                        if '.' in value:
489                            value = Decimal(value)
490                        else:
491                            value = int(value)
492                    else:
493                        # FLOAT column: binary-precision floating point.
494                        # This comes from FloatField columns.
495                        value = float(value)
496                elif precision > 0:
497                    # NUMBER(p,s) column: decimal-precision fixed point.
498                    # This comes from IntField and DecimalField columns.
499                    if scale == 0:
500                        value = int(value)
501                    else:
502                        value = Decimal(value)
503                elif '.' in value:
504                    # No type information. This normally comes from a
505                    # mathematical expression in the SELECT list. Guess int
506                    # or Decimal based on whether it has a decimal point.
507                    value = Decimal(value)
508                else:
509                    value = int(value)
510            elif desc[1] in (Database.STRING, Database.FIXED_CHAR,
511                             Database.LONG_STRING):
512                value = to_unicode(value)
513            casted.append(value)
514        return tuple(casted)
515
516    def __getattr__(self, attr):
517        if attr in self.__dict__:
518            return self.__dict__[attr]
519        else:
520            return getattr(self.cursor, attr)
521
522    def __iter__(self):
523        return iter(self.cursor)
524
525
526def to_unicode(s):
527    """
528    Convert strings to Unicode objects (and return all other data types
529    unchanged).
530    """
531    if isinstance(s, basestring):
532        return force_unicode(s)
533    return s
534
535
536def _get_sequence_reset_sql():
537    # TODO: colorize this SQL code with style.SQL_KEYWORD(), etc.
538    return """
539DECLARE
540    startvalue integer;
541    cval integer;
542BEGIN
543    LOCK TABLE %(table)s IN SHARE MODE;
544    SELECT NVL(MAX(%(column)s), 0) INTO startvalue FROM %(table)s;
545    SELECT "%(sequence)s".nextval INTO cval FROM dual;
546    cval := startvalue - cval;
547    IF cval != 0 THEN
548        EXECUTE IMMEDIATE 'ALTER SEQUENCE "%(sequence)s" MINVALUE 0 INCREMENT BY '||cval;
549        SELECT "%(sequence)s".nextval INTO cval FROM dual;
550        EXECUTE IMMEDIATE 'ALTER SEQUENCE "%(sequence)s" INCREMENT BY 1';
551    END IF;
552    COMMIT;
553END;
554/"""
555
556
557def get_sequence_name(table):
558    name_length = DatabaseOperations().max_name_length() - 3
559    return '%s_SQ' % util.truncate_name(table, name_length).upper()
560
561
562def get_trigger_name(table):
563    name_length = DatabaseOperations().max_name_length() - 3
564    return '%s_TR' % util.truncate_name(table, name_length).upper()