Code

Ticket #6095: 6095-alpha-03.diff

File 6095-alpha-03.diff, 15.1 KB (added by floguy, 7 years ago)
Line 
1Index: django/db/models/fields/related.py
2===================================================================
3--- django/db/models/fields/related.py  (revision 6898)
4+++ django/db/models/fields/related.py  (working copy)
5@@ -1,10 +1,10 @@
6 from django.db import connection, transaction
7-from django.db.models import signals, get_model
8+from django.db.models import signals, get_model, get_models
9 from django.db.models.fields import AutoField, Field, IntegerField, PositiveIntegerField, PositiveSmallIntegerField, get_ul_class
10 from django.db.models.related import RelatedObject
11 from django.utils.text import capfirst
12 from django.utils.translation import ugettext_lazy, string_concat, ungettext, ugettext as _
13-from django.utils.functional import curry
14+from django.utils.functional import curry, memoize
15 from django.utils.encoding import smart_unicode
16 from django.core import validators
17 from django import oldforms
18@@ -23,6 +23,10 @@
19 
20 pending_lookups = {}
21 
22+memoized_fk_field_reversals = {}
23+
24+model_db_table_cache = {}
25+
26 def add_lookup(rel_cls, field):
27     name = field.rel.to
28     module = rel_cls.__module__
29@@ -54,6 +58,29 @@
30     except klass.DoesNotExist:
31         raise validators.ValidationError, _("Please enter a valid %s.") % f.verbose_name
32 
33+def get_reverse_rel_field(from_model, to_model, related_name):
34+    key = (from_model._meta.app_label, from_model._meta.object_name,
35+            to_model._meta.app_label, to_model._meta.object_name,
36+            related_name)
37+    try:
38+        found_field = memoized_fk_field_reversals[key]
39+    except KeyError:
40+        found_field = None
41+        for field in from_model._meta.fields:
42+            if field.__class__ in (ForeignKey, OneToOneField, ManyToManyField):
43+                if field.rel.to == to_model:
44+                    found_field = field
45+                    break
46+        memoized_fk_field_reversals[key] = found_field
47+    return found_field
48+
49+def get_model_for_db_table(db_table):
50+    for model in get_models():
51+        if model._meta.db_table == db_table:
52+            return model
53+    return None
54+get_model_for_db_table = memoize(get_model_for_db_table, model_db_table_cache, 1)
55+
56 #HACK
57 class RelatedField(object):
58     def contribute_to_class(self, cls, name):
59@@ -267,7 +294,8 @@
60     and adds behavior for many-to-many related objects."""
61     class ManyRelatedManager(superclass):
62         def __init__(self, model=None, core_filters=None, instance=None, symmetrical=None,
63-                join_table=None, source_col_name=None, target_col_name=None):
64+                join_table=None, source_col_name=None, source_attname=None,
65+                target_attname=None, target_col_name=None):
66             super(ManyRelatedManager, self).__init__()
67             self.core_filters = core_filters
68             self.model = model
69@@ -276,6 +304,9 @@
70             self.join_table = join_table
71             self.source_col_name = source_col_name
72             self.target_col_name = target_col_name
73+            self.source_attname = source_attname
74+            self.target_attname = target_attname
75+            self.intermediary_model = get_model_for_db_table(self.join_table.replace('"',''))
76             self._pk_val = self.instance._get_pk_val()
77             if self._pk_val is None:
78                 raise ValueError("%r instance needs to have a primary key value before a many-to-many relationship can be used." % model)
79@@ -340,9 +371,15 @@
80 
81                 # Add the ones that aren't there already
82                 for obj_id in (new_ids - existing_ids):
83-                    cursor.execute("INSERT INTO %s (%s, %s) VALUES (%%s, %%s)" % \
84+                    if self.intermediary_model == None:
85+                        cursor.execute("INSERT INTO %s (%s, %s) VALUES (%%s, %%s)" % \
86                         (self.join_table, source_col_name, target_col_name),
87                         [self._pk_val, obj_id])
88+                    else:
89+                        new_obj = self.intermediary_model()
90+                        setattr(new_obj, self.source_attname, self._pk_val)
91+                        setattr(new_obj, self.target_attname, obj_id)
92+                        new_obj.save()
93                 transaction.commit_unless_managed()
94 
95         def _remove_items(self, source_col_name, target_col_name, *objs):
96@@ -398,14 +435,17 @@
97         RelatedManager = create_many_related_manager(superclass)
98 
99         qn = connection.ops.quote_name
100+        rel_field = self.related.field
101         manager = RelatedManager(
102             model=rel_model,
103             core_filters={'%s__pk' % self.related.field.name: instance._get_pk_val()},
104             instance=instance,
105             symmetrical=False,
106-            join_table=qn(self.related.field.m2m_db_table()),
107-            source_col_name=qn(self.related.field.m2m_reverse_name()),
108-            target_col_name=qn(self.related.field.m2m_column_name())
109+            join_table=qn(rel_field.m2m_db_table()),
110+            source_col_name=qn(rel_field.m2m_reverse_name()),
111+            target_col_name=qn(rel_field.m2m_column_name()),
112+            source_attname=rel_field.m2m_reverse_attname(),
113+            target_attname=rel_field.m2m_attname()
114         )
115 
116         return manager
117@@ -446,7 +486,9 @@
118             symmetrical=(self.field.rel.symmetrical and instance.__class__ == rel_model),
119             join_table=qn(self.field.m2m_db_table()),
120             source_col_name=qn(self.field.m2m_column_name()),
121-            target_col_name=qn(self.field.m2m_reverse_name())
122+            target_col_name=qn(self.field.m2m_reverse_name()),
123+            source_attname=self.field.m2m_attname(),
124+            target_attname=self.field.m2m_reverse_attname()
125         )
126 
127         return manager
128@@ -648,8 +690,11 @@
129             filter_interface=kwargs.pop('filter_interface', None),
130             limit_choices_to=kwargs.pop('limit_choices_to', None),
131             raw_id_admin=kwargs.pop('raw_id_admin', False),
132-            symmetrical=kwargs.pop('symmetrical', True))
133+            symmetrical=kwargs.pop('symmetrical', True),
134+            through=kwargs.pop('through', None))
135         self.db_table = kwargs.pop('db_table', None)
136+        if kwargs['rel'].through:
137+            assert not self.db_table, "Cannot specify a db_table if an intermediary model is used."
138         if kwargs["rel"].raw_id_admin:
139             kwargs.setdefault("validator_list", []).append(self.isValidIDList)
140         Field.__init__(self, **kwargs)
141@@ -672,23 +717,53 @@
142 
143     def _get_m2m_db_table(self, opts):
144         "Function that can be curried to provide the m2m table name for this relation"
145-        if self.db_table:
146+        if self.rel.through != None:
147+            return get_model(opts.app_label, self.rel.through)._meta.db_table
148+        elif self.db_table:
149             return self.db_table
150         else:
151             return '%s_%s' % (opts.db_table, self.name)
152 
153+    def _get_m2m_attname(self, related):
154+        try:
155+            through = get_model(related.opts.app_label, self.rel.through)
156+            field = get_reverse_rel_field(through, related.model, self.rel.related_name)
157+            attname, column = field.get_attname_column()
158+            return attname
159+        except:
160+            return None
161+
162     def _get_m2m_column_name(self, related):
163         "Function that can be curried to provide the source column name for the m2m table"
164         # If this is an m2m relation to self, avoid the inevitable name clash
165-        if related.model == related.parent_model:
166+        if self.rel.through != None:
167+            through = get_model(related.opts.app_label, self.rel.through)
168+            field = get_reverse_rel_field(through, related.model, self.rel.related_name)
169+            attname, column = field.get_attname_column()
170+            return column
171+        elif related.model == related.parent_model:
172             return 'from_' + related.model._meta.object_name.lower() + '_id'
173         else:
174             return related.model._meta.object_name.lower() + '_id'
175 
176+    def _get_m2m_reverse_attname(self, related):
177+        try:
178+            through = get_model(related.opts.app_label, self.rel.through)
179+            field = get_reverse_rel_field(through, related.parent_model, self.rel.related_name)
180+            attname, column = field.get_attname_column()
181+            return attname
182+        except:
183+            return None
184+
185     def _get_m2m_reverse_name(self, related):
186         "Function that can be curried to provide the related column name for the m2m table"
187         # If this is an m2m relation to self, avoid the inevitable name clash
188-        if related.model == related.parent_model:
189+        if self.rel.through != None:
190+            through = get_model(related.opts.app_label, self.rel.through)
191+            field = get_reverse_rel_field(through, related.parent_model, self.rel.related_name)
192+            attname, column = field.get_attname_column()
193+            return column
194+        elif related.model == related.parent_model:
195             return 'to_' + related.parent_model._meta.object_name.lower() + '_id'
196         else:
197             return related.parent_model._meta.object_name.lower() + '_id'
198@@ -745,6 +820,8 @@
199         # Set up the accessors for the column names on the m2m table
200         self.m2m_column_name = curry(self._get_m2m_column_name, related)
201         self.m2m_reverse_name = curry(self._get_m2m_reverse_name, related)
202+        self.m2m_attname = curry(self._get_m2m_attname, related)
203+        self.m2m_reverse_attname = curry(self._get_m2m_reverse_attname, related)
204 
205     def set_attributes_from_rel(self):
206         pass
207@@ -809,7 +886,8 @@
208 
209 class ManyToManyRel(object):
210     def __init__(self, to, num_in_admin=0, related_name=None,
211-        filter_interface=None, limit_choices_to=None, raw_id_admin=False, symmetrical=True):
212+        filter_interface=None, limit_choices_to=None, raw_id_admin=False, symmetrical=True,
213+        through = None):
214         self.to = to
215         self.num_in_admin = num_in_admin
216         self.related_name = related_name
217@@ -821,5 +899,6 @@
218         self.raw_id_admin = raw_id_admin
219         self.symmetrical = symmetrical
220         self.multiple = True
221+        self.through = through
222 
223         assert not (self.raw_id_admin and self.filter_interface), "ManyToManyRels may not use both raw_id_admin and filter_interface"
224Index: django/core/management/sql.py
225===================================================================
226--- django/core/management/sql.py       (revision 6898)
227+++ django/core/management/sql.py       (working copy)
228@@ -349,7 +349,7 @@
229     qn = connection.ops.quote_name
230     inline_references = connection.features.inline_fk_references
231     for f in opts.many_to_many:
232-        if not isinstance(f.rel, generic.GenericRel):
233+        if not isinstance(f.rel, generic.GenericRel) and getattr(f.rel, 'through', None) == None:
234             tablespace = f.db_tablespace or opts.db_tablespace
235             if tablespace and connection.features.supports_tablespaces and connection.features.autoindexes_primary_keys:
236                 tablespace_sql = ' ' + connection.ops.tablespace_sql(tablespace, inline=True)
237Index: tests/modeltests/m2m_manual/__init__.py
238===================================================================
239Index: tests/modeltests/m2m_manual/models.py
240===================================================================
241--- tests/modeltests/m2m_manual/models.py       (revision 0)
242+++ tests/modeltests/m2m_manual/models.py       (revision 0)
243@@ -0,0 +1,121 @@
244+from django.db import models
245+from datetime import datetime
246+
247+# M2M described on one of the models
248+class Person(models.Model):
249+    name = models.CharField(max_length=128)
250+
251+    def __unicode__(self):
252+        return self.name
253+
254+class Group(models.Model):
255+    name = models.CharField(max_length=128)
256+    members = models.ManyToManyField(Person, through='Membership')
257+    custom_members = models.ManyToManyField(Person, through='CustomMembership', related_name="custom")
258+   
259+    def __unicode__(self):
260+        return self.name
261+
262+class Membership(models.Model):
263+    person = models.ForeignKey(Person)
264+    group = models.ForeignKey(Group)
265+    date_joined = models.DateTimeField(default=datetime.now)
266+    invite_reason = models.CharField(max_length=64, null=True, blank=True)
267+   
268+    def __unicode__(self):
269+        return "%s is a member of %s" % (self.person.name, self.group.name)
270+
271+class CustomMembership(models.Model):
272+    person = models.ForeignKey(Person, db_column="custom_person_column", related_name="custom_person_related_name")
273+    group = models.ForeignKey(Group)
274+    date_joined = models.DateTimeField(default=datetime.now)
275+   
276+    def __unicode__(self):
277+        return "%s is a member of %s" % (self.person.name, self.group.name)
278+
279+__test__ = {'API_TESTS':"""
280+>>> from datetime import datetime
281+
282+>>> bob = Person(name = 'Bob')
283+>>> bob.save()
284+>>> jim = Person(name = 'Jim')
285+>>> jim.save()
286+>>> jane = Person(name = 'Jane')
287+>>> jane.save()
288+>>> rock = Group(name = 'Rock')
289+>>> rock.save()
290+>>> roll = Group(name = 'Roll')
291+>>> roll.save()
292+
293+>>> rock.members.add(jim, jane)
294+>>> rock.members.all()
295+[<Person: Jim>, <Person: Jane>]
296+
297+>>> roll.members.add(bob, jim)
298+>>> roll.members.all()
299+[<Person: Bob>, <Person: Jim>]
300+
301+>>> jane.group_set.all()
302+[<Group: Rock>]
303+
304+>>> jane.group_set.add(roll)
305+>>> jane.group_set.all()
306+[<Group: Rock>, <Group: Roll>]
307+
308+>>> jim.group_set.all()
309+[<Group: Rock>, <Group: Roll>]
310+
311+# Check to make sure that the associated Membership object is created.
312+>>> m = Membership.objects.get(person = jane, group = rock)
313+>>> m
314+<Membership: Jane is a member of Rock>
315+
316+>>> m.invite_reason = "She was just so awesome."
317+>>> m.save()
318+
319+>>> Membership.objects.filter(person = jim)
320+[<Membership: Jim is a member of Rock>, <Membership: Jim is a member of Roll>]
321+
322+>>> rock.custom_members.add(bob)
323+>>> rock.custom_members.all()
324+[<Person: Bob>]
325+
326+>>> jim.custom.add(rock)
327+>>> rock.custom_members.all()
328+[<Person: Bob>, <Person: Jim>]
329+
330+>>> jim.custom.all()
331+[<Group: Rock>]
332+
333+>>> jim.custom_person_related_name.all()
334+[<CustomMembership: Jim is a member of Rock>]
335+
336+###QUERY TESTS###
337+# Queries involving the related model (Person, in the case of Group) use its attname
338+>>> Group.objects.filter(members__name='Bob')
339+[<Group: Roll>]
340+
341+# Queries involving the relationship model (Membership, in the case of Group) use its model name
342+>>> Group.objects.filter(membership__invite_reason = "She was just so awesome.")
343+[<Group: Rock>]
344+
345+# Queries involving the reverse related model (Group, in the case of Person) use its model name
346+>>> Person.objects.filter(group__name="Rock")
347+[<Person: Jim>, <Person: Jane>]
348+
349+# If the m2m field has specified a related_name, using that will work.
350+>>> Person.objects.filter(custom__name="Rock")
351+[<Person: Bob>, <Person: Jim>]
352+
353+# Queries involving the relationship model (Membership, in the case of Group) use its model name
354+>>> Person.objects.filter(membership__invite_reason = "She was just so awesome.")
355+[<Person: Jane>]
356+
357+# Sometimes these queries can return non-distinct resultsets.
358+>>> Person.objects.filter(membership__date_joined__day = datetime.now().day)
359+[<Person: Bob>, <Person: Jim>, <Person: Jim>, <Person: Jane>, <Person: Jane>]
360+
361+# Adding a .distinct() works to correct this.
362+>>> Person.objects.filter(membership__date_joined__day = datetime.now().day).distinct()
363+[<Person: Bob>, <Person: Jim>, <Person: Jane>]
364+"""}
365\ No newline at end of file