From 39b1c73f7743f5a6d26d161db0d79f8de4402f78 Mon Sep 17 00:00:00 2001
From: Nate Bragg <jonathan.bragg@alum.rpi.edu>
Date: Tue, 17 Jan 2012 22:28:55 -0500
Subject: [PATCH] Modification of dgouldin's patch that fixes PEP8ness and
 tests.

---
 django/db/models/sql/query.py          |    6 ++++++
 tests/modeltests/many_to_many/tests.py |   10 ++++++++++
 tests/modeltests/many_to_one/tests.py  |    8 ++++++++
 tests/modeltests/one_to_one/tests.py   |   22 +++++++++++++++-------
 4 files changed, 39 insertions(+), 7 deletions(-)

diff --git a/django/db/models/sql/query.py b/django/db/models/sql/query.py
index ed2bc06..f53bd4f 100644
--- a/django/db/models/sql/query.py
+++ b/django/db/models/sql/query.py
@@ -14,8 +14,10 @@ from django.utils.encoding import force_unicode
 from django.utils.tree import Node
 from django.db import connections, DEFAULT_DB_ALIAS
 from django.db.models import signals
+from django.db.models.base import Model
 from django.db.models.expressions import ExpressionNode
 from django.db.models.fields import FieldDoesNotExist
+from django.db.models.fields.related import RelatedField
 from django.db.models.query_utils import InvalidQuery
 from django.db.models.sql import aggregates as base_aggregates_module
 from django.db.models.sql.constants import *
@@ -1108,6 +1110,10 @@ class Query(object):
                     can_reuse)
             return
 
+        if (isinstance(field, RelatedField) and isinstance(value, Model) and
+                not isinstance(value, target.model)):
+            raise TypeError, "'%s' instance expected" % target.model._meta.object_name
+
         table_promote = False
         join_promote = False
 
diff --git a/tests/modeltests/many_to_many/tests.py b/tests/modeltests/many_to_many/tests.py
index b00d7da..6370eb0 100644
--- a/tests/modeltests/many_to_many/tests.py
+++ b/tests/modeltests/many_to_many/tests.py
@@ -199,6 +199,11 @@ class ManyToManyTests(TestCase):
         self.assertQuerysetEqual(Article.objects.exclude(publications=self.p2),
                                  ['<Article: Django lets you build Web apps easily>'])
 
+        # Filter values on related fields are checked to ensure the correct
+        # model class is being used.
+        self.assertRaises(TypeError, Article.objects.filter,
+            publications=self.a1)
+
     def test_reverse_selects(self):
         # Reverse m2m queries are supported (i.e., starting at the table that
         # doesn't have a ManyToManyField).
@@ -249,6 +254,11 @@ class ManyToManyTests(TestCase):
                 '<Publication: The Python Journal>',
             ])
 
+        # Filter values on related fields are checked to ensure the correct
+        # model class is being used.
+        self.assertRaises(TypeError, Publication.objects.filter,
+             article=self.p1)
+
     def test_delete(self):
         # If we delete a Publication, its Articles won't be able to access it.
         self.p1.delete()
diff --git a/tests/modeltests/many_to_one/tests.py b/tests/modeltests/many_to_one/tests.py
index 922506e..b43bf03 100644
--- a/tests/modeltests/many_to_one/tests.py
+++ b/tests/modeltests/many_to_one/tests.py
@@ -232,6 +232,10 @@ class ManyToOneTests(TestCase):
                 "<Article: This is a test>",
             ])
 
+        # Filter values on related fields are checked to ensure the correct
+        # model class is being used.
+        self.assertRaises(TypeError, Article.objects.filter, reporter=self.a)
+
     def test_reverse_selects(self):
         a3 = Article.objects.create(id=None, headline="Third article",
                                     pub_date=datetime(2005, 7, 27), reporter_id=self.r.id)
@@ -303,6 +307,10 @@ class ManyToOneTests(TestCase):
             list(Article.objects.filter(reporter=self.r).distinct().order_by()
                  .values('reporter__first_name', 'reporter__last_name')))
 
+        # Filter values on related fields are checked to ensure the correct
+        # model class is being used.
+        self.assertRaises(TypeError, Reporter.objects.filter, article=self.r)
+
     def test_select_related(self):
         # Check that Article.objects.select_related().dates() works properly when
         # there are multiple Articles with the same date but different foreign-key
diff --git a/tests/modeltests/one_to_one/tests.py b/tests/modeltests/one_to_one/tests.py
index 6ee7852..6800aef 100644
--- a/tests/modeltests/one_to_one/tests.py
+++ b/tests/modeltests/one_to_one/tests.py
@@ -65,6 +65,10 @@ class OneToOneTests(TestCase):
         assert_get_restaurant(place__pk=self.p1.pk)
         assert_get_restaurant(place__name__startswith="Demon")
 
+        # Filter values on related fields are checked to ensure the correct
+        # model class is being used.
+        self.assertRaises(TypeError, Restaurant.objects.get, place=self.r)
+
         def assert_get_place(**params):
             self.assertEqual(repr(Place.objects.get(**params)),
                              '<Place: Demon Dogs the place>')
@@ -79,6 +83,10 @@ class OneToOneTests(TestCase):
         assert_get_place(id__exact=self.p1.pk)
         assert_get_place(pk=self.p1.pk)
 
+        # Filter values on related fields are checked to ensure the correct
+        # model class is being used.
+        self.assertRaises(TypeError, Place.objects.get, restaurant=self.p1)
+
     def test_foreign_key(self):
         # Add a Waiter to the Restaurant.
         w = self.r.waiter_set.create(name='Joe')
@@ -92,15 +100,15 @@ class OneToOneTests(TestCase):
         assert_filter_waiters(restaurant__place__exact=self.p1.pk)
         assert_filter_waiters(restaurant__place__exact=self.p1)
         assert_filter_waiters(restaurant__place__pk=self.p1.pk)
-        assert_filter_waiters(restaurant__exact=self.p1.pk)
-        assert_filter_waiters(restaurant__exact=self.p1)
-        assert_filter_waiters(restaurant__pk=self.p1.pk)
-        assert_filter_waiters(restaurant=self.p1.pk)
+        assert_filter_waiters(restaurant__exact=self.r.pk)
+        assert_filter_waiters(restaurant__exact=self.r)
+        assert_filter_waiters(restaurant__pk=self.r.pk)
+        assert_filter_waiters(restaurant=self.r.pk)
         assert_filter_waiters(restaurant=self.r)
-        assert_filter_waiters(id__exact=self.p1.pk)
-        assert_filter_waiters(pk=self.p1.pk)
+        assert_filter_waiters(id__exact=self.r.pk)
+        assert_filter_waiters(pk=self.r.pk)
         # Delete the restaurant; the waiter should also be removed
-        r = Restaurant.objects.get(pk=self.p1.pk)
+        r = Restaurant.objects.get(pk=self.r.pk)
         r.delete()
         self.assertEqual(Waiter.objects.count(), 0)
 
-- 
1.7.5.4

