Django

Code

Ticket #6161: oracle_qsrf_patch.diff

File oracle_qsrf_patch.diff, 15.3 kB (added by jbronn, 1 year ago)

Patch enabling Oracle functionality on the queryset-refactor branch.

  • branches/queryset-refactor/django/db/backends/oracle/base.py

    old new  
    9393        return 30 
    9494 
    9595    def query_set_class(self, DefaultQuerySet): 
    96         from django.db import connection 
    97         from django.db.models.query import EmptyResultSet, GET_ITERATOR_CHUNK_SIZE, quote_only_if_word 
     96        # Getting the DefaultQuery class. 
     97        DefaultQuery = DefaultQuerySet().query.__class__ 
    9898 
    99         class OracleQuerySet(DefaultQuerySet): 
    100  
    101             def iterator(self): 
    102                 "Performs the SELECT database lookup of this QuerySet." 
    103  
    104                 from django.db.models.query import get_cached_row 
    105  
    106                 # self._select is a dictionary, and dictionaries' key order is 
    107                 # undefined, so we convert it to a list of tuples. 
    108                 extra_select = self._select.items() 
    109  
    110                 full_query = None 
    111  
    112                 try: 
    113                     try: 
    114                         select, sql, params, full_query = self._get_sql_clause(get_full_query=True) 
    115                     except TypeError: 
    116                         select, sql, params = self._get_sql_clause() 
    117                 except EmptyResultSet: 
    118                     raise StopIteration 
    119                 if not full_query: 
    120                     full_query = "SELECT %s%s\n%s" % ((self._distinct and "DISTINCT " or ""), ', '.join(select), sql) 
    121  
    122                 cursor = connection.cursor() 
    123                 cursor.execute(full_query, params) 
    124  
    125                 fill_cache = self._select_related 
    126                 fields = self.model._meta.fields 
    127                 index_end = len(fields) 
    128  
    129                 # so here's the logic; 
    130                 # 1. retrieve each row in turn 
    131                 # 2. convert NCLOBs 
    132  
    133                 while 1: 
    134                     rows = cursor.fetchmany(GET_ITERATOR_CHUNK_SIZE) 
    135                     if not rows: 
    136                         raise StopIteration 
    137                     for row in rows: 
    138                         row = self.resolve_columns(row, fields) 
    139                         if fill_cache: 
    140                             obj, index_end = get_cached_row(klass=self.model, row=row, 
    141                                                             index_start=0, max_depth=self._max_related_depth) 
    142                         else: 
    143                             obj = self.model(*row[:index_end]) 
    144                         for i, k in enumerate(extra_select): 
    145                             setattr(obj, k[0], row[index_end+i]) 
    146                         yield obj 
    147  
    148  
    149             def _get_sql_clause(self, get_full_query=False): 
    150                 from django.db.models.query import fill_table_cache, \ 
    151                     handle_legacy_orderlist, orderfield2column 
    152  
    153                 opts = self.model._meta 
    154                 qn = connection.ops.quote_name 
    155  
    156                 # Construct the fundamental parts of the query: SELECT X FROM Y WHERE Z. 
    157                 select = ["%s.%s" % (qn(opts.db_table), qn(f.column)) for f in opts.fields] 
    158                 tables = [quote_only_if_word(t) for t in self._tables] 
    159                 joins = SortedDict() 
    160                 where = self._where[:] 
    161                 params = self._params[:] 
    162  
    163                 # Convert self._filters into SQL. 
    164                 joins2, where2, params2 = self._filters.get_sql(opts) 
    165                 joins.update(joins2) 
    166                 where.extend(where2) 
    167                 params.extend(params2) 
    168  
    169                 # Add additional tables and WHERE clauses based on select_related. 
    170                 if self._select_related: 
    171                     fill_table_cache(opts, select, tables, where, opts.db_table, [opts.db_table]) 
    172  
    173                 # Add any additional SELECTs. 
    174                 if self._select: 
    175                     select.extend(['(%s) AS %s' % (quote_only_if_word(s[1]), qn(s[0])) for s in self._select.items()]) 
    176  
    177                 # Start composing the body of the SQL statement. 
    178                 sql = [" FROM", qn(opts.db_table)] 
    179  
    180                 # Compose the join dictionary into SQL describing the joins. 
    181                 if joins: 
    182                     sql.append(" ".join(["%s %s %s ON %s" % (join_type, table, alias, condition) 
    183                                     for (alias, (table, join_type, condition)) in joins.items()])) 
    184  
    185                 # Compose the tables clause into SQL. 
    186                 if tables: 
    187                     sql.append(", " + ", ".join(tables)) 
    188  
    189                 # Compose the where clause into SQL. 
    190                 if where: 
    191                     sql.append(where and "WHERE " + " AND ".join(where)) 
    192  
    193                 # ORDER BY clause 
    194                 order_by = [] 
    195                 if self._order_by is not None: 
    196                     ordering_to_use = self._order_by 
    197                 else: 
    198                     ordering_to_use = opts.ordering 
    199                 for f in handle_legacy_orderlist(ordering_to_use): 
    200                     if f == '?': # Special case. 
    201                         order_by.append(DatabaseOperations().random_function_sql()) 
    202                     else: 
    203                         if f.startswith('-'): 
    204                             col_name = f[1:] 
    205                             order = "DESC" 
    206                         else: 
    207                             col_name = f 
    208                             order = "ASC" 
    209                         if "." in col_name: 
    210                             table_prefix, col_name = col_name.split('.', 1) 
    211                             table_prefix = qn(table_prefix) + '.' 
    212                         else: 
    213                             # Use the database table as a column prefix if it wasn't given, 
    214                             # and if the requested column isn't a custom SELECT. 
    215                             if "." not in col_name and col_name not in (self._select or ()): 
    216                                 table_prefix = qn(opts.db_table) + '.' 
    217                             else: 
    218                                 table_prefix = '' 
    219                         order_by.append('%s%s %s' % (table_prefix, qn(orderfield2column(col_name, opts)), order)) 
    220                 if order_by: 
    221                     sql.append("ORDER BY " + ", ".join(order_by)) 
    222  
    223                 # Look for column name collisions in the select elements 
    224                 # and fix them with an AS alias.  This allows us to do a 
    225                 # SELECT * later in the paging query. 
    226                 cols = [clause.split('.')[-1] for clause in select] 
    227                 for index, col in enumerate(cols): 
    228                     if cols.count(col) > 1: 
    229                         col = '%s%d' % (col.replace('"', ''), index) 
    230                         cols[index] = col 
    231                         select[index] = '%s AS %s' % (select[index], col) 
    232  
    233                 # LIMIT and OFFSET clauses 
    234                 # To support limits and offsets, Oracle requires some funky rewriting of an otherwise normal looking query. 
    235                 select_clause = ",".join(select) 
    236                 distinct = (self._distinct and "DISTINCT " or "") 
    237  
    238                 if order_by: 
    239                     order_by_clause = " OVER (ORDER BY %s )" % (", ".join(order_by)) 
    240                 else: 
    241                     #Oracle's row_number() function always requires an order-by clause. 
    242                     #So we need to define a default order-by, since none was provided. 
    243                     order_by_clause = " OVER (ORDER BY %s.%s)" % \ 
    244                         (qn(opts.db_table), qn(opts.fields[0].db_column or opts.fields[0].column)) 
    245                 # limit_and_offset_clause 
    246                 if self._limit is None: 
    247                     assert self._offset is None, "'offset' is not allowed without 'limit'" 
    248  
    249                 if self._offset is not None: 
    250                     offset = int(self._offset) 
    251                 else: 
    252                     offset = 0 
    253                 if self._limit is not None: 
    254                     limit = int(self._limit) 
    255                 else: 
    256                     limit = None 
    257  
    258                 limit_and_offset_clause = '' 
    259                 if limit is not None: 
    260                     limit_and_offset_clause = "WHERE rn > %s AND rn <= %s" % (offset, limit+offset) 
    261                 elif offset: 
    262                     limit_and_offset_clause = "WHERE rn > %s" % (offset) 
    263  
    264                 if len(limit_and_offset_clause) > 0: 
    265                     fmt = \ 
    266     """SELECT * FROM 
    267       (SELECT %s%s, 
    268               ROW_NUMBER()%s AS rn 
    269        %s) 
    270     %s""" 
    271                     full_query = fmt % (distinct, select_clause, 
    272                                         order_by_clause, ' '.join(sql).strip(), 
    273                                         limit_and_offset_clause) 
    274                 else: 
    275                     full_query = None 
    276  
    277                 if get_full_query: 
    278                     return select, " ".join(sql), params, full_query 
    279                 else: 
    280                     return select, " ".join(sql), params 
    281  
     99        class OracleQuery(DefaultQuery): 
    282100            def resolve_columns(self, row, fields=()): 
    283101                from django.db.models.fields import DateField, DateTimeField, \ 
    284                     TimeField, BooleanField, NullBooleanField, DecimalField, Field 
     102                     TimeField, BooleanField, NullBooleanField, DecimalField, Field 
    285103                values = [] 
    286104                for value, field in map(None, row, fields): 
    287105                    if isinstance(value, Database.LOB): 
     
    325143                    values.append(value) 
    326144                return values 
    327145 
     146            def as_sql(self, with_limits=True): 
     147                """ 
     148                Creates the SQL for this query. Returns the SQL string and list of 
     149                parameters.  This is overriden from the original Query class to 
     150                accommodate Oracle's limit/offset SQL. 
     151                 
     152                If 'with_limits' is False, any limit/offset information is not included 
     153                in the query. 
     154                """ 
     155                # The quotename function and the `do_offset` flag, required for the 
     156                # quirky SQL needed to perform limits/offsets in Oracle. 
     157                qn = self.quote_name_unless_alias 
     158                do_offset = with_limits and (self.high_mark or self.low_mark) 
     159 
     160                self.pre_sql_setup() 
     161                out_cols = self.get_columns() 
     162                ordering = self.get_ordering() 
     163                # This must come after 'select' and 'ordering' -- see docstring of 
     164                # get_from_clause() for details. 
     165                from_, f_params = self.get_from_clause() 
     166                where, w_params = self.where.as_sql(qn=qn) 
     167 
     168                result = ['SELECT'] 
     169                if self.distinct: 
     170                    result.append('DISTINCT') 
     171                result.append(', '.join(out_cols)) 
     172 
     173                # To do the equivalent of limit/offset queries in Oracle requires 
     174                # selecting an additional row number with the full query, and then 
     175                # limiting the selection based on row number. 
     176                if do_offset: 
     177                    # Inserting the extra SELECT. 
     178                    result.insert(0, 'SELECT * FROM\n (') 
     179                    if ordering: 
     180                        rn_orderby = ', '.join(ordering) 
     181                    else: 
     182                        # Oracle's row_number() function always requires an order-by clause. 
     183                        # So we need to define a default order-by, since none was provided. 
     184                        opts = self.model._meta 
     185                        rn_orderby = '%s.%s' % (qn(opts.db_table), qn(opts.fields[0].db_column or opts.fields[0].column)) 
     186                    # Appending the ROW_NUMBER() SQL into the extra select. 
     187                    result.append(',\n ROW_NUMBER() OVER (ORDER BY %s ) AS rn\n' % rn_orderby) 
     188 
     189                # Now adding on the regular FROM statement. 
     190                result.append('FROM') 
     191                result.extend(from_) 
     192                params = list(f_params) 
     193 
     194                if where: 
     195                    result.append('WHERE %s' % where) 
     196                if self.extra_where: 
     197                    if not where: 
     198                        result.append('WHERE') 
     199                    else: 
     200                        result.append('AND') 
     201                    result.append(' AND'.join(self.extra_where)) 
     202                params.extend(w_params) 
     203 
     204                if self.group_by: 
     205                    grouping = self.get_grouping() 
     206                    result.append('GROUP BY %s' % ', '.join(grouping)) 
     207 
     208                if ordering: 
     209                    result.append('ORDER BY %s' % ', '.join(ordering)) 
     210 
     211                if do_offset: 
     212                    # Closing off the extra SELECT placed in above that gets the 
     213                    # row number (`rn`) and place WHERE condition on `rn` for the 
     214                    # desired range. 
     215                    result.append(')\n WHERE rn > %d' % self.low_mark) 
     216                    if self.high_mark: 
     217                        result.append('AND rn <= %d' % self.high_mark) 
     218 
     219                params.extend(self.extra_params) 
     220                return ' '.join(result), tuple(params) 
     221               
     222        from django.db import connection 
     223        class OracleQuerySet(DefaultQuerySet): 
     224            "The OracleQuerySet is overriden to use OracleQuery." 
     225            def __init__(self, model=None, query=None): 
     226                super(OracleQuerySet, self).__init__(model=model, query=query) 
     227                self.query = query or OracleQuery(self.model, connection) 
     228 
    328229        return OracleQuerySet 
    329230 
    330231    def quote_name(self, name): 
     
    446347    charset = 'utf-8' 
    447348 
    448349    def _format_params(self, params): 
     350        sz_kwargs, result = {}, {} 
    449351        if isinstance(params, dict): 
    450             result = {} 
    451352            charset = self.charset 
    452353            for key, value in params.items(): 
    453354                result[smart_str(key, charset)] = smart_str(value, charset) 
    454             return result 
     355                if hasattr(value, 'oracle_type'): sz_kwargs[key] = value.oracle_type() 
    455356        else: 
    456             return tuple([smart_str(p, self.charset, True) for p in params]) 
     357            for i, param in enumerate(params): 
     358                key = 'arg%d' % i 
     359                result[key] = smart_str(param, self.charset, True) 
     360                if hasattr(param, 'oracle_type'): sz_kwargs[key] = param.oracle_type() 
    457361 
     362        # If any of the parameters had an `oracle_type` method, then we set 
     363        # the inputsizes for those parameters using the returned type 
     364        if sz_kwargs: self.setinputsizes(**sz_kwargs) 
     365        return result 
     366 
    458367    def execute(self, query, params=None): 
    459368        if params is None: 
    460369            params = [] 
     
    472381 
    473382    def executemany(self, query, params=None): 
    474383        try: 
    475           args = [(':arg%d' % i) for i in range(len(params[0]))] 
     384            args = [(':arg%d' % i) for i in range(len(params[0]))] 
    476385        except (IndexError, TypeError): 
    477           # No params given, nothing to do 
    478           return None 
     386            # No params given, nothing to do 
     387            return None 
    479388        # cx_Oracle wants no trailing ';' for SQL statements.  For PL/SQL, it 
    480389        # it does want a trailing ';' but not a trailing '/'.  However, these 
    481390        # characters must be included in the original query in case the query