Django

Code

Ticket #6161: oracle_qsrf_patch_v2.diff

File oracle_qsrf_patch_v2.diff, 14.3 kB (added by jbronn, 1 year ago)

A bit of an improvement; uses qs-rf facilities to be more concise and reduce amount of code.

  • branches/queryset-refactor/django/db/backends/oracle/base.py

    old new  
    9393        return 30 
    9494 
    9595    def query_set_class(self, DefaultQuerySet): 
    96         from django.db import connection 
    97         from django.db.models.query import EmptyResultSet, GET_ITERATOR_CHUNK_SIZE, quote_only_if_word 
     96        # Getting the DefaultQuery class. 
     97        DefaultQuery = DefaultQuerySet().query.__class__ 
    9898 
    99         class OracleQuerySet(DefaultQuerySet): 
    100  
    101             def iterator(self): 
    102                 "Performs the SELECT database lookup of this QuerySet." 
    103  
    104                 from django.db.models.query import get_cached_row 
    105  
    106                 # self._select is a dictionary, and dictionaries' key order is 
    107                 # undefined, so we convert it to a list of tuples. 
    108                 extra_select = self._select.items() 
    109  
    110                 full_query = None 
    111  
    112                 try: 
    113                     try: 
    114                         select, sql, params, full_query = self._get_sql_clause(get_full_query=True) 
    115                     except TypeError: 
    116                         select, sql, params = self._get_sql_clause() 
    117                 except EmptyResultSet: 
    118                     raise StopIteration 
    119                 if not full_query: 
    120                     full_query = "SELECT %s%s\n%s" % ((self._distinct and "DISTINCT " or ""), ', '.join(select), sql) 
    121  
    122                 cursor = connection.cursor() 
    123                 cursor.execute(full_query, params) 
    124  
    125                 fill_cache = self._select_related 
    126                 fields = self.model._meta.fields 
    127                 index_end = len(fields) 
    128  
    129                 # so here's the logic; 
    130                 # 1. retrieve each row in turn 
    131                 # 2. convert NCLOBs 
    132  
    133                 while 1: 
    134                     rows = cursor.fetchmany(GET_ITERATOR_CHUNK_SIZE) 
    135                     if not rows: 
    136                         raise StopIteration 
    137                     for row in rows: 
    138                         row = self.resolve_columns(row, fields) 
    139                         if fill_cache: 
    140                             obj, index_end = get_cached_row(klass=self.model, row=row, 
    141                                                             index_start=0, max_depth=self._max_related_depth) 
    142                         else: 
    143                             obj = self.model(*row[:index_end]) 
    144                         for i, k in enumerate(extra_select): 
    145                             setattr(obj, k[0], row[index_end+i]) 
    146                         yield obj 
    147  
    148  
    149             def _get_sql_clause(self, get_full_query=False): 
    150                 from django.db.models.query import fill_table_cache, \ 
    151                     handle_legacy_orderlist, orderfield2column 
    152  
    153                 opts = self.model._meta 
    154                 qn = connection.ops.quote_name 
    155  
    156                 # Construct the fundamental parts of the query: SELECT X FROM Y WHERE Z. 
    157                 select = ["%s.%s" % (qn(opts.db_table), qn(f.column)) for f in opts.fields] 
    158                 tables = [quote_only_if_word(t) for t in self._tables] 
    159                 joins = SortedDict() 
    160                 where = self._where[:] 
    161                 params = self._params[:] 
    162  
    163                 # Convert self._filters into SQL. 
    164                 joins2, where2, params2 = self._filters.get_sql(opts) 
    165                 joins.update(joins2) 
    166                 where.extend(where2) 
    167                 params.extend(params2) 
    168  
    169                 # Add additional tables and WHERE clauses based on select_related. 
    170                 if self._select_related: 
    171                     fill_table_cache(opts, select, tables, where, opts.db_table, [opts.db_table]) 
    172  
    173                 # Add any additional SELECTs. 
    174                 if self._select: 
    175                     select.extend(['(%s) AS %s' % (quote_only_if_word(s[1]), qn(s[0])) for s in self._select.items()]) 
    176  
    177                 # Start composing the body of the SQL statement. 
    178                 sql = [" FROM", qn(opts.db_table)] 
    179  
    180                 # Compose the join dictionary into SQL describing the joins. 
    181                 if joins: 
    182                     sql.append(" ".join(["%s %s %s ON %s" % (join_type, table, alias, condition) 
    183                                     for (alias, (table, join_type, condition)) in joins.items()])) 
    184  
    185                 # Compose the tables clause into SQL. 
    186                 if tables: 
    187                     sql.append(", " + ", ".join(tables)) 
    188  
    189                 # Compose the where clause into SQL. 
    190                 if where: 
    191                     sql.append(where and "WHERE " + " AND ".join(where)) 
    192  
    193                 # ORDER BY clause 
    194                 order_by = [] 
    195                 if self._order_by is not None: 
    196                     ordering_to_use = self._order_by 
    197                 else: 
    198                     ordering_to_use = opts.ordering 
    199                 for f in handle_legacy_orderlist(ordering_to_use): 
    200                     if f == '?': # Special case. 
    201                         order_by.append(DatabaseOperations().random_function_sql()) 
    202                     else: 
    203                         if f.startswith('-'): 
    204                             col_name = f[1:] 
    205                             order = "DESC" 
    206                         else: 
    207                             col_name = f 
    208                             order = "ASC" 
    209                         if "." in col_name: 
    210                             table_prefix, col_name = col_name.split('.', 1) 
    211                             table_prefix = qn(table_prefix) + '.' 
    212                         else: 
    213                             # Use the database table as a column prefix if it wasn't given, 
    214                             # and if the requested column isn't a custom SELECT. 
    215                             if "." not in col_name and col_name not in (self._select or ()): 
    216                                 table_prefix = qn(opts.db_table) + '.' 
    217                             else: 
    218                                 table_prefix = '' 
    219                         order_by.append('%s%s %s' % (table_prefix, qn(orderfield2column(col_name, opts)), order)) 
    220                 if order_by: 
    221                     sql.append("ORDER BY " + ", ".join(order_by)) 
    222  
    223                 # Look for column name collisions in the select elements 
    224                 # and fix them with an AS alias.  This allows us to do a 
    225                 # SELECT * later in the paging query. 
    226                 cols = [clause.split('.')[-1] for clause in select] 
    227                 for index, col in enumerate(cols): 
    228                     if cols.count(col) > 1: 
    229                         col = '%s%d' % (col.replace('"', ''), index) 
    230                         cols[index] = col 
    231                         select[index] = '%s AS %s' % (select[index], col) 
    232  
    233                 # LIMIT and OFFSET clauses 
    234                 # To support limits and offsets, Oracle requires some funky rewriting of an otherwise normal looking query. 
    235                 select_clause = ",".join(select) 
    236                 distinct = (self._distinct and "DISTINCT " or "") 
    237  
    238                 if order_by: 
    239                     order_by_clause = " OVER (ORDER BY %s )" % (", ".join(order_by)) 
    240                 else: 
    241                     #Oracle's row_number() function always requires an order-by clause. 
    242                     #So we need to define a default order-by, since none was provided. 
    243                     order_by_clause = " OVER (ORDER BY %s.%s)" % \ 
    244                         (qn(opts.db_table), qn(opts.fields[0].db_column or opts.fields[0].column)) 
    245                 # limit_and_offset_clause 
    246                 if self._limit is None: 
    247                     assert self._offset is None, "'offset' is not allowed without 'limit'" 
    248  
    249                 if self._offset is not None: 
    250                     offset = int(self._offset) 
    251                 else: 
    252                     offset = 0 
    253                 if self._limit is not None: 
    254                     limit = int(self._limit) 
    255                 else: 
    256                     limit = None 
    257  
    258                 limit_and_offset_clause = '' 
    259                 if limit is not None: 
    260                     limit_and_offset_clause = "WHERE rn > %s AND rn <= %s" % (offset, limit+offset) 
    261                 elif offset: 
    262                     limit_and_offset_clause = "WHERE rn > %s" % (offset) 
    263  
    264                 if len(limit_and_offset_clause) > 0: 
    265                     fmt = \ 
    266     """SELECT * FROM 
    267       (SELECT %s%s, 
    268               ROW_NUMBER()%s AS rn 
    269        %s) 
    270     %s""" 
    271                     full_query = fmt % (distinct, select_clause, 
    272                                         order_by_clause, ' '.join(sql).strip(), 
    273                                         limit_and_offset_clause) 
    274                 else: 
    275                     full_query = None 
    276  
    277                 if get_full_query: 
    278                     return select, " ".join(sql), params, full_query 
    279                 else: 
    280                     return select, " ".join(sql), params 
    281  
     99        class OracleQuery(DefaultQuery): 
    282100            def resolve_columns(self, row, fields=()): 
    283101                from django.db.models.fields import DateField, DateTimeField, \ 
    284                     TimeField, BooleanField, NullBooleanField, DecimalField, Field 
     102                     TimeField, BooleanField, NullBooleanField, DecimalField, Field 
    285103                values = [] 
    286104                for value, field in map(None, row, fields): 
    287105                    if isinstance(value, Database.LOB): 
     
    325143                    values.append(value) 
    326144                return values 
    327145 
     146            def as_sql(self, with_limits=True): 
     147                """ 
     148                Creates the SQL for this query. Returns the SQL string and list of 
     149                parameters.  This is overriden from the original Query class to 
     150                accommodate Oracle's limit/offset SQL. 
     151                 
     152                If 'with_limits' is False, any limit/offset information is not included 
     153                in the query. 
     154                """ 
     155                # The `do_offset` flag indicates whether we need to construct the 
     156                # SQL needed to use limit/offset w/Oracle. 
     157                do_offset = with_limits and (self.high_mark or self.low_mark) 
     158 
     159                # If no offsets, just return the result of the base class `as_sql`. 
     160                if not do_offset: 
     161                    return super(OracleQuery, self).as_sql(with_limits=False) 
     162 
     163                # `get_columns` needs to be called before `get_ordering` to populate 
     164                # `_select_alias`. 
     165                self.pre_sql_setup() 
     166                out_cols = self.get_columns() 
     167                ordering = self.get_ordering() 
     168 
     169                # Getting the "ORDER BY" SQL for the ROW_NUMBER() result. 
     170                if ordering: 
     171                    rn_orderby = ', '.join(ordering) 
     172                else: 
     173                    # Oracle's ROW_NUMBER() function always requires an order-by clause. 
     174                    # So we need to define a default order-by, since none was provided. 
     175                    qn = self.quote_name_unless_alias 
     176                    opts = self.model._meta 
     177                    rn_orderby = '%s.%s' % (qn(opts.db_table), qn(opts.fields[0].db_column or opts.fields[0].column)) 
     178 
     179                # Getting the selection SQL and the params, which has the `rn` 
     180                # extra selection SQL; we pop `rn` after this completes so we do 
     181                # not get the attribute on the returned models. 
     182                self.extra_select['rn'] = 'ROW_NUMBER() OVER (ORDER BY %s )' % rn_orderby 
     183                sql, params= super(OracleQuery, self).as_sql(with_limits=False) 
     184                self.extra_select.pop('rn') 
     185 
     186                # Constructing the result SQL, using the initial select SQL 
     187                # obtained above. 
     188                result = ['SELECT * FROM (%s)' % sql] 
     189                 
     190                # Place WHERE condition on `rn` for the desired range. 
     191                result.append('WHERE rn > %d' % self.low_mark) 
     192                if self.high_mark: 
     193                    result.append('AND rn <= %d' % self.high_mark) 
     194 
     195                # Returning the SQL w/params. 
     196                return ' '.join(result), params 
     197               
     198        from django.db import connection 
     199        class OracleQuerySet(DefaultQuerySet): 
     200            "The OracleQuerySet is overriden to use OracleQuery." 
     201            def __init__(self, model=None, query=None): 
     202                super(OracleQuerySet, self).__init__(model=model, query=query) 
     203                self.query = query or OracleQuery(self.model, connection) 
     204 
    328205        return OracleQuerySet 
    329206 
    330207    def quote_name(self, name): 
     
    446323    charset = 'utf-8' 
    447324 
    448325    def _format_params(self, params): 
     326        sz_kwargs, result = {}, {} 
    449327        if isinstance(params, dict): 
    450             result = {} 
    451328            charset = self.charset 
    452329            for key, value in params.items(): 
    453330                result[smart_str(key, charset)] = smart_str(value, charset) 
    454             return result 
     331                if hasattr(value, 'oracle_type'): sz_kwargs[key] = value.oracle_type() 
    455332        else: 
    456             return tuple([smart_str(p, self.charset, True) for p in params]) 
     333            for i, param in enumerate(params): 
     334                key = 'arg%d' % i 
     335                result[key] = smart_str(param, self.charset, True) 
     336                if hasattr(param, 'oracle_type'): sz_kwargs[key] = param.oracle_type() 
    457337 
     338        # If any of the parameters had an `oracle_type` method, then we set 
     339        # the inputsizes for those parameters using the returned type 
     340        if sz_kwargs: self.setinputsizes(**sz_kwargs) 
     341        return result 
     342 
    458343    def execute(self, query, params=None): 
    459344        if params is None: 
    460345            params = [] 
     
    472357 
    473358    def executemany(self, query, params=None): 
    474359        try: 
    475           args = [(':arg%d' % i) for i in range(len(params[0]))] 
     360            args = [(':arg%d' % i) for i in range(len(params[0]))] 
    476361        except (IndexError, TypeError): 
    477           # No params given, nothing to do 
    478           return None 
     362            # No params given, nothing to do 
     363            return None 
    479364        # cx_Oracle wants no trailing ';' for SQL statements.  For PL/SQL, it 
    480365        # it does want a trailing ';' but not a trailing '/'.  However, these 
    481366        # characters must be included in the original query in case the query