Code

Ticket #6095: 6095-alpha-04.diff

File 6095-alpha-04.diff, 20.4 KB (added by floguy, 6 years ago)
Line 
1Index: django/db/models/fields/related.py
2===================================================================
3--- django/db/models/fields/related.py  (revision 6903)
4+++ django/db/models/fields/related.py  (working copy)
5@@ -1,10 +1,10 @@
6 from django.db import connection, transaction
7-from django.db.models import signals, get_model
8+from django.db.models import signals, get_model, get_models
9 from django.db.models.fields import AutoField, Field, IntegerField, PositiveIntegerField, PositiveSmallIntegerField, get_ul_class
10 from django.db.models.related import RelatedObject
11 from django.utils.text import capfirst
12 from django.utils.translation import ugettext_lazy, string_concat, ungettext, ugettext as _
13-from django.utils.functional import curry
14+from django.utils.functional import curry, memoize
15 from django.utils.encoding import smart_unicode
16 from django.core import validators
17 from django import oldforms
18@@ -23,6 +23,10 @@
19 
20 pending_lookups = {}
21 
22+memoized_fk_field_reversals = {}
23+
24+model_db_table_cache = {}
25+
26 def add_lookup(rel_cls, field):
27     name = field.rel.to
28     module = rel_cls.__module__
29@@ -54,6 +58,31 @@
30     except klass.DoesNotExist:
31         raise validators.ValidationError, _("Please enter a valid %s.") % f.verbose_name
32 
33+def get_reverse_rel_field(from_model, to_model, related_name):
34+    "Gets the related field which points from one model to another."
35+    key = (from_model._meta.app_label, from_model._meta.object_name,
36+            to_model._meta.app_label, to_model._meta.object_name,
37+            related_name)
38+    try:
39+        found_field = memoized_fk_field_reversals[key]
40+    except KeyError:
41+        found_field = None
42+        for field in from_model._meta.fields:
43+            if field.__class__ in (ForeignKey, OneToOneField, ManyToManyField):
44+                if field.rel.to == to_model:
45+                    found_field = field
46+                    break
47+        memoized_fk_field_reversals[key] = found_field
48+    return found_field
49+
50+def get_model_for_db_table(db_table):
51+    "Gets a model class from a db_table string."
52+    for model in get_models():
53+        if model._meta.db_table == db_table:
54+            return model
55+    return None
56+get_model_for_db_table = memoize(get_model_for_db_table, model_db_table_cache, 1)
57+
58 #HACK
59 class RelatedField(object):
60     def contribute_to_class(self, cls, name):
61@@ -267,7 +296,8 @@
62     and adds behavior for many-to-many related objects."""
63     class ManyRelatedManager(superclass):
64         def __init__(self, model=None, core_filters=None, instance=None, symmetrical=None,
65-                join_table=None, source_col_name=None, target_col_name=None):
66+                join_table=None, source_col_name=None, source_attname=None,
67+                target_attname=None, target_col_name=None):
68             super(ManyRelatedManager, self).__init__()
69             self.core_filters = core_filters
70             self.model = model
71@@ -276,6 +306,9 @@
72             self.join_table = join_table
73             self.source_col_name = source_col_name
74             self.target_col_name = target_col_name
75+            self.source_attname = source_attname
76+            self.target_attname = target_attname
77+            self.intermediary_model = get_model_for_db_table(self.join_table.replace('"',''))
78             self._pk_val = self.instance._get_pk_val()
79             if self._pk_val is None:
80                 raise ValueError("%r instance needs to have a primary key value before a many-to-many relationship can be used." % model)
81@@ -340,9 +373,15 @@
82 
83                 # Add the ones that aren't there already
84                 for obj_id in (new_ids - existing_ids):
85-                    cursor.execute("INSERT INTO %s (%s, %s) VALUES (%%s, %%s)" % \
86+                    if self.intermediary_model == None:
87+                        cursor.execute("INSERT INTO %s (%s, %s) VALUES (%%s, %%s)" % \
88                         (self.join_table, source_col_name, target_col_name),
89                         [self._pk_val, obj_id])
90+                    else:
91+                        new_obj = self.intermediary_model()
92+                        setattr(new_obj, self.source_attname, self._pk_val)
93+                        setattr(new_obj, self.target_attname, obj_id)
94+                        new_obj.save()
95                 transaction.commit_unless_managed()
96 
97         def _remove_items(self, source_col_name, target_col_name, *objs):
98@@ -398,14 +437,17 @@
99         RelatedManager = create_many_related_manager(superclass)
100 
101         qn = connection.ops.quote_name
102+        rel_field = self.related.field
103         manager = RelatedManager(
104             model=rel_model,
105             core_filters={'%s__pk' % self.related.field.name: instance._get_pk_val()},
106             instance=instance,
107             symmetrical=False,
108-            join_table=qn(self.related.field.m2m_db_table()),
109-            source_col_name=qn(self.related.field.m2m_reverse_name()),
110-            target_col_name=qn(self.related.field.m2m_column_name())
111+            join_table=qn(rel_field.m2m_db_table()),
112+            source_col_name=qn(rel_field.m2m_reverse_name()),
113+            target_col_name=qn(rel_field.m2m_column_name()),
114+            source_attname=rel_field.m2m_reverse_attname(),
115+            target_attname=rel_field.m2m_attname()
116         )
117 
118         return manager
119@@ -446,7 +488,9 @@
120             symmetrical=(self.field.rel.symmetrical and instance.__class__ == rel_model),
121             join_table=qn(self.field.m2m_db_table()),
122             source_col_name=qn(self.field.m2m_column_name()),
123-            target_col_name=qn(self.field.m2m_reverse_name())
124+            target_col_name=qn(self.field.m2m_reverse_name()),
125+            source_attname=self.field.m2m_attname(),
126+            target_attname=self.field.m2m_reverse_attname()
127         )
128 
129         return manager
130@@ -648,8 +692,11 @@
131             filter_interface=kwargs.pop('filter_interface', None),
132             limit_choices_to=kwargs.pop('limit_choices_to', None),
133             raw_id_admin=kwargs.pop('raw_id_admin', False),
134-            symmetrical=kwargs.pop('symmetrical', True))
135+            symmetrical=kwargs.pop('symmetrical', True),
136+            through=kwargs.pop('through', None))
137         self.db_table = kwargs.pop('db_table', None)
138+        if kwargs['rel'].through:
139+            assert not self.db_table, "Cannot specify a db_table if an intermediary model is used."
140         if kwargs["rel"].raw_id_admin:
141             kwargs.setdefault("validator_list", []).append(self.isValidIDList)
142         Field.__init__(self, **kwargs)
143@@ -672,23 +719,53 @@
144 
145     def _get_m2m_db_table(self, opts):
146         "Function that can be curried to provide the m2m table name for this relation"
147-        if self.db_table:
148+        if self.rel.through != None:
149+            return get_model(opts.app_label, self.rel.through)._meta.db_table
150+        elif self.db_table:
151             return self.db_table
152         else:
153             return '%s_%s' % (opts.db_table, self.name)
154 
155+    def _get_m2m_attname(self, related):
156+        try:
157+            through = get_model(related.opts.app_label, self.rel.through)
158+            field = get_reverse_rel_field(through, related.model, self.rel.related_name)
159+            attname, column = field.get_attname_column()
160+            return attname
161+        except:
162+            return None
163+
164     def _get_m2m_column_name(self, related):
165         "Function that can be curried to provide the source column name for the m2m table"
166         # If this is an m2m relation to self, avoid the inevitable name clash
167-        if related.model == related.parent_model:
168+        if self.rel.through != None:
169+            through = get_model(related.opts.app_label, self.rel.through)
170+            field = get_reverse_rel_field(through, related.model, self.rel.related_name)
171+            attname, column = field.get_attname_column()
172+            return column
173+        elif related.model == related.parent_model:
174             return 'from_' + related.model._meta.object_name.lower() + '_id'
175         else:
176             return related.model._meta.object_name.lower() + '_id'
177 
178+    def _get_m2m_reverse_attname(self, related):
179+        try:
180+            through = get_model(related.opts.app_label, self.rel.through)
181+            field = get_reverse_rel_field(through, related.parent_model, self.rel.related_name)
182+            attname, column = field.get_attname_column()
183+            return attname
184+        except:
185+            return None
186+
187     def _get_m2m_reverse_name(self, related):
188         "Function that can be curried to provide the related column name for the m2m table"
189         # If this is an m2m relation to self, avoid the inevitable name clash
190-        if related.model == related.parent_model:
191+        if self.rel.through != None:
192+            through = get_model(related.opts.app_label, self.rel.through)
193+            field = get_reverse_rel_field(through, related.parent_model, self.rel.related_name)
194+            attname, column = field.get_attname_column()
195+            return column
196+        elif related.model == related.parent_model:
197             return 'to_' + related.parent_model._meta.object_name.lower() + '_id'
198         else:
199             return related.parent_model._meta.object_name.lower() + '_id'
200@@ -745,6 +822,8 @@
201         # Set up the accessors for the column names on the m2m table
202         self.m2m_column_name = curry(self._get_m2m_column_name, related)
203         self.m2m_reverse_name = curry(self._get_m2m_reverse_name, related)
204+        self.m2m_attname = curry(self._get_m2m_attname, related)
205+        self.m2m_reverse_attname = curry(self._get_m2m_reverse_attname, related)
206 
207     def set_attributes_from_rel(self):
208         pass
209@@ -809,7 +888,8 @@
210 
211 class ManyToManyRel(object):
212     def __init__(self, to, num_in_admin=0, related_name=None,
213-        filter_interface=None, limit_choices_to=None, raw_id_admin=False, symmetrical=True):
214+        filter_interface=None, limit_choices_to=None, raw_id_admin=False, symmetrical=True,
215+        through = None):
216         self.to = to
217         self.num_in_admin = num_in_admin
218         self.related_name = related_name
219@@ -821,5 +901,6 @@
220         self.raw_id_admin = raw_id_admin
221         self.symmetrical = symmetrical
222         self.multiple = True
223+        self.through = through
224 
225         assert not (self.raw_id_admin and self.filter_interface), "ManyToManyRels may not use both raw_id_admin and filter_interface"
226Index: django/core/management/validation.py
227===================================================================
228--- django/core/management/validation.py        (revision 6903)
229+++ django/core/management/validation.py        (working copy)
230@@ -104,6 +104,8 @@
231                         if r.get_accessor_name() == rel_query_name:
232                             e.add(opts, "Reverse query name for field '%s' clashes with related field '%s.%s'. Add a related_name argument to the definition for '%s'." % (f.name, rel_opts.object_name, r.get_accessor_name(), f.name))
233 
234+        seen_intermediary_signatures = []
235+
236         for i, f in enumerate(opts.many_to_many):
237             # Check to see if the related m2m field will clash with any
238             # existing fields, m2m fields, m2m related objects or related objects
239@@ -113,6 +115,28 @@
240                 # so skip the next section
241                 if isinstance(f.rel.to, (str, unicode)):
242                     continue
243+            if hasattr(f.rel, 'through') and f.rel.through != None:
244+                intermediary_model = None
245+                for model in models.get_models():
246+                    if model._meta.module_name == f.rel.through.lower():
247+                        intermediary_model = model
248+                if intermediary_model == None:
249+                    e.add(opts, "%s has a manually-defined m2m relationship through a model (%s) which does not exist." % (f.name, f.rel.through))
250+                else:
251+                    signature = (f.rel.to, cls, intermediary_model)
252+                    if signature in seen_intermediary_signatures:
253+                        e.add(opts, "%s has two manually defined m2m relationships through the same model (%s), which is not possible.  Please use a field on your intermediary model instead." % (cls._meta.object_name, intermediary_model._meta.object_name))
254+                    else:
255+                        seen_intermediary_signatures.append(signature)
256+                    seen_related_fk, seen_this_fk = False, False
257+                    for field in intermediary_model._meta.fields:
258+                        if field.rel:
259+                            if field.rel.to == f.rel.to:
260+                                seen_related_fk = True
261+                            elif field.rel.to == cls:
262+                                seen_this_fk = True
263+                    if not seen_related_fk or not seen_this_fk:
264+                        e.add(opts, "%s has a manualy-defined m2m relationship through a model (%s) which does not have foreign keys to %s and %s" % (f.name, f.rel.through, f.rel.to._meta.object_name, cls._meta.object_name))
265 
266             rel_opts = f.rel.to._meta
267             rel_name = RelatedObject(f.rel.to, cls, f).get_accessor_name()
268Index: django/core/management/sql.py
269===================================================================
270--- django/core/management/sql.py       (revision 6903)
271+++ django/core/management/sql.py       (working copy)
272@@ -349,7 +349,7 @@
273     qn = connection.ops.quote_name
274     inline_references = connection.features.inline_fk_references
275     for f in opts.many_to_many:
276-        if not isinstance(f.rel, generic.GenericRel):
277+        if not isinstance(f.rel, generic.GenericRel) and getattr(f.rel, 'through', None) == None:
278             tablespace = f.db_tablespace or opts.db_tablespace
279             if tablespace and connection.features.supports_tablespaces and connection.features.autoindexes_primary_keys:
280                 tablespace_sql = ' ' + connection.ops.tablespace_sql(tablespace, inline=True)
281Index: tests/modeltests/invalid_models/models.py
282===================================================================
283--- tests/modeltests/invalid_models/models.py   (revision 6903)
284+++ tests/modeltests/invalid_models/models.py   (working copy)
285@@ -111,7 +111,23 @@
286 class MissingRelations(models.Model):
287     rel1 = models.ForeignKey("Rel1")
288     rel2 = models.ManyToManyField("Rel2")
289+   
290+class MissingManualM2MModel(models.Model):
291+    name = models.CharField(max_length=5)
292+    missing_m2m = models.ManyToManyField(Model, through="MissingM2MModel")
293+   
294+class Person(models.Model):
295+    name = models.CharField(max_length=5)
296 
297+class Group(models.Model):
298+    name = models.CharField(max_length=5)
299+    primary = models.ManyToManyField(Person, through="Membership", related_name="primary")
300+    secondary = models.ManyToManyField(Person, through="Membership", related_name="secondary")
301+
302+class Membership(models.Model):
303+    person = models.ForeignKey(Person)
304+    group = models.ForeignKey(Group)
305+
306 model_errors = """invalid_models.fielderrors: "charfield": CharFields require a "max_length" attribute.
307 invalid_models.fielderrors: "decimalfield": DecimalFields require a "decimal_places" attribute.
308 invalid_models.fielderrors: "decimalfield": DecimalFields require a "max_digits" attribute.
309@@ -197,4 +213,6 @@
310 invalid_models.selfclashm2m: Reverse query name for m2m field 'm2m_4' clashes with field 'SelfClashM2M.selfclashm2m'. Add a related_name argument to the definition for 'm2m_4'.
311 invalid_models.missingrelations: 'rel2' has m2m relation with model Rel2, which has not been installed
312 invalid_models.missingrelations: 'rel1' has relation with model Rel1, which has not been installed
313+invalid_models.group: Group has two manually defined m2m relationships through the same model (Membership), which is not possible.  Please use a field on your intermediary model instead.
314+invalid_models.missingmanualm2mmodel: missing_m2m has a manually-defined m2m relationship through a model (MissingM2MModel) which does not exist.
315 """
316Index: tests/modeltests/m2m_manual/__init__.py
317===================================================================
318Index: tests/modeltests/m2m_manual/models.py
319===================================================================
320--- tests/modeltests/m2m_manual/models.py       (revision 0)
321+++ tests/modeltests/m2m_manual/models.py       (revision 0)
322@@ -0,0 +1,136 @@
323+from django.db import models
324+from datetime import datetime
325+
326+# M2M described on one of the models
327+class Person(models.Model):
328+    name = models.CharField(max_length=128)
329+
330+    def __unicode__(self):
331+        return self.name
332+
333+class Group(models.Model):
334+    name = models.CharField(max_length=128)
335+    members = models.ManyToManyField(Person, through='Membership')
336+    custom_members = models.ManyToManyField(Person, through='CustomMembership', related_name="custom")
337+   
338+    def __unicode__(self):
339+        return self.name
340+
341+class Membership(models.Model):
342+    person = models.ForeignKey(Person)
343+    group = models.ForeignKey(Group)
344+    date_joined = models.DateTimeField(default=datetime.now)
345+    invite_reason = models.CharField(max_length=64, null=True, blank=True)
346+   
347+    def __unicode__(self):
348+        return "%s is a member of %s" % (self.person.name, self.group.name)
349+
350+class CustomMembership(models.Model):
351+    person = models.ForeignKey(Person, db_column="custom_person_column", related_name="custom_person_related_name")
352+    group = models.ForeignKey(Group)
353+    weird_fk = models.ForeignKey(Membership, null=True)
354+    date_joined = models.DateTimeField(default=datetime.now)
355+   
356+    def __unicode__(self):
357+        return "%s is a member of %s" % (self.person.name, self.group.name)
358+
359+__test__ = {'API_TESTS':"""
360+>>> from datetime import datetime
361+
362+>>> bob = Person(name = 'Bob')
363+>>> bob.save()
364+>>> jim = Person(name = 'Jim')
365+>>> jim.save()
366+>>> jane = Person(name = 'Jane')
367+>>> jane.save()
368+>>> rock = Group(name = 'Rock')
369+>>> rock.save()
370+>>> roll = Group(name = 'Roll')
371+>>> roll.save()
372+
373+>>> rock.members.add(jim, jane)
374+>>> rock.members.all()
375+[<Person: Jim>, <Person: Jane>]
376+
377+>>> roll.members.add(bob, jim)
378+>>> roll.members.all()
379+[<Person: Bob>, <Person: Jim>]
380+
381+>>> jane.group_set.all()
382+[<Group: Rock>]
383+
384+>>> jane.group_set.add(roll)
385+>>> jane.group_set.all()
386+[<Group: Rock>, <Group: Roll>]
387+
388+>>> jim.group_set.all()
389+[<Group: Rock>, <Group: Roll>]
390+
391+# Check to make sure that the associated Membership object is created.
392+>>> m = Membership.objects.get(person = jane, group = rock)
393+>>> m
394+<Membership: Jane is a member of Rock>
395+
396+# Setting some date_joined dates
397+>>> m.invite_reason = "She was just so awesome."
398+>>> m.date_joined = datetime(2004, 1, 1)
399+>>> m.save()
400+
401+>>> m = Membership.objects.get(person = jane, group = roll)
402+>>> m.date_joined = datetime(2004, 1, 1)
403+>>> m.save()
404+
405+>>> m = Membership.objects.get(person = bob, group = roll)
406+>>> m.date_joined = datetime(2004, 1, 1)
407+>>> m.save()
408+
409+>>> Membership.objects.filter(person = jim)
410+[<Membership: Jim is a member of Rock>, <Membership: Jim is a member of Roll>]
411+
412+>>> rock.custom_members.add(bob)
413+>>> rock.custom_members.all()
414+[<Person: Bob>]
415+
416+>>> jim.custom.add(rock)
417+>>> rock.custom_members.all()
418+[<Person: Bob>, <Person: Jim>]
419+
420+>>> jim.custom.all()
421+[<Group: Rock>]
422+
423+>>> jim.custom_person_related_name.all()
424+[<CustomMembership: Jim is a member of Rock>]
425+
426+###QUERY TESTS###
427+# Queries involving the related model (Person, in the case of Group) use its attname
428+>>> Group.objects.filter(members__name='Bob')
429+[<Group: Roll>]
430+
431+# Queries involving the relationship model (Membership, in the case of Group) use its model name
432+>>> Group.objects.filter(membership__invite_reason = "She was just so awesome.")
433+[<Group: Rock>]
434+
435+# Queries involving the reverse related model (Group, in the case of Person) use its model name
436+>>> Person.objects.filter(group__name="Rock")
437+[<Person: Jim>, <Person: Jane>]
438+
439+# If the m2m field has specified a related_name, using that will work.
440+>>> Person.objects.filter(custom__name="Rock")
441+[<Person: Bob>, <Person: Jim>]
442+
443+# Queries involving the relationship model (Membership, in the case of Group) use its model name
444+>>> Person.objects.filter(membership__invite_reason = "She was just so awesome.")
445+[<Person: Jane>]
446+
447+# Let's see all of the groups that Jane joined after 1 Jan 2005:
448+>>> Group.objects.filter(membership__date_joined__gt = datetime(2005, 1, 1))
449+[<Group: Rock>, <Group: Roll>]
450+
451+# Now let's see all of the people that have joined Rock since 1 Jan 2005:
452+>>> Person.objects.filter(membership__date_joined__gt = datetime(2005, 1, 1))
453+[<Person: Jim>, <Person: Jim>]
454+
455+# Oops, that returned non-distinct results, let's fix that:
456+>>> Person.objects.filter(membership__date_joined__gt = datetime(2005, 1, 1)).distinct()
457+[<Person: Jim>]
458+"""}
459\ No newline at end of file