"""
Firebird database backend for Django.

Requires KInterbasDB 3.2: http://kinterbasdb.sourceforge.net/
The egenix mx (mx.DateTime) is NOT required

Database charset should be UNICODE_FSS or UTF8 (FireBird 2.0+)
To use UTF8 encoding add FIREBIRD_CHARSET = 'UTF8' to your settings.py 
UNICODE_FSS works with all versions and uses less memory
"""

from django.db.backends import BaseDatabaseWrapper, BaseDatabaseFeatures, BaseDatabaseOperations, util

try:
    import kinterbasdb as Database
except ImportError, e:
    from django.core.exceptions import ImproperlyConfigured
    raise ImproperlyConfigured, "Error loading KInterbasDB module: %s" % e

DatabaseError = Database.DatabaseError
IntegrityError = Database.IntegrityError

class DatabaseFeatures(BaseDatabaseFeatures):
    autoindexes_primary_keys = False
    needs_datetime_string_cast = False
    needs_default_null = True
    needs_upper_for_iops = True
    supports_constraints = False #some tests went strange without it
    uses_custom_icontains = True #CONTAINING <value> op instead of LIKE %<value>%
    uses_custom_startswith = True #STARTING WITH op. Faster than LIKE
    uses_custom_queryset = True

class DatabaseOperations(BaseDatabaseOperations):
    _max_name_length = 31
    def __init__(self):
        self._firebird_version = None
    
    def get_generator_name(self, name):
        return '%s_G' % util.truncate_name(name, self._max_name_length-2).upper()
        
    def get_trigger_name(self, name):
        return '%s_T' % util.truncate_name(name, self._max_name_length-2).upper() 
    
    def _get_firebird_version(self):
        if self._firebird_version is None:
            from django.db import connection
            self._firebird_version = [int(val) for val in connection.server_version.split()[-1].split('.')]
        return self._firebird_version
    firebird_version = property(_get_firebird_version)
    
    def _autoinc_sql_with_style(self, style, table_name, column_name):
        """
        To simulate auto-incrementing primary keys in Firebird, we have to
        create a generator and a trigger.
    
        Create the generators and triggers names based only on table name
        since django only support one auto field per model
        """
        
        KWD = style.SQL_KEYWORD
        TBL = style.SQL_TABLE
        FLD = style.SQL_FIELD
    
        generator_name = self.get_generator_name(table_name)
        trigger_name = self.get_trigger_name(table_name)
        column_name = self.quote_name(column_name)
        table_name = self.quote_name(table_name)
        
        generator_sql = "%s %s;" % ( KWD('CREATE GENERATOR'), 
                                     TBL(generator_name))      
        trigger_sql = "\n".join([
            "%s %s %s %s" % ( \
            KWD('CREATE TRIGGER'), TBL(trigger_name), KWD('FOR'),
            TBL(table_name)),
            "%s 0 %s" % (KWD('ACTIVE BEFORE INSERT POSITION'), KWD('AS')),
            KWD('BEGIN'), 
            "  %s ((%s.%s %s) %s (%s.%s = 0)) %s" % ( \
                KWD('IF'),
                KWD('NEW'), FLD(column_name), KWD('IS NULL'),
                KWD('OR'), KWD('NEW'), FLD(column_name),
                KWD('THEN')
            ),
            "  %s" % KWD('BEGIN'), 
            "    %s.%s = %s(%s, 1);" % ( \
                KWD('NEW'), FLD(column_name),
                KWD('GEN_ID'), TBL(generator_name)
            ),
            "  %s" % KWD('END'),
            KWD('END')
            ])
        return (generator_sql, trigger_sql)
    
    def autoinc_sql(self, table_name, column_name):
        # style argument disappeared, so we'll just import django's dummy
        from django.core.management.color import no_style, color_style
        return self._autoinc_sql_with_style(no_style(), table_name, column_name)

    def max_name_length(self):
        return self._max_name_length

    def query_set_class(self, DefaultQuerySet):
        from django.db import connection
        from django.db.models.query import EmptyResultSet, GET_ITERATOR_CHUNK_SIZE, quote_only_if_word

        class FirebirdQuerySet(DefaultQuerySet):
        #TODO: Optimize for Firebird and take full advanatage of its power
        # Now it's just a copy of django.db.models.query._QuerySet
        # with LIMIT/OFFSET removed and FIRST/SKIP added
            def _get_sql_clause(self):
                from django.db.models.query import SortedDict, handle_legacy_orderlist, orderfield2column, fill_table_cache
                qn = connection.ops.quote_name
                opts = self.model._meta

                # Construct the fundamental parts of the query: SELECT X FROM Y WHERE Z.
                select = ["%s.%s" % (qn(opts.db_table), qn(f.column)) for f in opts.fields]
                tables = [quote_only_if_word(t) for t in self._tables]
                joins = SortedDict()
                where = self._where[:]
                params = self._params[:]

                # Convert self._filters into SQL.
                joins2, where2, params2 = self._filters.get_sql(opts)
                joins.update(joins2)
                where.extend(where2)
                params.extend(params2)

                # Add additional tables and WHERE clauses based on select_related.
                if self._select_related:
                    fill_table_cache(opts, select, tables, where,
                                     old_prefix=opts.db_table,
                                     cache_tables_seen=[opts.db_table],
                                     max_depth=self._max_related_depth)

                # Add any additional SELECTs.
                if self._select:
                    select.extend(['(%s) AS %s' % (quote_only_if_word(s[1]), qn(s[0])) for s in self._select.items()])

                # Start composing the body of the SQL statement.
                sql = [" FROM", qn(opts.db_table)]

                # Compose the join dictionary into SQL describing the joins.
                if joins:
                    sql.append(" ".join(["%s %s AS %s ON %s" % (join_type, table, alias, condition)
                                    for (alias, (table, join_type, condition)) in joins.items()]))

                # Compose the tables clause into SQL.
                if tables:
                    sql.append(", " + ", ".join(tables))

                # Compose the where clause into SQL.
                if where:
                    sql.append(where and "WHERE " + " AND ".join(where))

                # ORDER BY clause
                order_by = []
                if self._order_by is not None:
                    ordering_to_use = self._order_by
                else:
                    ordering_to_use = opts.ordering
                for f in handle_legacy_orderlist(ordering_to_use):
                    if f == '?': # Special case.
                        order_by.append(connection.ops.random_function_sql())
                    else:
                        if f.startswith('-'):
                            col_name = f[1:]
                            order = "DESC"
                        else:
                            col_name = f
                            order = "ASC"
                        if "." in col_name:
                            table_prefix, col_name = col_name.split('.', 1)
                            table_prefix = qn(table_prefix) + '.'
                        else:
                            # Use the database table as a column prefix if it wasn't given,
                            # and if the requested column isn't a custom SELECT.
                            if "." not in col_name and col_name not in (self._select or ()):
                                table_prefix = qn(opts.db_table) + '.'
                            else:
                                table_prefix = ''
                        order_by.append('%s%s %s' % (table_prefix, qn(orderfield2column(col_name, opts)), order))
                if order_by:
                    sql.append("ORDER BY " + ", ".join(order_by))

                return select, " ".join(sql), params
            
            def iterator(self):
                "Performs the SELECT database lookup of this QuerySet."
                from django.db.models.query import get_cached_row
                
                try:
                    select, sql, params = self._get_sql_clause()
                except EmptyResultSet:
                    raise StopIteration 
                    
                # self._select is a dictionary, and dictionaries' key order is
                # undefined, so we convert it to a list of tuples.
                extra_select = self._select.items()
                
                cursor = connection.cursor() 
                limit_offset_before = "" 
                if self._limit is not None: 
                    limit_offset_before += "FIRST %s " % self._limit 
                    if self._offset: 
                        limit_offset_before += "SKIP %s " % self._offset
                else:
                    assert self._offset is None, "'offset' is not allowed without 'limit'"
                cursor.execute("SELECT " + limit_offset_before + (self._distinct and "DISTINCT " or "") + ",".join(select) + sql, params)
                fill_cache = self._select_related
                fields = self.model._meta.fields
                index_end = len(fields)
                has_resolve_columns = hasattr(self, 'resolve_columns')
                while 1:
                    rows = cursor.fetchmany(GET_ITERATOR_CHUNK_SIZE)
                    if not rows:
                        raise StopIteration
                    for row in rows:
                        if has_resolve_columns:
                            row = self.resolve_columns(row, fields)
                        if fill_cache:
                            obj, index_end = get_cached_row(klass=self.model, row=row,
                                                            index_start=0, max_depth=self._max_related_depth)
                        else:
                            obj = self.model(*row[:index_end])
                        for i, k in enumerate(extra_select):
                            setattr(obj, k[0], row[index_end+i])
                        yield obj
        return FirebirdQuerySet
    

    def quote_name(self, name):
        #Trancate and quote once. 
        #Generally works without upper() but some tests fail? without it
        #So let it be like in oracle backend
        if not name.startswith('"') and not name.endswith('"'):
            name = '"%s"' % util.truncate_name(name, self._max_name_length)
        return name.upper()

    def last_insert_id(self, cursor, table_name, pk_name):
        stmt = 'SELECT GEN_ID(%s, 0) from RDB$DATABASE'
        cursor.execute(stmt % self.get_generator_name(table_name))
        return cursor.fetchone()[0]

    def date_extract_sql(self, lookup_type, column_name):
        # lookup_type is 'year', 'month', 'day'
        return "EXTRACT(%s FROM %s)" % (lookup_type, column_name)

    def date_trunc_sql(self, lookup_type, column_name):
        if lookup_type == 'year':
             sql = "EXTRACT(year FROM %s)||'-01-01 00:00:00'" % column_name
        elif lookup_type == 'month':
            sql = "EXTRACT(year FROM %s)||'-'||EXTRACT(month FROM %s)||'-01 00:00:00'" % (column_name, column_name)
        elif lookup_type == 'day':
            sql = "EXTRACT(year FROM %s)||'-'||EXTRACT(month FROM %s)||'-'||EXTRACT(day FROM %s)||' 00:00:00'" % (column_name, column_name, column_name)
        return "CAST(%s AS TIMESTAMP)" % sql
    
    def cascade_delete_update_sql(self):
        # Solves FK problems with sql_flush
        return " ON DELETE CASCADE ON UPDATE CASCADE"
    
    def datetime_cast_sql(self):
        return None

    def limit_offset_sql(self, limit, offset=None):
        # limits are handled in custom FirebirdQuerySet 
        assert False, 'Limits are handled in a different way in Firebird'
        return ""

    def random_function_sql(self):
        return "rand()"

    def pk_default_value(self):
        return "NULL"
    
    def start_transaction_sql(self):
        return ""

    def sequence_reset_sql(self, style, model_list):
        from django.db import models
        output = []
        for model in model_list:
            for f in model._meta.fields:
                if isinstance(f, models.AutoField):
                    generator_name = self.get_generator_name(model._meta.db_table)
                    output.append("SET GENERATOR %s TO 0;" % generator_name)
                    break # Only one AutoField is allowed per model, so don't bother continuing.
            for f in model._meta.many_to_many:
                generator_name = self.get_generator_name(f.m2m_db_table())
                output.append("SET GENERATOR %s TO 0;" % generator_name)
        return output
    
    def sql_flush(self, style, tables, sequences):
        if tables:
            # FK constraints gave us a lot of trouble with default values
            # that was a reason behind very ugly and dangerous code here
            # Solved with "ON DELETE CASCADE" with all FK references        
            sql = ['%s %s %s;' % \
                    (style.SQL_KEYWORD('DELETE'),
                     style.SQL_KEYWORD('FROM'),
                     style.SQL_FIELD(self.quote_name(table))
                     ) for table in tables]
            for generator_info in sequences:
                table_name = generator_info['table']
                query = "SET GENERATOR %s TO 0;" % self.get_generator_name(table_name)
                sql.append(query)
            return sql
        else:
            return []

#    def fulltext_search_sql(self, field_name):
#        return field_name + ' CONTAINING %s'
        
    def drop_sequence_sql(self, table):
        return "DROP GENERATOR %s;" % self.get_generator_name(table)
        
    def last_executed_query(self, cursor, sql, params):
        """
        Returns a string of the query last executed by the given cursor, with
        placeholders replaced with actual values.

        `sql` is the raw query containing placeholders, and `params` is the
        sequence of parameters. These are used by default, but this method
        exists for database backends to provide a better implementation
        according to their own quoting schemes.
        """
        from django.utils.encoding import smart_unicode, force_unicode

        # Convert params to contain Unicode values.
        to_unicode = lambda s: force_unicode(s, strings_only=True)
        if isinstance(params, (list, tuple)):
            u_params = tuple([to_unicode(val) for val in params])
        else:
            u_params = dict([(to_unicode(k), to_unicode(v)) for k, v in params.items()])
        try:
            #Extracts sql right from KInterbasDB's prepared statement
            return smart_unicode(cursor.query) % u_params
        except TypeError:
            return smart_unicode(sql) % u_params

class FirebirdCursorWrapper(object):
    """
    Django uses "format" ('%s') style placeholders, but firebird uses "qmark" ('?') style.
    This fixes it -- but note that if you want to use a literal "%s" in a query,
    you'll need to use "%%s".
    
    We also do all automatic type conversions here.
    """
    import kinterbasdb.typeconv_datetime_stdlib as tc_dt
    import kinterbasdb.typeconv_fixed_decimal as tc_fd
    import kinterbasdb.typeconv_text_unicode as tc_tu
    import django.utils.encoding as dj_eu
   
    def timestamp_conv_in(self, timestamp):
        if isinstance(timestamp, basestring):
            #Replaces 6 digits microseconds to 4 digits allowed in Firebird
            timestamp = timestamp[:24]
        return self.tc_dt.timestamp_conv_in(timestamp)
    
    def time_conv_in(self, value):
        import datetime
        if isinstance(value, datetime.datetime):
            value = datetime.time(value.hour, value.minute, value.second, value.micosecond)       
        return self.tc_dt.time_conv_in(value) 
    
    def ascii_conv_in(self, text):  
        return self.dj_eu.smart_str(text, 'ascii')

    def unicode_conv_in(self, text):
        return self.tc_tu.unicode_conv_in((self.dj_eu.force_unicode(text[0]), self.FB_CHARSET_CODE))

    def blob_conv_in(self, text): 
        return self.tc_tu.unicode_conv_in((self.dj_eu.force_unicode(text), self.FB_CHARSET_CODE))

    def blob_conv_out(self, text):
        return self.tc_tu.unicode_conv_out((text, self.FB_CHARSET_CODE))
        
    def __init__(self, cursor):
        from django.conf import settings
        self.FB_CHARSET_CODE = 3 #UNICODE_FSS
        if hasattr(settings, 'FIREBIRD_CHARSET'):
            if settings.FIREBIRD_CHARSET == 'UTF8':
                self.FB_CHARSET_CODE = 4 # UTF-8 with Firebird 2.0+    
        self.cursor = cursor
        
        # Prepared Statement 
        # http://kinterbasdb.sourceforge.net/dist_docs/usage.html#adv_prepared_statements
        # Need to decide wether they are useful or not
        # Maybe add prepare, execute_prep and executemany_pep methods here
        # and rewrite QuerySet to take advantage of them?
        # Could speed the things up
        self._statement = None
        self.cursor.set_type_trans_in({
            'DATE':             self.tc_dt.date_conv_in,
            'TIME':             self.time_conv_in,
            'TIMESTAMP':        self.timestamp_conv_in,
            'FIXED':            self.tc_fd.fixed_conv_in_imprecise,
            'TEXT':             self.ascii_conv_in,
            'TEXT_UNICODE':     self.unicode_conv_in,
            'BLOB':             self.blob_conv_in
        })
        self.cursor.set_type_trans_out({
            'DATE':             self.tc_dt.date_conv_out,
            'TIME':             self.tc_dt.time_conv_out,
            'TIMESTAMP':        self.tc_dt.timestamp_conv_out,
            'FIXED':            self.tc_fd.fixed_conv_out_imprecise,
            'TEXT':             self.dj_eu.force_unicode,
            'TEXT_UNICODE':     self.tc_tu.unicode_conv_out,
            'BLOB':             self.blob_conv_out
        })
    
    def _get_query(self):
        if self._statement:
            return self._statement.sql
    def _get_statement(self):
        if self._statement:
            return self._statement
    query = property(_get_query)
    statement = property(_get_statement)
        
    def execute(self, query, params=()):
        query = self.convert_query(query, len(params))
        if self._get_query() != query:
            try:
                self._statement = self.cursor.prep(query)
            except Database.ProgrammingError, e:
                print query % params
                raise DatabaseError, e
        return self.cursor.execute(self._statement, params)

    def executemany(self, query, param_list):
        query = self.convert_query(query, len(param_list[0]))
        if self._get_query() != query:
            self._statement = self.cursor.prep(query)
        return self.cursor.executemany(self._statement, param_list)

    def convert_query(self, query, num_params):
        return query % tuple("?" * num_params)
    
    def __getattr__(self, attr):
        if attr in self.__dict__:
            return self.__dict__[attr]
        else:
            return getattr(self.cursor, attr)

class DatabaseWrapper(BaseDatabaseWrapper):
    features = DatabaseFeatures()
    ops = DatabaseOperations()
    operators = {
        'exact': '= %s',
        'iexact': '= UPPER(%s)',
        'contains': "LIKE %s ESCAPE'\\'",
        'icontains': 'CONTAINING %s', #case is ignored
        'gt': '> %s',
        'gte': '>= %s',
        'lt': '< %s',
        'lte': '<= %s',
        'startswith': 'STARTING WITH %s', #looks to be faster then LIKE
        'endswith': "LIKE %s ESCAPE'\\'",
        'istartswith': 'STARTING WITH UPPER(%s)',
        'iendswith': "LIKE UPPER(%s) ESCAPE'\\'",
    }
    _current_cursor = None
    def _connect(self, settings):
        if settings.DATABASE_NAME == '':
            from django.core.exceptions import ImproperlyConfigured
            raise ImproperlyConfigured, "You need to specify DATABASE_NAME in your Django settings file."
        charset = 'UNICODE_FSS'
        if hasattr(settings, 'FIREBIRD_CHARSET'):
            if settings.FIREBIRD_CHARSET == 'UTF8':
                charset = 'UTF8'    
        kwargs = {'charset' : charset }
        if settings.DATABASE_HOST:
            kwargs['dsn'] = "%s:%s" % (settings.DATABASE_HOST, settings.DATABASE_NAME)
        else:
            kwargs['dsn'] = "localhost:%s" % settings.DATABASE_NAME
        if settings.DATABASE_USER:
            kwargs['user'] = settings.DATABASE_USER
        if settings.DATABASE_PASSWORD:
            kwargs['password'] = settings.DATABASE_PASSWORD
        self.connection = Database.connect(**kwargs)
        assert self.charset == charset
        try:
            self.connection.execute_immediate("""
                DECLARE EXTERNAL FUNCTION rand
                RETURNS DOUBLE PRECISION
                BY VALUE ENTRY_POINT 'IB_UDF_rand' MODULE_NAME 'ib_udf';
            """)
        except Database.ProgrammingError:
            pass #Already defined 
        
    def cursor(self, name=None):
        #Cursors can be named
        #http://kinterbasdb.sourceforge.net/dist_docs/usage.html#adv_named_cursors
        #and maybe useful for scrolling updates and deletes
        from django.conf import settings
        cursor = self._cursor(settings, name)
        if settings.DEBUG:
            return self.make_debug_cursor(cursor)
        return cursor
    
    def _cursor(self, settings, name=None):
        if self.connection is None:
            self._connect(settings)
        cursor = self.connection.cursor()
        if name:
            cursor.name = name
        cursor = FirebirdCursorWrapper(cursor)
        self._current_cursor = cursor
        return cursor
    
    #Returns query from prepared statement
    def _get_query(self):
        if self._current_cursor:
            return self._current_cursor.query
    query = property(_get_query)
    #Returns prepared statement itself
    def _get_statement(self):
        if self._current_cursor:
            return self._current_cursor.statement
    statement = property(_get_statement)
        
    
    def __getattr__(self, attr):
        if attr in self.__dict__:
            return self.__dict__[attr]
        else:
            return getattr(self.connection, attr)
    

