Ticket #18823: m2m_through_field.patch

File m2m_through_field.patch, 5.5 KB (added by anonymous, 12 years ago)
  • django/db/models/fields/related.py

    diff --git a/django/db/models/fields/related.py b/django/db/models/fields/related.py
    index 08cc0a7..c6b9f90 100644
    a b def create_many_related_manager(superclass, rel):  
    558558            self.reverse = reverse
    559559            self.through = through
    560560            self.prefetch_cache_name = prefetch_cache_name
    561             self._pk_val = self.instance.pk
     561            self._pk_val = self._get_fk_val(self.instance, source_field_name)
    562562            if self._pk_val is None:
    563563                raise ValueError("%r instance needs to have a primary key value before a many-to-many relationship can be used." % instance.__class__.__name__)
    564564
     565        def _get_fk_val(self, obj, field_name):
     566            # Get's the correct value for this relationship
     567            # takes to_field into account
     568            fk = self.through._meta.get_field(field_name)
     569            value = obj.pk
     570            if fk.rel.field_name and fk.rel.field_name != fk.rel.to._meta.pk.attname:
     571                attname = fk.rel.get_related_field().get_attname()
     572                value = fk.get_prep_lookup('exact', getattr(obj, attname))
     573            return value
     574
    565575        def get_query_set(self):
    566576            try:
    567577                return self.instance._prefetched_objects_cache[self.prefetch_cache_name]
    def create_many_related_manager(superclass, rel):  
    662672                        if not router.allow_relation(obj, self.instance):
    663673                           raise ValueError('Cannot add "%r": instance is on database "%s", value is on database "%s"' %
    664674                                               (obj, self.instance._state.db, obj._state.db))
    665                         new_ids.add(obj.pk)
     675                        new_ids.add(self._get_fk_val(obj, target_field_name))
    666676                    elif isinstance(obj, Model):
    667677                        raise TypeError("'%s' instance expected, got %r" % (self.model._meta.object_name, obj))
    668678                    else:
    def create_many_related_manager(superclass, rel):  
    689699                    })
    690700                    for obj_id in new_ids
    691701                ])
     702
    692703                if self.reverse or source_field_name == self.source_field_name:
    693704                    # Don't send the signal when we are inserting the
    694705                    # duplicate data row for symmetrical reverse entries.
    def create_many_related_manager(superclass, rel):  
    707718                old_ids = set()
    708719                for obj in objs:
    709720                    if isinstance(obj, self.model):
    710                         old_ids.add(obj.pk)
     721                        old_ids.add(self._get_fk_val(obj, target_field_name))
    711722                    else:
    712723                        old_ids.add(obj)
    713724                # Work out what DB we're operating on
  • tests/regressiontests/m2m_through_regress/models.py

    diff --git a/tests/regressiontests/m2m_through_regress/models.py b/tests/regressiontests/m2m_through_regress/models.py
    index 47c24ed..73a4645 100644
    a b class CarDriver(models.Model):  
    8080    car = models.ForeignKey('Car', to_field='make')
    8181    driver = models.ForeignKey('Driver', to_field='name')
    8282
     83    class Meta:
     84        auto_created = Car
     85
    8386    def __str__(self):
    8487        return "pk=%s car=%s driver=%s" % (str(self.pk), self.car, self.driver)
  • tests/regressiontests/m2m_through_regress/tests.py

    diff --git a/tests/regressiontests/m2m_through_regress/tests.py b/tests/regressiontests/m2m_through_regress/tests.py
    index 458c194..9808bd4 100644
    a b class ToFieldThroughTests(TestCase):  
    136136            ["<Car: Toyota>"]
    137137            )
    138138
     139    def test_to_field_clear_reverse(self):
     140        self.driver.car_set.clear()
     141        self.assertQuerysetEqual(
     142            self.driver.car_set.all(),[])
     143
     144    def test_to_field_clear(self):
     145        self.car.drivers.clear()
     146        self.assertQuerysetEqual(
     147            self.car.drivers.all(),[])
     148
     149class AutoToFieldThroughTests(TestCase):
     150    def setUp(self):
     151        self.car = Car.objects.create(make="Toyota")
     152        self.driver = Driver.objects.create(name="Ryan Briscoe")
     153
     154    def test_add(self):
     155        self.assertQuerysetEqual(
     156            self.car.drivers.all(),[])
     157        self.car.drivers.add(self.driver)
     158        self.assertQuerysetEqual(
     159            self.car.drivers.all(),
     160            ["<Driver: Ryan Briscoe>"]
     161            )
     162
     163    def test_add_reverse(self):
     164        self.assertQuerysetEqual(
     165            self.driver.car_set.all(),[])
     166        self.driver.car_set.add(self.car)
     167        self.assertQuerysetEqual(
     168            self.driver.car_set.all(),
     169            ["<Car: Toyota>"]
     170            )
     171
     172    def test_remove(self):
     173        CarDriver.objects.create(car=self.car, driver=self.driver)
     174        self.assertQuerysetEqual(
     175            self.car.drivers.all(),
     176            ["<Driver: Ryan Briscoe>"]
     177            )
     178        self.car.drivers.remove(self.driver)
     179        self.assertQuerysetEqual(
     180            self.car.drivers.all(),[])
     181
     182
     183    def test_remove_reverse(self):
     184        CarDriver.objects.create(car=self.car, driver=self.driver)
     185        self.assertQuerysetEqual(
     186            self.driver.car_set.all(),
     187            ["<Car: Toyota>"]
     188            )
     189        self.driver.car_set.remove(self.car)
     190        self.assertQuerysetEqual(
     191            self.driver.car_set.all(),[])
     192
     193
    139194class ThroughLoadDataTestCase(TestCase):
    140195    fixtures = ["m2m_through"]
    141196
Back to Top