Code

Ticket #16458: 16458_inheritance_eq.diff

File 16458_inheritance_eq.diff, 6.2 KB (added by akaariai, 3 years ago)

inherited models can be equal

Line 
1diff --git a/django/db/models/base.py b/django/db/models/base.py
2index 71fd1f7..42efe54 100644
3--- a/django/db/models/base.py
4+++ b/django/db/models/base.py
5@@ -129,7 +129,12 @@ class ModelBase(type):
6         # Do the appropriate setup for any model parents.
7         o2o_map = dict([(f.rel.to, f) for f in new_class._meta.local_fields
8                 if isinstance(f, OneToOneField)])
9-
10+        # The concrete_models is a set of all multi-table models this model
11+        # represents, that is new_class if it is a concrete model, plus all
12+        # multi-table inherited parents. Needed for fast __eq__ implementation.
13+        new_class._meta.concrete_models = set()
14+        if not is_proxy and not abstract:
15+            new_class._meta.concrete_models.add(new_class)
16         for base in parents:
17             original_base = base
18             if not hasattr(base, '_meta'):
19@@ -152,6 +157,8 @@ class ModelBase(type):
20                 while base._meta.proxy:
21                     # Skip over a proxy class to the "real" base it proxies.
22                     base = base._meta.proxy_for_model
23+                # Found a concrete model, and this class represents it.
24+                new_class._meta.concrete_models.update(base._meta.concrete_models)
25                 if base in o2o_map:
26                     field = o2o_map[base]
27                 elif not is_proxy:
28@@ -382,7 +389,14 @@ class Model(object):
29         return '%s object' % self.__class__.__name__
30 
31     def __eq__(self, other):
32-        return isinstance(other, self.__class__) and self._get_pk_val() == other._get_pk_val()
33+        """
34+        Two models are considered equal if they share a common multitable-inherited parent
35+        or if one is an instance of the other. This can be used by checking the _meta
36+        .concrete_models variable.
37+        """
38+        return bool(isinstance(other, Model) and
39+               self._meta.concrete_models.intersection(other._meta.concrete_models) and
40+               self._get_pk_val() == other._get_pk_val())
41 
42     def __ne__(self, other):
43         return not self.__eq__(other)
44diff --git a/tests/modeltests/basic/tests.py b/tests/modeltests/basic/tests.py
45index ff09d9b..c066f43 100644
46--- a/tests/modeltests/basic/tests.py
47+++ b/tests/modeltests/basic/tests.py
48@@ -206,6 +206,9 @@ class ModelTest(TestCase):
49         # Check that != and == operators behave as expecte on instances
50         self.assertTrue(a7 != a8)
51         self.assertFalse(a7 == a8)
52+        # And that they work when the other is not a Model instance
53+        self.assertFalse(a7 == None)
54+        self.assertFalse(a7 == self)
55         self.assertEqual(a8, Article.objects.get(id__exact=a8.id))
56 
57         self.assertTrue(Article.objects.get(id__exact=a8.id) != Article.objects.get(id__exact=a7.id))
58diff --git a/tests/modeltests/defer/tests.py b/tests/modeltests/defer/tests.py
59index 5f6c53d..9cf32ee 100644
60--- a/tests/modeltests/defer/tests.py
61+++ b/tests/modeltests/defer/tests.py
62@@ -23,10 +23,18 @@ class DeferTests(TestCase):
63         p1 = Primary.objects.create(name="p1", value="xx", related=s1)
64 
65         qs = Primary.objects.all()
66-
67         self.assert_delayed(qs.defer("name")[0], 1)
68         self.assert_delayed(qs.only("name")[0], 2)
69         self.assert_delayed(qs.defer("related__first")[0], 0)
70+       
71+        # deferred instances are equal to their non-deferred counterpart
72+        deferred_p1 = qs.defer("name")[0]
73+        self.assertTrue(deferred_p1==p1)
74+        # The __eq__ operator is symmetric as well as the == operator
75+        self.assertTrue(p1==deferred_p1)
76+        self.assertTrue(deferred_p1.__eq__(p1))
77+        self.assertTrue(p1.__eq__(deferred_p1))
78+
79 
80         obj = qs.select_related().only("related__first")[0]
81         self.assert_delayed(obj, 2)
82diff --git a/tests/modeltests/model_inheritance/tests.py b/tests/modeltests/model_inheritance/tests.py
83index 334297a..28b9258 100644
84--- a/tests/modeltests/model_inheritance/tests.py
85+++ b/tests/modeltests/model_inheritance/tests.py
86@@ -69,6 +69,20 @@ class ModelInheritanceTests(TestCase):
87             StudentWorker.objects.get, pk__lt=sw2.pk + 100
88         )
89 
90+        # A multi-table inherited instance is considered equal to its base
91+        # class
92+        w = Worker(pk=1)
93+        sw = StudentWorker(pk=1)
94+        self.assertTrue(w==sw)
95+        # The __eq__ operator is symmetric as well as the '==' operator
96+        self.assertTrue(sw==w)
97+        self.assertTrue(w.__eq__(sw))
98+        self.assertTrue(sw.__eq__(w))
99+        # A common abstract base class does not lead to equality
100+        s = Student(pk=1)
101+        self.assertFalse(s==w)
102+       
103+
104     def test_multiple_table(self):
105         post = Post.objects.create(title="Lorem Ipsum")
106         # The Post model has distinct accessors for the Comment and Link models.
107@@ -269,6 +283,13 @@ class ModelInheritanceTests(TestCase):
108         self.assertNumQueries(1,
109             lambda: ItalianRestaurant.objects.select_related("chef")[0].chef
110         )
111+        # Having a common parent works when there are models in the chain
112+        r = Place(pk=1)
113+        i = ItalianRestaurant(pk=1)
114+        self.assertTrue(r==i)
115+        # A common concrete (multitable-inherited) parent leads to equality
116+        s = Supplier(pk=1)
117+        self.assertTrue(i==s)
118 
119     def test_mixin_init(self):
120         m = MixinModel()
121diff --git a/tests/modeltests/proxy_models/tests.py b/tests/modeltests/proxy_models/tests.py
122index 0a46a25..a4c7d40 100644
123--- a/tests/modeltests/proxy_models/tests.py
124+++ b/tests/modeltests/proxy_models/tests.py
125@@ -45,6 +45,18 @@ class ProxyModelTests(TestCase):
126         self.assertEqual(MyPerson.objects.get(name="Foo McBar").id, person.id)
127         self.assertFalse(MyPerson.objects.get(id=person.id).has_special_name())
128 
129+    def test_proxy_eq(self):
130+        """
131+        Proxied models are considered equal to their concrete base class.
132+        """
133+        p = Person(pk=1)
134+        mp = MyPerson(pk=1)
135+        self.assertTrue(p==mp)
136+        # the __eq__ operator is symmetric as well as the == operator
137+        self.assertTrue(mp==p)
138+        self.assertTrue(p.__eq__(mp))
139+        self.assertTrue(mp.__eq__(p))
140+
141     def test_no_proxy(self):
142         """
143         Person is not proxied by StatusPerson subclass.