""" $Id: base.py 37 2007-06-05 19:49:06Z the_paya $

IBM DB2 database backend for Django

Requires PyDB2: http://sourceforge.net/projects/pydb2/
With this patch: http://sourceforge.net/tracker/index.php?func=detail&aid=1731609&group_id=67548&atid=518208
"""
from django.db.backends import BaseDatabaseWrapper, BaseDatabaseFeatures, BaseDatabaseOperations, util
import re
from django.utils.encoding import smart_str, force_unicode

try:
    import DB2 as Database
except ImportError, e:
    from django.core.exceptions import ImproperlyConfigured
    raise ImproperlyConfigured, "Error loading DB2 python module: %s" % e
import datetime
from django.utils.datastructures import SortedDict

Warning			= Database.Warning
Error			= Database.Error
InterfaceError		= Database.InterfaceError
DatabaseError		= Database.DatabaseError
DataError		= Database.DataError
OperationalError	= Database.OperationalError
IntegrityError		= Database.IntegrityError
InternalError		= Database.InternalError
ProgrammingError	= Database.ProgrammingError
NotSupportedError	= Database.NotSupportedError

class Cursor(Database.Cursor):
    """Return a cursor.
    Doing some translation tricks here.
    Set the database charset in DATABASE_CHARSET setting"""
    try:
        charset = settings.DATABASE_CHARSET
    except:
        charset = 'iso-8859-1'

    def _rewrite_args(self, query, params=None):
        print "Hello"
        # formatting parameters into charset
        if params is None:
            params = []
        else:
            params = self._format_params(params)
        # formatting query into charset
        query = smart_str(query, self.charset)
        a = query.find('CREATE')
        if a >= 0 and a < 10: # assuming this is a create command
            # DB2 doesn't want a 'lone' NULL (it's implied).
            query = re.sub('(?<!NOT) NULL', '', query)
            # DB2 does not like primary key definition without NOT NULL
            query = re.sub('(?<!NOT NULL) PRIMARY KEY', ' NOT NULL PRIMARY KEY',query)
        # PyDB2 uses '?' as the parameter style.
        query = query.replace("%s", "?")
        return query, params

    def _format_params(self, params=None):
        return self._smart_str(params)

    def execute(self, query, params=None):
        query, params = self._rewrite_args(query, params)
        return Database.Cursor.execute(self, query, params)

    def executemany(self, query, params=None):
        query, params = self._rewrite_args(query, params)
        return Database.Cursor.executemany(self, query, params)

    def fetchone(self):
        # force results into unicode
        return self._force_unicode(Database.Cursor.fetchone(self))

    def fetchmany(self, size=None):
        if size is None:
            size = self.arraysize
        # force results into unicode
        return self._force_unicode(Database.Cursor.fetchmany(self, size))

    def fetchall(self):
        # is this ever used ?
        # force results into unicode
        return self._force_unicode(Database.Cursor.fetchall(self))

    def _smart_str(self, s=None):
        if s is None:
            return s
        if isinstance(s, dict):
            result = {}
            charset = self.charset
            for key, value in s.items():
                result[smart_str(key, charset)] = self._smart_str(value, charset)
            return result
        elif isinstance(s, (tuple,list)):
            return tuple([self._smart_str(p) for p in s])
        else:
            if isinstance(s, basestring):
                try:
                    return smart_str(s, self.charset, True)
                except UnicodeEncodeError:
                    return ''
            return s

    def _force_unicode(self,s=None):
        if s is None:
            return s
        if isinstance(s, dict):
            result = {}
            for key, value in s.items():
                result[force_unicode(key, charset)] = self._force_unicode(value, charset)
            return result
        elif isinstance(s, (tuple,list)):
            return tuple([self._force_unicode(p) for p in s])
        else:
            if isinstance(s, basestring):
                try:
                    return force_unicode(s, encoding=self.charset)
                except UnicodeEncodeError:
                    return u''
            return s

integrity_off_tables = []

def turn_integrity_off_if_transaction(sender):
    """ Turns off integrity on a table if transaction is managed manually.
    This function will be connected to the pre_save signal"""
    from django.db import transaction, connection
    from django.db.models import Model
    from django.db.backends.db2_9.introspection import get_table_list
    if transaction.is_managed():
        if issubclass(sender, Model):
            t = sender._meta.db_table
            if t not in integrity_off_tables:
                print "turning off integrity for %s" % t
                cursor = connection.cursor()
                # a = get_table_list(cursor)
                # if t in a:
                cursor.execute('SET INTEGRITY FOR %s OFF READ ACCESS' % (quote_name(t),))
                cursor.close()
                cursor = None
                integrity_off_tables.append(t)

class Connection(Database.Connection):
    def cursor(self):
        return Cursor(self._db.cursor())

    def commit(self):
        from django.db.backends.db2_9.introspection import get_table_list
        cursor = self.cursor()
        a = get_table_list(cursor)		
        for t in integrity_off_tables:
            if t in a:
                print "turning on integrity for %s" % t
                cursor.execute('SET INTEGRITY FOR %s IMMEDIATE CHECKED' % (quote_name(t),))
        cursor.close()
        cursor = None
        self._db.commit()

Database.connect = Connection

class DatabaseFeatures(BaseDatabaseFeatures):
    allows_group_by_ordinal = False
    allows_unique_and_pk = True
    needs_datetime_string_cast = True
    needs_upper_for_iops = True
    needs_cast_for_iops = True
    autoindexes_primary_keys = True
    supports_constraints = True
    supports_tablespaces = False
    supports_compound_statements = True
    uses_case_insensitive_names = True

#dictfetchone = util.dictfetchone
#dictfetchmany = util.dictfetchmany
#dictfetchall = util.dictfetchall

class DatabaseOperations(BaseDatabaseOperations):
    
    def quote_name(self, name):
        """Name quoting.
        Names of type schema.tablename become "schema"."tablename"."""
        from django.conf import settings
        if not name.startswith('"') and not name.endswith('"'):
            return '.'.join(['"%s"' % util.truncate_name(f.upper(), self.max_name_length()) for f in name.split('.')])
        return name.upper()

    def last_insert_id(self, cursor, table_name, pk_name):
        cursor.execute("SELECT IDENTITY_VAL_LOCAL() FROM %s" % self.quote_name(table_name))
        # There is probably a 'better' way to do this.
        #cursor.execute("SELECT MAX(%s) FROM %s" % (self.quote_name(pk_name), self.quote_name(table_name)))
        a = cursor.fetchone()
        if a is None:
            return 0
        return int(a[0])
    
    def date_extract_sql(self, lookup_type, field_name):
        # lookup_type is 'year', 'month', 'day'
        if lookup_type == 'year':
            return "YEAR(%s)" % field_name
        elif lookup_type == 'month':
            return "MONTH(%s)" % field_name
        elif lookup_type == 'day':
            return "DAY(%s)" % field_name
    
    def date_trunc_sql(self, lookup_type, field_name):
        # lookup_type is 'year', 'month', 'day'
        # Hard one. I have seen TRUNC_TIMESTAMP somewhere but is not present
        # in any of my databases.
        # Doesn't work 'directly' since the GROUP BY needs this as well, 
        # DB2 can't take "GROUP BY 1"
        if lookup_type == 'year':
            return "TIMESTAMP(DATE(%s -DAYOFYEAR(%s) DAYS +1 DAY),'00:00:00')" % (field_name, field_name)
        elif lookup_type == 'month':
            return "TIMESTAMP(DATE(%s -DAY(%s) DAYS +1 DAY),'00:00:00')" % (field_name, field_name)
        elif lookup_type == 'day':
            return "TIMESTAMP(DATE(%s),'00:00:00')" % field_name
    
    def datetime_cast_sql(self):
        return ""
    
    def limit_offset_sql(self, limit, offset=None):
        # Limits and offset are too complicated to be handled here.
        # Instead, they are handled in DB2QuerySet.
        return ""
    
    def random_function_sql(self):
        return "RAND()"
    
    def deferrable_sql(self):
        # DB2 does not support deferring constraints to end of transaction ?
        # return ""
        return " ON DELETE CASCADE"
        # return " ON DELETE SET NULL"
        # does not work when column is NOT NULL
    
    def fulltext_search_sql(self, field_name):
        # DB2 has some nice extra packages that enables a CONTAINS() function,
        # but they're not available in my dbs.
        raise NotImplementedError
    
    def drop_foreignkey_sql(self):
        return "DROP CONSTRAINT"
    
    def pk_default_value(self):
        return "DEFAULT"
    
    def max_name_length(self):
        return 128;
    
    def max_constraint_length(self):
        # This needs a patch of management.py to detect and use this function
        # 18 for primarykeys and constraint names
        return 18;
    
    def alias(self, table, column):
        if table.count('.')==0:
            return "%s__%s" % (table, column)
        return "%s__%s" % (table.split('.')[-1], column)
    
    def start_transaction_sql(self):
        return "BEGIN;"
    
    def autoinc_sql(self, table):
        return None
    
    def drop_sequence(self, table):
        return "DROP SEQUENCE %s;" % self.quote_name(self.sequence_name(table))
    
    def sequence_name(self, table):
        return table
    
    def _sequence_reset_sql(self):
        return 'ALTER TABLE %s ALTER COLUMN %s RESTART WITH %s'
    
    def sql_flush(self, style, tables, sequences):
        """Return a list of SQL statements required to remove all data from
        all tables in the database (without actually removing the tables
        themselves) and put the database in an empty 'initial' state"""
        if tables:
            sql = []
            #~ for table in tables:
                #~ if table.count('.') == 0:
                    #~ sql.append('SET INTEGRITY FOR %s OFF' % (quote_name(table),))
            #~ for table in tables:
                #~ if table.count('.') == 0:
                    #~ sql.append('ALTER TABLE %s ACTIVATE NOT LOGGED INITIALLY WITH EMPTY TABLE' % (quote_name(table),))
            #~ for table in tables:
                #~ if table.count('.') == 0:
                    #~ sql.append('%s %s %s;' % \
                        #~ (style.SQL_KEYWORD('DELETE'), style.SQL_KEYWORD('FROM'),
                            #~ style.SQL_FIELD(quote_name(table))))
            #~ for table in reversed(tables):
                #~ if table.count('.') == 0:
                    #~ sql.append('SET INTEGRITY FOR %s IMMEDIATE CHECKED' % (quote_name(table),))
            #~ for table in tables:
                #~ if table.count('.') == 0:
                    #~ for sequence_info in sequences:
                        #~ if sequence_info['table'].upper() == table.upper():
                            #~ column = sequence_info['column']
                            #~ if column is not None:
                                #~ query = _get_sequence_reset_sql() % (quote_name(table),quote_name(column),1)
                                #~ sql.append(query)
            return sql
        else:
            return []
    
    def sql_sequence_reset(self, style, model_list):
        "Returns a list of the SQL statements to reset sequences for the given models."
        from django.db import models
        from django.db import connection
        output = []
        query = self._sequence_reset_sql()
        for model in model_list:
            for f in model._meta.fields:
                if isinstance(f, models.AutoField):
                    cursor = connection.cursor()
                    max_id = self.last_insert_id(cursor, model._meta.db_table, f.column) + 1
                    output.append(query % (self.quote_name(model._meta.db_table), self.quote_name(f.column), max_id))
                    cursor.close()
                    cursor = None
                    break # Only one AutoField is allowed per model, so don't bother continuing.
            #~ for f in model._meta.many_to_many:
                #~ cursor = connection.cursor()
                #~ max_id = last_insert_id(cursor, model._meta.db_table, f.column) + 1
                #~ output.append(query % (quote_name(f.m2m_db_table()), quote_name(f.m2m_column_name()), max_id))
        return output
    
    def query_set_class(self, DefaultQuerySet):
        "Create a custom QuerySet class for DB2."
    
        from django.db import backend, connection
        from django.db.models.query import EmptyResultSet, GET_ITERATOR_CHUNK_SIZE, quote_only_if_word
    
        class DB2QuerySet(DefaultQuerySet):
    
            def iterator(self):
                "Performs the SELECT database lookup of this QuerySet."
    
                from django.db.models.query import get_cached_row
    
                # self._select is a dictionary, and dictionaries' key order is
                # undefined, so we convert it to a list of tuples.
                extra_select = self._select.items()
    
                full_query = None
    
                try:
                    try:
                        select, sql, params, full_query = self._sql_clause(get_full_query=True)
                    except TypeError:
                        select, sql, params = self._sql_clause()
                except EmptyResultSet:
                    raise StopIteration
                if not full_query:
                    full_query = "SELECT %s%s\n%s" % \
                                ((self._distinct and "DISTINCT " or ""),
                                ', '.join(select), sql)
    
                cursor = connection.cursor()
                cursor.execute(full_query, params)
    
                fill_cache = self._select_related
                fields = self.model._meta.fields
                index_end = len(fields)
    
                # so here's the logic;
                # 1. retrieve each row in turn
                # 2. convert NCLOBs
    
                while 1:
                    rows = cursor.fetchmany(GET_ITERATOR_CHUNK_SIZE)
                    if not rows:
                        raise StopIteration
                    for row in rows:
                        row = self.resolve_columns(row, fields)
                        if fill_cache:
                            obj, index_end = cached_row(klass=self.model, row=row,
                                                            index_start=0, max_depth=self._max_related_depth)
                        else:
                            obj = self.model(*row[:index_end])
                        for i, k in enumerate(extra_select):
                            setattr(obj, k[0], row[index_end+i])
                        yield obj
    
            # DISTINCT could not work properly
            def _get_sql_clause(self, get_full_query=False):
                from django.db.models.query import fill_table_cache, \
                    handle_legacy_orderlist, orderfield2column
    
                opts = self.model._meta
    
                select = ["%s.%s" % (backend.quote_name(opts.db_table), backend.quote_name(f.column)) for f in opts.fields]
                tables = [quote_only_if_word(t) for t in self._tables]
                joins = SortedDict()
                where = self._where[:]
                params = self._params[:]
    
                # Convert self._filters into SQL.
                joins2, where2, params2 = self._filters.get_sql(opts)
                joins.update(joins2)
                where.extend(where2)
                params.extend(params2)
    
                # Add additional tables and WHERE clauses based on select_related.
                if self._select_related:
                    fill_table_cache(opts, select, tables, where, opts.db_table, [opts.db_table])
    
                # Add any additional SELECTs.
                if self._select:
                    select.extend(['(%s) AS %s' % (quote_only_if_word(s[1]), backend.quote_name(s[0])) for s in self._select.items()])
    
                # Start composing the body of the SQL statement.
                sql = [" FROM", backend.quote_name(opts.db_table)]
    
                # Compose the join dictionary into SQL describing the joins.
                if joins:
                    sql.append(" ".join(["%s %s %s ON %s" % (join_type, table, alias, condition)
                                    for (alias, (table, join_type, condition)) in joins.items()]))
    
                # Compose the tables clause into SQL.
                if tables:
                    sql.append(", " + ", ".join(tables))
    
                # Compose the where clause into SQL.
                if where:
                    sql.append(where and "WHERE " + " AND ".join(where))
    
                # ORDER BY clause
                order_by = []
                if self._order_by is not None:
                    ordering_to_use = self._order_by
                else:
                    ordering_to_use = opts.ordering
                for f in handle_legacy_orderlist(ordering_to_use):
                    if f == '?': # Special case.
                        order_by.append(backend.get_random_function_sql())
                    else:
                        if f.startswith('-'):
                            col_name = f[1:]
                            order = "DESC"
                        else:
                            col_name = f
                            order = "ASC"
                        if "." in col_name:
                            table_prefix, col_name = col_name.split('.', 1)
                            table_prefix = backend.quote_name(table_prefix) + '.'
                        else:
                            # Use the database table as a column prefix if it wasn't given,
                            # and if the requested column isn't a custom SELECT.
                            if "." not in col_name and col_name not in (self._select or ()):
                                table_prefix = backend.quote_name(opts.db_table) + '.'
                            else:
                                table_prefix = ''
                        order_by.append('%s%s %s' % (table_prefix, backend.quote_name(orderfield2column(col_name, opts)), order))
                if order_by:
                    sql.append("ORDER BY " + ", ".join(order_by))
    
                # Look for column name collisions in the select elements
                # and fix them with an AS alias.  This allows us to do a
                # SELECT * later in the paging query.
                cols = [clause.split('.')[-1] for clause in select]
                for index, col in enumerate(cols):
                    if cols.count(col) > 1:
                        col = '%s%d' % (col.replace('"', ''), index)
                        cols[index] = col
                        select[index] = '%s AS %s' % (select[index], col)
    
                # LIMIT and OFFSET clauses
                # To support limits and offsets, Oracle requires some funky rewriting of an otherwise normal looking query.
                select_clause = ",".join(select)
                distinct = (self._distinct and "DISTINCT " or "")
    
                if order_by:
                    order_by_clause = " OVER (ORDER BY %s )" % (", ".join(order_by))
                else:
                    #DB2's row_number() function always requires an order-by clause.
                    #So we need to define a default order-by, since none was provided.
                    order_by_clause = " OVER (ORDER BY %s.%s)" % \
                        (backend.quote_name(opts.db_table),
                        backend.quote_name(opts.fields[0].db_column or opts.fields[0].column))
                # limit_and_offset_clause
                if self._limit is None:
                    assert self._offset is None, "'offset' is not allowed without 'limit'"
    
                if self._offset is not None:
                    offset = int(self._offset)
                else:
                    offset = 0
                if self._limit is not None:
                    limit = int(self._limit)
                else:
                    limit = None
                if limit == 0:
                    limit = None
                limit_and_offset_clause = ''
                if limit is not None:
                    limit_and_offset_clause = "WHERE rn > %s AND rn <= %s" % (offset, limit+offset)
                elif offset:
                    limit_and_offset_clause = "WHERE rn > %s" % (offset)
    
                if len(limit_and_offset_clause) > 0:
                    if limit is not None:
                        fmt = "SELECT * FROM (SELECT %s%s, ROW_NUMBER()%s AS rn %s FETCH FIRST %s ROWS ONLY) AS foo %s"
                        full_query = fmt % (distinct, select_clause,
                                        order_by_clause, ' '.join(sql).strip(), limit+offset,
                                        limit_and_offset_clause)
                    else:
                        fmt = "SELECT * FROM (SELECT %s%s, ROW_NUMBER()%s AS rn %s ) AS foo %s"
                        full_query = fmt % (distinct, select_clause,
                                        order_by_clause, ' '.join(sql).strip(), 
                                        limit_and_offset_clause)
    
                else:
                    full_query = None
    
                if get_full_query:
                    return select, " ".join(sql), params, full_query
                else:
                    return select, " ".join(sql), params
    
            def resolve_columns(self, row, fields=()):
                from django.db.models.fields import Field, CharField, BooleanField, TextField
                values = []
                for value, field in map(None, row, fields):
                    # strip trailing spaces in char and text fields
                    #if isinstance(field, (CharField, TextField,)):
                    if isinstance(value, basestring):
                        if value:
                            value = value.strip()
                    # create real booleans
                    if isinstance(field, BooleanField):
                        value = {0: False, 1: True}.get(value, False)
                    values.append(value)
                return values
    
        return DB2QuerySet

class DatabaseWrapper(object):
    
    features = DatabaseFeatures()
    ops = DatabaseOperations()
    
    # UPPER needs typecasting or DB2 does not know which upper function to use
    # it does not matter if the typecast is correct
    string_lookup_length = '50'
    operators = {
        'exact': "= %s",
        'iexact': "= UPPER(CAST(%s as VARCHAR("+string_lookup_length+"))) ",
        'contains': "LIKE %s ESCAPE '\\'",
        'icontains': "LIKE UPPER(CAST(%s as VARCHAR("+string_lookup_length+"))) ESCAPE '\\'",
        'gt': "> %s",
        'gte': ">= %s",
        'lt': "< %s",
        'lte': "<= %s",
        'startswith': "LIKE %s ESCAPE '\\'",
        'endswith': "LIKE %s ESCAPE '\\'",
        'istartswith': "LIKE UPPER(CAST(%s as VARCHAR("+string_lookup_length+")))  ESCAPE '\\'",
        'iendswith': "LIKE UPPER(CAST(%s as VARCHAR("+string_lookup_length+")))  ESCAPE '\\'",
    }
    
    def __init__(self, **kwargs):
        self.connection = None
        self.queries = []
        self.server_version = None
        self.options = kwargs
    
    def _valid_connection(self):
        return self.connection is not None
    
    def cursor(self):
        from django.conf import settings
        from warnings import filterwarnings
        if self.connection is None:
            conn_dict = {}
            # A DB2 client is configured with nodes, and then with databases connected
            # to these nodes, I don't know if there is a way of specifying a complete
            # DSN with a hostname and port, the PyDB2 module shows no sign of this either.
            # So this DSN is actually the database name configured in the host's client instance.
            if settings.DATABASE_NAME == '':
                from django.core.exceptions import ImproperlyConfigured
                raise ImproperlyConfigured, "You need to specify DATABASE_NAME in your Django settings file."
            conn_dict['dsn'] = settings.DATABASE_NAME
            if settings.DATABASE_USER == '':
                from django.core.exceptions import ImproperlyConfigured
                raise ImproperlyConfigured, "You need to specify DATABASE_USER in your Django settings file."
            conn_dict['uid'] = settings.DATABASE_USER
            if settings.DATABASE_PASSWORD == '':
                from django.core.exceptions import ImproperlyConfigured
                raise ImproperlyConfigured, "You need to specify DATABASE_PASSWORD in your Django settings file."
            conn_dict['pwd'] = settings.DATABASE_PASSWORD
            # Just imitating others here, I haven't seen any "options" for DB2.
            conn_dict.update(self.options)
            self.connection = Database.connect(**conn_dict)
            cursor = self.connection.cursor()
        else:
            cursor = self.connection.cursor()
        if settings.DEBUG:
            return util.CursorDebugWrapper(cursor, self)
        return cursor

    def _commit(self):
        if self.connection is not None:
            return self.connection.commit()
    
    def _rollback(self):
        if self.connection is not None:
            return self.connection.rollback()
    
    def close(self):
        if self.connection is not None:
            self.connection.close()
            self.connection = None

