Code

Ticket #6095: 6095-alpha-02.diff

File 6095-alpha-02.diff, 13.5 KB (added by floguy, 6 years ago)

Updated patch to handle custom related_name and custom db_column on manually created foreign keys.

Line 
1Index: django/db/models/fields/related.py
2===================================================================
3--- django/db/models/fields/related.py  (revision 6898)
4+++ django/db/models/fields/related.py  (working copy)
5@@ -1,10 +1,10 @@
6 from django.db import connection, transaction
7-from django.db.models import signals, get_model
8+from django.db.models import signals, get_model, get_models
9 from django.db.models.fields import AutoField, Field, IntegerField, PositiveIntegerField, PositiveSmallIntegerField, get_ul_class
10 from django.db.models.related import RelatedObject
11 from django.utils.text import capfirst
12 from django.utils.translation import ugettext_lazy, string_concat, ungettext, ugettext as _
13-from django.utils.functional import curry
14+from django.utils.functional import curry, memoize
15 from django.utils.encoding import smart_unicode
16 from django.core import validators
17 from django import oldforms
18@@ -23,6 +23,10 @@
19 
20 pending_lookups = {}
21 
22+memoized_fk_field_reversals = {}
23+
24+model_db_table_cache = {}
25+
26 def add_lookup(rel_cls, field):
27     name = field.rel.to
28     module = rel_cls.__module__
29@@ -54,6 +58,29 @@
30     except klass.DoesNotExist:
31         raise validators.ValidationError, _("Please enter a valid %s.") % f.verbose_name
32 
33+def get_reverse_rel_field(from_model, to_model, related_name):
34+    key = (from_model._meta.app_label, from_model._meta.object_name,
35+            to_model._meta.app_label, to_model._meta.object_name,
36+            related_name)
37+    try:
38+        found_field = memoized_fk_field_reversals[key]
39+    except KeyError:
40+        found_field = None
41+        for field in from_model._meta.fields:
42+            if field.__class__ in (ForeignKey, OneToOneField, ManyToManyField):
43+                if field.rel.to == to_model:
44+                    found_field = field
45+                    break
46+        memoized_fk_field_reversals[key] = found_field
47+    return found_field
48+
49+def get_model_for_db_table(db_table):
50+    for model in get_models():
51+        if model._meta.db_table == db_table:
52+            return model
53+    return None
54+get_model_for_db_table = memoize(get_model_for_db_table, model_db_table_cache, 1)
55+
56 #HACK
57 class RelatedField(object):
58     def contribute_to_class(self, cls, name):
59@@ -267,7 +294,8 @@
60     and adds behavior for many-to-many related objects."""
61     class ManyRelatedManager(superclass):
62         def __init__(self, model=None, core_filters=None, instance=None, symmetrical=None,
63-                join_table=None, source_col_name=None, target_col_name=None):
64+                join_table=None, source_col_name=None, source_attname=None,
65+                target_attname=None, target_col_name=None):
66             super(ManyRelatedManager, self).__init__()
67             self.core_filters = core_filters
68             self.model = model
69@@ -276,6 +304,9 @@
70             self.join_table = join_table
71             self.source_col_name = source_col_name
72             self.target_col_name = target_col_name
73+            self.source_attname = source_attname
74+            self.target_attname = target_attname
75+            self.intermediary_model = get_model_for_db_table(self.join_table.replace('"',''))
76             self._pk_val = self.instance._get_pk_val()
77             if self._pk_val is None:
78                 raise ValueError("%r instance needs to have a primary key value before a many-to-many relationship can be used." % model)
79@@ -340,9 +371,15 @@
80 
81                 # Add the ones that aren't there already
82                 for obj_id in (new_ids - existing_ids):
83-                    cursor.execute("INSERT INTO %s (%s, %s) VALUES (%%s, %%s)" % \
84+                    if self.intermediary_model == None:
85+                        cursor.execute("INSERT INTO %s (%s, %s) VALUES (%%s, %%s)" % \
86                         (self.join_table, source_col_name, target_col_name),
87                         [self._pk_val, obj_id])
88+                    else:
89+                        new_obj = self.intermediary_model()
90+                        setattr(new_obj, self.source_attname, self._pk_val)
91+                        setattr(new_obj, self.target_attname, obj_id)
92+                        new_obj.save()
93                 transaction.commit_unless_managed()
94 
95         def _remove_items(self, source_col_name, target_col_name, *objs):
96@@ -398,14 +435,18 @@
97         RelatedManager = create_many_related_manager(superclass)
98 
99         qn = connection.ops.quote_name
100+        source_attname, target_attname = None, None
101+        rel_field = self.related.field
102         manager = RelatedManager(
103             model=rel_model,
104             core_filters={'%s__pk' % self.related.field.name: instance._get_pk_val()},
105             instance=instance,
106             symmetrical=False,
107-            join_table=qn(self.related.field.m2m_db_table()),
108-            source_col_name=qn(self.related.field.m2m_reverse_name()),
109-            target_col_name=qn(self.related.field.m2m_column_name())
110+            join_table=qn(rel_field.m2m_db_table()),
111+            source_col_name=qn(rel_field.m2m_reverse_name()),
112+            target_col_name=qn(rel_field.m2m_column_name()),
113+            source_attname=rel_field.m2m_reverse_attname(),
114+            target_attname=rel_field.m2m_attname()
115         )
116 
117         return manager
118@@ -446,7 +487,9 @@
119             symmetrical=(self.field.rel.symmetrical and instance.__class__ == rel_model),
120             join_table=qn(self.field.m2m_db_table()),
121             source_col_name=qn(self.field.m2m_column_name()),
122-            target_col_name=qn(self.field.m2m_reverse_name())
123+            target_col_name=qn(self.field.m2m_reverse_name()),
124+            source_attname=self.field.m2m_attname(),
125+            target_attname=self.field.m2m_reverse_attname()
126         )
127 
128         return manager
129@@ -648,7 +691,8 @@
130             filter_interface=kwargs.pop('filter_interface', None),
131             limit_choices_to=kwargs.pop('limit_choices_to', None),
132             raw_id_admin=kwargs.pop('raw_id_admin', False),
133-            symmetrical=kwargs.pop('symmetrical', True))
134+            symmetrical=kwargs.pop('symmetrical', True),
135+            through=kwargs.pop('through', None))
136         self.db_table = kwargs.pop('db_table', None)
137         if kwargs["rel"].raw_id_admin:
138             kwargs.setdefault("validator_list", []).append(self.isValidIDList)
139@@ -672,23 +716,53 @@
140 
141     def _get_m2m_db_table(self, opts):
142         "Function that can be curried to provide the m2m table name for this relation"
143-        if self.db_table:
144+        if self.rel.through != None:
145+            return get_model(opts.app_label, self.rel.through)._meta.db_table
146+        elif self.db_table:
147             return self.db_table
148         else:
149             return '%s_%s' % (opts.db_table, self.name)
150 
151+    def _get_m2m_attname(self, related):
152+        try:
153+            through = get_model(related.opts.app_label, self.rel.through)
154+            field = get_reverse_rel_field(through, related.model, self.rel.related_name)
155+            attname, column = field.get_attname_column()
156+            return attname
157+        except:
158+            return None
159+
160     def _get_m2m_column_name(self, related):
161         "Function that can be curried to provide the source column name for the m2m table"
162         # If this is an m2m relation to self, avoid the inevitable name clash
163-        if related.model == related.parent_model:
164+        if self.rel.through != None:
165+            through = get_model(related.opts.app_label, self.rel.through)
166+            field = get_reverse_rel_field(through, related.model, self.rel.related_name)
167+            attname, column = field.get_attname_column()
168+            return column
169+        elif related.model == related.parent_model:
170             return 'from_' + related.model._meta.object_name.lower() + '_id'
171         else:
172             return related.model._meta.object_name.lower() + '_id'
173 
174+    def _get_m2m_reverse_attname(self, related):
175+        try:
176+            through = get_model(related.opts.app_label, self.rel.through)
177+            field = get_reverse_rel_field(through, related.parent_model, self.rel.related_name)
178+            attname, column = field.get_attname_column()
179+            return attname
180+        except:
181+            return None
182+
183     def _get_m2m_reverse_name(self, related):
184         "Function that can be curried to provide the related column name for the m2m table"
185         # If this is an m2m relation to self, avoid the inevitable name clash
186-        if related.model == related.parent_model:
187+        if self.rel.through != None:
188+            through = get_model(related.opts.app_label, self.rel.through)
189+            field = get_reverse_rel_field(through, related.parent_model, self.rel.related_name)
190+            attname, column = field.get_attname_column()
191+            return column
192+        elif related.model == related.parent_model:
193             return 'to_' + related.parent_model._meta.object_name.lower() + '_id'
194         else:
195             return related.parent_model._meta.object_name.lower() + '_id'
196@@ -745,6 +819,8 @@
197         # Set up the accessors for the column names on the m2m table
198         self.m2m_column_name = curry(self._get_m2m_column_name, related)
199         self.m2m_reverse_name = curry(self._get_m2m_reverse_name, related)
200+        self.m2m_attname = curry(self._get_m2m_attname, related)
201+        self.m2m_reverse_attname = curry(self._get_m2m_reverse_attname, related)
202 
203     def set_attributes_from_rel(self):
204         pass
205@@ -809,7 +885,8 @@
206 
207 class ManyToManyRel(object):
208     def __init__(self, to, num_in_admin=0, related_name=None,
209-        filter_interface=None, limit_choices_to=None, raw_id_admin=False, symmetrical=True):
210+        filter_interface=None, limit_choices_to=None, raw_id_admin=False, symmetrical=True,
211+        through = None):
212         self.to = to
213         self.num_in_admin = num_in_admin
214         self.related_name = related_name
215@@ -821,5 +898,6 @@
216         self.raw_id_admin = raw_id_admin
217         self.symmetrical = symmetrical
218         self.multiple = True
219+        self.through = through
220 
221         assert not (self.raw_id_admin and self.filter_interface), "ManyToManyRels may not use both raw_id_admin and filter_interface"
222Index: django/core/management/sql.py
223===================================================================
224--- django/core/management/sql.py       (revision 6898)
225+++ django/core/management/sql.py       (working copy)
226@@ -349,7 +349,7 @@
227     qn = connection.ops.quote_name
228     inline_references = connection.features.inline_fk_references
229     for f in opts.many_to_many:
230-        if not isinstance(f.rel, generic.GenericRel):
231+        if not isinstance(f.rel, generic.GenericRel) and getattr(f.rel, 'through', None) == None:
232             tablespace = f.db_tablespace or opts.db_tablespace
233             if tablespace and connection.features.supports_tablespaces and connection.features.autoindexes_primary_keys:
234                 tablespace_sql = ' ' + connection.ops.tablespace_sql(tablespace, inline=True)
235Index: tests/modeltests/m2m_manual/__init__.py
236===================================================================
237Index: tests/modeltests/m2m_manual/models.py
238===================================================================
239--- tests/modeltests/m2m_manual/models.py       (revision 0)
240+++ tests/modeltests/m2m_manual/models.py       (revision 0)
241@@ -0,0 +1,84 @@
242+from django.db import models
243+from datetime import datetime
244+
245+# M2M described on one of the models
246+class Person(models.Model):
247+    name = models.CharField(max_length=128)
248+
249+    def __unicode__(self):
250+        return self.name
251+
252+class Group(models.Model):
253+    name = models.CharField(max_length=128)
254+    members = models.ManyToManyField(Person, through='Membership')
255+    custom_members = models.ManyToManyField(Person, through='CustomMembership', related_name="custom")
256+   
257+    def __unicode__(self):
258+        return self.name
259+
260+class Membership(models.Model):
261+    person = models.ForeignKey(Person)
262+    group = models.ForeignKey(Group)
263+    date_joined = models.DateTimeField(default=datetime.now)
264+   
265+    def __unicode__(self):
266+        return "%s is a member of %s" % (self.person.name, self.group.name)
267+
268+class CustomMembership(models.Model):
269+    person = models.ForeignKey(Person, db_column="custom_person_column", related_name="custom_person_related_name")
270+    group = models.ForeignKey(Group)
271+    date_joined = models.DateTimeField(default=datetime.now)
272+   
273+    def __unicode__(self):
274+        return "%s is a member of %s" % (self.person.name, self.group.name)
275+
276+__test__ = {'API_TESTS':"""
277+>>> bob = Person(name = 'Bob')
278+>>> bob.save()
279+>>> jim = Person(name = 'Jim')
280+>>> jim.save()
281+>>> jane = Person(name = 'Jane')
282+>>> jane.save()
283+>>> rock = Group(name = 'Rock')
284+>>> rock.save()
285+>>> roll = Group(name = 'Roll')
286+>>> roll.save()
287+
288+>>> rock.members.add(jim, jane)
289+>>> rock.members.all()
290+[<Person: Jim>, <Person: Jane>]
291+
292+>>> roll.members.add(bob, jim)
293+>>> roll.members.all()
294+[<Person: Bob>, <Person: Jim>]
295+
296+>>> jane.group_set.all()
297+[<Group: Rock>]
298+
299+>>> jane.group_set.add(roll)
300+>>> jane.group_set.all()
301+[<Group: Rock>, <Group: Roll>]
302+
303+>>> jim.group_set.all()
304+[<Group: Rock>, <Group: Roll>]
305+
306+>>> Membership.objects.filter(person = jane, group = rock)
307+[<Membership: Jane is a member of Rock>]
308+
309+>>> Membership.objects.filter(person = jim)
310+[<Membership: Jim is a member of Rock>, <Membership: Jim is a member of Roll>]
311+
312+>>> rock.custom_members.add(bob)
313+>>> rock.custom_members.all()
314+[<Person: Bob>]
315+
316+>>> jim.custom.add(rock)
317+>>> rock.custom_members.all()
318+[<Person: Bob>, <Person: Jim>]
319+
320+>>> jim.custom.all()
321+[<Group: Rock>]
322+
323+>>> jim.custom_person_related_name.all()
324+[<CustomMembership: Jim is a member of Rock>]
325+"""}
326\ No newline at end of file