Ticket #6161: oracle_qsrf_patch.diff

File oracle_qsrf_patch.diff, 15.3 KB (added by jbronn, 16 years ago)

Patch enabling Oracle functionality on the queryset-refactor branch.

  • branches/queryset-refactor/django/db/backends/oracle/base.py

     
    9393        return 30
    9494
    9595    def query_set_class(self, DefaultQuerySet):
    96         from django.db import connection
    97         from django.db.models.query import EmptyResultSet, GET_ITERATOR_CHUNK_SIZE, quote_only_if_word
     96        # Getting the DefaultQuery class.
     97        DefaultQuery = DefaultQuerySet().query.__class__
    9898
    99         class OracleQuerySet(DefaultQuerySet):
    100 
    101             def iterator(self):
    102                 "Performs the SELECT database lookup of this QuerySet."
    103 
    104                 from django.db.models.query import get_cached_row
    105 
    106                 # self._select is a dictionary, and dictionaries' key order is
    107                 # undefined, so we convert it to a list of tuples.
    108                 extra_select = self._select.items()
    109 
    110                 full_query = None
    111 
    112                 try:
    113                     try:
    114                         select, sql, params, full_query = self._get_sql_clause(get_full_query=True)
    115                     except TypeError:
    116                         select, sql, params = self._get_sql_clause()
    117                 except EmptyResultSet:
    118                     raise StopIteration
    119                 if not full_query:
    120                     full_query = "SELECT %s%s\n%s" % ((self._distinct and "DISTINCT " or ""), ', '.join(select), sql)
    121 
    122                 cursor = connection.cursor()
    123                 cursor.execute(full_query, params)
    124 
    125                 fill_cache = self._select_related
    126                 fields = self.model._meta.fields
    127                 index_end = len(fields)
    128 
    129                 # so here's the logic;
    130                 # 1. retrieve each row in turn
    131                 # 2. convert NCLOBs
    132 
    133                 while 1:
    134                     rows = cursor.fetchmany(GET_ITERATOR_CHUNK_SIZE)
    135                     if not rows:
    136                         raise StopIteration
    137                     for row in rows:
    138                         row = self.resolve_columns(row, fields)
    139                         if fill_cache:
    140                             obj, index_end = get_cached_row(klass=self.model, row=row,
    141                                                             index_start=0, max_depth=self._max_related_depth)
    142                         else:
    143                             obj = self.model(*row[:index_end])
    144                         for i, k in enumerate(extra_select):
    145                             setattr(obj, k[0], row[index_end+i])
    146                         yield obj
    147 
    148 
    149             def _get_sql_clause(self, get_full_query=False):
    150                 from django.db.models.query import fill_table_cache, \
    151                     handle_legacy_orderlist, orderfield2column
    152 
    153                 opts = self.model._meta
    154                 qn = connection.ops.quote_name
    155 
    156                 # Construct the fundamental parts of the query: SELECT X FROM Y WHERE Z.
    157                 select = ["%s.%s" % (qn(opts.db_table), qn(f.column)) for f in opts.fields]
    158                 tables = [quote_only_if_word(t) for t in self._tables]
    159                 joins = SortedDict()
    160                 where = self._where[:]
    161                 params = self._params[:]
    162 
    163                 # Convert self._filters into SQL.
    164                 joins2, where2, params2 = self._filters.get_sql(opts)
    165                 joins.update(joins2)
    166                 where.extend(where2)
    167                 params.extend(params2)
    168 
    169                 # Add additional tables and WHERE clauses based on select_related.
    170                 if self._select_related:
    171                     fill_table_cache(opts, select, tables, where, opts.db_table, [opts.db_table])
    172 
    173                 # Add any additional SELECTs.
    174                 if self._select:
    175                     select.extend(['(%s) AS %s' % (quote_only_if_word(s[1]), qn(s[0])) for s in self._select.items()])
    176 
    177                 # Start composing the body of the SQL statement.
    178                 sql = [" FROM", qn(opts.db_table)]
    179 
    180                 # Compose the join dictionary into SQL describing the joins.
    181                 if joins:
    182                     sql.append(" ".join(["%s %s %s ON %s" % (join_type, table, alias, condition)
    183                                     for (alias, (table, join_type, condition)) in joins.items()]))
    184 
    185                 # Compose the tables clause into SQL.
    186                 if tables:
    187                     sql.append(", " + ", ".join(tables))
    188 
    189                 # Compose the where clause into SQL.
    190                 if where:
    191                     sql.append(where and "WHERE " + " AND ".join(where))
    192 
    193                 # ORDER BY clause
    194                 order_by = []
    195                 if self._order_by is not None:
    196                     ordering_to_use = self._order_by
    197                 else:
    198                     ordering_to_use = opts.ordering
    199                 for f in handle_legacy_orderlist(ordering_to_use):
    200                     if f == '?': # Special case.
    201                         order_by.append(DatabaseOperations().random_function_sql())
    202                     else:
    203                         if f.startswith('-'):
    204                             col_name = f[1:]
    205                             order = "DESC"
    206                         else:
    207                             col_name = f
    208                             order = "ASC"
    209                         if "." in col_name:
    210                             table_prefix, col_name = col_name.split('.', 1)
    211                             table_prefix = qn(table_prefix) + '.'
    212                         else:
    213                             # Use the database table as a column prefix if it wasn't given,
    214                             # and if the requested column isn't a custom SELECT.
    215                             if "." not in col_name and col_name not in (self._select or ()):
    216                                 table_prefix = qn(opts.db_table) + '.'
    217                             else:
    218                                 table_prefix = ''
    219                         order_by.append('%s%s %s' % (table_prefix, qn(orderfield2column(col_name, opts)), order))
    220                 if order_by:
    221                     sql.append("ORDER BY " + ", ".join(order_by))
    222 
    223                 # Look for column name collisions in the select elements
    224                 # and fix them with an AS alias.  This allows us to do a
    225                 # SELECT * later in the paging query.
    226                 cols = [clause.split('.')[-1] for clause in select]
    227                 for index, col in enumerate(cols):
    228                     if cols.count(col) > 1:
    229                         col = '%s%d' % (col.replace('"', ''), index)
    230                         cols[index] = col
    231                         select[index] = '%s AS %s' % (select[index], col)
    232 
    233                 # LIMIT and OFFSET clauses
    234                 # To support limits and offsets, Oracle requires some funky rewriting of an otherwise normal looking query.
    235                 select_clause = ",".join(select)
    236                 distinct = (self._distinct and "DISTINCT " or "")
    237 
    238                 if order_by:
    239                     order_by_clause = " OVER (ORDER BY %s )" % (", ".join(order_by))
    240                 else:
    241                     #Oracle's row_number() function always requires an order-by clause.
    242                     #So we need to define a default order-by, since none was provided.
    243                     order_by_clause = " OVER (ORDER BY %s.%s)" % \
    244                         (qn(opts.db_table), qn(opts.fields[0].db_column or opts.fields[0].column))
    245                 # limit_and_offset_clause
    246                 if self._limit is None:
    247                     assert self._offset is None, "'offset' is not allowed without 'limit'"
    248 
    249                 if self._offset is not None:
    250                     offset = int(self._offset)
    251                 else:
    252                     offset = 0
    253                 if self._limit is not None:
    254                     limit = int(self._limit)
    255                 else:
    256                     limit = None
    257 
    258                 limit_and_offset_clause = ''
    259                 if limit is not None:
    260                     limit_and_offset_clause = "WHERE rn > %s AND rn <= %s" % (offset, limit+offset)
    261                 elif offset:
    262                     limit_and_offset_clause = "WHERE rn > %s" % (offset)
    263 
    264                 if len(limit_and_offset_clause) > 0:
    265                     fmt = \
    266     """SELECT * FROM
    267       (SELECT %s%s,
    268               ROW_NUMBER()%s AS rn
    269        %s)
    270     %s"""
    271                     full_query = fmt % (distinct, select_clause,
    272                                         order_by_clause, ' '.join(sql).strip(),
    273                                         limit_and_offset_clause)
    274                 else:
    275                     full_query = None
    276 
    277                 if get_full_query:
    278                     return select, " ".join(sql), params, full_query
    279                 else:
    280                     return select, " ".join(sql), params
    281 
     99        class OracleQuery(DefaultQuery):
    282100            def resolve_columns(self, row, fields=()):
    283101                from django.db.models.fields import DateField, DateTimeField, \
    284                     TimeField, BooleanField, NullBooleanField, DecimalField, Field
     102                     TimeField, BooleanField, NullBooleanField, DecimalField, Field
    285103                values = []
    286104                for value, field in map(None, row, fields):
    287105                    if isinstance(value, Database.LOB):
     
    325143                    values.append(value)
    326144                return values
    327145
     146            def as_sql(self, with_limits=True):
     147                """
     148                Creates the SQL for this query. Returns the SQL string and list of
     149                parameters.  This is overriden from the original Query class to
     150                accommodate Oracle's limit/offset SQL.
     151               
     152                If 'with_limits' is False, any limit/offset information is not included
     153                in the query.
     154                """
     155                # The quotename function and the `do_offset` flag, required for the
     156                # quirky SQL needed to perform limits/offsets in Oracle.
     157                qn = self.quote_name_unless_alias
     158                do_offset = with_limits and (self.high_mark or self.low_mark)
     159
     160                self.pre_sql_setup()
     161                out_cols = self.get_columns()
     162                ordering = self.get_ordering()
     163                # This must come after 'select' and 'ordering' -- see docstring of
     164                # get_from_clause() for details.
     165                from_, f_params = self.get_from_clause()
     166                where, w_params = self.where.as_sql(qn=qn)
     167
     168                result = ['SELECT']
     169                if self.distinct:
     170                    result.append('DISTINCT')
     171                result.append(', '.join(out_cols))
     172
     173                # To do the equivalent of limit/offset queries in Oracle requires
     174                # selecting an additional row number with the full query, and then
     175                # limiting the selection based on row number.
     176                if do_offset:
     177                    # Inserting the extra SELECT.
     178                    result.insert(0, 'SELECT * FROM\n (')
     179                    if ordering:
     180                        rn_orderby = ', '.join(ordering)
     181                    else:
     182                        # Oracle's row_number() function always requires an order-by clause.
     183                        # So we need to define a default order-by, since none was provided.
     184                        opts = self.model._meta
     185                        rn_orderby = '%s.%s' % (qn(opts.db_table), qn(opts.fields[0].db_column or opts.fields[0].column))
     186                    # Appending the ROW_NUMBER() SQL into the extra select.
     187                    result.append(',\n ROW_NUMBER() OVER (ORDER BY %s ) AS rn\n' % rn_orderby)
     188
     189                # Now adding on the regular FROM statement.
     190                result.append('FROM')
     191                result.extend(from_)
     192                params = list(f_params)
     193
     194                if where:
     195                    result.append('WHERE %s' % where)
     196                if self.extra_where:
     197                    if not where:
     198                        result.append('WHERE')
     199                    else:
     200                        result.append('AND')
     201                    result.append(' AND'.join(self.extra_where))
     202                params.extend(w_params)
     203
     204                if self.group_by:
     205                    grouping = self.get_grouping()
     206                    result.append('GROUP BY %s' % ', '.join(grouping))
     207
     208                if ordering:
     209                    result.append('ORDER BY %s' % ', '.join(ordering))
     210
     211                if do_offset:
     212                    # Closing off the extra SELECT placed in above that gets the
     213                    # row number (`rn`) and place WHERE condition on `rn` for the
     214                    # desired range.
     215                    result.append(')\n WHERE rn > %d' % self.low_mark)
     216                    if self.high_mark:
     217                        result.append('AND rn <= %d' % self.high_mark)
     218
     219                params.extend(self.extra_params)
     220                return ' '.join(result), tuple(params)
     221             
     222        from django.db import connection
     223        class OracleQuerySet(DefaultQuerySet):
     224            "The OracleQuerySet is overriden to use OracleQuery."
     225            def __init__(self, model=None, query=None):
     226                super(OracleQuerySet, self).__init__(model=model, query=query)
     227                self.query = query or OracleQuery(self.model, connection)
     228
    328229        return OracleQuerySet
    329230
    330231    def quote_name(self, name):
     
    446347    charset = 'utf-8'
    447348
    448349    def _format_params(self, params):
     350        sz_kwargs, result = {}, {}
    449351        if isinstance(params, dict):
    450             result = {}
    451352            charset = self.charset
    452353            for key, value in params.items():
    453354                result[smart_str(key, charset)] = smart_str(value, charset)
    454             return result
     355                if hasattr(value, 'oracle_type'): sz_kwargs[key] = value.oracle_type()
    455356        else:
    456             return tuple([smart_str(p, self.charset, True) for p in params])
     357            for i, param in enumerate(params):
     358                key = 'arg%d' % i
     359                result[key] = smart_str(param, self.charset, True)
     360                if hasattr(param, 'oracle_type'): sz_kwargs[key] = param.oracle_type()
    457361
     362        # If any of the parameters had an `oracle_type` method, then we set
     363        # the inputsizes for those parameters using the returned type
     364        if sz_kwargs: self.setinputsizes(**sz_kwargs)
     365        return result
     366
    458367    def execute(self, query, params=None):
    459368        if params is None:
    460369            params = []
     
    472381
    473382    def executemany(self, query, params=None):
    474383        try:
    475           args = [(':arg%d' % i) for i in range(len(params[0]))]
     384            args = [(':arg%d' % i) for i in range(len(params[0]))]
    476385        except (IndexError, TypeError):
    477           # No params given, nothing to do
    478           return None
     386            # No params given, nothing to do
     387            return None
    479388        # cx_Oracle wants no trailing ';' for SQL statements.  For PL/SQL, it
    480389        # it does want a trailing ';' but not a trailing '/'.  However, these
    481390        # characters must be included in the original query in case the query
Back to Top