Ticket #6095: 6095-alpha-02.diff

File 6095-alpha-02.diff, 13.5 KB (added by floguy, 16 years ago)

Updated patch to handle custom related_name and custom db_column on manually created foreign keys.

  • django/db/models/fields/related.py

     
    11from django.db import connection, transaction
    2 from django.db.models import signals, get_model
     2from django.db.models import signals, get_model, get_models
    33from django.db.models.fields import AutoField, Field, IntegerField, PositiveIntegerField, PositiveSmallIntegerField, get_ul_class
    44from django.db.models.related import RelatedObject
    55from django.utils.text import capfirst
    66from django.utils.translation import ugettext_lazy, string_concat, ungettext, ugettext as _
    7 from django.utils.functional import curry
     7from django.utils.functional import curry, memoize
    88from django.utils.encoding import smart_unicode
    99from django.core import validators
    1010from django import oldforms
     
    2323
    2424pending_lookups = {}
    2525
     26memoized_fk_field_reversals = {}
     27
     28model_db_table_cache = {}
     29
    2630def add_lookup(rel_cls, field):
    2731    name = field.rel.to
    2832    module = rel_cls.__module__
     
    5458    except klass.DoesNotExist:
    5559        raise validators.ValidationError, _("Please enter a valid %s.") % f.verbose_name
    5660
     61def get_reverse_rel_field(from_model, to_model, related_name):
     62    key = (from_model._meta.app_label, from_model._meta.object_name,
     63            to_model._meta.app_label, to_model._meta.object_name,
     64            related_name)
     65    try:
     66        found_field = memoized_fk_field_reversals[key]
     67    except KeyError:
     68        found_field = None
     69        for field in from_model._meta.fields:
     70            if field.__class__ in (ForeignKey, OneToOneField, ManyToManyField):
     71                if field.rel.to == to_model:
     72                    found_field = field
     73                    break
     74        memoized_fk_field_reversals[key] = found_field
     75    return found_field
     76
     77def get_model_for_db_table(db_table):
     78    for model in get_models():
     79        if model._meta.db_table == db_table:
     80            return model
     81    return None
     82get_model_for_db_table = memoize(get_model_for_db_table, model_db_table_cache, 1)
     83
    5784#HACK
    5885class RelatedField(object):
    5986    def contribute_to_class(self, cls, name):
     
    267294    and adds behavior for many-to-many related objects."""
    268295    class ManyRelatedManager(superclass):
    269296        def __init__(self, model=None, core_filters=None, instance=None, symmetrical=None,
    270                 join_table=None, source_col_name=None, target_col_name=None):
     297                join_table=None, source_col_name=None, source_attname=None,
     298                target_attname=None, target_col_name=None):
    271299            super(ManyRelatedManager, self).__init__()
    272300            self.core_filters = core_filters
    273301            self.model = model
     
    276304            self.join_table = join_table
    277305            self.source_col_name = source_col_name
    278306            self.target_col_name = target_col_name
     307            self.source_attname = source_attname
     308            self.target_attname = target_attname
     309            self.intermediary_model = get_model_for_db_table(self.join_table.replace('"',''))
    279310            self._pk_val = self.instance._get_pk_val()
    280311            if self._pk_val is None:
    281312                raise ValueError("%r instance needs to have a primary key value before a many-to-many relationship can be used." % model)
     
    340371
    341372                # Add the ones that aren't there already
    342373                for obj_id in (new_ids - existing_ids):
    343                     cursor.execute("INSERT INTO %s (%s, %s) VALUES (%%s, %%s)" % \
     374                    if self.intermediary_model == None:
     375                        cursor.execute("INSERT INTO %s (%s, %s) VALUES (%%s, %%s)" % \
    344376                        (self.join_table, source_col_name, target_col_name),
    345377                        [self._pk_val, obj_id])
     378                    else:
     379                        new_obj = self.intermediary_model()
     380                        setattr(new_obj, self.source_attname, self._pk_val)
     381                        setattr(new_obj, self.target_attname, obj_id)
     382                        new_obj.save()
    346383                transaction.commit_unless_managed()
    347384
    348385        def _remove_items(self, source_col_name, target_col_name, *objs):
     
    398435        RelatedManager = create_many_related_manager(superclass)
    399436
    400437        qn = connection.ops.quote_name
     438        source_attname, target_attname = None, None
     439        rel_field = self.related.field
    401440        manager = RelatedManager(
    402441            model=rel_model,
    403442            core_filters={'%s__pk' % self.related.field.name: instance._get_pk_val()},
    404443            instance=instance,
    405444            symmetrical=False,
    406             join_table=qn(self.related.field.m2m_db_table()),
    407             source_col_name=qn(self.related.field.m2m_reverse_name()),
    408             target_col_name=qn(self.related.field.m2m_column_name())
     445            join_table=qn(rel_field.m2m_db_table()),
     446            source_col_name=qn(rel_field.m2m_reverse_name()),
     447            target_col_name=qn(rel_field.m2m_column_name()),
     448            source_attname=rel_field.m2m_reverse_attname(),
     449            target_attname=rel_field.m2m_attname()
    409450        )
    410451
    411452        return manager
     
    446487            symmetrical=(self.field.rel.symmetrical and instance.__class__ == rel_model),
    447488            join_table=qn(self.field.m2m_db_table()),
    448489            source_col_name=qn(self.field.m2m_column_name()),
    449             target_col_name=qn(self.field.m2m_reverse_name())
     490            target_col_name=qn(self.field.m2m_reverse_name()),
     491            source_attname=self.field.m2m_attname(),
     492            target_attname=self.field.m2m_reverse_attname()
    450493        )
    451494
    452495        return manager
     
    648691            filter_interface=kwargs.pop('filter_interface', None),
    649692            limit_choices_to=kwargs.pop('limit_choices_to', None),
    650693            raw_id_admin=kwargs.pop('raw_id_admin', False),
    651             symmetrical=kwargs.pop('symmetrical', True))
     694            symmetrical=kwargs.pop('symmetrical', True),
     695            through=kwargs.pop('through', None))
    652696        self.db_table = kwargs.pop('db_table', None)
    653697        if kwargs["rel"].raw_id_admin:
    654698            kwargs.setdefault("validator_list", []).append(self.isValidIDList)
     
    672716
    673717    def _get_m2m_db_table(self, opts):
    674718        "Function that can be curried to provide the m2m table name for this relation"
    675         if self.db_table:
     719        if self.rel.through != None:
     720            return get_model(opts.app_label, self.rel.through)._meta.db_table
     721        elif self.db_table:
    676722            return self.db_table
    677723        else:
    678724            return '%s_%s' % (opts.db_table, self.name)
    679725
     726    def _get_m2m_attname(self, related):
     727        try:
     728            through = get_model(related.opts.app_label, self.rel.through)
     729            field = get_reverse_rel_field(through, related.model, self.rel.related_name)
     730            attname, column = field.get_attname_column()
     731            return attname
     732        except:
     733            return None
     734
    680735    def _get_m2m_column_name(self, related):
    681736        "Function that can be curried to provide the source column name for the m2m table"
    682737        # If this is an m2m relation to self, avoid the inevitable name clash
    683         if related.model == related.parent_model:
     738        if self.rel.through != None:
     739            through = get_model(related.opts.app_label, self.rel.through)
     740            field = get_reverse_rel_field(through, related.model, self.rel.related_name)
     741            attname, column = field.get_attname_column()
     742            return column
     743        elif related.model == related.parent_model:
    684744            return 'from_' + related.model._meta.object_name.lower() + '_id'
    685745        else:
    686746            return related.model._meta.object_name.lower() + '_id'
    687747
     748    def _get_m2m_reverse_attname(self, related):
     749        try:
     750            through = get_model(related.opts.app_label, self.rel.through)
     751            field = get_reverse_rel_field(through, related.parent_model, self.rel.related_name)
     752            attname, column = field.get_attname_column()
     753            return attname
     754        except:
     755            return None
     756
    688757    def _get_m2m_reverse_name(self, related):
    689758        "Function that can be curried to provide the related column name for the m2m table"
    690759        # If this is an m2m relation to self, avoid the inevitable name clash
    691         if related.model == related.parent_model:
     760        if self.rel.through != None:
     761            through = get_model(related.opts.app_label, self.rel.through)
     762            field = get_reverse_rel_field(through, related.parent_model, self.rel.related_name)
     763            attname, column = field.get_attname_column()
     764            return column
     765        elif related.model == related.parent_model:
    692766            return 'to_' + related.parent_model._meta.object_name.lower() + '_id'
    693767        else:
    694768            return related.parent_model._meta.object_name.lower() + '_id'
     
    745819        # Set up the accessors for the column names on the m2m table
    746820        self.m2m_column_name = curry(self._get_m2m_column_name, related)
    747821        self.m2m_reverse_name = curry(self._get_m2m_reverse_name, related)
     822        self.m2m_attname = curry(self._get_m2m_attname, related)
     823        self.m2m_reverse_attname = curry(self._get_m2m_reverse_attname, related)
    748824
    749825    def set_attributes_from_rel(self):
    750826        pass
     
    809885
    810886class ManyToManyRel(object):
    811887    def __init__(self, to, num_in_admin=0, related_name=None,
    812         filter_interface=None, limit_choices_to=None, raw_id_admin=False, symmetrical=True):
     888        filter_interface=None, limit_choices_to=None, raw_id_admin=False, symmetrical=True,
     889        through = None):
    813890        self.to = to
    814891        self.num_in_admin = num_in_admin
    815892        self.related_name = related_name
     
    821898        self.raw_id_admin = raw_id_admin
    822899        self.symmetrical = symmetrical
    823900        self.multiple = True
     901        self.through = through
    824902
    825903        assert not (self.raw_id_admin and self.filter_interface), "ManyToManyRels may not use both raw_id_admin and filter_interface"
  • django/core/management/sql.py

     
    349349    qn = connection.ops.quote_name
    350350    inline_references = connection.features.inline_fk_references
    351351    for f in opts.many_to_many:
    352         if not isinstance(f.rel, generic.GenericRel):
     352        if not isinstance(f.rel, generic.GenericRel) and getattr(f.rel, 'through', None) == None:
    353353            tablespace = f.db_tablespace or opts.db_tablespace
    354354            if tablespace and connection.features.supports_tablespaces and connection.features.autoindexes_primary_keys:
    355355                tablespace_sql = ' ' + connection.ops.tablespace_sql(tablespace, inline=True)
  • tests/modeltests/m2m_manual/models.py

     
     1from django.db import models
     2from datetime import datetime
     3
     4# M2M described on one of the models
     5class Person(models.Model):
     6    name = models.CharField(max_length=128)
     7
     8    def __unicode__(self):
     9        return self.name
     10
     11class Group(models.Model):
     12    name = models.CharField(max_length=128)
     13    members = models.ManyToManyField(Person, through='Membership')
     14    custom_members = models.ManyToManyField(Person, through='CustomMembership', related_name="custom")
     15   
     16    def __unicode__(self):
     17        return self.name
     18
     19class Membership(models.Model):
     20    person = models.ForeignKey(Person)
     21    group = models.ForeignKey(Group)
     22    date_joined = models.DateTimeField(default=datetime.now)
     23   
     24    def __unicode__(self):
     25        return "%s is a member of %s" % (self.person.name, self.group.name)
     26
     27class CustomMembership(models.Model):
     28    person = models.ForeignKey(Person, db_column="custom_person_column", related_name="custom_person_related_name")
     29    group = models.ForeignKey(Group)
     30    date_joined = models.DateTimeField(default=datetime.now)
     31   
     32    def __unicode__(self):
     33        return "%s is a member of %s" % (self.person.name, self.group.name)
     34
     35__test__ = {'API_TESTS':"""
     36>>> bob = Person(name = 'Bob')
     37>>> bob.save()
     38>>> jim = Person(name = 'Jim')
     39>>> jim.save()
     40>>> jane = Person(name = 'Jane')
     41>>> jane.save()
     42>>> rock = Group(name = 'Rock')
     43>>> rock.save()
     44>>> roll = Group(name = 'Roll')
     45>>> roll.save()
     46
     47>>> rock.members.add(jim, jane)
     48>>> rock.members.all()
     49[<Person: Jim>, <Person: Jane>]
     50
     51>>> roll.members.add(bob, jim)
     52>>> roll.members.all()
     53[<Person: Bob>, <Person: Jim>]
     54
     55>>> jane.group_set.all()
     56[<Group: Rock>]
     57
     58>>> jane.group_set.add(roll)
     59>>> jane.group_set.all()
     60[<Group: Rock>, <Group: Roll>]
     61
     62>>> jim.group_set.all()
     63[<Group: Rock>, <Group: Roll>]
     64
     65>>> Membership.objects.filter(person = jane, group = rock)
     66[<Membership: Jane is a member of Rock>]
     67
     68>>> Membership.objects.filter(person = jim)
     69[<Membership: Jim is a member of Rock>, <Membership: Jim is a member of Roll>]
     70
     71>>> rock.custom_members.add(bob)
     72>>> rock.custom_members.all()
     73[<Person: Bob>]
     74
     75>>> jim.custom.add(rock)
     76>>> rock.custom_members.all()
     77[<Person: Bob>, <Person: Jim>]
     78
     79>>> jim.custom.all()
     80[<Group: Rock>]
     81
     82>>> jim.custom_person_related_name.all()
     83[<CustomMembership: Jim is a member of Rock>]
     84"""}
     85 No newline at end of file
Back to Top