Code

Ticket #10159: geowherenode_expressions_fix_v3.diff

File geowherenode_expressions_fix_v3.diff, 16.1 KB (added by jbronn, 5 years ago)

Now takes into account different SRIDs.

Line 
1Index: django/contrib/gis/db/models/sql/where.py
2===================================================================
3--- django/contrib/gis/db/models/sql/where.py   (revision 9814)
4+++ django/contrib/gis/db/models/sql/where.py   (working copy)
5@@ -1,7 +1,11 @@
6-import datetime
7+from django.db import connection
8 from django.db.models.fields import Field
9+from django.db.models.sql.constants import LOOKUP_SEP
10+from django.db.models.sql.expressions import SQLEvaluator
11 from django.db.models.sql.where import WhereNode
12 from django.contrib.gis.db.backend import get_geo_where_clause, SpatialBackend
13+from django.contrib.gis.db.models.fields import GeometryField
14+qn = connection.ops.quote_name
15 
16 class GeoAnnotation(object):
17     """
18@@ -37,10 +41,36 @@
19             # Not a geographic field, so call `WhereNode.add`.
20             return super(GeoWhereNode, self).add(data, connector)
21         else:
22-            # `GeometryField.get_db_prep_lookup` returns a where clause
23-            # substitution array in addition to the parameters.
24-            where, params = field.get_db_prep_lookup(lookup_type, value)
25+            if isinstance(value, SQLEvaluator):
26+                # Getting the geographic field to compare with from the expression.
27+                geo_fld = self._check_geo_field(value.opts, value.expression.name)
28+                if not geo_fld:
29+                    raise ValueError('No geographic field found in expression.')
30 
31+                # Get the SRID of the geometry field that the expression was meant
32+                # to operate on -- it's needed to determine whether transformation
33+                # SQL is necessary.
34+                srid = geo_fld._srid
35+
36+                # Getting the quoted representation of the geometry column that
37+                # the expression is operating on.
38+                geo_col = '%s.%s' % tuple(map(qn, value.cols[value.expression]))
39+
40+                # If it's in a different SRID, we'll need to wrap in
41+                # transformation SQL.
42+                if srid != field._srid and SpatialBackend.transform:
43+                    placeholder = '%s(%%s, %s)' % (SpatialBackend.transform, field._srid)
44+                else:
45+                    placeholder = '%s'
46+
47+                # Setting these up as if we had called `field.get_db_prep_lookup()`.
48+                where =  [placeholder % geo_col]
49+                params = ()
50+            else:
51+                # `GeometryField.get_db_prep_lookup` returns a where clause
52+                # substitution array in addition to the parameters.
53+                where, params = field.get_db_prep_lookup(lookup_type, value)
54+
55             # The annotation will be a `GeoAnnotation` object that
56             # will contain the necessary geometry field metadata for
57             # the `get_geo_where_clause` to construct the appropriate
58@@ -64,3 +94,42 @@
59             # If not a GeometryField, call the `make_atom` from the
60             # base class.
61             return super(GeoWhereNode, self).make_atom(child, qn)
62+
63+    @classmethod
64+    def _check_geo_field(cls, opts, lookup):
65+        """
66+        Utility for checking the given lookup with the given model options. 
67+        The lookup is a string either specifying the geographic field, e.g.
68+        'point, 'the_geom', or a related lookup on a geographic field like
69+        'address__point'.
70+
71+        If a GeometryField exists according to the given lookup on the model
72+        options, it will be returned.  Otherwise returns None.
73+        """
74+        # This takes into account the situation where the lookup is a
75+        # lookup to a related geographic field, e.g., 'address__point'.
76+        field_list = lookup.split(LOOKUP_SEP)
77+
78+        # Reversing so list operates like a queue of related lookups,
79+        # and popping the top lookup.
80+        field_list.reverse()
81+        fld_name = field_list.pop()
82+
83+        try:
84+            geo_fld = opts.get_field(fld_name)
85+            # If the field list is still around, then it means that the
86+            # lookup was for a geometry field across a relationship --
87+            # thus we keep on getting the related model options and the
88+            # model field associated with the next field in the list
89+            # until there's no more left.
90+            while len(field_list):
91+                opts = geo_fld.rel.to._meta
92+                geo_fld = opts.get_field(field_list.pop())
93+        except (FieldDoesNotExist, AttributeError):
94+            return False
95+
96+        # Finally, make sure we got a Geographic field and return.
97+        if isinstance(geo_fld, GeometryField):
98+            return geo_fld
99+        else:
100+            return False
101Index: django/contrib/gis/db/models/sql/query.py
102===================================================================
103--- django/contrib/gis/db/models/sql/query.py   (revision 9814)
104+++ django/contrib/gis/db/models/sql/query.py   (working copy)
105@@ -270,48 +270,17 @@
106             # Because WKT doesn't contain spatial reference information,
107             # the SRID is prefixed to the returned WKT to ensure that the
108             # transformed geometries have an SRID different than that of the
109-            # field -- this is only used by `transform` for Oracle backends.
110-            if self.transformed_srid and SpatialBackend.oracle:
111+            # field -- this is only used by `transform` for Oracle and
112+            # SpatiaLite backends.  It's not clear that this is a complete
113+            # solution (though maybe it is?).
114+            if self.transformed_srid and ( SpatialBackend.oracle or
115+                                           SpatialBackend.sqlite3 ):
116                 sel_fmt = "'SRID=%d;'||%s" % (self.transformed_srid, sel_fmt)
117         else:
118             sel_fmt = '%s'
119         return sel_fmt
120 
121     # Private API utilities, subject to change.
122-    def _check_geo_field(self, model, name_param):
123-        """
124-        Recursive utility routine for checking the given name parameter
125-        on the given model.  Initially, the name parameter is a string,
126-        of the field on the given model e.g., 'point', 'the_geom'.
127-        Related model field strings like 'address__point', may also be
128-        used.
129-
130-        If a GeometryField exists according to the given name parameter
131-        it will be returned, otherwise returns False.
132-        """
133-        if isinstance(name_param, basestring):
134-            # This takes into account the situation where the name is a
135-            # lookup to a related geographic field, e.g., 'address__point'.
136-            name_param = name_param.split(sql.constants.LOOKUP_SEP)
137-            name_param.reverse() # Reversing so list operates like a queue of related lookups.
138-        elif not isinstance(name_param, list):
139-            raise TypeError
140-        try:
141-            # Getting the name of the field for the model (by popping the first
142-            # name from the `name_param` list created above).
143-            fld, mod, direct, m2m = model._meta.get_field_by_name(name_param.pop())
144-        except (FieldDoesNotExist, IndexError):
145-            return False
146-        # TODO: ManyToManyField?
147-        if isinstance(fld, GeometryField):
148-            return fld # A-OK.
149-        elif isinstance(fld, ForeignKey):
150-            # ForeignKey encountered, return the output of this utility called
151-            # on the _related_ model with the remaining name parameters.
152-            return self._check_geo_field(fld.rel.to, name_param) # Recurse to check ForeignKey relation.
153-        else:
154-            return False
155-
156     def _field_column(self, field, table_alias=None):
157         """
158         Helper function that returns the database column for the given field.
159@@ -339,4 +308,4 @@
160         else:
161             # Otherwise, check by the given field name -- which may be
162             # a lookup to a _related_ geographic field.
163-            return self._check_geo_field(self.model, field_name)
164+            return GeoWhereNode._check_geo_field(self.model._meta, field_name)
165Index: django/contrib/gis/tests/relatedapp/tests.py
166===================================================================
167--- django/contrib/gis/tests/relatedapp/tests.py        (revision 9814)
168+++ django/contrib/gis/tests/relatedapp/tests.py        (working copy)
169@@ -1,8 +1,9 @@
170 import os, unittest
171 from django.contrib.gis.geos import *
172-from django.contrib.gis.tests.utils import no_mysql, postgis
173+from django.contrib.gis.db.models import F, Extent, Union
174+from django.contrib.gis.tests.utils import no_mysql, mysql, postgis
175 from django.conf import settings
176-from models import City, Location, DirectoryEntry
177+from models import City, Location, DirectoryEntry, Parcel
178 
179 cities = (('Aurora', 'TX', -97.516111, 33.058333),
180           ('Roswell', 'NM', -104.528056, 33.387222),
181@@ -14,10 +15,8 @@
182     def test01_setup(self):
183         "Setting up for related model tests."
184         for name, state, lon, lat in cities:
185-            loc = Location(point=Point(lon, lat))
186-            loc.save()
187-            c = City(name=name, state=state, location=loc)
188-            c.save()
189+            loc = Location.objects.create(point=Point(lon, lat))
190+            c = City.objects.create(name=name, state=state, location=loc)
191             
192     def test02_select_related(self):
193         "Testing `select_related` on geographic models (see #7126)."
194@@ -39,22 +38,14 @@
195         # US Survey Feet (thus a tolerance of 0 implies error w/in 1 survey foot).
196         if postgis:
197             tol = 3
198-            nqueries = 4 # +1 for `postgis_lib_version`
199         else:
200             tol = 0
201-            nqueries = 3
202             
203         def check_pnt(ref, pnt):
204             self.assertAlmostEqual(ref.x, pnt.x, tol)
205             self.assertAlmostEqual(ref.y, pnt.y, tol)
206             self.assertEqual(ref.srid, pnt.srid)
207 
208-        # Turning on debug so we can manually verify the number of SQL queries issued.
209-        # DISABLED: the number of queries count testing mechanism is way too brittle.
210-        #dbg = settings.DEBUG
211-        #settings.DEBUG = True
212-        from django.db import connection
213-
214         # Each city transformed to the SRID of their state plane coordinate system.
215         transformed = (('Kecksburg', 2272, 'POINT(1490553.98959621 314792.131023984)'),
216                        ('Roswell', 2257, 'POINT(481902.189077221 868477.766629735)'),
217@@ -65,30 +56,34 @@
218             # Doing this implicitly sets `select_related` select the location.
219             qs = list(City.objects.filter(name=name).transform(srid, field_name='location__point'))
220             check_pnt(GEOSGeometry(wkt, srid), qs[0].location.point)
221-        #settings.DEBUG= dbg
222 
223-        # Verifying the number of issued SQL queries.
224-        #self.assertEqual(nqueries, len(connection.queries))
225-
226     @no_mysql
227     def test04_related_aggregate(self):
228         "Testing the `extent` and `unionagg` GeoQuerySet aggregates on related geographic models."
229-        if postgis:
230-            # One for all locations, one that excludes Roswell.
231-            all_extent = (-104.528060913086, 33.0583305358887,-79.4607315063477, 40.1847610473633)
232-            txpa_extent = (-97.51611328125, 33.0583305358887,-79.4607315063477, 40.1847610473633)
233-            e1 = City.objects.extent(field_name='location__point')
234-            e2 = City.objects.exclude(name='Roswell').extent(field_name='location__point')
235-            for ref, e in [(all_extent, e1), (txpa_extent, e2)]:
236-                for ref_val, e_val in zip(ref, e): self.assertAlmostEqual(ref_val, e_val)
237 
238+        # This combines the Extent and Union aggregates into one query
239+        aggs = City.objects.aggregate(Extent('location__point'), Union('location__point'))
240+
241+        # One for all locations, one that excludes Roswell.
242+        all_extent = (-104.528060913086, 33.0583305358887,-79.4607315063477, 40.1847610473633)
243+        txpa_extent = (-97.51611328125, 33.0583305358887,-79.4607315063477, 40.1847610473633)
244+        e1 = City.objects.extent(field_name='location__point')
245+        e2 = City.objects.exclude(name='Roswell').extent(field_name='location__point')
246+        e3 = aggs['location__point__extent']
247+
248+        for ref, e in [(all_extent, e1), (txpa_extent, e2), (all_extent, e3)]:
249+            for ref_val, e_val in zip(ref, e): self.assertAlmostEqual(ref_val, e_val)
250+
251         # The second union is for a query that has something in the WHERE clause.
252         ref_u1 = GEOSGeometry('MULTIPOINT(-104.528056 33.387222,-97.516111 33.058333,-79.460734 40.18476)', 4326)
253         ref_u2 = GEOSGeometry('MULTIPOINT(-97.516111 33.058333,-79.460734 40.18476)', 4326)
254         u1 = City.objects.unionagg(field_name='location__point')
255         u2 = City.objects.exclude(name='Roswell').unionagg(field_name='location__point')
256+        u3 = aggs['location__point__union']
257+
258         self.assertEqual(ref_u1, u1)
259         self.assertEqual(ref_u2, u2)
260+        self.assertEqual(ref_u1, u3)
261         
262     def test05_select_related_fk_to_subclass(self):
263         "Testing that calling select_related on a query over a model with an FK to a model subclass works"
264@@ -96,7 +91,52 @@
265         l = list(DirectoryEntry.objects.all().select_related())
266 
267     # TODO: Related tests for KML, GML, and distance lookups.
268+    def test6_f_expressions(self):
269+        "Testing F() expressions on Geometry fields."
270+        # Constructing a dummy parcel border and getting the City FK
271+        b1 = GEOSGeometry('POLYGON((-97.501205 33.052520,-97.501205 33.052576,-97.501150 33.052576,-97.501150 33.052520,-97.501205 33.052520))', srid=4326)
272+        pcity = City.objects.get(name='Aurora')
273+
274+        # First parcel has incorrect center point that is equal to the City;
275+        # it also has a second border that is different from the first as a
276+        # 100ft buffer around the City.
277+        c1 = pcity.location.point
278+        c2 = c1.transform(2276, clone=True)
279+        b2 = c2.buffer(100)
280+        p1 = Parcel.objects.create(name='P1', city=pcity, center1=c1, center2=c2, border1=b1, border2=b2)
281+
282+        # Now creating a second Parcel where the borders are the same
283+        # _border1_ but in different coordinate systems.  The borders are the
284+        # same here
285+        c1 = b1.centroid
286+        c2 = c1.transform(2276, clone=True)
287+        p2 = Parcel.objects.create(name='P2', city=pcity, center1=c1, center2=c2, border1=b1, border2=b1)
288+
289+        # Should return the second Parcel, which has the center within the
290+        # border.
291+        qs = Parcel.objects.filter(center1__within=F('border1'))
292+        self.assertEqual(1, len(qs))
293+        self.assertEqual('P2', qs[0].name)
294         
295+        if not mysql:
296+            # This time center2 is in a different coordinate system and needs
297+            # to be wrapped in transformation SQL.
298+            qs = Parcel.objects.filter(center2__within=F('border1'))
299+            self.assertEqual(1, len(qs))
300+            self.assertEqual('P2', qs[0].name)           
301+       
302+        # Should return the first Parcel, which has the center point equal
303+        # to the point in the City ForeignKey.
304+        qs = Parcel.objects.filter(center1=F('city__location__point'))
305+        self.assertEqual(1, len(qs))
306+        self.assertEqual('P1', qs[0].name)
307+
308+        if not mysql:
309+            # This time the city column should be wrapped in transformation SQL.
310+            qs = Parcel.objects.filter(border2__contains=F('city__location__point'))
311+            self.assertEqual(1, len(qs))
312+            self.assertEqual('P1', qs[0].name)
313+       
314 def suite():
315     s = unittest.TestSuite()
316     s.addTest(unittest.makeSuite(RelatedGeoModelTest))
317Index: django/contrib/gis/tests/relatedapp/models.py
318===================================================================
319--- django/contrib/gis/tests/relatedapp/models.py       (revision 9814)
320+++ django/contrib/gis/tests/relatedapp/models.py       (working copy)
321@@ -20,3 +20,14 @@
322     listing_text = models.CharField(max_length=50)
323     location = models.ForeignKey(AugmentedLocation)
324     objects = models.GeoManager()
325+
326+class Parcel(models.Model):
327+    name = models.CharField(max_length=30)
328+    city = models.ForeignKey(City)
329+    center1 = models.PointField()
330+    # Throwing a curveball w/`db_column` here.
331+    center2 = models.PointField(srid=2276, db_column='mycenter')
332+    border1 = models.PolygonField()
333+    border2 = models.PolygonField(srid=2276)
334+    objects = models.GeoManager()
335+    def __unicode__(self): return self.name