Code

Ticket #6095: 6095-alpha-01.diff

File 6095-alpha-01.diff, 9.6 KB (added by floguy, 7 years ago)

Rudimentary patch to get the base functionality working. Tests were first, then implementation.

Line 
1Index: django/db/models/fields/related.py
2===================================================================
3--- django/db/models/fields/related.py  (revision 6888)
4+++ django/db/models/fields/related.py  (working copy)
5@@ -1,10 +1,10 @@
6 from django.db import connection, transaction
7-from django.db.models import signals, get_model
8+from django.db.models import signals, get_model, get_models
9 from django.db.models.fields import AutoField, Field, IntegerField, PositiveIntegerField, PositiveSmallIntegerField, get_ul_class
10 from django.db.models.related import RelatedObject
11 from django.utils.text import capfirst
12 from django.utils.translation import ugettext_lazy, string_concat, ungettext, ugettext as _
13-from django.utils.functional import curry
14+from django.utils.functional import curry, memoize
15 from django.utils.encoding import smart_unicode
16 from django.core import validators
17 from django import oldforms
18@@ -23,6 +23,10 @@
19 
20 pending_lookups = {}
21 
22+memoized_fk_field_reversals = {}
23+
24+model_db_table_cache = {}
25+
26 def add_lookup(rel_cls, field):
27     name = field.rel.to
28     module = rel_cls.__module__
29@@ -54,6 +58,30 @@
30     except klass.DoesNotExist:
31         raise validators.ValidationError, _("Please enter a valid %s.") % f.verbose_name
32 
33+def get_reverse_rel_field(from_model, to_model, related_name):
34+    key = (from_model._meta.app_label, from_model._meta.object_name,
35+            to_model._meta.app_label, to_model._meta.object_name,
36+            related_name)
37+    try:
38+        found_field = memoized_fk_field_reversals[key]
39+    except KeyError:
40+        found_field = None
41+        for field in from_model._meta.fields:
42+            if field.__class__ in (ForeignKey, OneToOneField, ManyToManyField):
43+                if field.rel.related_name == related_name:
44+                    if field.rel.to == to_model:
45+                        found_field = field
46+                        break
47+        memoized_fk_field_reversals[key] = found_field
48+    return found_field
49+
50+def get_model_for_db_table(db_table):
51+    for model in get_models():
52+        if model._meta.db_table == db_table:
53+            return model
54+    return None
55+get_model_for_db_table = memoize(get_model_for_db_table, model_db_table_cache, 1)
56+
57 #HACK
58 class RelatedField(object):
59     def contribute_to_class(self, cls, name):
60@@ -276,6 +304,7 @@
61             self.join_table = join_table
62             self.source_col_name = source_col_name
63             self.target_col_name = target_col_name
64+            self.intermediary_model = get_model_for_db_table(self.join_table.replace('"',''))
65             self._pk_val = self.instance._get_pk_val()
66             if self._pk_val is None:
67                 raise ValueError("%r instance needs to have a primary key value before a many-to-many relationship can be used." % model)
68@@ -340,9 +369,15 @@
69 
70                 # Add the ones that aren't there already
71                 for obj_id in (new_ids - existing_ids):
72-                    cursor.execute("INSERT INTO %s (%s, %s) VALUES (%%s, %%s)" % \
73+                    if self.intermediary_model == None:
74+                        cursor.execute("INSERT INTO %s (%s, %s) VALUES (%%s, %%s)" % \
75                         (self.join_table, source_col_name, target_col_name),
76                         [self._pk_val, obj_id])
77+                    else:
78+                        new_obj = self.intermediary_model()
79+                        setattr(new_obj, source_col_name.replace('"', ''), self._pk_val)
80+                        setattr(new_obj, target_col_name.replace('"', ''), obj_id)
81+                        new_obj.save()
82                 transaction.commit_unless_managed()
83 
84         def _remove_items(self, source_col_name, target_col_name, *objs):
85@@ -648,7 +683,8 @@
86             filter_interface=kwargs.pop('filter_interface', None),
87             limit_choices_to=kwargs.pop('limit_choices_to', None),
88             raw_id_admin=kwargs.pop('raw_id_admin', False),
89-            symmetrical=kwargs.pop('symmetrical', True))
90+            symmetrical=kwargs.pop('symmetrical', True),
91+            through=kwargs.pop('through', None))
92         self.db_table = kwargs.pop('db_table', None)
93         if kwargs["rel"].raw_id_admin:
94             kwargs.setdefault("validator_list", []).append(self.isValidIDList)
95@@ -672,7 +708,9 @@
96 
97     def _get_m2m_db_table(self, opts):
98         "Function that can be curried to provide the m2m table name for this relation"
99-        if self.db_table:
100+        if self.rel.through != None:
101+            return get_model(opts.app_label, self.rel.through)._meta.db_table
102+        elif self.db_table:
103             return self.db_table
104         else:
105             return '%s_%s' % (opts.db_table, self.name)
106@@ -680,7 +718,12 @@
107     def _get_m2m_column_name(self, related):
108         "Function that can be curried to provide the source column name for the m2m table"
109         # If this is an m2m relation to self, avoid the inevitable name clash
110-        if related.model == related.parent_model:
111+        if self.rel.through != None:
112+            through = get_model(related.opts.app_label, self.rel.through)
113+            field = get_reverse_rel_field(through, related.model, self.rel.related_name)
114+            attname, column = field.get_attname_column()
115+            return column
116+        elif related.model == related.parent_model:
117             return 'from_' + related.model._meta.object_name.lower() + '_id'
118         else:
119             return related.model._meta.object_name.lower() + '_id'
120@@ -688,7 +731,12 @@
121     def _get_m2m_reverse_name(self, related):
122         "Function that can be curried to provide the related column name for the m2m table"
123         # If this is an m2m relation to self, avoid the inevitable name clash
124-        if related.model == related.parent_model:
125+        if self.rel.through != None:
126+            through = get_model(related.opts.app_label, self.rel.through)
127+            field = get_reverse_rel_field(through, related.parent_model, self.rel.related_name)
128+            attname, column = field.get_attname_column()
129+            return column
130+        elif related.model == related.parent_model:
131             return 'to_' + related.parent_model._meta.object_name.lower() + '_id'
132         else:
133             return related.parent_model._meta.object_name.lower() + '_id'
134@@ -809,7 +857,8 @@
135 
136 class ManyToManyRel(object):
137     def __init__(self, to, num_in_admin=0, related_name=None,
138-        filter_interface=None, limit_choices_to=None, raw_id_admin=False, symmetrical=True):
139+        filter_interface=None, limit_choices_to=None, raw_id_admin=False, symmetrical=True,
140+        through = None):
141         self.to = to
142         self.num_in_admin = num_in_admin
143         self.related_name = related_name
144@@ -821,5 +870,6 @@
145         self.raw_id_admin = raw_id_admin
146         self.symmetrical = symmetrical
147         self.multiple = True
148+        self.through = through
149 
150         assert not (self.raw_id_admin and self.filter_interface), "ManyToManyRels may not use both raw_id_admin and filter_interface"
151Index: django/core/management/sql.py
152===================================================================
153--- django/core/management/sql.py       (revision 6888)
154+++ django/core/management/sql.py       (working copy)
155@@ -349,7 +349,7 @@
156     qn = connection.ops.quote_name
157     inline_references = connection.features.inline_fk_references
158     for f in opts.many_to_many:
159-        if not isinstance(f.rel, generic.GenericRel):
160+        if not isinstance(f.rel, generic.GenericRel) and getattr(f.rel, 'through', None) == None:
161             tablespace = f.db_tablespace or opts.db_tablespace
162             if tablespace and connection.features.supports_tablespaces and connection.features.autoindexes_primary_keys:
163                 tablespace_sql = ' ' + connection.ops.tablespace_sql(tablespace, inline=True)
164Index: tests/modeltests/m2m_manual/__init__.py
165===================================================================
166Index: tests/modeltests/m2m_manual/models.py
167===================================================================
168--- tests/modeltests/m2m_manual/models.py       (revision 0)
169+++ tests/modeltests/m2m_manual/models.py       (revision 0)
170@@ -0,0 +1,61 @@
171+from django.db import models
172+from datetime import datetime
173+
174+# M2M described on one of the models
175+class Person(models.Model):
176+    name = models.CharField(max_length=128)
177+
178+    def __unicode__(self):
179+        return self.name
180+
181+class Group(models.Model):
182+    name = models.CharField(max_length=128)
183+    members = models.ManyToManyField(Person, through='Membership')
184+   
185+    def __unicode__(self):
186+        return self.name
187+
188+class Membership(models.Model):
189+    person = models.ForeignKey(Person)
190+    group = models.ForeignKey(Group)
191+    date_joined = models.DateTimeField(default=datetime.now)
192+   
193+    def __unicode__(self):
194+        return "%s is a member of %s" % (self.person.name, self.group.name)
195+
196+__test__ = {'API_TESTS':"""
197+>>> bob = Person(name = 'Bob')
198+>>> bob.save()
199+>>> jim = Person(name = 'Jim')
200+>>> jim.save()
201+>>> jane = Person(name = 'Jane')
202+>>> jane.save()
203+>>> rock = Group(name = 'Rock')
204+>>> rock.save()
205+>>> roll = Group(name = 'Roll')
206+>>> roll.save()
207+
208+>>> rock.members.add(jim, jane)
209+>>> rock.members.all()
210+[<Person: Jim>, <Person: Jane>]
211+
212+>>> roll.members.add(bob, jim)
213+>>> roll.members.all()
214+[<Person: Bob>, <Person: Jim>]
215+
216+>>> jane.group_set.all()
217+[<Group: Rock>]
218+
219+>>> jane.group_set.add(roll)
220+>>> jane.group_set.all()
221+[<Group: Rock>, <Group: Roll>]
222+
223+>>> jim.group_set.all()
224+[<Group: Rock>, <Group: Roll>]
225+
226+>>> Membership.objects.filter(person = jane, group = rock)
227+[<Membership: Jane is a member of Rock>]
228+
229+>>> Membership.objects.filter(person = jim)
230+[<Membership: Jim is a member of Rock>, <Membership: Jim is a member of Roll>]
231+"""}
232\ No newline at end of file