Ticket #6344: inspectdb_refactor.diff

File inspectdb_refactor.diff, 13.1 KB (added by Daniel Pope <dan@…>, 8 years ago)

Patch for inspectdb command

  • django/core/management/commands/inspectdb.py

     
    11from django.core.management.base import NoArgsCommand, CommandError
     2import re
     3import keyword
    24
     5class Database(object):
     6    """Maintains a hash of all tables in the database.
     7
     8      Used for correcting any duplicate model names and resolving foreign key relationships.
     9    """
     10    def __init__(self):
     11        self.models={}
     12
     13    def add_model(self, model):
     14        self.models[model.table_name]=model
     15
     16    def get_model(self, table_name):
     17        return self.models[table_name]
     18
     19    def __str__(self):
     20        s=''
     21        for m in self.models.values():
     22                s+=str(m)+'\n\n'
     23        return s       
     24
     25class TableModel(object):
     26    def __init__(self, database, table_name):
     27        self.table_name=table_name
     28        self.model_name=self.model_name_for_table(table_name)
     29        self.columns=[]
     30        self.relations=[]
     31
     32        database.add_model(self)
     33        self.database=database
     34
     35    def add_column(self, col_name, type, extra_params, comments):
     36        """Add a column."""
     37        name=self.field_name_for_column(col_name)
     38        self.columns.append({'name': name, 'col_name': col_name, 'type': type, 'extra_params': extra_params, 'comments': comments})
     39
     40    def add_relation(self, col_name, rel_to, null):
     41        """Adds a related column.
     42
     43        We don't resolve the table relation at this point.
     44        """
     45        name=self.field_name_for_column(col_name)
     46        self.relations.append({'name': name, 'col_name': col_name, 'rel_to': rel_to, 'null': null})
     47
     48    def group_fields(self):
     49        """Groups fields and performs other heuristic fixups.
     50
     51        """
     52
     53        keys=[f for f in self.columns if f['type'] == 'AutoField' or 'primary_key' in f['extra_params']]
     54        ids=[f for f in self.columns if f not in keys and f['name'].endswith('_id')]
     55
     56        if len(keys) == 0:
     57                # use heuristics to locate a candidate primary key from amongst the ids
     58                tests=[self.model_name.lower()+'id']
     59                if self.model_name.lower().endswith('s'):
     60                        tests.append(self.model_name.lower()[:-1]+'id')
     61
     62                for id in ids:
     63                        if id['name'].replace('_', '') in tests:
     64                                id['comments'].append('NOTE: selected as primary_key')
     65                                id['extra_params']['primary_key']=True
     66                                keys.append(id)
     67                                ids.remove(id)
     68                                break
     69
     70        #resolve relations
     71        rels=self.relations[:]
     72        for r in rels:
     73            if r['rel_to'] == self.table_name:
     74                related = "'self'"
     75            else:
     76                related = "'%s'"%self.database.get_model(r['rel_to']).model_name
     77            r['related']=related
     78
     79        #upgrade integer primary keys to autofields if none exists
     80        if len([k for k in keys if f['type'] == 'AutoField']) == 0:
     81            for k in keys:
     82                if k['type'] == 'IntegerField':
     83                    k['comments'].append('upgraded from IntegerField')
     84                    k['type']='AutoField'
     85                    break
     86
     87        # rename the AutoFields, if it exists, to 'id'
     88        for k in keys:
     89            if k['type'] == 'AutoField':
     90                k['name']='id'  #no comment should be necessary
     91                break
     92
     93        other=[f for f in self.columns if f not in keys and f not in ids]
     94
     95        return (keys, rels, ids, other)
     96
     97    def model_name_for_table(self, table):
     98        """Compute a Python-friendly Model class name for a given table.
     99       
     100        - converts names to CamelCase
     101        - removes non-alphanumberic symbols
     102        """
     103        model=re.sub(r'\b([a-z])', lambda x: x.group(1).upper(), table)
     104        model=re.sub(r'[^A-Za-z0-9]', '', model)
     105        return model
     106
     107    def field_name_for_column(self, col):
     108        """Compute a Python-friendly Field name for a given column.
     109       
     110        - converts CamelCase names into lower_case_with_underscores
     111        - removes hypens
     112        - converts any contiguous sequences of other non-alphanumberic symbols to _
     113        - if the field name conflicts with a Python keyword, append '_field'
     114        - if the field name starts with a digit, prepend 'f_'
     115        """
     116
     117        field=re.sub(r'[A-Z]+', lambda x: '_'+x.group(0).lower(), col) #convert from CamelCase
     118        field=field.replace('-', '') #remove hyphen
     119        field=re.sub(r'[^A-Za-z0-9]+', '_', field) #replace non-alphanumberics
     120
     121        field=re.sub('(^_)|(_$)', '', field) #strip leading underscore
     122
     123        field=re.sub(r'^([0-9])', r'f_\1', field) #fix field names starting with digits
     124
     125        if keyword.iskeyword(field):
     126                field+='_field' #avoid conflict with Python keywords
     127
     128        return field
     129
     130    def _field_as_str(self, f):
     131        params = ', '.join(['%s=%r'%p for p in f['extra_params'].items()])
     132
     133        if f['comments']:
     134            comments=' # '+'; '.join(f['comments'])
     135        else:
     136            comments=''
     137
     138        return '%s = models.%s(%s)%s'%(f['name'], f['type'], params, comments)
     139
     140    def _rel_as_str(self, f):
     141        if f['null']:
     142                return '%s = models.ForeignKey(%r, db_column=%r, null=True, blank=True)'%(f['name'], f['related'], f['col_name'])
     143        else:
     144                return '%s = models.ForeignKey(%r, db_column=%r)'%(f['name'], f['related'], f['col_name'])
     145
     146    def __str__(self):
     147        s='class %s(models.Model):\n' % self.model_name
     148
     149        keys, rels, ids, other=self.group_fields()
     150        if not keys:
     151            s+='    # Warning: this model needs a field with primary_key=True\n\n'
     152
     153        for f in keys:
     154            s+='    %s\n'%self._field_as_str(f)
     155        if keys:
     156            s+='\n'
     157
     158        for f in ids:
     159            s+='    %s\n'%self._field_as_str(f)
     160        if ids:
     161            s+='\n'
     162
     163        for r in rels:
     164            s+='    %s\n'%self._rel_as_str(f)
     165        if rels:
     166            s+='\n'
     167
     168        for f in other:
     169            s+='    %s\n'%self._field_as_str(f)
     170        if other:
     171            s+='\n'
     172
     173        s+='    class Meta:\n'
     174        s+='        db_table = %r\n' % self.table_name
     175
     176        return s
     177
     178
    3179class Command(NoArgsCommand):
    4180    help = "Introspects the database tables in the given database and outputs a Django model module."
    5181
     
    14190
    15191    def handle_inspection(self):
    16192        from django.db import connection, get_introspection_module
    17         import keyword
    18193
    19194        introspection_module = get_introspection_module()
    20195
    21         table2model = lambda table_name: table_name.title().replace('_', '')
     196        cursor = connection.cursor()
     197        database=Database()
    22198
    23         cursor = connection.cursor()
    24199        yield "# This is an auto-generated Django model module."
    25200        yield "# You'll have to do the following manually to clean this up:"
    26         yield "#     * Rearrange models' order"
    27201        yield "#     * Make sure each model has one field with primary_key=True"
    28         yield "# Feel free to rename the models, but don't rename db_table values or field names."
     202        yield "#     * Rename models if desired (by convention, model names are singular)"
     203        yield "# Note: do NOT rename db_table or db_column values."
    29204        yield "#"
    30205        yield "# Also note: You'll have to insert the output of 'django-admin.py sqlcustom [appname]'"
    31206        yield "# into your database."
     
    33208        yield 'from django.db import models'
    34209        yield ''
    35210        for table_name in introspection_module.get_table_list(cursor):
    36             yield 'class %s(models.Model):' % table2model(table_name)
     211            model=TableModel(database, table_name)
     212
    37213            try:
    38214                relations = introspection_module.get_relations(cursor, table_name)
    39215            except NotImplementedError:
     
    43219            except NotImplementedError:
    44220                indexes = {}
    45221            for i, row in enumerate(introspection_module.get_table_description(cursor, table_name)):
    46                 att_name = row[0].lower()
     222
     223                # This is from Python DB-API spec v2 http://www.python.org/dev/peps/pep-0249/
     224                column_name, type_code, display_size, internal_size, precision, scale, null_ok = row
     225
    47226                comment_notes = [] # Holds Field notes, to be displayed in a Python comment.
    48                 extra_params = {}  # Holds Field parameters such as 'db_column'.
    49227
    50                 if ' ' in att_name:
    51                     extra_params['db_column'] = att_name
    52                     att_name = att_name.replace(' ', '')
    53                     comment_notes.append('Field renamed to remove spaces.')
    54                 if keyword.iskeyword(att_name):
    55                     extra_params['db_column'] = att_name
    56                     att_name += '_field'
    57                     comment_notes.append('Field renamed because it was a Python reserved word.')
     228                extra_params = {'db_column': column_name}  # Holds Field parameters
     229                # Always specify the db column name - even if the field name matches the column,
     230                # we expect users to fix/refactor their models.
    58231
     232
    59233                if i in relations:
    60                     rel_to = relations[i][1] == table_name and "'self'" or table2model(relations[i][1])
    61                     field_type = 'ForeignKey(%s' % rel_to
    62                     if att_name.endswith('_id'):
    63                         att_name = att_name[:-3]
    64                     else:
    65                         extra_params['db_column'] = att_name
     234                    model.add_relation(column_name, relations[i][1], bool(null_ok))
     235
     236#                    rel_to = relations[i][1] == table_name and "'self'" or self.model_name_for_table(relations[i][1])
     237#                    field_type = 'ForeignKey(%s' % rel_to
     238#                    if att_name.endswith('_id'):
     239#                        att_name = att_name[:-3]
     240#                        if 'db_column' not in extra_params:
     241#                               extra_params['db_column'] = att_name
    66242                else:
    67243                    try:
    68                         field_type = introspection_module.DATA_TYPES_REVERSE[row[1]]
     244                        field_type = introspection_module.DATA_TYPES_REVERSE[type_code]
    69245                    except KeyError:
    70246                        field_type = 'TextField'
    71247                        comment_notes.append('This field type is a guess.')
     
    77253                        extra_params.update(new_params)
    78254
    79255                    # Add max_length for all CharFields.
    80                     if field_type == 'CharField' and row[3]:
    81                         extra_params['max_length'] = row[3]
     256                    if field_type == 'CharField' and internal_size:
     257                        extra_params['max_length'] = internal_size
    82258
    83259                    if field_type == 'DecimalField':
    84                         extra_params['max_digits'] = row[4]
    85                         extra_params['decimal_places'] = row[5]
     260                        extra_params['max_digits'] = precision
     261                        extra_params['decimal_places'] = scale
    86262
    87263                    # Add primary_key and unique, if necessary.
    88                     column_name = extra_params.get('db_column', att_name)
    89264                    if column_name in indexes:
    90265                        if indexes[column_name]['primary_key']:
    91266                            extra_params['primary_key'] = True
    92267                        elif indexes[column_name]['unique']:
    93268                            extra_params['unique'] = True
    94269
    95                     field_type += '('
     270#                    field_type += '('
    96271
    97272                # Don't output 'id = meta.AutoField(primary_key=True)', because
    98273                # that's assumed if it doesn't exist.
    99                 if att_name == 'id' and field_type == 'AutoField(' and extra_params == {'primary_key': True}:
    100                     continue
     274                #if att_name == 'id' and field_type == 'AutoField' and extra_params == {'primary_key': True}:
     275                #    continue
    101276
    102277                # Add 'null' and 'blank', if the 'null_ok' flag was present in the
    103278                # table description.
    104                 if row[6]: # If it's NULL...
     279                if null_ok: # If it's NULL...
    105280                    extra_params['blank'] = True
    106                     if not field_type in ('TextField(', 'CharField('):
     281                    if not field_type in ('TextField', 'CharField'):
    107282                        extra_params['null'] = True
    108283
    109                 field_desc = '%s = models.%s' % (att_name, field_type)
    110                 if extra_params:
    111                     if not field_desc.endswith('('):
    112                         field_desc += ', '
    113                     field_desc += ', '.join(['%s=%r' % (k, v) for k, v in extra_params.items()])
    114                 field_desc += ')'
    115                 if comment_notes:
    116                     field_desc += ' # ' + ' '.join(comment_notes)
    117                 yield '    %s' % field_desc
    118             yield '    class Meta:'
    119             yield '        db_table = %r' % table_name
    120             yield ''
     284                model.add_column(column_name, field_type, extra_params, comment_notes)
     285
     286#                field_desc = '%s = models.%s' % (att_name, field_type)
     287#                if extra_params:
     288#                    if not field_desc.endswith('('):
     289#                        field_desc += ', '
     290#                    field_desc += ', '.join(['%s=%r' % (k, v) for k, v in extra_params.items()])
     291#                field_desc += ')'
     292#                if comment_notes:
     293#                    field_desc += ' # ' + ' '.join(comment_notes)
     294#                yield '    %s' % field_desc
     295#
     296#            yield '    '
     297#            yield '    class Meta:'
     298#            yield '        db_table = %r' % table_name
     299#            yield ''
     300#            yield ''
     301        yield str(database)
Back to Top