Code

Ticket #9475: 9475.m2m_add_remove.r16133.diff

File 9475.m2m_add_remove.r16133.diff, 25.1 KB (added by emulbreh, 3 years ago)
Line 
1Index: tests/modeltests/m2m_add_and_remove/__init__.py
2===================================================================
3Index: tests/modeltests/m2m_add_and_remove/tests.py
4===================================================================
5--- tests/modeltests/m2m_add_and_remove/tests.py        (revision 0)
6+++ tests/modeltests/m2m_add_and_remove/tests.py        (revision 0)
7@@ -0,0 +1,129 @@
8+from django.db import models
9+from django.test import TestCase
10+
11+from models import M, R, Through, ThroughDefault, ThroughAuto, ThroughUT, ThroughCanRemove, ThroughCanAdd
12+
13+
14+class M2mAddRemoveTests(TestCase):
15+    def assert_cannot_remove(self, name):
16+        m = M.objects.create()
17+        r = R.objects.create()
18+        manager = getattr(m, name)
19+        reverse_manager = getattr(r, "%s_m_set" % name)
20+        self.assertRaises(AttributeError, getattr, manager, 'remove')
21+        self.assertRaises(AttributeError, getattr, reverse_manager, 'remove')
22+
23+    def assert_cannot_add(self, name):
24+        reverse_name = "%s_m_set" % name
25+        m = M.objects.create()
26+        r = R.objects.create()
27+        manager = getattr(m, name)
28+        reverse_manager = getattr(r, reverse_name)
29+        self.assertRaises(AttributeError, getattr, manager, 'add')
30+        self.assertRaises(AttributeError, getattr, reverse_manager, 'add')
31+        self.assertRaises(AttributeError, manager.create)
32+        self.assertRaises(AttributeError, reverse_manager.create)
33+        self.assertRaises(AttributeError, setattr, m, name, [])
34+        self.assertRaises(AttributeError, setattr, r, reverse_name, [])
35+
36+    def assert_can_add(self, name):
37+        reverse_name = "%s_m_set" % name
38+        m = M.objects.create()
39+        r = R.objects.create()
40+        manager = getattr(m, name)
41+        reverse_manager = getattr(r, reverse_name)
42+
43+        manager.add(r)
44+        self.assertEqual(list(manager.all()), [r])
45+        self.assertEqual(list(reverse_manager.all()), [m])
46+        manager.add(r)
47+        self.assertEqual(list(manager.all()), [r])
48+        self.assertEqual(list(reverse_manager.all()), [m])
49+        manager.clear()
50+
51+        reverse_manager.add(m)
52+        self.assertEqual(list(manager.all()), [r])
53+        self.assertEqual(list(reverse_manager.all()), [m])
54+        reverse_manager.add(m)
55+        self.assertEqual(list(manager.all()), [r])
56+        self.assertEqual(list(reverse_manager.all()), [m])
57+        reverse_manager.clear()
58+
59+        r2 = manager.create()
60+        reverse_manager2 = getattr(r2, reverse_name)
61+        self.assertEqual(list(manager.all()), [r2])
62+        self.assertEqual(list(reverse_manager2.all()), [m])
63+        manager.clear()
64+
65+        m2 = reverse_manager.create()
66+        manager2 = getattr(m2, name)
67+        self.assertEqual(list(manager2.all()), [r])
68+        self.assertEqual(list(reverse_manager.all()), [m2])
69+        reverse_manager.clear()
70+
71+        setattr(m, name, [r])
72+        self.assertEqual(list(manager.all()), [r])
73+        manager.clear()
74+
75+        setattr(r, reverse_name, [m])
76+        self.assertEqual(list(reverse_manager.all()), [m])
77+        reverse_manager.clear()
78+
79+    def assert_can_remove(self, name, extra):
80+        through = M._meta.get_field(name).rel.through
81+        m = M.objects.create()
82+        r = R.objects.create()
83+
84+        def fill():
85+            for extra_kwargs in extra:
86+                kwargs = {'m': m, 'r': r}
87+                kwargs.update(extra_kwargs)
88+                through.objects.create(**kwargs)
89+
90+        manager = getattr(m, name)
91+        reverse_manager = getattr(r, "%s_m_set" % name)
92+
93+        fill()
94+        manager.remove(r)
95+        self.failIf(manager.exists())
96+        self.failIf(reverse_manager.exists())
97+
98+        fill()
99+        reverse_manager.remove(m)
100+        self.failIf(reverse_manager.exists())
101+        self.failIf(manager.exists())
102+
103+    def _test_managers(self, name, can_remove=False, can_add=False, extra=()):
104+        if can_add:
105+            self.assert_can_add(name)
106+        else:
107+            self.assert_cannot_add(name)
108+        if can_remove:
109+            self.assert_can_remove(name, extra)
110+        else:
111+            self.assert_cannot_remove(name)
112+
113+    def test_default(self):
114+        self._test_managers('default', can_add=True, can_remove=True, extra=[{}])
115+
116+    def test_default_cannot_remove(self):
117+        self._test_managers('default_cannot_remove', can_add=True, can_remove=False)
118+
119+    def test_default_cannot_add(self):
120+        self._test_managers('default_cannot_add', can_add=False, can_remove=True, extra=[{}])
121+
122+    def test_through_default(self):
123+        self._test_managers('through_default', can_add=False, can_remove=False)
124+
125+    def test_through_auto(self):
126+        self._test_managers('through_auto', can_add=True, can_remove=False)
127+
128+    def test_through_ut(self):
129+        self._test_managers('through_ut', can_add=False, can_remove=True, extra=[{'extra': 'foo'}])
130+
131+    def test_through_can_add(self):
132+        self._test_managers('through_can_add', can_add=True, can_remove=False)
133+
134+    def test_through_can_remove(self):
135+        self._test_managers('through_can_remove', can_add=False, can_remove=True, extra=[{'extra': 'foo'}, {'extra': 'bar'}])
136+       
137\ No newline at end of file
138Index: tests/modeltests/m2m_add_and_remove/models.py
139===================================================================
140--- tests/modeltests/m2m_add_and_remove/models.py       (revision 0)
141+++ tests/modeltests/m2m_add_and_remove/models.py       (revision 0)
142@@ -0,0 +1,56 @@
143+from django.db import models
144+from django.test import TestCase
145+
146+
147+class M(models.Model):
148+    default = models.ManyToManyField("R", related_name="default_m_set")
149+    default_cannot_remove = models.ManyToManyField("R", can_remove=False, related_name="default_cannot_remove_m_set")
150+    default_cannot_add = models.ManyToManyField("R", can_add=False, related_name="default_cannot_add_m_set")
151+    through_default = models.ManyToManyField("R", through="ThroughDefault", related_name="through_default_m_set")
152+    through_auto = models.ManyToManyField("R", through="ThroughAuto", related_name="through_auto_m_set")
153+    through_ut = models.ManyToManyField("R", through="ThroughUT", related_name="through_ut_m_set")
154+    through_can_add = models.ManyToManyField("R", can_add=True, through="ThroughCanAdd", related_name="through_can_add_m_set")
155+    through_can_remove = models.ManyToManyField("R", can_remove=True, through="ThroughCanRemove", related_name="through_can_remove_m_set")
156+
157+
158+class R(models.Model):
159+    name = models.CharField(max_length=30)
160+
161+
162+class Through(models.Model):
163+    m = models.ForeignKey(M, related_name="%(class)s_set")
164+    r = models.ForeignKey(R, related_name="%(class)s_set")
165+
166+    class Meta:
167+        abstract = True
168+
169+
170+class ThroughDefault(Through):
171+    extra = models.CharField(max_length=10)
172+
173+
174+class ThroughAuto(Through):
175+    ctime = models.DateTimeField(auto_now_add=True)
176+    mtime = models.DateTimeField(auto_now=True)
177+    default = models.IntegerField(default=42)
178+    null = models.DateTimeField(null=True)
179+
180+
181+class ThroughUT(Through):
182+    extra = models.CharField(max_length=10)
183+
184+    class Meta:
185+        unique_together = ('m', 'r')
186+
187+
188+class ThroughCanRemove(Through):
189+    extra = models.CharField(max_length=10)
190+
191+
192+class ThroughCanAdd(Through):
193+    extra = models.CharField(max_length=10)
194+
195+    def save(self, **kwargs):
196+        self.extra = "foo"
197+        return super(ThroughCanAdd, self).save(**kwargs)
198+       
199\ No newline at end of file
200Index: tests/modeltests/m2m_through/tests.py
201===================================================================
202--- tests/modeltests/m2m_through/tests.py       (revision 16133)
203+++ tests/modeltests/m2m_through/tests.py       (working copy)
204@@ -67,12 +67,7 @@
205 
206 
207     def test_forward_descriptors(self):
208-        # Due to complications with adding via an intermediary model,
209-        # the add method is not provided.
210-        self.assertRaises(AttributeError, lambda: self.rock.members.add(self.bob))
211-        # Create is also disabled as it suffers from the same problems as add.
212-        self.assertRaises(AttributeError, lambda: self.rock.members.create(name='Anne'))
213-        # Remove has similar complications, and is not provided either.
214+        # Remove doesn't work, because Membership `person` and `group` are not unique_together.
215         self.assertRaises(AttributeError, lambda: self.rock.members.remove(self.jim))
216 
217         m1 = Membership.objects.create(person=self.jim, group=self.rock)
218@@ -93,9 +88,6 @@
219             []
220         )
221 
222-        # Assignment should not work with models specifying a through model for many of
223-        # the same reasons as adding.
224-        self.assertRaises(AttributeError, setattr, self.rock, "members", backup)
225         # Let's re-save those instances that we've cleared.
226         m1.save()
227         m2.save()
228@@ -109,12 +101,7 @@
229         )
230 
231     def test_reverse_descriptors(self):
232-        # Due to complications with adding via an intermediary model,
233-        # the add method is not provided.
234-        self.assertRaises(AttributeError, lambda: self.bob.group_set.add(self.rock))
235-        # Create is also disabled as it suffers from the same problems as add.
236-        self.assertRaises(AttributeError, lambda: self.bob.group_set.create(name="funk"))
237-        # Remove has similar complications, and is not provided either.
238+        # Remove doesn't work, because Membership `person` and `group` are not unique_together.
239         self.assertRaises(AttributeError, lambda: self.jim.group_set.remove(self.rock))
240 
241         m1 = Membership.objects.create(person=self.jim, group=self.rock)
242@@ -133,11 +120,8 @@
243             self.jim.group_set.all(),
244             []
245         )
246-        # Assignment should not work with models specifying a through model for many of
247-        # the same reasons as adding.
248-        self.assertRaises(AttributeError, setattr, self.jim, "group_set", backup)
249-        # Let's re-save those instances that we've cleared.
250 
251+        # Let's re-save those instances that we've cleared.
252         m1.save()
253         m2.save()
254         # Verifying that those instances were re-saved successfully.
255Index: tests/regressiontests/m2m_through_regress/tests.py
256===================================================================
257--- tests/regressiontests/m2m_through_regress/tests.py  (revision 16133)
258+++ tests/regressiontests/m2m_through_regress/tests.py  (working copy)
259@@ -39,12 +39,6 @@
260             ]
261         )
262 
263-        self.assertRaises(AttributeError, setattr, bob, "group_set", [])
264-        self.assertRaises(AttributeError, setattr, roll, "members", [])
265-
266-        self.assertRaises(AttributeError, rock.members.create, name="Anne")
267-        self.assertRaises(AttributeError, bob.group_set.create, name="Funk")
268-
269         UserMembership.objects.create(user=frank, group=rock)
270         UserMembership.objects.create(user=frank, group=roll)
271         UserMembership.objects.create(user=jane, group=rock)
272Index: django/db/models/fields/related.py
273===================================================================
274--- django/db/models/fields/related.py  (revision 16133)
275+++ django/db/models/fields/related.py  (working copy)
276@@ -472,10 +472,13 @@
277 
278         return manager
279 
280-def create_many_related_manager(superclass, rel=False):
281+def create_many_related_manager(superclass, field):
282     """Creates a manager that subclasses 'superclass' (which is a Manager)
283     and adds behavior for many-to-many related objects."""
284-    through = rel.through
285+    through = field.rel.through
286+    can_add = field.can_add()
287+    can_remove = field.can_remove()
288+
289     class ManyRelatedManager(superclass):
290         def __init__(self, model=None, core_filters=None, instance=None, symmetrical=None,
291                 join_table=None, source_field_name=None, target_field_name=None,
292@@ -497,9 +500,7 @@
293             db = self._db or router.db_for_read(self.instance.__class__, instance=self.instance)
294             return superclass.get_query_set(self).using(db)._next_is_sticky().filter(**(self.core_filters))
295 
296-        # If the ManyToMany relation has an intermediary model,
297-        # the add and remove methods do not exist.
298-        if rel.through._meta.auto_created:
299+        if can_add:
300             def add(self, *objs):
301                 self._add_items(self.source_field_name, self.target_field_name, *objs)
302 
303@@ -508,6 +509,7 @@
304                     self._add_items(self.target_field_name, self.source_field_name, *objs)
305             add.alters_data = True
306 
307+        if can_remove:
308             def remove(self, *objs):
309                 self._remove_items(self.source_field_name, self.target_field_name, *objs)
310 
311@@ -527,9 +529,9 @@
312         def create(self, **kwargs):
313             # This check needs to be done here, since we can't later remove this
314             # from the method lookup table, as we do with add and remove.
315-            if not rel.through._meta.auto_created:
316+            if not can_add:
317                 opts = through._meta
318-                raise AttributeError("Cannot use create() on a ManyToManyField which specifies an intermediary model. Use %s.%s's Manager instead." % (opts.app_label, opts.object_name))
319+                raise AttributeError("Cannot use create() on this ManyToManyField. Use %s.%s's Manager instead." % (opts.app_label, opts.object_name))
320             db = router.db_for_write(self.instance.__class__, instance=self.instance)
321             new_obj = super(ManyRelatedManager, self.db_manager(db)).create(**kwargs)
322             self.add(new_obj)
323@@ -537,6 +539,11 @@
324         create.alters_data = True
325 
326         def get_or_create(self, **kwargs):
327+            # This check needs to be done here, since we can't later remove this
328+            # from the method lookup table, as we do with add and remove.
329+            if not can_add:
330+                opts = through._meta
331+                raise AttributeError, "Cannot use get_or_create() on this ManyToManyField. Use %s.%s's Manager instead." % (opts.app_label, opts.object_name)
332             db = router.db_for_write(self.instance.__class__, instance=self.instance)
333             obj, created = \
334                 super(ManyRelatedManager, self.db_manager(db)).get_or_create(**kwargs)
335@@ -578,7 +585,7 @@
336                 if self.reverse or source_field_name == self.source_field_name:
337                     # Don't send the signal when we are inserting the
338                     # duplicate data row for symmetrical reverse entries.
339-                    signals.m2m_changed.send(sender=rel.through, action='pre_add',
340+                    signals.m2m_changed.send(sender=through, action='pre_add',
341                         instance=self.instance, reverse=self.reverse,
342                         model=self.model, pk_set=new_ids, using=db)
343                 # Add the ones that aren't there already
344@@ -590,7 +597,7 @@
345                 if self.reverse or source_field_name == self.source_field_name:
346                     # Don't send the signal when we are inserting the
347                     # duplicate data row for symmetrical reverse entries.
348-                    signals.m2m_changed.send(sender=rel.through, action='post_add',
349+                    signals.m2m_changed.send(sender=through, action='post_add',
350                         instance=self.instance, reverse=self.reverse,
351                         model=self.model, pk_set=new_ids, using=db)
352 
353@@ -614,7 +621,7 @@
354                 if self.reverse or source_field_name == self.source_field_name:
355                     # Don't send the signal when we are deleting the
356                     # duplicate data row for symmetrical reverse entries.
357-                    signals.m2m_changed.send(sender=rel.through, action="pre_remove",
358+                    signals.m2m_changed.send(sender=through, action="pre_remove",
359                         instance=self.instance, reverse=self.reverse,
360                         model=self.model, pk_set=old_ids, using=db)
361                 # Remove the specified objects from the join table
362@@ -625,7 +632,7 @@
363                 if self.reverse or source_field_name == self.source_field_name:
364                     # Don't send the signal when we are deleting the
365                     # duplicate data row for symmetrical reverse entries.
366-                    signals.m2m_changed.send(sender=rel.through, action="post_remove",
367+                    signals.m2m_changed.send(sender=through, action="post_remove",
368                         instance=self.instance, reverse=self.reverse,
369                         model=self.model, pk_set=old_ids, using=db)
370 
371@@ -635,7 +642,7 @@
372             if self.reverse or source_field_name == self.source_field_name:
373                 # Don't send the signal when we are clearing the
374                 # duplicate data rows for symmetrical reverse entries.
375-                signals.m2m_changed.send(sender=rel.through, action="pre_clear",
376+                signals.m2m_changed.send(sender=through, action="pre_clear",
377                     instance=self.instance, reverse=self.reverse,
378                     model=self.model, pk_set=None, using=db)
379             self.through._default_manager.using(db).filter(**{
380@@ -644,7 +651,7 @@
381             if self.reverse or source_field_name == self.source_field_name:
382                 # Don't send the signal when we are clearing the
383                 # duplicate data rows for symmetrical reverse entries.
384-                signals.m2m_changed.send(sender=rel.through, action="post_clear",
385+                signals.m2m_changed.send(sender=through, action="post_clear",
386                     instance=self.instance, reverse=self.reverse,
387                     model=self.model, pk_set=None, using=db)
388 
389@@ -668,7 +675,7 @@
390         # model's default manager.
391         rel_model = self.related.model
392         superclass = rel_model._default_manager.__class__
393-        RelatedManager = create_many_related_manager(superclass, self.related.field.rel)
394+        RelatedManager = create_many_related_manager(superclass, self.related.field)
395 
396         manager = RelatedManager(
397             model=rel_model,
398@@ -686,9 +693,9 @@
399         if instance is None:
400             raise AttributeError("Manager must be accessed via instance")
401 
402-        if not self.related.field.rel.through._meta.auto_created:
403+        if not self.related.field.can_add():
404             opts = self.related.field.rel.through._meta
405-            raise AttributeError("Cannot set values on a ManyToManyField which specifies an intermediary model. Use %s.%s's Manager instead." % (opts.app_label, opts.object_name))
406+            raise AttributeError("Cannot set values on this ManyToManyField. Use %s.%s's Manager instead." % (opts.app_label, opts.object_name))
407 
408         manager = self.__get__(instance)
409         manager.clear()
410@@ -720,7 +727,7 @@
411         # model's default manager.
412         rel_model=self.field.rel.to
413         superclass = rel_model._default_manager.__class__
414-        RelatedManager = create_many_related_manager(superclass, self.field.rel)
415+        RelatedManager = create_many_related_manager(superclass, self.field)
416 
417         manager = RelatedManager(
418             model=rel_model,
419@@ -738,9 +745,9 @@
420         if instance is None:
421             raise AttributeError("Manager must be accessed via instance")
422 
423-        if not self.field.rel.through._meta.auto_created:
424+        if not self.field.can_add():
425             opts = self.field.rel.through._meta
426-            raise AttributeError("Cannot set values on a ManyToManyField which specifies an intermediary model.  Use %s.%s's Manager instead." % (opts.app_label, opts.object_name))
427+            raise AttributeError("Cannot set values a this ManyToManyField.  Use %s.%s's Manager instead." % (opts.app_label, opts.object_name))
428 
429         manager = self.__get__(instance)
430         manager.clear()
431@@ -1019,6 +1026,9 @@
432         if kwargs['rel'].through is not None:
433             assert self.db_table is None, "Cannot specify a db_table if an intermediary model is used."
434 
435+        self._can_add = kwargs.pop('can_add', None)
436+        self._can_remove = kwargs.pop('can_remove', None)
437+
438         Field.__init__(self, **kwargs)
439 
440         msg = _('Hold down "Control", or "Command" on a Mac, to select more than one.')
441@@ -1037,37 +1047,78 @@
442             return util.truncate_name('%s_%s' % (opts.db_table, self.name),
443                                       connection.ops.max_name_length())
444 
445-    def _get_m2m_attr(self, related, attr):
446-        "Function that can be curried to provide the source accessor or DB column name for the m2m table"
447-        cache_attr = '_m2m_%s_cache' % attr
448+    def _get_intermediary_fields(self, related):
449+        cache_attr = '_m2m_intermediary_fields_cache'
450         if hasattr(self, cache_attr):
451             return getattr(self, cache_attr)
452+
453+        candidates = []
454+        related_candidates = []
455+        auto_add = True
456         for f in self.rel.through._meta.fields:
457-            if hasattr(f,'rel') and f.rel and f.rel.to == related.model:
458-                setattr(self, cache_attr, getattr(f, attr))
459-                return getattr(self, cache_attr)
460+            if hasattr(f,'rel') and f.rel:
461+                if f.rel.to == related.model:
462+                    candidates.append(f)
463+                    continue
464+                elif f.rel.to == related.parent_model:
465+                    related_candidates.append(f)
466+                    continue
467+            if isinstance(f, AutoField) or f.null or f.has_default():
468+                continue
469+            elif getattr(f, 'auto_now_add', False) or getattr(f, 'auto_now', False):
470+                continue
471+            else:
472+                auto_add = False
473+        if related.model == related.parent_model:
474+            # m2m to self
475+            assert len(candidates) == 2, "There are too many ForeignKeys to %s" % related.model
476+            field, related_field = candidates
477+        else:
478+            # TODO: intelligently pick a candidate if there is more than one (See ticket #8618).
479+            # For now, model validation will reject models with more than suitable FK anyway.
480+            assert len(candidates) == 1, "There are no ForeignKeys to %s" % related.model
481+            assert len(related_candidates) == 1, "There are no ForeignKeys to %s" % related.parent_model
482+            field, related_field = candidates[0], related_candidates[0]
483 
484+        if self._can_add is None:
485+            self._can_add = auto_add
486+
487+        if self._can_remove is None:
488+            self._can_remove = False
489+            unique_together = [frozenset(ut) for ut in self.rel.through._meta.unique_together]
490+            names = frozenset([field.name, related_field.name])
491+            for ut in unique_together:
492+                if names <= ut:
493+                    self._can_remove = True
494+                    break
495+
496+        setattr(self, cache_attr, (field, related_field))
497+        return (field, related_field)
498+
499+    def _get_can_add(self, related):
500+        if self._can_add is None:
501+            self._get_intermediary_fields(related)
502+        return self._can_add
503+
504+    def _get_can_remove(self, related):
505+        if self._can_remove is None:
506+            self._get_intermediary_fields(related)
507+        return self._can_remove
508+
509+    def _get_m2m_attr(self, related, attr):
510+        "Function that can be curried to provide a source field attribute"
511+        cache_attr = '_m2m_%s_cache' % attr
512+        if not hasattr(self, cache_attr):
513+            field, _ = self._get_intermediary_fields(related)
514+            setattr(self, cache_attr, getattr(field, attr))
515+        return getattr(self, cache_attr)
516+
517     def _get_m2m_reverse_attr(self, related, attr):
518-        "Function that can be curried to provide the related accessor or DB column name for the m2m table"
519+        "Function that can be curried to provide a related field attribute"
520         cache_attr = '_m2m_reverse_%s_cache' % attr
521-        if hasattr(self, cache_attr):
522-            return getattr(self, cache_attr)
523-        found = False
524-        for f in self.rel.through._meta.fields:
525-            if hasattr(f,'rel') and f.rel and f.rel.to == related.parent_model:
526-                if related.model == related.parent_model:
527-                    # If this is an m2m-intermediate to self,
528-                    # the first foreign key you find will be
529-                    # the source column. Keep searching for
530-                    # the second foreign key.
531-                    if found:
532-                        setattr(self, cache_attr, getattr(f, attr))
533-                        break
534-                    else:
535-                        found = True
536-                else:
537-                    setattr(self, cache_attr, getattr(f, attr))
538-                    break
539+        if not hasattr(self, cache_attr):
540+            _, related_field = self._get_intermediary_fields(related)
541+            setattr(self, cache_attr, getattr(related_field, attr))
542         return getattr(self, cache_attr)
543 
544     def value_to_string(self, obj):
545@@ -1134,6 +1185,9 @@
546         self.m2m_field_name = curry(self._get_m2m_attr, related, 'name')
547         self.m2m_reverse_field_name = curry(self._get_m2m_reverse_attr, related, 'name')
548 
549+        self.can_add = curry(self._get_can_add, related)
550+        self.can_remove = curry(self._get_can_remove, related)
551+
552         get_m2m_rel = curry(self._get_m2m_attr, related, 'rel')
553         self.m2m_target_field_name = lambda: get_m2m_rel().field_name
554         get_m2m_reverse_rel = curry(self._get_m2m_reverse_attr, related, 'rel')