
import re

foreign_key_re = re.compile(r"\sCONSTRAINT `[^`]*` FOREIGN KEY \(`([^`]*)`\) REFERENCES `([^`]*)` \(`([^`]*)`\)")

from cx_Oracle import NUMBER, ROWID, LONG_STRING, STRING, FIXED_CHAR, Timestamp, LOB, BLOB, CLOB, BINARY


def get_table_list(cursor):
    "Returns a list of table names in the current database."
    cursor.execute("SELECT TABLE_NAME FROM USER_TABLES")
    return [row[0] for row in cursor.fetchall()]

table_description_cache = {}

def get_table_description(cursor, table_name):
    "Returns a description of the table, with the DB-API cursor.description interface."
    cursor.execute("SELECT * FROM \"%s\" where rownum < 2" % table_name)
    return cursor.description

_name_to_index_cache = {}

def _name_to_index(cursor, table_name):
    """
    Returns a dictionary of {field_name: field_index} for the given table.
    Indexes are 0-based.
    """
    if not _name_to_index_cache.get(table_name):
        _name_to_index_cache[table_name] = dict([(d[0], i) for i, d in enumerate(get_table_description(cursor, table_name))])
    return _name_to_index_cache[table_name]


def columnum(cursor, table_name, column_name):
    res = _name_to_index(cursor,table_name)[column_name]
    return res

def get_relations(cursor, table_name):
    """
    Returns a dictionary of {field_index: (field_index_other_table, other_table)}
    representing all relationships to the given table. Indexes are 0-based.

    """

    cursor.execute("select col.column_name, con.constraint_type from all_cons_columns col, all_constraints con where col.constraint_name = con.constraint_name and con.constraint_type = 'P' and col.table_name = '%s'" % table_name )
    rows = cursor.fetchall()
    try:
        primary_key = rows[0][0];
    except:
        primary_key = None


    res = {}
    query = """
        select col.column_name, col.table_name, col.constraint_name, 
               c2.table_name, c2.column_name
          from all_cons_columns col, all_constraints con, all_cons_columns c2
          where con.constraint_type = 'R' 
            and con.r_constraint_name = c2.constraint_name
            and con.constraint_name = col.constraint_name
            and not (col.position = c2.position and 
                     col.table_name = c2.table_name)
            and col.table_name = '%s'
       """
    cursor.execute(query % table_name)
    relations = {}
    rows = cursor.fetchall()
    for row in rows:
        if row[0] != primary_key:
	    try:
		relations[columnum(cursor, table_name, row[0])] = (
		    columnum(cursor,row[3],row[4]),  
		    row[3]
                    ) 
            except:
                pass
    return relations

def get_indexes(cursor, table_name):
    """
    Returns a dictionary of fieldname -> infodict for the given table,
    where each infodict is in the format:
        {'primary_key': boolean representing whether it's the primary key,
         'unique': boolean representing whether it's a unique index}
    """
    cursor.execute("select col.column_name, con.constraint_type from all_cons_columns col, all_constraints con where col.constraint_name = con.constraint_name and con.constraint_type in ('P','U') and col.table_name = '%s'" % table_name )
    rows = cursor.fetchall()
    res = {}
    for r in rows:
        res[r[0]] = {'primary_key': 0, 'unique': 0}
   
    for r in rows:
        if r[1] == 'P':
            res[r[0]]['primary_key'] = 1
        if r[1] == 'U':
            res[r[0]]['unique'] = 1

    return res

# Maps type codes to Django Field types.
DATA_TYPES_REVERSE = {
    NUMBER:      'IntegerField',
    ROWID:  	 'IntegerField',
    LONG_STRING: 'TextField',
    STRING: 	 'TextField',
    FIXED_CHAR:  'CharField',
    Timestamp: 	 'DateTimeField',
    LOB:   'TextField',
    BLOB:  'TextField',
    CLOB:  'TextField',
    BINARY:'TextField',
}
