Ticket #5543: base.py.diff

File base.py.diff, 6.7 KB (added by hklv, 16 years ago)

Patch for Django-1.0/django/db/backends/oracle/base.py. Adds support for Oracle stored procedures and functions ("callproc" and "callfunc" method ). Added a "commit" method. Fixed a typo in the _format_params method.

  • base.py

    old new  
    77import os
    88import datetime
    99import time
     10import string
    1011
    1112# Oracle takes client-side character set encoding from the environment.
    1213os.environ['NLS_LANG'] = '.UTF8'
     
    251252            else:
    252253                conn_string = "%s/%s@%s" % (settings.DATABASE_USER, settings.DATABASE_PASSWORD, settings.DATABASE_NAME)
    253254                self.connection = Database.connect(conn_string, **self.options)
     255               
    254256            cursor = FormatStylePlaceholderCursor(self.connection)
    255257            # Set oracle date to ansi date format.  This only needs to execute
    256258            # once when we create a new connection.
    257             cursor.execute("ALTER SESSION SET NLS_DATE_FORMAT = 'YYYY-MM-DD' "
    258                            "NLS_TIMESTAMP_FORMAT = 'YYYY-MM-DD HH24:MI:SS.FF'")
     259            cursor.execute("ALTER SESSION SET NLS_DATE_FORMAT = 'YYYY-MM-DD'"
     260                "NLS_TIMESTAMP_FORMAT = 'YYYY-MM-DD HH24:MI:SS.FF'")
    259261            try:
    260262                self.oracle_version = int(self.connection.version.split('.')[0])
    261263                # There's no way for the DatabaseOperations class to know the
     
    280282        cursor.arraysize = 100
    281283        return cursor
    282284
     285    def commit(self):
     286            self.connection.commit()
     287
    283288
    284289class OracleParam(object):
    285290    """
     
    314319    charset = 'utf-8'
    315320
    316321    def _format_params(self, params):
     322        if params is None:
     323            params = []
    317324        if isinstance(params, dict):
    318325            result = {}
    319326            for key, value in params.items():
    320                 result[smart_str(key, self.charset)] = OracleParam(param, self.charset)
     327                result[smart_str(key, self.charset)] = OracleParam(value, self.charset)
    321328            return result
    322329        else:
    323330            return tuple([OracleParam(p, self.charset, True) for p in params])
     
    333340            for key, value in iterator:
    334341                if value.input_size: sizes[key] = value.input_size
    335342        if isinstance(sizes, dict):
    336             self.setinputsizes(**sizes)
     343            return self.setinputsizes(**sizes)
    337344        else:
    338             self.setinputsizes(*sizes)
    339 
     345            return self.setinputsizes(*sizes)
     346           
     347           
    340348    def _param_generator(self, params):
    341349        if isinstance(params, dict):
    342350            return dict([(k, p.smart_str) for k, p in params.iteritems()])
     
    344352            return [p.smart_str for p in params]
    345353
    346354    def execute(self, query, params=None):
    347         if params is None:
    348             params = []
    349         else:
    350             params = self._format_params(params)
     355        params = self._format_params(params)
    351356        args = [(':arg%d' % i) for i in range(len(params))]
    352357        # cx_Oracle wants no trailing ';' for SQL statements.  For PL/SQL, it
    353358        # it does want a trailing ';' but not a trailing '/'.  However, these
     
    356361        if query.endswith(';') or query.endswith('/'):
    357362            query = query[:-1]
    358363        query = smart_str(query, self.charset) % tuple(args)
     364        #print 'query', query, [x.smart_str for x in params]
     365        #print
    359366        self._guess_input_sizes([params])
    360367        try:
    361368            return Database.Cursor.execute(self, query, self._param_generator(params))
     
    388395                e = IntegrityError(e.message)
    389396            raise e
    390397
     398
     399    def callproc(self, name, params=None, keys=None):
     400        """
     401        Call a PL/SQL procedure with the given name.
     402        The list of parameters must contain one entry for each argument that the procedure expects.
     403        The result is a modified copy of the input sequence.
     404        Input parameters are left untouched;
     405        output and input/output parameters are replaced with possibly new values,
     406        and must be initialised as Oracle variables;
     407        ex: arg1 = cursor.var(cx_Oracle.STRING), arg1.setvalue('this is an input value').
     408        An optional list of corresponding keys will return a dictionnary instead of a list.
     409        """
     410        if isinstance(params, dict):
     411            raise('callproc: "params" must be a list, not a dictionnary')
     412        params  = self._format_params(params)
     413        # "keys" is a list of corresponding keys to return a dictionnary of results
     414        if keys:
     415            if isinstance(keys, list):
     416                if len(keys) != len(params):
     417                    raise('callproc: the number of keys differs from the number of params')
     418            else:
     419                raise('callproc: "keys" must be a list, not a dictionnary')
     420        results = self._guess_input_sizes([params])
     421        try:
     422            Database.Cursor.execute( self
     423                , "begin %s(%s); end;" % (name,
     424                    string.join([":arg%d"%i for i in range(len(params))], ','))
     425                , self._param_generator(params)
     426            )
     427        except DatabaseError, e:
     428            # cx_Oracle <= 4.4.0 wrongly raises a DatabaseError for ORA-01400.
     429            if e.message.code == 1400 and type(e) != IntegrityError:
     430                e = IntegrityError(e.message)
     431            raise e
     432        if isinstance(keys, list):
     433            return dict([(keys[i], results[i].getvalue()) for i in range(len(keys))])
     434        else:
     435            return tuple([result.getvalue() for result in results])
     436
     437
     438    def callfunc(self, name, returnType, params=None):
     439        """
     440        Call a function with the given name.
     441        The return type must be a Oracle type; ex: cx_Oracle.STRING
     442        The sequence of parameters must contain one entry for each argument that the function expects.
     443        The result of the call is the return value of the function.
     444        """
     445        if isinstance(params, dict):
     446            raise('callfunc: parameters must be in a list, not a dictionnary')
     447        if 'cx_Oracle.' not in str(returnType):
     448            raise('callfunc: returnType must be a Oracle type')
     449        params = self._format_params([self.var(returnType)]+params)
     450        results = self._guess_input_sizes([params])
     451        try:
     452            Database.Cursor.execute(self
     453                , "begin :arg0 := %s(%s); end;" % (name,
     454                    string.join([":arg%d"%(i+1) for i in range(len(params)-1)], ','))
     455                , self._param_generator(params)
     456            )
     457        except DatabaseError, e:
     458            # cx_Oracle <= 4.4.0 wrongly raises a DatabaseError for ORA-01400.
     459            if e.message.code == 1400 and type(e) != IntegrityError:
     460                e = IntegrityError(e.message)
     461            raise e
     462        return results[0].getvalue()
     463           
    391464    def fetchone(self):
    392465        row = Database.Cursor.fetchone(self)
    393466        if row is None:
Back to Top