Code

Ticket #9475: 9475.m2m_add_remove.r12281.diff

File 9475.m2m_add_remove.r12281.diff, 24.4 KB (added by emulbreh, 4 years ago)
Line 
1Index: tests/modeltests/m2m_add_and_remove/__init__.py
2===================================================================
3--- tests/modeltests/m2m_add_and_remove/__init__.py     (revision 0)
4+++ tests/modeltests/m2m_add_and_remove/__init__.py     (revision 0)
5@@ -0,0 +1,2 @@
6+
7+
8Index: tests/modeltests/m2m_add_and_remove/models.py
9===================================================================
10--- tests/modeltests/m2m_add_and_remove/models.py       (revision 0)
11+++ tests/modeltests/m2m_add_and_remove/models.py       (revision 0)
12@@ -0,0 +1,176 @@
13+from django.db import models
14+from django.test import TestCase
15+
16+
17+class M(models.Model):
18+    default = models.ManyToManyField("R", related_name="default_m_set")
19+    default_cannot_remove = models.ManyToManyField("R", can_remove=False, related_name="default_cannot_remove_m_set")
20+    default_cannot_add = models.ManyToManyField("R", can_add=False, related_name="default_cannot_add_m_set")
21+    through_default = models.ManyToManyField("R", through="ThroughDefault", related_name="through_default_m_set")
22+    through_auto = models.ManyToManyField("R", through="ThroughAuto", related_name="through_auto_m_set")
23+    through_ut = models.ManyToManyField("R", through="ThroughUT", related_name="through_ut_m_set")
24+    through_can_add = models.ManyToManyField("R", can_add=True, through="ThroughCanAdd", related_name="through_can_add_m_set")
25+    through_can_remove = models.ManyToManyField("R", can_remove=True, through="ThroughCanRemove", related_name="through_can_remove_m_set")
26+
27+class R(models.Model):
28+    name = models.CharField(max_length=30)
29+
30+class Through(models.Model):
31+    m = models.ForeignKey(M, related_name="%(class)s_set")
32+    r = models.ForeignKey(R, related_name="%(class)s_set")
33+
34+    class Meta:
35+        abstract = True
36+       
37+class ThroughDefault(Through):
38+    extra = models.CharField(max_length=10)
39+
40+class ThroughAuto(Through):
41+    ctime = models.DateTimeField(auto_now_add=True)
42+    mtime = models.DateTimeField(auto_now=True)
43+    default = models.IntegerField(default=42)
44+    null = models.DateTimeField(null=True)
45+   
46+class ThroughUT(Through):
47+    extra = models.CharField(max_length=10)
48+
49+    class Meta:
50+        unique_together = ('m', 'r')
51+       
52+class ThroughCanRemove(Through):
53+    extra = models.CharField(max_length=10)
54+
55+class ThroughCanAdd(Through):
56+    extra = models.CharField(max_length=10)
57+   
58+    def save(self, **kwargs):
59+        self.extra = "foo"
60+        return super(ThroughCanAdd, self).save(**kwargs)
61+
62+class M2mAddRemoveTests(TestCase):
63+    def assert_cannot_remove(self, name):
64+        m = M.objects.create()
65+        r = R.objects.create()
66+        manager = getattr(m, name)
67+        reverse_manager = getattr(r, "%s_m_set" % name)
68+        self.assertRaises(AttributeError, getattr, manager, 'remove')
69+        self.assertRaises(AttributeError, getattr, reverse_manager, 'remove')
70+       
71+    def assert_cannot_add(self, name):
72+        reverse_name = "%s_m_set" % name
73+        m = M.objects.create()
74+        r = R.objects.create()
75+        manager = getattr(m, name)
76+        reverse_manager = getattr(r, reverse_name)
77+        self.assertRaises(AttributeError, getattr, manager, 'add')
78+        self.assertRaises(AttributeError, getattr, reverse_manager, 'add')
79+        self.assertRaises(AttributeError, manager.create)
80+        self.assertRaises(AttributeError, reverse_manager.create)
81+        def assign():
82+            setattr(m, name, [])
83+        self.assertRaises(AttributeError, assign)
84+        def assign_reverse():
85+            setattr(r, reverse_name, [])
86+        self.assertRaises(AttributeError, assign_reverse)
87+       
88+    def assert_can_add(self, name):
89+        reverse_name = "%s_m_set" % name
90+        m = M.objects.create()
91+        r = R.objects.create()
92+        manager = getattr(m, name)
93+        reverse_manager = getattr(r, reverse_name)
94+       
95+        manager.add(r)
96+        self.failUnlessEqual(list(manager.all()), [r])
97+        self.failUnlessEqual(list(reverse_manager.all()), [m])
98+        manager.add(r)
99+        self.failUnlessEqual(list(manager.all()), [r])
100+        self.failUnlessEqual(list(reverse_manager.all()), [m])
101+        manager.clear()
102+       
103+        reverse_manager.add(m)
104+        self.failUnlessEqual(list(manager.all()), [r])
105+        self.failUnlessEqual(list(reverse_manager.all()), [m])
106+        reverse_manager.add(m)
107+        self.failUnlessEqual(list(manager.all()), [r])
108+        self.failUnlessEqual(list(reverse_manager.all()), [m])
109+        reverse_manager.clear()
110+       
111+        r2 = manager.create()
112+        reverse_manager2 = getattr(r2, reverse_name)
113+        self.failUnlessEqual(list(manager.all()), [r2])
114+        self.failUnlessEqual(list(reverse_manager2.all()), [m])
115+        manager.clear()
116+       
117+        m2 = reverse_manager.create()
118+        manager2 = getattr(m2, name)
119+        self.failUnlessEqual(list(manager2.all()), [r])
120+        self.failUnlessEqual(list(reverse_manager.all()), [m2])
121+        reverse_manager.clear()
122+       
123+        setattr(m, name, [r])
124+        self.failUnlessEqual(list(manager.all()), [r])
125+        manager.clear()
126+       
127+        setattr(r, reverse_name, [m])
128+        self.failUnlessEqual(list(reverse_manager.all()), [m])
129+        reverse_manager.clear()
130+       
131+    def assert_can_remove(self, name, extra):
132+        through = M._meta.get_field(name).rel.through
133+        m = M.objects.create()
134+        r = R.objects.create()
135+       
136+        def fill():           
137+            for extra_kwargs in extra:
138+                kwargs = {'m': m, 'r': r}
139+                kwargs.update(extra_kwargs)
140+                through.objects.create(**kwargs)
141+
142+        manager = getattr(m, name)
143+        reverse_manager = getattr(r, "%s_m_set" % name)
144+
145+        fill()
146+        manager.remove(r)
147+        self.failIf(manager.exists())
148+        self.failIf(reverse_manager.exists())
149+       
150+        fill()
151+        reverse_manager.remove(m)
152+        self.failIf(reverse_manager.exists())
153+        self.failIf(manager.exists())
154+       
155+    def _test_managers(self, name, can_remove=False, can_add=False, extra=()):
156+        if can_add:
157+            self.assert_can_add(name)
158+        else:
159+            self.assert_cannot_add(name)
160+        if can_remove:
161+            self.assert_can_remove(name, extra)
162+        else:
163+            self.assert_cannot_remove(name)
164+
165+    def test_default(self):
166+        self._test_managers('default', can_add=True, can_remove=True, extra=[{}])
167+       
168+    def test_default_cannot_remove(self):
169+        self._test_managers('default_cannot_remove', can_add=True, can_remove=False)
170+       
171+    def test_default_cannot_add(self):
172+        self._test_managers('default_cannot_add', can_add=False, can_remove=True, extra=[{}])
173+
174+    def test_through_default(self):
175+        self._test_managers('through_default', can_add=False, can_remove=False)
176+       
177+    def test_through_auto(self):
178+        self._test_managers('through_auto', can_add=True, can_remove=False)
179+       
180+    def test_through_ut(self):
181+        self._test_managers('through_ut', can_add=False, can_remove=True, extra=[{'extra': 'foo'}])
182+       
183+    def test_through_can_add(self):
184+        self._test_managers('through_can_add', can_add=True, can_remove=False)
185+       
186+    def test_through_can_remove(self):
187+        self._test_managers('through_can_remove', can_add=False, can_remove=True, extra=[{'extra': 'foo'}, {'extra': 'bar'}])
188+       
189\ No newline at end of file
190Index: tests/modeltests/m2m_through/models.py
191===================================================================
192--- tests/modeltests/m2m_through/models.py      (revision 12281)
193+++ tests/modeltests/m2m_through/models.py      (working copy)
194@@ -122,46 +122,13 @@
195 
196 ### Forward Descriptors Tests ###
197 
198-# Due to complications with adding via an intermediary model,
199-# the add method is not provided.
200->>> rock.members.add(bob)
201-Traceback (most recent call last):
202-...
203-AttributeError: 'ManyRelatedManager' object has no attribute 'add'
204-
205-# Create is also disabled as it suffers from the same problems as add.
206->>> rock.members.create(name='Anne')
207-Traceback (most recent call last):
208-...
209-AttributeError: Cannot use create() on a ManyToManyField which specifies an intermediary model. Use m2m_through.Membership's Manager instead.
210-
211-# Remove has similar complications, and is not provided either.
212->>> rock.members.remove(jim)
213-Traceback (most recent call last):
214-...
215-AttributeError: 'ManyRelatedManager' object has no attribute 'remove'
216-
217-# Here we back up the list of all members of Rock.
218->>> backup = list(rock.members.all())
219-
220-# ...and we verify that it has worked.
221->>> backup
222-[<Person: Jane>, <Person: Jim>]
223-
224-# The clear function should still work.
225+# The clear function should work.
226 >>> rock.members.clear()
227 
228 # Now there will be no members of Rock.
229 >>> rock.members.all()
230 []
231 
232-# Assignment should not work with models specifying a through model for many of
233-# the same reasons as adding.
234->>> rock.members = backup
235-Traceback (most recent call last):
236-...
237-AttributeError: Cannot set values on a ManyToManyField which specifies an intermediary model.  Use m2m_through.Membership's Manager instead.
238-
239 # Let's re-save those instances that we've cleared.
240 >>> m1.save()
241 >>> m2.save()
242@@ -173,44 +140,13 @@
243 
244 ### Reverse Descriptors Tests ###
245 
246-# Due to complications with adding via an intermediary model,
247-# the add method is not provided.
248->>> bob.group_set.add(rock)
249-Traceback (most recent call last):
250-...
251-AttributeError: 'ManyRelatedManager' object has no attribute 'add'
252-
253-# Create is also disabled as it suffers from the same problems as add.
254->>> bob.group_set.create(name='Funk')
255-Traceback (most recent call last):
256-...
257-AttributeError: Cannot use create() on a ManyToManyField which specifies an intermediary model. Use m2m_through.Membership's Manager instead.
258-
259-# Remove has similar complications, and is not provided either.
260->>> jim.group_set.remove(rock)
261-Traceback (most recent call last):
262-...
263-AttributeError: 'ManyRelatedManager' object has no attribute 'remove'
264-
265-# Here we back up the list of all of Jim's groups.
266->>> backup = list(jim.group_set.all())
267->>> backup
268-[<Group: Rock>, <Group: Roll>]
269-
270-# The clear function should still work.
271+# The clear function should work.
272 >>> jim.group_set.clear()
273 
274 # Now Jim will be in no groups.
275 >>> jim.group_set.all()
276 []
277 
278-# Assignment should not work with models specifying a through model for many of
279-# the same reasons as adding.
280->>> jim.group_set = backup
281-Traceback (most recent call last):
282-...
283-AttributeError: Cannot set values on a ManyToManyField which specifies an intermediary model.  Use m2m_through.Membership's Manager instead.
284-
285 # Let's re-save those instances that we've cleared.
286 >>> m1.save()
287 >>> m4.save()
288Index: tests/regressiontests/m2m_through_regress/models.py
289===================================================================
290--- tests/regressiontests/m2m_through_regress/models.py (revision 12281)
291+++ tests/regressiontests/m2m_through_regress/models.py (working copy)
292@@ -80,27 +80,6 @@
293 >>> roll.members.all()
294 [<Person: Bob>]
295 
296-# Error messages use the model name, not repr of the class name
297->>> bob.group_set = []
298-Traceback (most recent call last):
299-...
300-AttributeError: Cannot set values on a ManyToManyField which specifies an intermediary model.  Use m2m_through_regress.Membership's Manager instead.
301-
302->>> roll.members = []
303-Traceback (most recent call last):
304-...
305-AttributeError: Cannot set values on a ManyToManyField which specifies an intermediary model.  Use m2m_through_regress.Membership's Manager instead.
306-
307->>> rock.members.create(name='Anne')
308-Traceback (most recent call last):
309-...
310-AttributeError: Cannot use create() on a ManyToManyField which specifies an intermediary model.  Use m2m_through_regress.Membership's Manager instead.
311-
312->>> bob.group_set.create(name='Funk')
313-Traceback (most recent call last):
314-...
315-AttributeError: Cannot use create() on a ManyToManyField which specifies an intermediary model.  Use m2m_through_regress.Membership's Manager instead.
316-
317 # Now test that the intermediate with a relationship outside
318 # the current app (i.e., UserMembership) workds
319 >>> UserMembership.objects.create(user=frank, group=rock)
320Index: django/db/models/fields/related.py
321===================================================================
322--- django/db/models/fields/related.py  (revision 12281)
323+++ django/db/models/fields/related.py  (working copy)
324@@ -433,10 +433,13 @@
325 
326         return manager
327 
328-def create_many_related_manager(superclass, rel=False):
329+def create_many_related_manager(superclass, field):
330     """Creates a manager that subclasses 'superclass' (which is a Manager)
331     and adds behavior for many-to-many related objects."""
332-    through = rel.through
333+    through = field.rel.through
334+    can_add = field.can_add()
335+    can_remove = field.can_remove()
336+   
337     class ManyRelatedManager(superclass):
338         def __init__(self, model=None, core_filters=None, instance=None, symmetrical=None,
339                 join_table=None, source_field_name=None, target_field_name=None,
340@@ -458,9 +461,7 @@
341             db = router.db_for_read(self.instance.__class__, instance=self.instance)
342             return superclass.get_query_set(self).using(db)._next_is_sticky().filter(**(self.core_filters))
343 
344-        # If the ManyToMany relation has an intermediary model,
345-        # the add and remove methods do not exist.
346-        if rel.through._meta.auto_created:
347+        if can_add:
348             def add(self, *objs):
349                 self._add_items(self.source_field_name, self.target_field_name, *objs)
350 
351@@ -469,6 +470,7 @@
352                     self._add_items(self.target_field_name, self.source_field_name, *objs)
353             add.alters_data = True
354 
355+        if can_remove:
356             def remove(self, *objs):
357                 self._remove_items(self.source_field_name, self.target_field_name, *objs)
358 
359@@ -488,9 +490,9 @@
360         def create(self, **kwargs):
361             # This check needs to be done here, since we can't later remove this
362             # from the method lookup table, as we do with add and remove.
363-            if not rel.through._meta.auto_created:
364+            if not can_add:
365                 opts = through._meta
366-                raise AttributeError("Cannot use create() on a ManyToManyField which specifies an intermediary model. Use %s.%s's Manager instead." % (opts.app_label, opts.object_name))
367+                raise AttributeError("Cannot use create() on this ManyToManyField. Use %s.%s's Manager instead." % (opts.app_label, opts.object_name))
368             db = router.db_for_write(self.instance.__class__, instance=self.instance)
369             new_obj = super(ManyRelatedManager, self).using(db).create(**kwargs)
370             self.add(new_obj)
371@@ -498,6 +500,11 @@
372         create.alters_data = True
373 
374         def get_or_create(self, **kwargs):
375+            # This check needs to be done here, since we can't later remove this
376+            # from the method lookup table, as we do with add and remove.
377+            if not can_add:
378+                opts = through._meta
379+                raise AttributeError, "Cannot use get_or_create() on this ManyToManyField. Use %s.%s's Manager instead." % (opts.app_label, opts.object_name)
380             db = router.db_for_write(self.instance.__class__, instance=self.instance)
381             obj, created = \
382                 super(ManyRelatedManager, self).using(db).get_or_create(**kwargs)
383@@ -544,7 +551,7 @@
384                 if self.reverse or source_field_name == self.source_field_name:
385                     # Don't send the signal when we are inserting the
386                     # duplicate data row for symmetrical reverse entries.
387-                    signals.m2m_changed.send(sender=rel.through, action='add',
388+                    signals.m2m_changed.send(sender=through, action='add',
389                         instance=self.instance, reverse=self.reverse,
390                         model=self.model, pk_set=new_ids)
391 
392@@ -571,7 +578,7 @@
393                 if self.reverse or source_field_name == self.source_field_name:
394                     # Don't send the signal when we are deleting the
395                     # duplicate data row for symmetrical reverse entries.
396-                    signals.m2m_changed.send(sender=rel.through, action="remove",
397+                    signals.m2m_changed.send(sender=through, action="remove",
398                         instance=self.instance, reverse=self.reverse,
399                         model=self.model, pk_set=old_ids)
400 
401@@ -580,7 +587,7 @@
402             if self.reverse or source_field_name == self.source_field_name:
403                 # Don't send the signal when we are clearing the
404                 # duplicate data rows for symmetrical reverse entries.
405-                signals.m2m_changed.send(sender=rel.through, action="clear",
406+                signals.m2m_changed.send(sender=through, action="clear",
407                     instance=self.instance, reverse=self.reverse,
408                     model=self.model, pk_set=None)
409             db = router.db_for_write(self.through.__class__, instance=self.instance)
410@@ -608,7 +615,7 @@
411         # model's default manager.
412         rel_model = self.related.model
413         superclass = rel_model._default_manager.__class__
414-        RelatedManager = create_many_related_manager(superclass, self.related.field.rel)
415+        RelatedManager = create_many_related_manager(superclass, self.related.field)
416 
417         manager = RelatedManager(
418             model=rel_model,
419@@ -626,9 +633,9 @@
420         if instance is None:
421             raise AttributeError("Manager must be accessed via instance")
422 
423-        if not self.related.field.rel.through._meta.auto_created:
424+        if not self.related.field.can_add():
425             opts = self.related.field.rel.through._meta
426-            raise AttributeError("Cannot set values on a ManyToManyField which specifies an intermediary model. Use %s.%s's Manager instead." % (opts.app_label, opts.object_name))
427+            raise AttributeError("Cannot set values on this ManyToManyField. Use %s.%s's Manager instead." % (opts.app_label, opts.object_name))
428 
429         manager = self.__get__(instance)
430         manager.clear()
431@@ -660,7 +667,7 @@
432         # model's default manager.
433         rel_model=self.field.rel.to
434         superclass = rel_model._default_manager.__class__
435-        RelatedManager = create_many_related_manager(superclass, self.field.rel)
436+        RelatedManager = create_many_related_manager(superclass, self.field)
437 
438         manager = RelatedManager(
439             model=rel_model,
440@@ -678,9 +685,9 @@
441         if instance is None:
442             raise AttributeError("Manager must be accessed via instance")
443 
444-        if not self.field.rel.through._meta.auto_created:
445+        if not self.field.can_add():
446             opts = self.field.rel.through._meta
447-            raise AttributeError("Cannot set values on a ManyToManyField which specifies an intermediary model.  Use %s.%s's Manager instead." % (opts.app_label, opts.object_name))
448+            raise AttributeError("Cannot set values a this ManyToManyField.  Use %s.%s's Manager instead." % (opts.app_label, opts.object_name))
449 
450         manager = self.__get__(instance)
451         manager.clear()
452@@ -953,6 +960,9 @@
453         self.db_table = kwargs.pop('db_table', None)
454         if kwargs['rel'].through is not None:
455             assert self.db_table is None, "Cannot specify a db_table if an intermediary model is used."
456+           
457+        self._can_add = kwargs.pop('can_add', None)
458+        self._can_remove = kwargs.pop('can_remove', None)
459 
460         Field.__init__(self, **kwargs)
461 
462@@ -971,38 +981,78 @@
463         else:
464             return util.truncate_name('%s_%s' % (opts.db_table, self.name),
465                                       connection.ops.max_name_length())
466-
467-    def _get_m2m_attr(self, related, attr):
468-        "Function that can be curried to provide the source column name for the m2m table"
469-        cache_attr = '_m2m_%s_cache' % attr
470+                                     
471+    def _get_intermediary_fields(self, related):
472+        cache_attr = '_m2m_intermediary_fields_cache'
473         if hasattr(self, cache_attr):
474             return getattr(self, cache_attr)
475+           
476+        candidates = []
477+        related_candidates = []
478+        auto_add = True
479         for f in self.rel.through._meta.fields:
480-            if hasattr(f,'rel') and f.rel and f.rel.to == related.model:
481-                setattr(self, cache_attr, getattr(f, attr))
482-                return getattr(self, cache_attr)
483+            if hasattr(f,'rel') and f.rel:
484+                if f.rel.to == related.model:
485+                    candidates.append(f)
486+                    continue
487+                elif f.rel.to == related.parent_model:
488+                    related_candidates.append(f)
489+                    continue
490+            if isinstance(f, AutoField) or f.null or f.has_default():
491+                continue
492+            elif getattr(f, 'auto_now_add', False) or getattr(f, 'auto_now', False):
493+                continue
494+            else:
495+                auto_add = False
496+        if related.model == related.parent_model:
497+            # m2m to self
498+            assert len(candidates) == 2, "There are too many ForeignKeys to %s" % related.model
499+            field, related_field = candidates
500+        else:
501+            assert len(candidates) == 1, "There are no ForeignKeys to %s" % related.model
502+            assert len(related_candidates) == 1, "There are no ForeignKeys to %s" % related.parent_model
503+            # TODO: intelligently pick a candidate if there is more than one. For now, just use the first.
504+            field, related_field = candidates[0], related_candidates[0]
505 
506+        if self._can_add is None:
507+            self._can_add = auto_add
508+
509+        if self._can_remove is None:
510+            self._can_remove = False
511+            unique_together = [frozenset(ut) for ut in self.rel.through._meta.unique_together]               
512+            names = frozenset([field.name, related_field.name])
513+            for ut in unique_together:
514+                if names <= ut:
515+                    self._can_remove = True
516+                    break
517+       
518+        setattr(self, cache_attr, (field, related_field))
519+        return (field, related_field)
520+       
521+    def _get_can_add(self, related):
522+        if self._can_add is None:
523+            self._get_intermediary_fields(related)
524+        return self._can_add
525+       
526+    def _get_can_remove(self, related):
527+        if self._can_remove is None:
528+            self._get_intermediary_fields(related)
529+        return self._can_remove
530+
531+    def _get_m2m_attr(self, related, attr):
532+        "Function that can be curried to provide a source field attribute"
533+        cache_attr = '_m2m_%s_cache' % attr
534+        if not hasattr(self, cache_attr):
535+            field, _ = self._get_intermediary_fields(related)
536+            setattr(self, cache_attr, getattr(field, attr))
537+        return getattr(self, cache_attr)
538+
539     def _get_m2m_reverse_attr(self, related, attr):
540-        "Function that can be curried to provide the related column name for the m2m table"
541+        "Function that can be curried to provide a related field attribute"
542         cache_attr = '_m2m_reverse_%s_cache' % attr
543-        if hasattr(self, cache_attr):
544-            return getattr(self, cache_attr)
545-        found = False
546-        for f in self.rel.through._meta.fields:
547-            if hasattr(f,'rel') and f.rel and f.rel.to == related.parent_model:
548-                if related.model == related.parent_model:
549-                    # If this is an m2m-intermediate to self,
550-                    # the first foreign key you find will be
551-                    # the source column. Keep searching for
552-                    # the second foreign key.
553-                    if found:
554-                        setattr(self, cache_attr, getattr(f, attr))
555-                        break
556-                    else:
557-                        found = True
558-                else:
559-                    setattr(self, cache_attr, getattr(f, attr))
560-                    break
561+        if not hasattr(self, cache_attr):
562+            _, related_field = self._get_intermediary_fields(related)
563+            setattr(self, cache_attr, getattr(related_field, attr))           
564         return getattr(self, cache_attr)
565 
566     def isValidIDList(self, field_data, all_data):
567@@ -1087,6 +1137,9 @@
568 
569         self.m2m_field_name = curry(self._get_m2m_attr, related, 'name')
570         self.m2m_reverse_field_name = curry(self._get_m2m_reverse_attr, related, 'name')
571+       
572+        self.can_add = curry(self._get_can_add, related)
573+        self.can_remove = curry(self._get_can_remove, related)
574 
575     def set_attributes_from_rel(self):
576         pass