"""
Classses for representing query functions and expressions. Not useful outside
the SQL domain.

Expressions do not need to inherit from E, but must implement the same methods
"""

class E(object):
    """
    Base class for expressions and functions.
    Not abstract, can be used to give an alias to expressions resulting from
    using Python operators, e.g. E(Sum('field1') + Sum('field2'), alias='a_name')
    """
    def __init__(self, expr, alias=None):
        self.expr = expr
        self._alias = alias

    def __add__(self, rhs):
        return Add(self, rhs)
    def __sub__(self, rhs):
        return Add(self, rhs)
    def __mul__(self, rhs):
        return Mul(self, rhs)
    def __div__(self, rhs):
        return Div(self, rhs)

    def relabel_aliases(self, table_map, column_map=None):
        """
        Relabel the column alias, if necessary.
        """
        self.expr.relabel_aliases(table_map, column_map)

    def as_sql(self, quote_func=None):
        """
        Returns the SQL string fragment for this object.

        The quote_func function is used to quote the column components. If
        None, it defaults to doing nothing.
        """
        return self.expr.as_sql(quote_func)

    def get_cols(self):
        """
        Returns a list of columns used in the expression.
        """
        return self.expr.get_cols()

    def output_alias(self):
        """
        Returns string to use as key in values dictionary
        """
        return self._alias

    def do_group(self):
        """
        Returns whether this expression should trigger auto-grouping.
        Only expressions containing Aggregates trigger auto-grouping.
        """
        return self.expr.do_group()


class F(E):
    """
    Class to use a column in expressions, e.g:
    F('price') < Avg('price')
    F objects cannot contain arbitrary expressions, but only a column name.
    """
    def relabel_aliases(self, table_map, column_map=None):
        c = self.expr
        if isinstance(c, (list, tuple)):
            self.expr = (table_map.get(c[0], c[0]), c[1])
        elif isinstance(c, str):
            if column_map:
                col = column_map.get(c, c)
            self.expr = (table_map.get('', ''), col)

    def get_cols(self):
        c = self.expr
        if isinstance(c, (list, tuple)):
            return [c[1]]
        elif isinstance(c, str):
            return [c]

    def output_alias(self):
        return self._alias or self.expr
    
    def as_sql(self, quote_func=None):
        if not quote_func:
            quote_func = lambda x: x
        expr = self.expr
        if isinstance(expr, (list, tuple)):
            expr = '%s.%s' %  tuple([quote_func(c) for c in expr])
        return '%s' % expr

    def do_group(self):
        return False


class UnaryOp(E):
    """
    Base class for representing a unary operator such NOT
    """
    def as_sql(self, quote_func=None):
        return '%s %s' % (self._op, self.expr.as_sql(quote_func))


class BinaryOp(E):
    """
    Base class for representing a binary operator such as +, -, * or /
    """
    _op = None
    def __init__(self, lhs, rhs, alias=None):
        self.lhs = lhs
        self.rhs = rhs
        self._alias = alias
    
    def relabel_aliases(self, table_map, column_map=None):
        self.lhs.relabel_aliases(table_map, column_map)
        self.rhs.relabel_aliases(table_map, column_map)
    
    def get_cols(self):
        return self.lhs.get_cols() + self.rhs.get_cols()
    
    def as_sql(self, quote_func=None):
        return '(%s %s %s)' % (self.lhs.as_sql(quote_func), self._op, 
            self.rhs.as_sql(quote_func))
    
    def do_group(self):
        return self.rhs.do_group() or self.lhs.do_group()


class Add(BinaryOp):
    _op = '+'

class Sub(BinaryOp):
    _op = '-'

class Mul(BinaryOp):
    _op = '*'

class Div(BinaryOp):
    _op = '/'

class Not(UnaryOp):
    _op = 'NOT'


class Function(E):
    """
    Base class for query functions.
    """
    _func = None
    
    def relabel_aliases(self, table_map, column_map=None):
        c = self.expr
        if isinstance(c, (list, tuple)):
            self.expr = (table_map.get(c[0], c[0]), c[1])
        elif isinstance(c, str):
            if column_map:
                col = column_map.get(c, c)
            self.expr = (table_map.get('', ''), col)
        else:
            self.expr.relabel_aliases(table_map)

    def get_cols(self):
        c = self.expr
        if isinstance(c, (list, tuple)):
            return [c[1]]
        elif isinstance(c, str):
            return [c]
        else:
            return c.get_cols()

    def output_alias(self):
        return self._alias or (self.expr + '_' + self._func.lower())
    
    def as_sql(self, quote_func=None):
        if not quote_func:
            quote_func = lambda x: x
        expr = self.expr
        if hasattr(expr, 'as_sql'):
            expr = expr.as_sql(quote_func)
        elif isinstance(expr, (list, tuple)):
            expr = '%s.%s' %  tuple([quote_func(c) for c in expr])
        return '%s(%s)' % (self._func, expr)


class Aggregate(Function):
    """
    Base class for query Aggregates.
    An Aggregate in an expression will trigger auto-grouping.
    Aggregates cannot contain arbitrary expressions, but only a column name.
    """
    #def __init__(self, col, alias=None):
    #    if not isinstance(c, (str, list, tuple)):
    #        raise FieldError("Invalid field name in aggregate function: '%s'" % col)
    
    def do_group(self):
        return True


class Sum(Aggregate):
    """
    Perform a sum on the given column.
    """
    _func = 'SUM'
  
class Avg(Aggregate):
    """
    Perform an average on the given column.
    """
    _func = 'AVG'

class Min(Aggregate):
    """
    Select the minimum of the given column.
    """
    _func = 'MIN'

class Max(Aggregate):
    """
    Select the maximum the given column.
    """
    _func = 'MAX'

class Count(Aggregate):
    _func = 'COUNT'
    """
    Perform a count on the given column.
    """
    def __init__(self, expr='*', alias=None, distinct=False):
        self.distinct = distinct
        super(Count, self).__init__(expr, alias=alias)

    def get_cols(self):
        if self.expr == '*':
            return []
        return super(Count, self).get_cols()

    def relabel_aliases(self, table_map, column_map=None):
        if self.expr == '*':
            return
        return super(Count, self).relabel_aliases(table_map, column_map)
        
    def as_sql(self, quote_func=None):
        if not quote_func:
            quote_func = lambda x: x
        expr = self.expr
        if hasattr(expr, 'as_sql'):
            expr = expr.as_sql(quote_func)
        elif isinstance(expr, (list, tuple)):
            expr = '%s.%s' %  tuple([quote_func(c) for c in expr])
        if self.distinct:
            return 'COUNT(DISTINCT %s)' % expr
        else:
            return 'COUNT(%s)' % expr


