Ticket #9475: 9475.m2m_add_remove.r16133.diff

File 9475.m2m_add_remove.r16133.diff, 25.1 KB (added by Johannes Dollinger, 14 years ago)
  • tests/modeltests/m2m_add_and_remove/tests.py

     
     1from django.db import models
     2from django.test import TestCase
     3
     4from models import M, R, Through, ThroughDefault, ThroughAuto, ThroughUT, ThroughCanRemove, ThroughCanAdd
     5
     6
     7class M2mAddRemoveTests(TestCase):
     8    def assert_cannot_remove(self, name):
     9        m = M.objects.create()
     10        r = R.objects.create()
     11        manager = getattr(m, name)
     12        reverse_manager = getattr(r, "%s_m_set" % name)
     13        self.assertRaises(AttributeError, getattr, manager, 'remove')
     14        self.assertRaises(AttributeError, getattr, reverse_manager, 'remove')
     15
     16    def assert_cannot_add(self, name):
     17        reverse_name = "%s_m_set" % name
     18        m = M.objects.create()
     19        r = R.objects.create()
     20        manager = getattr(m, name)
     21        reverse_manager = getattr(r, reverse_name)
     22        self.assertRaises(AttributeError, getattr, manager, 'add')
     23        self.assertRaises(AttributeError, getattr, reverse_manager, 'add')
     24        self.assertRaises(AttributeError, manager.create)
     25        self.assertRaises(AttributeError, reverse_manager.create)
     26        self.assertRaises(AttributeError, setattr, m, name, [])
     27        self.assertRaises(AttributeError, setattr, r, reverse_name, [])
     28
     29    def assert_can_add(self, name):
     30        reverse_name = "%s_m_set" % name
     31        m = M.objects.create()
     32        r = R.objects.create()
     33        manager = getattr(m, name)
     34        reverse_manager = getattr(r, reverse_name)
     35
     36        manager.add(r)
     37        self.assertEqual(list(manager.all()), [r])
     38        self.assertEqual(list(reverse_manager.all()), [m])
     39        manager.add(r)
     40        self.assertEqual(list(manager.all()), [r])
     41        self.assertEqual(list(reverse_manager.all()), [m])
     42        manager.clear()
     43
     44        reverse_manager.add(m)
     45        self.assertEqual(list(manager.all()), [r])
     46        self.assertEqual(list(reverse_manager.all()), [m])
     47        reverse_manager.add(m)
     48        self.assertEqual(list(manager.all()), [r])
     49        self.assertEqual(list(reverse_manager.all()), [m])
     50        reverse_manager.clear()
     51
     52        r2 = manager.create()
     53        reverse_manager2 = getattr(r2, reverse_name)
     54        self.assertEqual(list(manager.all()), [r2])
     55        self.assertEqual(list(reverse_manager2.all()), [m])
     56        manager.clear()
     57
     58        m2 = reverse_manager.create()
     59        manager2 = getattr(m2, name)
     60        self.assertEqual(list(manager2.all()), [r])
     61        self.assertEqual(list(reverse_manager.all()), [m2])
     62        reverse_manager.clear()
     63
     64        setattr(m, name, [r])
     65        self.assertEqual(list(manager.all()), [r])
     66        manager.clear()
     67
     68        setattr(r, reverse_name, [m])
     69        self.assertEqual(list(reverse_manager.all()), [m])
     70        reverse_manager.clear()
     71
     72    def assert_can_remove(self, name, extra):
     73        through = M._meta.get_field(name).rel.through
     74        m = M.objects.create()
     75        r = R.objects.create()
     76
     77        def fill():
     78            for extra_kwargs in extra:
     79                kwargs = {'m': m, 'r': r}
     80                kwargs.update(extra_kwargs)
     81                through.objects.create(**kwargs)
     82
     83        manager = getattr(m, name)
     84        reverse_manager = getattr(r, "%s_m_set" % name)
     85
     86        fill()
     87        manager.remove(r)
     88        self.failIf(manager.exists())
     89        self.failIf(reverse_manager.exists())
     90
     91        fill()
     92        reverse_manager.remove(m)
     93        self.failIf(reverse_manager.exists())
     94        self.failIf(manager.exists())
     95
     96    def _test_managers(self, name, can_remove=False, can_add=False, extra=()):
     97        if can_add:
     98            self.assert_can_add(name)
     99        else:
     100            self.assert_cannot_add(name)
     101        if can_remove:
     102            self.assert_can_remove(name, extra)
     103        else:
     104            self.assert_cannot_remove(name)
     105
     106    def test_default(self):
     107        self._test_managers('default', can_add=True, can_remove=True, extra=[{}])
     108
     109    def test_default_cannot_remove(self):
     110        self._test_managers('default_cannot_remove', can_add=True, can_remove=False)
     111
     112    def test_default_cannot_add(self):
     113        self._test_managers('default_cannot_add', can_add=False, can_remove=True, extra=[{}])
     114
     115    def test_through_default(self):
     116        self._test_managers('through_default', can_add=False, can_remove=False)
     117
     118    def test_through_auto(self):
     119        self._test_managers('through_auto', can_add=True, can_remove=False)
     120
     121    def test_through_ut(self):
     122        self._test_managers('through_ut', can_add=False, can_remove=True, extra=[{'extra': 'foo'}])
     123
     124    def test_through_can_add(self):
     125        self._test_managers('through_can_add', can_add=True, can_remove=False)
     126
     127    def test_through_can_remove(self):
     128        self._test_managers('through_can_remove', can_add=False, can_remove=True, extra=[{'extra': 'foo'}, {'extra': 'bar'}])
     129       
     130 No newline at end of file
  • tests/modeltests/m2m_add_and_remove/models.py

     
     1from django.db import models
     2from django.test import TestCase
     3
     4
     5class M(models.Model):
     6    default = models.ManyToManyField("R", related_name="default_m_set")
     7    default_cannot_remove = models.ManyToManyField("R", can_remove=False, related_name="default_cannot_remove_m_set")
     8    default_cannot_add = models.ManyToManyField("R", can_add=False, related_name="default_cannot_add_m_set")
     9    through_default = models.ManyToManyField("R", through="ThroughDefault", related_name="through_default_m_set")
     10    through_auto = models.ManyToManyField("R", through="ThroughAuto", related_name="through_auto_m_set")
     11    through_ut = models.ManyToManyField("R", through="ThroughUT", related_name="through_ut_m_set")
     12    through_can_add = models.ManyToManyField("R", can_add=True, through="ThroughCanAdd", related_name="through_can_add_m_set")
     13    through_can_remove = models.ManyToManyField("R", can_remove=True, through="ThroughCanRemove", related_name="through_can_remove_m_set")
     14
     15
     16class R(models.Model):
     17    name = models.CharField(max_length=30)
     18
     19
     20class Through(models.Model):
     21    m = models.ForeignKey(M, related_name="%(class)s_set")
     22    r = models.ForeignKey(R, related_name="%(class)s_set")
     23
     24    class Meta:
     25        abstract = True
     26
     27
     28class ThroughDefault(Through):
     29    extra = models.CharField(max_length=10)
     30
     31
     32class ThroughAuto(Through):
     33    ctime = models.DateTimeField(auto_now_add=True)
     34    mtime = models.DateTimeField(auto_now=True)
     35    default = models.IntegerField(default=42)
     36    null = models.DateTimeField(null=True)
     37
     38
     39class ThroughUT(Through):
     40    extra = models.CharField(max_length=10)
     41
     42    class Meta:
     43        unique_together = ('m', 'r')
     44
     45
     46class ThroughCanRemove(Through):
     47    extra = models.CharField(max_length=10)
     48
     49
     50class ThroughCanAdd(Through):
     51    extra = models.CharField(max_length=10)
     52
     53    def save(self, **kwargs):
     54        self.extra = "foo"
     55        return super(ThroughCanAdd, self).save(**kwargs)
     56       
     57 No newline at end of file
  • tests/modeltests/m2m_through/tests.py

     
    6767
    6868
    6969    def test_forward_descriptors(self):
    70         # Due to complications with adding via an intermediary model,
    71         # the add method is not provided.
    72         self.assertRaises(AttributeError, lambda: self.rock.members.add(self.bob))
    73         # Create is also disabled as it suffers from the same problems as add.
    74         self.assertRaises(AttributeError, lambda: self.rock.members.create(name='Anne'))
    75         # Remove has similar complications, and is not provided either.
     70        # Remove doesn't work, because Membership `person` and `group` are not unique_together.
    7671        self.assertRaises(AttributeError, lambda: self.rock.members.remove(self.jim))
    7772
    7873        m1 = Membership.objects.create(person=self.jim, group=self.rock)
     
    9388            []
    9489        )
    9590
    96         # Assignment should not work with models specifying a through model for many of
    97         # the same reasons as adding.
    98         self.assertRaises(AttributeError, setattr, self.rock, "members", backup)
    9991        # Let's re-save those instances that we've cleared.
    10092        m1.save()
    10193        m2.save()
     
    109101        )
    110102
    111103    def test_reverse_descriptors(self):
    112         # Due to complications with adding via an intermediary model,
    113         # the add method is not provided.
    114         self.assertRaises(AttributeError, lambda: self.bob.group_set.add(self.rock))
    115         # Create is also disabled as it suffers from the same problems as add.
    116         self.assertRaises(AttributeError, lambda: self.bob.group_set.create(name="funk"))
    117         # Remove has similar complications, and is not provided either.
     104        # Remove doesn't work, because Membership `person` and `group` are not unique_together.
    118105        self.assertRaises(AttributeError, lambda: self.jim.group_set.remove(self.rock))
    119106
    120107        m1 = Membership.objects.create(person=self.jim, group=self.rock)
     
    133120            self.jim.group_set.all(),
    134121            []
    135122        )
    136         # Assignment should not work with models specifying a through model for many of
    137         # the same reasons as adding.
    138         self.assertRaises(AttributeError, setattr, self.jim, "group_set", backup)
    139         # Let's re-save those instances that we've cleared.
    140123
     124        # Let's re-save those instances that we've cleared.
    141125        m1.save()
    142126        m2.save()
    143127        # Verifying that those instances were re-saved successfully.
  • tests/regressiontests/m2m_through_regress/tests.py

     
    3939            ]
    4040        )
    4141
    42         self.assertRaises(AttributeError, setattr, bob, "group_set", [])
    43         self.assertRaises(AttributeError, setattr, roll, "members", [])
    44 
    45         self.assertRaises(AttributeError, rock.members.create, name="Anne")
    46         self.assertRaises(AttributeError, bob.group_set.create, name="Funk")
    47 
    4842        UserMembership.objects.create(user=frank, group=rock)
    4943        UserMembership.objects.create(user=frank, group=roll)
    5044        UserMembership.objects.create(user=jane, group=rock)
  • django/db/models/fields/related.py

     
    472472
    473473        return manager
    474474
    475 def create_many_related_manager(superclass, rel=False):
     475def create_many_related_manager(superclass, field):
    476476    """Creates a manager that subclasses 'superclass' (which is a Manager)
    477477    and adds behavior for many-to-many related objects."""
    478     through = rel.through
     478    through = field.rel.through
     479    can_add = field.can_add()
     480    can_remove = field.can_remove()
     481
    479482    class ManyRelatedManager(superclass):
    480483        def __init__(self, model=None, core_filters=None, instance=None, symmetrical=None,
    481484                join_table=None, source_field_name=None, target_field_name=None,
     
    497500            db = self._db or router.db_for_read(self.instance.__class__, instance=self.instance)
    498501            return superclass.get_query_set(self).using(db)._next_is_sticky().filter(**(self.core_filters))
    499502
    500         # If the ManyToMany relation has an intermediary model,
    501         # the add and remove methods do not exist.
    502         if rel.through._meta.auto_created:
     503        if can_add:
    503504            def add(self, *objs):
    504505                self._add_items(self.source_field_name, self.target_field_name, *objs)
    505506
     
    508509                    self._add_items(self.target_field_name, self.source_field_name, *objs)
    509510            add.alters_data = True
    510511
     512        if can_remove:
    511513            def remove(self, *objs):
    512514                self._remove_items(self.source_field_name, self.target_field_name, *objs)
    513515
     
    527529        def create(self, **kwargs):
    528530            # This check needs to be done here, since we can't later remove this
    529531            # from the method lookup table, as we do with add and remove.
    530             if not rel.through._meta.auto_created:
     532            if not can_add:
    531533                opts = through._meta
    532                 raise AttributeError("Cannot use create() on a ManyToManyField which specifies an intermediary model. Use %s.%s's Manager instead." % (opts.app_label, opts.object_name))
     534                raise AttributeError("Cannot use create() on this ManyToManyField. Use %s.%s's Manager instead." % (opts.app_label, opts.object_name))
    533535            db = router.db_for_write(self.instance.__class__, instance=self.instance)
    534536            new_obj = super(ManyRelatedManager, self.db_manager(db)).create(**kwargs)
    535537            self.add(new_obj)
     
    537539        create.alters_data = True
    538540
    539541        def get_or_create(self, **kwargs):
     542            # This check needs to be done here, since we can't later remove this
     543            # from the method lookup table, as we do with add and remove.
     544            if not can_add:
     545                opts = through._meta
     546                raise AttributeError, "Cannot use get_or_create() on this ManyToManyField. Use %s.%s's Manager instead." % (opts.app_label, opts.object_name)
    540547            db = router.db_for_write(self.instance.__class__, instance=self.instance)
    541548            obj, created = \
    542549                super(ManyRelatedManager, self.db_manager(db)).get_or_create(**kwargs)
     
    578585                if self.reverse or source_field_name == self.source_field_name:
    579586                    # Don't send the signal when we are inserting the
    580587                    # duplicate data row for symmetrical reverse entries.
    581                     signals.m2m_changed.send(sender=rel.through, action='pre_add',
     588                    signals.m2m_changed.send(sender=through, action='pre_add',
    582589                        instance=self.instance, reverse=self.reverse,
    583590                        model=self.model, pk_set=new_ids, using=db)
    584591                # Add the ones that aren't there already
     
    590597                if self.reverse or source_field_name == self.source_field_name:
    591598                    # Don't send the signal when we are inserting the
    592599                    # duplicate data row for symmetrical reverse entries.
    593                     signals.m2m_changed.send(sender=rel.through, action='post_add',
     600                    signals.m2m_changed.send(sender=through, action='post_add',
    594601                        instance=self.instance, reverse=self.reverse,
    595602                        model=self.model, pk_set=new_ids, using=db)
    596603
     
    614621                if self.reverse or source_field_name == self.source_field_name:
    615622                    # Don't send the signal when we are deleting the
    616623                    # duplicate data row for symmetrical reverse entries.
    617                     signals.m2m_changed.send(sender=rel.through, action="pre_remove",
     624                    signals.m2m_changed.send(sender=through, action="pre_remove",
    618625                        instance=self.instance, reverse=self.reverse,
    619626                        model=self.model, pk_set=old_ids, using=db)
    620627                # Remove the specified objects from the join table
     
    625632                if self.reverse or source_field_name == self.source_field_name:
    626633                    # Don't send the signal when we are deleting the
    627634                    # duplicate data row for symmetrical reverse entries.
    628                     signals.m2m_changed.send(sender=rel.through, action="post_remove",
     635                    signals.m2m_changed.send(sender=through, action="post_remove",
    629636                        instance=self.instance, reverse=self.reverse,
    630637                        model=self.model, pk_set=old_ids, using=db)
    631638
     
    635642            if self.reverse or source_field_name == self.source_field_name:
    636643                # Don't send the signal when we are clearing the
    637644                # duplicate data rows for symmetrical reverse entries.
    638                 signals.m2m_changed.send(sender=rel.through, action="pre_clear",
     645                signals.m2m_changed.send(sender=through, action="pre_clear",
    639646                    instance=self.instance, reverse=self.reverse,
    640647                    model=self.model, pk_set=None, using=db)
    641648            self.through._default_manager.using(db).filter(**{
     
    644651            if self.reverse or source_field_name == self.source_field_name:
    645652                # Don't send the signal when we are clearing the
    646653                # duplicate data rows for symmetrical reverse entries.
    647                 signals.m2m_changed.send(sender=rel.through, action="post_clear",
     654                signals.m2m_changed.send(sender=through, action="post_clear",
    648655                    instance=self.instance, reverse=self.reverse,
    649656                    model=self.model, pk_set=None, using=db)
    650657
     
    668675        # model's default manager.
    669676        rel_model = self.related.model
    670677        superclass = rel_model._default_manager.__class__
    671         RelatedManager = create_many_related_manager(superclass, self.related.field.rel)
     678        RelatedManager = create_many_related_manager(superclass, self.related.field)
    672679
    673680        manager = RelatedManager(
    674681            model=rel_model,
     
    686693        if instance is None:
    687694            raise AttributeError("Manager must be accessed via instance")
    688695
    689         if not self.related.field.rel.through._meta.auto_created:
     696        if not self.related.field.can_add():
    690697            opts = self.related.field.rel.through._meta
    691             raise AttributeError("Cannot set values on a ManyToManyField which specifies an intermediary model. Use %s.%s's Manager instead." % (opts.app_label, opts.object_name))
     698            raise AttributeError("Cannot set values on this ManyToManyField. Use %s.%s's Manager instead." % (opts.app_label, opts.object_name))
    692699
    693700        manager = self.__get__(instance)
    694701        manager.clear()
     
    720727        # model's default manager.
    721728        rel_model=self.field.rel.to
    722729        superclass = rel_model._default_manager.__class__
    723         RelatedManager = create_many_related_manager(superclass, self.field.rel)
     730        RelatedManager = create_many_related_manager(superclass, self.field)
    724731
    725732        manager = RelatedManager(
    726733            model=rel_model,
     
    738745        if instance is None:
    739746            raise AttributeError("Manager must be accessed via instance")
    740747
    741         if not self.field.rel.through._meta.auto_created:
     748        if not self.field.can_add():
    742749            opts = self.field.rel.through._meta
    743             raise AttributeError("Cannot set values on a ManyToManyField which specifies an intermediary model.  Use %s.%s's Manager instead." % (opts.app_label, opts.object_name))
     750            raise AttributeError("Cannot set values a this ManyToManyField.  Use %s.%s's Manager instead." % (opts.app_label, opts.object_name))
    744751
    745752        manager = self.__get__(instance)
    746753        manager.clear()
     
    10191026        if kwargs['rel'].through is not None:
    10201027            assert self.db_table is None, "Cannot specify a db_table if an intermediary model is used."
    10211028
     1029        self._can_add = kwargs.pop('can_add', None)
     1030        self._can_remove = kwargs.pop('can_remove', None)
     1031
    10221032        Field.__init__(self, **kwargs)
    10231033
    10241034        msg = _('Hold down "Control", or "Command" on a Mac, to select more than one.')
     
    10371047            return util.truncate_name('%s_%s' % (opts.db_table, self.name),
    10381048                                      connection.ops.max_name_length())
    10391049
    1040     def _get_m2m_attr(self, related, attr):
    1041         "Function that can be curried to provide the source accessor or DB column name for the m2m table"
    1042         cache_attr = '_m2m_%s_cache' % attr
     1050    def _get_intermediary_fields(self, related):
     1051        cache_attr = '_m2m_intermediary_fields_cache'
    10431052        if hasattr(self, cache_attr):
    10441053            return getattr(self, cache_attr)
     1054
     1055        candidates = []
     1056        related_candidates = []
     1057        auto_add = True
    10451058        for f in self.rel.through._meta.fields:
    1046             if hasattr(f,'rel') and f.rel and f.rel.to == related.model:
    1047                 setattr(self, cache_attr, getattr(f, attr))
    1048                 return getattr(self, cache_attr)
     1059            if hasattr(f,'rel') and f.rel:
     1060                if f.rel.to == related.model:
     1061                    candidates.append(f)
     1062                    continue
     1063                elif f.rel.to == related.parent_model:
     1064                    related_candidates.append(f)
     1065                    continue
     1066            if isinstance(f, AutoField) or f.null or f.has_default():
     1067                continue
     1068            elif getattr(f, 'auto_now_add', False) or getattr(f, 'auto_now', False):
     1069                continue
     1070            else:
     1071                auto_add = False
     1072        if related.model == related.parent_model:
     1073            # m2m to self
     1074            assert len(candidates) == 2, "There are too many ForeignKeys to %s" % related.model
     1075            field, related_field = candidates
     1076        else:
     1077            # TODO: intelligently pick a candidate if there is more than one (See ticket #8618).
     1078            # For now, model validation will reject models with more than suitable FK anyway.
     1079            assert len(candidates) == 1, "There are no ForeignKeys to %s" % related.model
     1080            assert len(related_candidates) == 1, "There are no ForeignKeys to %s" % related.parent_model
     1081            field, related_field = candidates[0], related_candidates[0]
    10491082
     1083        if self._can_add is None:
     1084            self._can_add = auto_add
     1085
     1086        if self._can_remove is None:
     1087            self._can_remove = False
     1088            unique_together = [frozenset(ut) for ut in self.rel.through._meta.unique_together]
     1089            names = frozenset([field.name, related_field.name])
     1090            for ut in unique_together:
     1091                if names <= ut:
     1092                    self._can_remove = True
     1093                    break
     1094
     1095        setattr(self, cache_attr, (field, related_field))
     1096        return (field, related_field)
     1097
     1098    def _get_can_add(self, related):
     1099        if self._can_add is None:
     1100            self._get_intermediary_fields(related)
     1101        return self._can_add
     1102
     1103    def _get_can_remove(self, related):
     1104        if self._can_remove is None:
     1105            self._get_intermediary_fields(related)
     1106        return self._can_remove
     1107
     1108    def _get_m2m_attr(self, related, attr):
     1109        "Function that can be curried to provide a source field attribute"
     1110        cache_attr = '_m2m_%s_cache' % attr
     1111        if not hasattr(self, cache_attr):
     1112            field, _ = self._get_intermediary_fields(related)
     1113            setattr(self, cache_attr, getattr(field, attr))
     1114        return getattr(self, cache_attr)
     1115
    10501116    def _get_m2m_reverse_attr(self, related, attr):
    1051         "Function that can be curried to provide the related accessor or DB column name for the m2m table"
     1117        "Function that can be curried to provide a related field attribute"
    10521118        cache_attr = '_m2m_reverse_%s_cache' % attr
    1053         if hasattr(self, cache_attr):
    1054             return getattr(self, cache_attr)
    1055         found = False
    1056         for f in self.rel.through._meta.fields:
    1057             if hasattr(f,'rel') and f.rel and f.rel.to == related.parent_model:
    1058                 if related.model == related.parent_model:
    1059                     # If this is an m2m-intermediate to self,
    1060                     # the first foreign key you find will be
    1061                     # the source column. Keep searching for
    1062                     # the second foreign key.
    1063                     if found:
    1064                         setattr(self, cache_attr, getattr(f, attr))
    1065                         break
    1066                     else:
    1067                         found = True
    1068                 else:
    1069                     setattr(self, cache_attr, getattr(f, attr))
    1070                     break
     1119        if not hasattr(self, cache_attr):
     1120            _, related_field = self._get_intermediary_fields(related)
     1121            setattr(self, cache_attr, getattr(related_field, attr))
    10711122        return getattr(self, cache_attr)
    10721123
    10731124    def value_to_string(self, obj):
     
    11341185        self.m2m_field_name = curry(self._get_m2m_attr, related, 'name')
    11351186        self.m2m_reverse_field_name = curry(self._get_m2m_reverse_attr, related, 'name')
    11361187
     1188        self.can_add = curry(self._get_can_add, related)
     1189        self.can_remove = curry(self._get_can_remove, related)
     1190
    11371191        get_m2m_rel = curry(self._get_m2m_attr, related, 'rel')
    11381192        self.m2m_target_field_name = lambda: get_m2m_rel().field_name
    11391193        get_m2m_reverse_rel = curry(self._get_m2m_reverse_attr, related, 'rel')
Back to Top