Ticket #9475: 9475.m2m_add_remove.r12281.diff

File 9475.m2m_add_remove.r12281.diff, 24.4 KB (added by Johannes Dollinger, 14 years ago)
  • tests/modeltests/m2m_add_and_remove/__init__.py

     
     1
     2
  • tests/modeltests/m2m_add_and_remove/models.py

     
     1from django.db import models
     2from django.test import TestCase
     3
     4
     5class M(models.Model):
     6    default = models.ManyToManyField("R", related_name="default_m_set")
     7    default_cannot_remove = models.ManyToManyField("R", can_remove=False, related_name="default_cannot_remove_m_set")
     8    default_cannot_add = models.ManyToManyField("R", can_add=False, related_name="default_cannot_add_m_set")
     9    through_default = models.ManyToManyField("R", through="ThroughDefault", related_name="through_default_m_set")
     10    through_auto = models.ManyToManyField("R", through="ThroughAuto", related_name="through_auto_m_set")
     11    through_ut = models.ManyToManyField("R", through="ThroughUT", related_name="through_ut_m_set")
     12    through_can_add = models.ManyToManyField("R", can_add=True, through="ThroughCanAdd", related_name="through_can_add_m_set")
     13    through_can_remove = models.ManyToManyField("R", can_remove=True, through="ThroughCanRemove", related_name="through_can_remove_m_set")
     14
     15class R(models.Model):
     16    name = models.CharField(max_length=30)
     17
     18class Through(models.Model):
     19    m = models.ForeignKey(M, related_name="%(class)s_set")
     20    r = models.ForeignKey(R, related_name="%(class)s_set")
     21
     22    class Meta:
     23        abstract = True
     24       
     25class ThroughDefault(Through):
     26    extra = models.CharField(max_length=10)
     27
     28class ThroughAuto(Through):
     29    ctime = models.DateTimeField(auto_now_add=True)
     30    mtime = models.DateTimeField(auto_now=True)
     31    default = models.IntegerField(default=42)
     32    null = models.DateTimeField(null=True)
     33   
     34class ThroughUT(Through):
     35    extra = models.CharField(max_length=10)
     36
     37    class Meta:
     38        unique_together = ('m', 'r')
     39       
     40class ThroughCanRemove(Through):
     41    extra = models.CharField(max_length=10)
     42
     43class ThroughCanAdd(Through):
     44    extra = models.CharField(max_length=10)
     45   
     46    def save(self, **kwargs):
     47        self.extra = "foo"
     48        return super(ThroughCanAdd, self).save(**kwargs)
     49
     50class M2mAddRemoveTests(TestCase):
     51    def assert_cannot_remove(self, name):
     52        m = M.objects.create()
     53        r = R.objects.create()
     54        manager = getattr(m, name)
     55        reverse_manager = getattr(r, "%s_m_set" % name)
     56        self.assertRaises(AttributeError, getattr, manager, 'remove')
     57        self.assertRaises(AttributeError, getattr, reverse_manager, 'remove')
     58       
     59    def assert_cannot_add(self, name):
     60        reverse_name = "%s_m_set" % name
     61        m = M.objects.create()
     62        r = R.objects.create()
     63        manager = getattr(m, name)
     64        reverse_manager = getattr(r, reverse_name)
     65        self.assertRaises(AttributeError, getattr, manager, 'add')
     66        self.assertRaises(AttributeError, getattr, reverse_manager, 'add')
     67        self.assertRaises(AttributeError, manager.create)
     68        self.assertRaises(AttributeError, reverse_manager.create)
     69        def assign():
     70            setattr(m, name, [])
     71        self.assertRaises(AttributeError, assign)
     72        def assign_reverse():
     73            setattr(r, reverse_name, [])
     74        self.assertRaises(AttributeError, assign_reverse)
     75       
     76    def assert_can_add(self, name):
     77        reverse_name = "%s_m_set" % name
     78        m = M.objects.create()
     79        r = R.objects.create()
     80        manager = getattr(m, name)
     81        reverse_manager = getattr(r, reverse_name)
     82       
     83        manager.add(r)
     84        self.failUnlessEqual(list(manager.all()), [r])
     85        self.failUnlessEqual(list(reverse_manager.all()), [m])
     86        manager.add(r)
     87        self.failUnlessEqual(list(manager.all()), [r])
     88        self.failUnlessEqual(list(reverse_manager.all()), [m])
     89        manager.clear()
     90       
     91        reverse_manager.add(m)
     92        self.failUnlessEqual(list(manager.all()), [r])
     93        self.failUnlessEqual(list(reverse_manager.all()), [m])
     94        reverse_manager.add(m)
     95        self.failUnlessEqual(list(manager.all()), [r])
     96        self.failUnlessEqual(list(reverse_manager.all()), [m])
     97        reverse_manager.clear()
     98       
     99        r2 = manager.create()
     100        reverse_manager2 = getattr(r2, reverse_name)
     101        self.failUnlessEqual(list(manager.all()), [r2])
     102        self.failUnlessEqual(list(reverse_manager2.all()), [m])
     103        manager.clear()
     104       
     105        m2 = reverse_manager.create()
     106        manager2 = getattr(m2, name)
     107        self.failUnlessEqual(list(manager2.all()), [r])
     108        self.failUnlessEqual(list(reverse_manager.all()), [m2])
     109        reverse_manager.clear()
     110       
     111        setattr(m, name, [r])
     112        self.failUnlessEqual(list(manager.all()), [r])
     113        manager.clear()
     114       
     115        setattr(r, reverse_name, [m])
     116        self.failUnlessEqual(list(reverse_manager.all()), [m])
     117        reverse_manager.clear()
     118       
     119    def assert_can_remove(self, name, extra):
     120        through = M._meta.get_field(name).rel.through
     121        m = M.objects.create()
     122        r = R.objects.create()
     123       
     124        def fill():           
     125            for extra_kwargs in extra:
     126                kwargs = {'m': m, 'r': r}
     127                kwargs.update(extra_kwargs)
     128                through.objects.create(**kwargs)
     129
     130        manager = getattr(m, name)
     131        reverse_manager = getattr(r, "%s_m_set" % name)
     132
     133        fill()
     134        manager.remove(r)
     135        self.failIf(manager.exists())
     136        self.failIf(reverse_manager.exists())
     137       
     138        fill()
     139        reverse_manager.remove(m)
     140        self.failIf(reverse_manager.exists())
     141        self.failIf(manager.exists())
     142       
     143    def _test_managers(self, name, can_remove=False, can_add=False, extra=()):
     144        if can_add:
     145            self.assert_can_add(name)
     146        else:
     147            self.assert_cannot_add(name)
     148        if can_remove:
     149            self.assert_can_remove(name, extra)
     150        else:
     151            self.assert_cannot_remove(name)
     152
     153    def test_default(self):
     154        self._test_managers('default', can_add=True, can_remove=True, extra=[{}])
     155       
     156    def test_default_cannot_remove(self):
     157        self._test_managers('default_cannot_remove', can_add=True, can_remove=False)
     158       
     159    def test_default_cannot_add(self):
     160        self._test_managers('default_cannot_add', can_add=False, can_remove=True, extra=[{}])
     161
     162    def test_through_default(self):
     163        self._test_managers('through_default', can_add=False, can_remove=False)
     164       
     165    def test_through_auto(self):
     166        self._test_managers('through_auto', can_add=True, can_remove=False)
     167       
     168    def test_through_ut(self):
     169        self._test_managers('through_ut', can_add=False, can_remove=True, extra=[{'extra': 'foo'}])
     170       
     171    def test_through_can_add(self):
     172        self._test_managers('through_can_add', can_add=True, can_remove=False)
     173       
     174    def test_through_can_remove(self):
     175        self._test_managers('through_can_remove', can_add=False, can_remove=True, extra=[{'extra': 'foo'}, {'extra': 'bar'}])
     176       
     177 No newline at end of file
  • tests/modeltests/m2m_through/models.py

     
    122122
    123123### Forward Descriptors Tests ###
    124124
    125 # Due to complications with adding via an intermediary model,
    126 # the add method is not provided.
    127 >>> rock.members.add(bob)
    128 Traceback (most recent call last):
    129 ...
    130 AttributeError: 'ManyRelatedManager' object has no attribute 'add'
    131 
    132 # Create is also disabled as it suffers from the same problems as add.
    133 >>> rock.members.create(name='Anne')
    134 Traceback (most recent call last):
    135 ...
    136 AttributeError: Cannot use create() on a ManyToManyField which specifies an intermediary model. Use m2m_through.Membership's Manager instead.
    137 
    138 # Remove has similar complications, and is not provided either.
    139 >>> rock.members.remove(jim)
    140 Traceback (most recent call last):
    141 ...
    142 AttributeError: 'ManyRelatedManager' object has no attribute 'remove'
    143 
    144 # Here we back up the list of all members of Rock.
    145 >>> backup = list(rock.members.all())
    146 
    147 # ...and we verify that it has worked.
    148 >>> backup
    149 [<Person: Jane>, <Person: Jim>]
    150 
    151 # The clear function should still work.
     125# The clear function should work.
    152126>>> rock.members.clear()
    153127
    154128# Now there will be no members of Rock.
    155129>>> rock.members.all()
    156130[]
    157131
    158 # Assignment should not work with models specifying a through model for many of
    159 # the same reasons as adding.
    160 >>> rock.members = backup
    161 Traceback (most recent call last):
    162 ...
    163 AttributeError: Cannot set values on a ManyToManyField which specifies an intermediary model.  Use m2m_through.Membership's Manager instead.
    164 
    165132# Let's re-save those instances that we've cleared.
    166133>>> m1.save()
    167134>>> m2.save()
     
    173140
    174141### Reverse Descriptors Tests ###
    175142
    176 # Due to complications with adding via an intermediary model,
    177 # the add method is not provided.
    178 >>> bob.group_set.add(rock)
    179 Traceback (most recent call last):
    180 ...
    181 AttributeError: 'ManyRelatedManager' object has no attribute 'add'
    182 
    183 # Create is also disabled as it suffers from the same problems as add.
    184 >>> bob.group_set.create(name='Funk')
    185 Traceback (most recent call last):
    186 ...
    187 AttributeError: Cannot use create() on a ManyToManyField which specifies an intermediary model. Use m2m_through.Membership's Manager instead.
    188 
    189 # Remove has similar complications, and is not provided either.
    190 >>> jim.group_set.remove(rock)
    191 Traceback (most recent call last):
    192 ...
    193 AttributeError: 'ManyRelatedManager' object has no attribute 'remove'
    194 
    195 # Here we back up the list of all of Jim's groups.
    196 >>> backup = list(jim.group_set.all())
    197 >>> backup
    198 [<Group: Rock>, <Group: Roll>]
    199 
    200 # The clear function should still work.
     143# The clear function should work.
    201144>>> jim.group_set.clear()
    202145
    203146# Now Jim will be in no groups.
    204147>>> jim.group_set.all()
    205148[]
    206149
    207 # Assignment should not work with models specifying a through model for many of
    208 # the same reasons as adding.
    209 >>> jim.group_set = backup
    210 Traceback (most recent call last):
    211 ...
    212 AttributeError: Cannot set values on a ManyToManyField which specifies an intermediary model.  Use m2m_through.Membership's Manager instead.
    213 
    214150# Let's re-save those instances that we've cleared.
    215151>>> m1.save()
    216152>>> m4.save()
  • tests/regressiontests/m2m_through_regress/models.py

     
    8080>>> roll.members.all()
    8181[<Person: Bob>]
    8282
    83 # Error messages use the model name, not repr of the class name
    84 >>> bob.group_set = []
    85 Traceback (most recent call last):
    86 ...
    87 AttributeError: Cannot set values on a ManyToManyField which specifies an intermediary model.  Use m2m_through_regress.Membership's Manager instead.
    88 
    89 >>> roll.members = []
    90 Traceback (most recent call last):
    91 ...
    92 AttributeError: Cannot set values on a ManyToManyField which specifies an intermediary model.  Use m2m_through_regress.Membership's Manager instead.
    93 
    94 >>> rock.members.create(name='Anne')
    95 Traceback (most recent call last):
    96 ...
    97 AttributeError: Cannot use create() on a ManyToManyField which specifies an intermediary model.  Use m2m_through_regress.Membership's Manager instead.
    98 
    99 >>> bob.group_set.create(name='Funk')
    100 Traceback (most recent call last):
    101 ...
    102 AttributeError: Cannot use create() on a ManyToManyField which specifies an intermediary model.  Use m2m_through_regress.Membership's Manager instead.
    103 
    10483# Now test that the intermediate with a relationship outside
    10584# the current app (i.e., UserMembership) workds
    10685>>> UserMembership.objects.create(user=frank, group=rock)
  • django/db/models/fields/related.py

     
    433433
    434434        return manager
    435435
    436 def create_many_related_manager(superclass, rel=False):
     436def create_many_related_manager(superclass, field):
    437437    """Creates a manager that subclasses 'superclass' (which is a Manager)
    438438    and adds behavior for many-to-many related objects."""
    439     through = rel.through
     439    through = field.rel.through
     440    can_add = field.can_add()
     441    can_remove = field.can_remove()
     442   
    440443    class ManyRelatedManager(superclass):
    441444        def __init__(self, model=None, core_filters=None, instance=None, symmetrical=None,
    442445                join_table=None, source_field_name=None, target_field_name=None,
     
    458461            db = router.db_for_read(self.instance.__class__, instance=self.instance)
    459462            return superclass.get_query_set(self).using(db)._next_is_sticky().filter(**(self.core_filters))
    460463
    461         # If the ManyToMany relation has an intermediary model,
    462         # the add and remove methods do not exist.
    463         if rel.through._meta.auto_created:
     464        if can_add:
    464465            def add(self, *objs):
    465466                self._add_items(self.source_field_name, self.target_field_name, *objs)
    466467
     
    469470                    self._add_items(self.target_field_name, self.source_field_name, *objs)
    470471            add.alters_data = True
    471472
     473        if can_remove:
    472474            def remove(self, *objs):
    473475                self._remove_items(self.source_field_name, self.target_field_name, *objs)
    474476
     
    488490        def create(self, **kwargs):
    489491            # This check needs to be done here, since we can't later remove this
    490492            # from the method lookup table, as we do with add and remove.
    491             if not rel.through._meta.auto_created:
     493            if not can_add:
    492494                opts = through._meta
    493                 raise AttributeError("Cannot use create() on a ManyToManyField which specifies an intermediary model. Use %s.%s's Manager instead." % (opts.app_label, opts.object_name))
     495                raise AttributeError("Cannot use create() on this ManyToManyField. Use %s.%s's Manager instead." % (opts.app_label, opts.object_name))
    494496            db = router.db_for_write(self.instance.__class__, instance=self.instance)
    495497            new_obj = super(ManyRelatedManager, self).using(db).create(**kwargs)
    496498            self.add(new_obj)
     
    498500        create.alters_data = True
    499501
    500502        def get_or_create(self, **kwargs):
     503            # This check needs to be done here, since we can't later remove this
     504            # from the method lookup table, as we do with add and remove.
     505            if not can_add:
     506                opts = through._meta
     507                raise AttributeError, "Cannot use get_or_create() on this ManyToManyField. Use %s.%s's Manager instead." % (opts.app_label, opts.object_name)
    501508            db = router.db_for_write(self.instance.__class__, instance=self.instance)
    502509            obj, created = \
    503510                super(ManyRelatedManager, self).using(db).get_or_create(**kwargs)
     
    544551                if self.reverse or source_field_name == self.source_field_name:
    545552                    # Don't send the signal when we are inserting the
    546553                    # duplicate data row for symmetrical reverse entries.
    547                     signals.m2m_changed.send(sender=rel.through, action='add',
     554                    signals.m2m_changed.send(sender=through, action='add',
    548555                        instance=self.instance, reverse=self.reverse,
    549556                        model=self.model, pk_set=new_ids)
    550557
     
    571578                if self.reverse or source_field_name == self.source_field_name:
    572579                    # Don't send the signal when we are deleting the
    573580                    # duplicate data row for symmetrical reverse entries.
    574                     signals.m2m_changed.send(sender=rel.through, action="remove",
     581                    signals.m2m_changed.send(sender=through, action="remove",
    575582                        instance=self.instance, reverse=self.reverse,
    576583                        model=self.model, pk_set=old_ids)
    577584
     
    580587            if self.reverse or source_field_name == self.source_field_name:
    581588                # Don't send the signal when we are clearing the
    582589                # duplicate data rows for symmetrical reverse entries.
    583                 signals.m2m_changed.send(sender=rel.through, action="clear",
     590                signals.m2m_changed.send(sender=through, action="clear",
    584591                    instance=self.instance, reverse=self.reverse,
    585592                    model=self.model, pk_set=None)
    586593            db = router.db_for_write(self.through.__class__, instance=self.instance)
     
    608615        # model's default manager.
    609616        rel_model = self.related.model
    610617        superclass = rel_model._default_manager.__class__
    611         RelatedManager = create_many_related_manager(superclass, self.related.field.rel)
     618        RelatedManager = create_many_related_manager(superclass, self.related.field)
    612619
    613620        manager = RelatedManager(
    614621            model=rel_model,
     
    626633        if instance is None:
    627634            raise AttributeError("Manager must be accessed via instance")
    628635
    629         if not self.related.field.rel.through._meta.auto_created:
     636        if not self.related.field.can_add():
    630637            opts = self.related.field.rel.through._meta
    631             raise AttributeError("Cannot set values on a ManyToManyField which specifies an intermediary model. Use %s.%s's Manager instead." % (opts.app_label, opts.object_name))
     638            raise AttributeError("Cannot set values on this ManyToManyField. Use %s.%s's Manager instead." % (opts.app_label, opts.object_name))
    632639
    633640        manager = self.__get__(instance)
    634641        manager.clear()
     
    660667        # model's default manager.
    661668        rel_model=self.field.rel.to
    662669        superclass = rel_model._default_manager.__class__
    663         RelatedManager = create_many_related_manager(superclass, self.field.rel)
     670        RelatedManager = create_many_related_manager(superclass, self.field)
    664671
    665672        manager = RelatedManager(
    666673            model=rel_model,
     
    678685        if instance is None:
    679686            raise AttributeError("Manager must be accessed via instance")
    680687
    681         if not self.field.rel.through._meta.auto_created:
     688        if not self.field.can_add():
    682689            opts = self.field.rel.through._meta
    683             raise AttributeError("Cannot set values on a ManyToManyField which specifies an intermediary model.  Use %s.%s's Manager instead." % (opts.app_label, opts.object_name))
     690            raise AttributeError("Cannot set values a this ManyToManyField.  Use %s.%s's Manager instead." % (opts.app_label, opts.object_name))
    684691
    685692        manager = self.__get__(instance)
    686693        manager.clear()
     
    953960        self.db_table = kwargs.pop('db_table', None)
    954961        if kwargs['rel'].through is not None:
    955962            assert self.db_table is None, "Cannot specify a db_table if an intermediary model is used."
     963           
     964        self._can_add = kwargs.pop('can_add', None)
     965        self._can_remove = kwargs.pop('can_remove', None)
    956966
    957967        Field.__init__(self, **kwargs)
    958968
     
    971981        else:
    972982            return util.truncate_name('%s_%s' % (opts.db_table, self.name),
    973983                                      connection.ops.max_name_length())
    974 
    975     def _get_m2m_attr(self, related, attr):
    976         "Function that can be curried to provide the source column name for the m2m table"
    977         cache_attr = '_m2m_%s_cache' % attr
     984                                     
     985    def _get_intermediary_fields(self, related):
     986        cache_attr = '_m2m_intermediary_fields_cache'
    978987        if hasattr(self, cache_attr):
    979988            return getattr(self, cache_attr)
     989           
     990        candidates = []
     991        related_candidates = []
     992        auto_add = True
    980993        for f in self.rel.through._meta.fields:
    981             if hasattr(f,'rel') and f.rel and f.rel.to == related.model:
    982                 setattr(self, cache_attr, getattr(f, attr))
    983                 return getattr(self, cache_attr)
     994            if hasattr(f,'rel') and f.rel:
     995                if f.rel.to == related.model:
     996                    candidates.append(f)
     997                    continue
     998                elif f.rel.to == related.parent_model:
     999                    related_candidates.append(f)
     1000                    continue
     1001            if isinstance(f, AutoField) or f.null or f.has_default():
     1002                continue
     1003            elif getattr(f, 'auto_now_add', False) or getattr(f, 'auto_now', False):
     1004                continue
     1005            else:
     1006                auto_add = False
     1007        if related.model == related.parent_model:
     1008            # m2m to self
     1009            assert len(candidates) == 2, "There are too many ForeignKeys to %s" % related.model
     1010            field, related_field = candidates
     1011        else:
     1012            assert len(candidates) == 1, "There are no ForeignKeys to %s" % related.model
     1013            assert len(related_candidates) == 1, "There are no ForeignKeys to %s" % related.parent_model
     1014            # TODO: intelligently pick a candidate if there is more than one. For now, just use the first.
     1015            field, related_field = candidates[0], related_candidates[0]
    9841016
     1017        if self._can_add is None:
     1018            self._can_add = auto_add
     1019
     1020        if self._can_remove is None:
     1021            self._can_remove = False
     1022            unique_together = [frozenset(ut) for ut in self.rel.through._meta.unique_together]               
     1023            names = frozenset([field.name, related_field.name])
     1024            for ut in unique_together:
     1025                if names <= ut:
     1026                    self._can_remove = True
     1027                    break
     1028       
     1029        setattr(self, cache_attr, (field, related_field))
     1030        return (field, related_field)
     1031       
     1032    def _get_can_add(self, related):
     1033        if self._can_add is None:
     1034            self._get_intermediary_fields(related)
     1035        return self._can_add
     1036       
     1037    def _get_can_remove(self, related):
     1038        if self._can_remove is None:
     1039            self._get_intermediary_fields(related)
     1040        return self._can_remove
     1041
     1042    def _get_m2m_attr(self, related, attr):
     1043        "Function that can be curried to provide a source field attribute"
     1044        cache_attr = '_m2m_%s_cache' % attr
     1045        if not hasattr(self, cache_attr):
     1046            field, _ = self._get_intermediary_fields(related)
     1047            setattr(self, cache_attr, getattr(field, attr))
     1048        return getattr(self, cache_attr)
     1049
    9851050    def _get_m2m_reverse_attr(self, related, attr):
    986         "Function that can be curried to provide the related column name for the m2m table"
     1051        "Function that can be curried to provide a related field attribute"
    9871052        cache_attr = '_m2m_reverse_%s_cache' % attr
    988         if hasattr(self, cache_attr):
    989             return getattr(self, cache_attr)
    990         found = False
    991         for f in self.rel.through._meta.fields:
    992             if hasattr(f,'rel') and f.rel and f.rel.to == related.parent_model:
    993                 if related.model == related.parent_model:
    994                     # If this is an m2m-intermediate to self,
    995                     # the first foreign key you find will be
    996                     # the source column. Keep searching for
    997                     # the second foreign key.
    998                     if found:
    999                         setattr(self, cache_attr, getattr(f, attr))
    1000                         break
    1001                     else:
    1002                         found = True
    1003                 else:
    1004                     setattr(self, cache_attr, getattr(f, attr))
    1005                     break
     1053        if not hasattr(self, cache_attr):
     1054            _, related_field = self._get_intermediary_fields(related)
     1055            setattr(self, cache_attr, getattr(related_field, attr))           
    10061056        return getattr(self, cache_attr)
    10071057
    10081058    def isValidIDList(self, field_data, all_data):
     
    10871137
    10881138        self.m2m_field_name = curry(self._get_m2m_attr, related, 'name')
    10891139        self.m2m_reverse_field_name = curry(self._get_m2m_reverse_attr, related, 'name')
     1140       
     1141        self.can_add = curry(self._get_can_add, related)
     1142        self.can_remove = curry(self._get_can_remove, related)
    10901143
    10911144    def set_attributes_from_rel(self):
    10921145        pass
Back to Top