Ticket #6161: oracle_qsrf_patch_v2.diff

File oracle_qsrf_patch_v2.diff, 14.3 KB (added by jbronn, 16 years ago)

A bit of an improvement; uses qs-rf facilities to be more concise and reduce amount of code.

  • branches/queryset-refactor/django/db/backends/oracle/base.py

     
    9393        return 30
    9494
    9595    def query_set_class(self, DefaultQuerySet):
    96         from django.db import connection
    97         from django.db.models.query import EmptyResultSet, GET_ITERATOR_CHUNK_SIZE, quote_only_if_word
     96        # Getting the DefaultQuery class.
     97        DefaultQuery = DefaultQuerySet().query.__class__
    9898
    99         class OracleQuerySet(DefaultQuerySet):
    100 
    101             def iterator(self):
    102                 "Performs the SELECT database lookup of this QuerySet."
    103 
    104                 from django.db.models.query import get_cached_row
    105 
    106                 # self._select is a dictionary, and dictionaries' key order is
    107                 # undefined, so we convert it to a list of tuples.
    108                 extra_select = self._select.items()
    109 
    110                 full_query = None
    111 
    112                 try:
    113                     try:
    114                         select, sql, params, full_query = self._get_sql_clause(get_full_query=True)
    115                     except TypeError:
    116                         select, sql, params = self._get_sql_clause()
    117                 except EmptyResultSet:
    118                     raise StopIteration
    119                 if not full_query:
    120                     full_query = "SELECT %s%s\n%s" % ((self._distinct and "DISTINCT " or ""), ', '.join(select), sql)
    121 
    122                 cursor = connection.cursor()
    123                 cursor.execute(full_query, params)
    124 
    125                 fill_cache = self._select_related
    126                 fields = self.model._meta.fields
    127                 index_end = len(fields)
    128 
    129                 # so here's the logic;
    130                 # 1. retrieve each row in turn
    131                 # 2. convert NCLOBs
    132 
    133                 while 1:
    134                     rows = cursor.fetchmany(GET_ITERATOR_CHUNK_SIZE)
    135                     if not rows:
    136                         raise StopIteration
    137                     for row in rows:
    138                         row = self.resolve_columns(row, fields)
    139                         if fill_cache:
    140                             obj, index_end = get_cached_row(klass=self.model, row=row,
    141                                                             index_start=0, max_depth=self._max_related_depth)
    142                         else:
    143                             obj = self.model(*row[:index_end])
    144                         for i, k in enumerate(extra_select):
    145                             setattr(obj, k[0], row[index_end+i])
    146                         yield obj
    147 
    148 
    149             def _get_sql_clause(self, get_full_query=False):
    150                 from django.db.models.query import fill_table_cache, \
    151                     handle_legacy_orderlist, orderfield2column
    152 
    153                 opts = self.model._meta
    154                 qn = connection.ops.quote_name
    155 
    156                 # Construct the fundamental parts of the query: SELECT X FROM Y WHERE Z.
    157                 select = ["%s.%s" % (qn(opts.db_table), qn(f.column)) for f in opts.fields]
    158                 tables = [quote_only_if_word(t) for t in self._tables]
    159                 joins = SortedDict()
    160                 where = self._where[:]
    161                 params = self._params[:]
    162 
    163                 # Convert self._filters into SQL.
    164                 joins2, where2, params2 = self._filters.get_sql(opts)
    165                 joins.update(joins2)
    166                 where.extend(where2)
    167                 params.extend(params2)
    168 
    169                 # Add additional tables and WHERE clauses based on select_related.
    170                 if self._select_related:
    171                     fill_table_cache(opts, select, tables, where, opts.db_table, [opts.db_table])
    172 
    173                 # Add any additional SELECTs.
    174                 if self._select:
    175                     select.extend(['(%s) AS %s' % (quote_only_if_word(s[1]), qn(s[0])) for s in self._select.items()])
    176 
    177                 # Start composing the body of the SQL statement.
    178                 sql = [" FROM", qn(opts.db_table)]
    179 
    180                 # Compose the join dictionary into SQL describing the joins.
    181                 if joins:
    182                     sql.append(" ".join(["%s %s %s ON %s" % (join_type, table, alias, condition)
    183                                     for (alias, (table, join_type, condition)) in joins.items()]))
    184 
    185                 # Compose the tables clause into SQL.
    186                 if tables:
    187                     sql.append(", " + ", ".join(tables))
    188 
    189                 # Compose the where clause into SQL.
    190                 if where:
    191                     sql.append(where and "WHERE " + " AND ".join(where))
    192 
    193                 # ORDER BY clause
    194                 order_by = []
    195                 if self._order_by is not None:
    196                     ordering_to_use = self._order_by
    197                 else:
    198                     ordering_to_use = opts.ordering
    199                 for f in handle_legacy_orderlist(ordering_to_use):
    200                     if f == '?': # Special case.
    201                         order_by.append(DatabaseOperations().random_function_sql())
    202                     else:
    203                         if f.startswith('-'):
    204                             col_name = f[1:]
    205                             order = "DESC"
    206                         else:
    207                             col_name = f
    208                             order = "ASC"
    209                         if "." in col_name:
    210                             table_prefix, col_name = col_name.split('.', 1)
    211                             table_prefix = qn(table_prefix) + '.'
    212                         else:
    213                             # Use the database table as a column prefix if it wasn't given,
    214                             # and if the requested column isn't a custom SELECT.
    215                             if "." not in col_name and col_name not in (self._select or ()):
    216                                 table_prefix = qn(opts.db_table) + '.'
    217                             else:
    218                                 table_prefix = ''
    219                         order_by.append('%s%s %s' % (table_prefix, qn(orderfield2column(col_name, opts)), order))
    220                 if order_by:
    221                     sql.append("ORDER BY " + ", ".join(order_by))
    222 
    223                 # Look for column name collisions in the select elements
    224                 # and fix them with an AS alias.  This allows us to do a
    225                 # SELECT * later in the paging query.
    226                 cols = [clause.split('.')[-1] for clause in select]
    227                 for index, col in enumerate(cols):
    228                     if cols.count(col) > 1:
    229                         col = '%s%d' % (col.replace('"', ''), index)
    230                         cols[index] = col
    231                         select[index] = '%s AS %s' % (select[index], col)
    232 
    233                 # LIMIT and OFFSET clauses
    234                 # To support limits and offsets, Oracle requires some funky rewriting of an otherwise normal looking query.
    235                 select_clause = ",".join(select)
    236                 distinct = (self._distinct and "DISTINCT " or "")
    237 
    238                 if order_by:
    239                     order_by_clause = " OVER (ORDER BY %s )" % (", ".join(order_by))
    240                 else:
    241                     #Oracle's row_number() function always requires an order-by clause.
    242                     #So we need to define a default order-by, since none was provided.
    243                     order_by_clause = " OVER (ORDER BY %s.%s)" % \
    244                         (qn(opts.db_table), qn(opts.fields[0].db_column or opts.fields[0].column))
    245                 # limit_and_offset_clause
    246                 if self._limit is None:
    247                     assert self._offset is None, "'offset' is not allowed without 'limit'"
    248 
    249                 if self._offset is not None:
    250                     offset = int(self._offset)
    251                 else:
    252                     offset = 0
    253                 if self._limit is not None:
    254                     limit = int(self._limit)
    255                 else:
    256                     limit = None
    257 
    258                 limit_and_offset_clause = ''
    259                 if limit is not None:
    260                     limit_and_offset_clause = "WHERE rn > %s AND rn <= %s" % (offset, limit+offset)
    261                 elif offset:
    262                     limit_and_offset_clause = "WHERE rn > %s" % (offset)
    263 
    264                 if len(limit_and_offset_clause) > 0:
    265                     fmt = \
    266     """SELECT * FROM
    267       (SELECT %s%s,
    268               ROW_NUMBER()%s AS rn
    269        %s)
    270     %s"""
    271                     full_query = fmt % (distinct, select_clause,
    272                                         order_by_clause, ' '.join(sql).strip(),
    273                                         limit_and_offset_clause)
    274                 else:
    275                     full_query = None
    276 
    277                 if get_full_query:
    278                     return select, " ".join(sql), params, full_query
    279                 else:
    280                     return select, " ".join(sql), params
    281 
     99        class OracleQuery(DefaultQuery):
    282100            def resolve_columns(self, row, fields=()):
    283101                from django.db.models.fields import DateField, DateTimeField, \
    284                     TimeField, BooleanField, NullBooleanField, DecimalField, Field
     102                     TimeField, BooleanField, NullBooleanField, DecimalField, Field
    285103                values = []
    286104                for value, field in map(None, row, fields):
    287105                    if isinstance(value, Database.LOB):
     
    325143                    values.append(value)
    326144                return values
    327145
     146            def as_sql(self, with_limits=True):
     147                """
     148                Creates the SQL for this query. Returns the SQL string and list of
     149                parameters.  This is overriden from the original Query class to
     150                accommodate Oracle's limit/offset SQL.
     151               
     152                If 'with_limits' is False, any limit/offset information is not included
     153                in the query.
     154                """
     155                # The `do_offset` flag indicates whether we need to construct the
     156                # SQL needed to use limit/offset w/Oracle.
     157                do_offset = with_limits and (self.high_mark or self.low_mark)
     158
     159                # If no offsets, just return the result of the base class `as_sql`.
     160                if not do_offset:
     161                    return super(OracleQuery, self).as_sql(with_limits=False)
     162
     163                # `get_columns` needs to be called before `get_ordering` to populate
     164                # `_select_alias`.
     165                self.pre_sql_setup()
     166                out_cols = self.get_columns()
     167                ordering = self.get_ordering()
     168
     169                # Getting the "ORDER BY" SQL for the ROW_NUMBER() result.
     170                if ordering:
     171                    rn_orderby = ', '.join(ordering)
     172                else:
     173                    # Oracle's ROW_NUMBER() function always requires an order-by clause.
     174                    # So we need to define a default order-by, since none was provided.
     175                    qn = self.quote_name_unless_alias
     176                    opts = self.model._meta
     177                    rn_orderby = '%s.%s' % (qn(opts.db_table), qn(opts.fields[0].db_column or opts.fields[0].column))
     178
     179                # Getting the selection SQL and the params, which has the `rn`
     180                # extra selection SQL; we pop `rn` after this completes so we do
     181                # not get the attribute on the returned models.
     182                self.extra_select['rn'] = 'ROW_NUMBER() OVER (ORDER BY %s )' % rn_orderby
     183                sql, params= super(OracleQuery, self).as_sql(with_limits=False)
     184                self.extra_select.pop('rn')
     185
     186                # Constructing the result SQL, using the initial select SQL
     187                # obtained above.
     188                result = ['SELECT * FROM (%s)' % sql]
     189               
     190                # Place WHERE condition on `rn` for the desired range.
     191                result.append('WHERE rn > %d' % self.low_mark)
     192                if self.high_mark:
     193                    result.append('AND rn <= %d' % self.high_mark)
     194
     195                # Returning the SQL w/params.
     196                return ' '.join(result), params
     197             
     198        from django.db import connection
     199        class OracleQuerySet(DefaultQuerySet):
     200            "The OracleQuerySet is overriden to use OracleQuery."
     201            def __init__(self, model=None, query=None):
     202                super(OracleQuerySet, self).__init__(model=model, query=query)
     203                self.query = query or OracleQuery(self.model, connection)
     204
    328205        return OracleQuerySet
    329206
    330207    def quote_name(self, name):
     
    446323    charset = 'utf-8'
    447324
    448325    def _format_params(self, params):
     326        sz_kwargs, result = {}, {}
    449327        if isinstance(params, dict):
    450             result = {}
    451328            charset = self.charset
    452329            for key, value in params.items():
    453330                result[smart_str(key, charset)] = smart_str(value, charset)
    454             return result
     331                if hasattr(value, 'oracle_type'): sz_kwargs[key] = value.oracle_type()
    455332        else:
    456             return tuple([smart_str(p, self.charset, True) for p in params])
     333            for i, param in enumerate(params):
     334                key = 'arg%d' % i
     335                result[key] = smart_str(param, self.charset, True)
     336                if hasattr(param, 'oracle_type'): sz_kwargs[key] = param.oracle_type()
    457337
     338        # If any of the parameters had an `oracle_type` method, then we set
     339        # the inputsizes for those parameters using the returned type
     340        if sz_kwargs: self.setinputsizes(**sz_kwargs)
     341        return result
     342
    458343    def execute(self, query, params=None):
    459344        if params is None:
    460345            params = []
     
    472357
    473358    def executemany(self, query, params=None):
    474359        try:
    475           args = [(':arg%d' % i) for i in range(len(params[0]))]
     360            args = [(':arg%d' % i) for i in range(len(params[0]))]
    476361        except (IndexError, TypeError):
    477           # No params given, nothing to do
    478           return None
     362            # No params given, nothing to do
     363            return None
    479364        # cx_Oracle wants no trailing ';' for SQL statements.  For PL/SQL, it
    480365        # it does want a trailing ';' but not a trailing '/'.  However, these
    481366        # characters must be included in the original query in case the query
Back to Top