Ticket #6161: oracle_qsrf_patch_v3.diff

File oracle_qsrf_patch_v3.diff, 15.5 KB (added by Erin Kelly, 16 years ago)

New version of patch that reverts the changes added in [6905]

  • django/db/backends/oracle/base.py

     
    9999        return 30
    100100
    101101    def query_set_class(self, DefaultQuerySet):
    102         from django.db import connection
    103         from django.db.models.query import EmptyResultSet, GET_ITERATOR_CHUNK_SIZE, quote_only_if_word
     102        # Getting the DefaultQuery class.
     103        DefaultQuery = DefaultQuerySet().query.__class__
    104104
    105         class OracleQuerySet(DefaultQuerySet):
    106 
    107             def iterator(self):
    108                 "Performs the SELECT database lookup of this QuerySet."
    109 
    110                 from django.db.models.query import get_cached_row
    111 
    112                 # self._select is a dictionary, and dictionaries' key order is
    113                 # undefined, so we convert it to a list of tuples.
    114                 extra_select = self._select.items()
    115 
    116                 full_query = None
    117 
    118                 try:
    119                     try:
    120                         select, sql, params, full_query = self._get_sql_clause(get_full_query=True)
    121                     except TypeError:
    122                         select, sql, params = self._get_sql_clause()
    123                 except EmptyResultSet:
    124                     raise StopIteration
    125                 if not full_query:
    126                     full_query = "SELECT %s%s\n%s" % ((self._distinct and "DISTINCT " or ""), ', '.join(select), sql)
    127 
    128                 cursor = connection.cursor()
    129                 cursor.execute(full_query, params)
    130 
    131                 fill_cache = self._select_related
    132                 fields = self.model._meta.fields
    133                 index_end = len(fields)
    134 
    135                 # so here's the logic;
    136                 # 1. retrieve each row in turn
    137                 # 2. convert NCLOBs
    138 
    139                 while 1:
    140                     rows = cursor.fetchmany(GET_ITERATOR_CHUNK_SIZE)
    141                     if not rows:
    142                         raise StopIteration
    143                     for row in rows:
    144                         row = self.resolve_columns(row, fields)
    145                         if fill_cache:
    146                             obj, index_end = get_cached_row(klass=self.model, row=row,
    147                                                             index_start=0, max_depth=self._max_related_depth)
    148                         else:
    149                             obj = self.model(*row[:index_end])
    150                         for i, k in enumerate(extra_select):
    151                             setattr(obj, k[0], row[index_end+i])
    152                         yield obj
    153 
    154 
    155             def _get_sql_clause(self, get_full_query=False):
    156                 from django.db.models.query import fill_table_cache, \
    157                     handle_legacy_orderlist, orderfield2column
    158 
    159                 opts = self.model._meta
    160                 qn = connection.ops.quote_name
    161 
    162                 # Construct the fundamental parts of the query: SELECT X FROM Y WHERE Z.
    163                 select = ["%s.%s" % (qn(opts.db_table), qn(f.column)) for f in opts.fields]
    164                 tables = [quote_only_if_word(t) for t in self._tables]
    165                 joins = SortedDict()
    166                 where = self._where[:]
    167                 params = self._params[:]
    168 
    169                 # Convert self._filters into SQL.
    170                 joins2, where2, params2 = self._filters.get_sql(opts)
    171                 joins.update(joins2)
    172                 where.extend(where2)
    173                 params.extend(params2)
    174 
    175                 # Add additional tables and WHERE clauses based on select_related.
    176                 if self._select_related:
    177                     fill_table_cache(opts, select, tables, where, opts.db_table, [opts.db_table])
    178 
    179                 # Add any additional SELECTs.
    180                 if self._select:
    181                     select.extend(['(%s) AS %s' % (quote_only_if_word(s[1]), qn(s[0])) for s in self._select.items()])
    182 
    183                 # Start composing the body of the SQL statement.
    184                 sql = [" FROM", qn(opts.db_table)]
    185 
    186                 # Compose the join dictionary into SQL describing the joins.
    187                 if joins:
    188                     sql.append(" ".join(["%s %s %s ON %s" % (join_type, table, alias, condition)
    189                                     for (alias, (table, join_type, condition)) in joins.items()]))
    190 
    191                 # Compose the tables clause into SQL.
    192                 if tables:
    193                     sql.append(", " + ", ".join(tables))
    194 
    195                 # Compose the where clause into SQL.
    196                 if where:
    197                     sql.append(where and "WHERE " + " AND ".join(where))
    198 
    199                 # ORDER BY clause
    200                 order_by = []
    201                 if self._order_by is not None:
    202                     ordering_to_use = self._order_by
    203                 else:
    204                     ordering_to_use = opts.ordering
    205                 for f in handle_legacy_orderlist(ordering_to_use):
    206                     if f == '?': # Special case.
    207                         order_by.append(DatabaseOperations().random_function_sql())
    208                     else:
    209                         if f.startswith('-'):
    210                             col_name = f[1:]
    211                             order = "DESC"
    212                         else:
    213                             col_name = f
    214                             order = "ASC"
    215                         if "." in col_name:
    216                             table_prefix, col_name = col_name.split('.', 1)
    217                             table_prefix = qn(table_prefix) + '.'
    218                         else:
    219                             # Use the database table as a column prefix if it wasn't given,
    220                             # and if the requested column isn't a custom SELECT.
    221                             if "." not in col_name and col_name not in (self._select or ()):
    222                                 table_prefix = qn(opts.db_table) + '.'
    223                             else:
    224                                 table_prefix = ''
    225                         order_by.append('%s%s %s' % (table_prefix, qn(orderfield2column(col_name, opts)), order))
    226                 if order_by:
    227                     sql.append("ORDER BY " + ", ".join(order_by))
    228 
    229                 # Look for column name collisions in the select elements
    230                 # and fix them with an AS alias.  This allows us to do a
    231                 # SELECT * later in the paging query.
    232                 cols = [clause.split('.')[-1] for clause in select]
    233                 for index, col in enumerate(cols):
    234                     if cols.count(col) > 1:
    235                         col = '%s%d' % (col.replace('"', ''), index)
    236                         cols[index] = col
    237                         select[index] = '%s AS %s' % (select[index], col)
    238 
    239                 # LIMIT and OFFSET clauses
    240                 # To support limits and offsets, Oracle requires some funky rewriting of an otherwise normal looking query.
    241                 select_clause = ",".join(select)
    242                 distinct = (self._distinct and "DISTINCT " or "")
    243 
    244                 if order_by:
    245                     order_by_clause = " OVER (ORDER BY %s )" % (", ".join(order_by))
    246                 else:
    247                     #Oracle's row_number() function always requires an order-by clause.
    248                     #So we need to define a default order-by, since none was provided.
    249                     order_by_clause = " OVER (ORDER BY %s.%s)" % \
    250                         (qn(opts.db_table), qn(opts.fields[0].db_column or opts.fields[0].column))
    251                 # limit_and_offset_clause
    252                 if self._limit is None:
    253                     assert self._offset is None, "'offset' is not allowed without 'limit'"
    254 
    255                 if self._offset is not None:
    256                     offset = int(self._offset)
    257                 else:
    258                     offset = 0
    259                 if self._limit is not None:
    260                     limit = int(self._limit)
    261                 else:
    262                     limit = None
    263 
    264                 limit_and_offset_clause = ''
    265                 if limit is not None:
    266                     limit_and_offset_clause = "WHERE rn > %s AND rn <= %s" % (offset, limit+offset)
    267                 elif offset:
    268                     limit_and_offset_clause = "WHERE rn > %s" % (offset)
    269 
    270                 if len(limit_and_offset_clause) > 0:
    271                     fmt = \
    272     """SELECT * FROM
    273       (SELECT %s%s,
    274               ROW_NUMBER()%s AS rn
    275        %s)
    276     %s"""
    277                     full_query = fmt % (distinct, select_clause,
    278                                         order_by_clause, ' '.join(sql).strip(),
    279                                         limit_and_offset_clause)
    280                 else:
    281                     full_query = None
    282 
    283                 if get_full_query:
    284                     return select, " ".join(sql), params, full_query
    285                 else:
    286                     return select, " ".join(sql), params
    287 
     105        class OracleQuery(DefaultQuery):
    288106            def resolve_columns(self, row, fields=()):
    289107                from django.db.models.fields import DateField, DateTimeField, \
    290                     TimeField, BooleanField, NullBooleanField, DecimalField, Field
     108                     TimeField, BooleanField, NullBooleanField, DecimalField, Field
    291109                values = []
    292110                for value, field in map(None, row, fields):
    293111                    if isinstance(value, Database.LOB):
     
    331149                    values.append(value)
    332150                return values
    333151
     152            def as_sql(self, with_limits=True):
     153                """
     154                Creates the SQL for this query. Returns the SQL string and list of
     155                parameters.  This is overriden from the original Query class to
     156                accommodate Oracle's limit/offset SQL.
     157               
     158                If 'with_limits' is False, any limit/offset information is not included
     159                in the query.
     160                """
     161                # The `do_offset` flag indicates whether we need to construct the
     162                # SQL needed to use limit/offset w/Oracle.
     163                do_offset = with_limits and (self.high_mark or self.low_mark)
     164
     165                # If no offsets, just return the result of the base class `as_sql`.
     166                if not do_offset:
     167                    return super(OracleQuery, self).as_sql(with_limits=False)
     168
     169                # `get_columns` needs to be called before `get_ordering` to populate
     170                # `_select_alias`.
     171                self.pre_sql_setup()
     172                out_cols = self.get_columns()
     173                ordering = self.get_ordering()
     174
     175                # Getting the "ORDER BY" SQL for the ROW_NUMBER() result.
     176                if ordering:
     177                    rn_orderby = ', '.join(ordering)
     178                else:
     179                    # Oracle's ROW_NUMBER() function always requires an order-by clause.
     180                    # So we need to define a default order-by, since none was provided.
     181                    qn = self.quote_name_unless_alias
     182                    opts = self.model._meta
     183                    rn_orderby = '%s.%s' % (qn(opts.db_table), qn(opts.fields[0].db_column or opts.fields[0].column))
     184
     185                # Getting the selection SQL and the params, which has the `rn`
     186                # extra selection SQL; we pop `rn` after this completes so we do
     187                # not get the attribute on the returned models.
     188                self.extra_select['rn'] = 'ROW_NUMBER() OVER (ORDER BY %s )' % rn_orderby
     189                sql, params= super(OracleQuery, self).as_sql(with_limits=False)
     190                self.extra_select.pop('rn')
     191
     192                # Constructing the result SQL, using the initial select SQL
     193                # obtained above.
     194                result = ['SELECT * FROM (%s)' % sql]
     195               
     196                # Place WHERE condition on `rn` for the desired range.
     197                result.append('WHERE rn > %d' % self.low_mark)
     198                if self.high_mark:
     199                    result.append('AND rn <= %d' % self.high_mark)
     200
     201                # Returning the SQL w/params.
     202                return ' '.join(result), params
     203             
     204        from django.db import connection
     205        class OracleQuerySet(DefaultQuerySet):
     206            "The OracleQuerySet is overriden to use OracleQuery."
     207            def __init__(self, model=None, query=None):
     208                super(OracleQuerySet, self).__init__(model=model, query=query)
     209                self.query = query or OracleQuery(self.model, connection)
     210
    334211        return OracleQuerySet
    335212
    336213    def quote_name(self, name):
     
    480357    charset = 'utf-8'
    481358
    482359    def _format_params(self, params):
     360        sz_kwargs, result = {}, {}
    483361        if isinstance(params, dict):
    484             result = {}
    485362            charset = self.charset
    486363            for key, value in params.items():
    487364                result[smart_str(key, charset)] = smart_str(value, charset)
    488             return result
     365                if hasattr(value, 'oracle_type'): sz_kwargs[key] = value.oracle_type()
    489366        else:
    490             return tuple([smart_str(p, self.charset, True) for p in params])
     367            for i, param in enumerate(params):
     368                key = 'arg%d' % i
     369                result[key] = smart_str(param, self.charset, True)
     370                if hasattr(param, 'oracle_type'): sz_kwargs[key] = param.oracle_type()
     371   
     372        # If any of the parameters had an `oracle_type` method, then we set
     373        # the inputsizes for those parameters using the returned type
     374        if sz_kwargs: self.setinputsizes(**sz_kwargs)
     375        return result
    491376
    492     def _guess_input_sizes(self, params_list):
    493         # Mark any string parameter greater than 4000 characters as an NCLOB.
    494         if isinstance(params_list[0], dict):
    495             sizes = {}
    496             iterators = [params.iteritems() for params in params_list]
    497         else:
    498             sizes = [None] * len(params_list[0])
    499             iterators = [enumerate(params) for params in params_list]
    500         for iterator in iterators:
    501             for key, value in iterator:
    502                 if isinstance(value, basestring) and len(value) > 4000:
    503                     sizes[key] = Database.NCLOB
    504         if isinstance(sizes, dict):
    505             self.setinputsizes(**sizes)
    506         else:
    507             self.setinputsizes(*sizes)
    508 
    509377    def execute(self, query, params=None):
    510378        if params is None:
    511379            params = []
     
    519387        if query.endswith(';') or query.endswith('/'):
    520388            query = query[:-1]
    521389        query = smart_str(query, self.charset) % tuple(args)
    522         self._guess_input_sizes([params])
    523390        return Database.Cursor.execute(self, query, params)
    524391
    525392    def executemany(self, query, params=None):
    526393        try:
    527           args = [(':arg%d' % i) for i in range(len(params[0]))]
     394            args = [(':arg%d' % i) for i in range(len(params[0]))]
    528395        except (IndexError, TypeError):
    529           # No params given, nothing to do
    530           return None
     396            # No params given, nothing to do
     397            return None
    531398        # cx_Oracle wants no trailing ';' for SQL statements.  For PL/SQL, it
    532399        # it does want a trailing ';' but not a trailing '/'.  However, these
    533400        # characters must be included in the original query in case the query
     
    536403            query = query[:-1]
    537404        query = smart_str(query, self.charset) % tuple(args)
    538405        new_param_list = [self._format_params(i) for i in params]
    539         self._guess_input_sizes(new_param_list)
    540406        return Database.Cursor.executemany(self, query, new_param_list)
    541407
    542408    def fetchone(self):
Back to Top