Ticket #10159: geowherenode_expressions_fix_v3.diff

File geowherenode_expressions_fix_v3.diff, 16.1 KB (added by jbronn, 7 years ago)

Now takes into account different SRIDs.

  • django/contrib/gis/db/models/sql/where.py

     
    1 import datetime
     1from django.db import connection
    22from django.db.models.fields import Field
     3from django.db.models.sql.constants import LOOKUP_SEP
     4from django.db.models.sql.expressions import SQLEvaluator
    35from django.db.models.sql.where import WhereNode
    46from django.contrib.gis.db.backend import get_geo_where_clause, SpatialBackend
     7from django.contrib.gis.db.models.fields import GeometryField
     8qn = connection.ops.quote_name
    59
    610class GeoAnnotation(object):
    711    """
     
    3741            # Not a geographic field, so call `WhereNode.add`.
    3842            return super(GeoWhereNode, self).add(data, connector)
    3943        else:
    40             # `GeometryField.get_db_prep_lookup` returns a where clause
    41             # substitution array in addition to the parameters.
    42             where, params = field.get_db_prep_lookup(lookup_type, value)
     44            if isinstance(value, SQLEvaluator):
     45                # Getting the geographic field to compare with from the expression.
     46                geo_fld = self._check_geo_field(value.opts, value.expression.name)
     47                if not geo_fld:
     48                    raise ValueError('No geographic field found in expression.')
    4349
     50                # Get the SRID of the geometry field that the expression was meant
     51                # to operate on -- it's needed to determine whether transformation
     52                # SQL is necessary.
     53                srid = geo_fld._srid
     54
     55                # Getting the quoted representation of the geometry column that
     56                # the expression is operating on.
     57                geo_col = '%s.%s' % tuple(map(qn, value.cols[value.expression]))
     58
     59                # If it's in a different SRID, we'll need to wrap in
     60                # transformation SQL.
     61                if srid != field._srid and SpatialBackend.transform:
     62                    placeholder = '%s(%%s, %s)' % (SpatialBackend.transform, field._srid)
     63                else:
     64                    placeholder = '%s'
     65
     66                # Setting these up as if we had called `field.get_db_prep_lookup()`.
     67                where =  [placeholder % geo_col]
     68                params = ()
     69            else:
     70                # `GeometryField.get_db_prep_lookup` returns a where clause
     71                # substitution array in addition to the parameters.
     72                where, params = field.get_db_prep_lookup(lookup_type, value)
     73
    4474            # The annotation will be a `GeoAnnotation` object that
    4575            # will contain the necessary geometry field metadata for
    4676            # the `get_geo_where_clause` to construct the appropriate
     
    6494            # If not a GeometryField, call the `make_atom` from the
    6595            # base class.
    6696            return super(GeoWhereNode, self).make_atom(child, qn)
     97
     98    @classmethod
     99    def _check_geo_field(cls, opts, lookup):
     100        """
     101        Utility for checking the given lookup with the given model options. 
     102        The lookup is a string either specifying the geographic field, e.g.
     103        'point, 'the_geom', or a related lookup on a geographic field like
     104        'address__point'.
     105
     106        If a GeometryField exists according to the given lookup on the model
     107        options, it will be returned.  Otherwise returns None.
     108        """
     109        # This takes into account the situation where the lookup is a
     110        # lookup to a related geographic field, e.g., 'address__point'.
     111        field_list = lookup.split(LOOKUP_SEP)
     112
     113        # Reversing so list operates like a queue of related lookups,
     114        # and popping the top lookup.
     115        field_list.reverse()
     116        fld_name = field_list.pop()
     117
     118        try:
     119            geo_fld = opts.get_field(fld_name)
     120            # If the field list is still around, then it means that the
     121            # lookup was for a geometry field across a relationship --
     122            # thus we keep on getting the related model options and the
     123            # model field associated with the next field in the list
     124            # until there's no more left.
     125            while len(field_list):
     126                opts = geo_fld.rel.to._meta
     127                geo_fld = opts.get_field(field_list.pop())
     128        except (FieldDoesNotExist, AttributeError):
     129            return False
     130
     131        # Finally, make sure we got a Geographic field and return.
     132        if isinstance(geo_fld, GeometryField):
     133            return geo_fld
     134        else:
     135            return False
  • django/contrib/gis/db/models/sql/query.py

     
    270270            # Because WKT doesn't contain spatial reference information,
    271271            # the SRID is prefixed to the returned WKT to ensure that the
    272272            # transformed geometries have an SRID different than that of the
    273             # field -- this is only used by `transform` for Oracle backends.
    274             if self.transformed_srid and SpatialBackend.oracle:
     273            # field -- this is only used by `transform` for Oracle and
     274            # SpatiaLite backends.  It's not clear that this is a complete
     275            # solution (though maybe it is?).
     276            if self.transformed_srid and ( SpatialBackend.oracle or
     277                                           SpatialBackend.sqlite3 ):
    275278                sel_fmt = "'SRID=%d;'||%s" % (self.transformed_srid, sel_fmt)
    276279        else:
    277280            sel_fmt = '%s'
    278281        return sel_fmt
    279282
    280283    # Private API utilities, subject to change.
    281     def _check_geo_field(self, model, name_param):
    282         """
    283         Recursive utility routine for checking the given name parameter
    284         on the given model.  Initially, the name parameter is a string,
    285         of the field on the given model e.g., 'point', 'the_geom'.
    286         Related model field strings like 'address__point', may also be
    287         used.
    288 
    289         If a GeometryField exists according to the given name parameter
    290         it will be returned, otherwise returns False.
    291         """
    292         if isinstance(name_param, basestring):
    293             # This takes into account the situation where the name is a
    294             # lookup to a related geographic field, e.g., 'address__point'.
    295             name_param = name_param.split(sql.constants.LOOKUP_SEP)
    296             name_param.reverse() # Reversing so list operates like a queue of related lookups.
    297         elif not isinstance(name_param, list):
    298             raise TypeError
    299         try:
    300             # Getting the name of the field for the model (by popping the first
    301             # name from the `name_param` list created above).
    302             fld, mod, direct, m2m = model._meta.get_field_by_name(name_param.pop())
    303         except (FieldDoesNotExist, IndexError):
    304             return False
    305         # TODO: ManyToManyField?
    306         if isinstance(fld, GeometryField):
    307             return fld # A-OK.
    308         elif isinstance(fld, ForeignKey):
    309             # ForeignKey encountered, return the output of this utility called
    310             # on the _related_ model with the remaining name parameters.
    311             return self._check_geo_field(fld.rel.to, name_param) # Recurse to check ForeignKey relation.
    312         else:
    313             return False
    314 
    315284    def _field_column(self, field, table_alias=None):
    316285        """
    317286        Helper function that returns the database column for the given field.
     
    339308        else:
    340309            # Otherwise, check by the given field name -- which may be
    341310            # a lookup to a _related_ geographic field.
    342             return self._check_geo_field(self.model, field_name)
     311            return GeoWhereNode._check_geo_field(self.model._meta, field_name)
  • django/contrib/gis/tests/relatedapp/tests.py

     
    11import os, unittest
    22from django.contrib.gis.geos import *
    3 from django.contrib.gis.tests.utils import no_mysql, postgis
     3from django.contrib.gis.db.models import F, Extent, Union
     4from django.contrib.gis.tests.utils import no_mysql, mysql, postgis
    45from django.conf import settings
    5 from models import City, Location, DirectoryEntry
     6from models import City, Location, DirectoryEntry, Parcel
    67
    78cities = (('Aurora', 'TX', -97.516111, 33.058333),
    89          ('Roswell', 'NM', -104.528056, 33.387222),
     
    1415    def test01_setup(self):
    1516        "Setting up for related model tests."
    1617        for name, state, lon, lat in cities:
    17             loc = Location(point=Point(lon, lat))
    18             loc.save()
    19             c = City(name=name, state=state, location=loc)
    20             c.save()
     18            loc = Location.objects.create(point=Point(lon, lat))
     19            c = City.objects.create(name=name, state=state, location=loc)
    2120           
    2221    def test02_select_related(self):
    2322        "Testing `select_related` on geographic models (see #7126)."
     
    3938        # US Survey Feet (thus a tolerance of 0 implies error w/in 1 survey foot).
    4039        if postgis:
    4140            tol = 3
    42             nqueries = 4 # +1 for `postgis_lib_version`
    4341        else:
    4442            tol = 0
    45             nqueries = 3
    4643           
    4744        def check_pnt(ref, pnt):
    4845            self.assertAlmostEqual(ref.x, pnt.x, tol)
    4946            self.assertAlmostEqual(ref.y, pnt.y, tol)
    5047            self.assertEqual(ref.srid, pnt.srid)
    5148
    52         # Turning on debug so we can manually verify the number of SQL queries issued.
    53         # DISABLED: the number of queries count testing mechanism is way too brittle.
    54         #dbg = settings.DEBUG
    55         #settings.DEBUG = True
    56         from django.db import connection
    57 
    5849        # Each city transformed to the SRID of their state plane coordinate system.
    5950        transformed = (('Kecksburg', 2272, 'POINT(1490553.98959621 314792.131023984)'),
    6051                       ('Roswell', 2257, 'POINT(481902.189077221 868477.766629735)'),
     
    6556            # Doing this implicitly sets `select_related` select the location.
    6657            qs = list(City.objects.filter(name=name).transform(srid, field_name='location__point'))
    6758            check_pnt(GEOSGeometry(wkt, srid), qs[0].location.point)
    68         #settings.DEBUG= dbg
    6959
    70         # Verifying the number of issued SQL queries.
    71         #self.assertEqual(nqueries, len(connection.queries))
    72 
    7360    @no_mysql
    7461    def test04_related_aggregate(self):
    7562        "Testing the `extent` and `unionagg` GeoQuerySet aggregates on related geographic models."
    76         if postgis:
    77             # One for all locations, one that excludes Roswell.
    78             all_extent = (-104.528060913086, 33.0583305358887,-79.4607315063477, 40.1847610473633)
    79             txpa_extent = (-97.51611328125, 33.0583305358887,-79.4607315063477, 40.1847610473633)
    80             e1 = City.objects.extent(field_name='location__point')
    81             e2 = City.objects.exclude(name='Roswell').extent(field_name='location__point')
    82             for ref, e in [(all_extent, e1), (txpa_extent, e2)]:
    83                 for ref_val, e_val in zip(ref, e): self.assertAlmostEqual(ref_val, e_val)
    8463
     64        # This combines the Extent and Union aggregates into one query
     65        aggs = City.objects.aggregate(Extent('location__point'), Union('location__point'))
     66
     67        # One for all locations, one that excludes Roswell.
     68        all_extent = (-104.528060913086, 33.0583305358887,-79.4607315063477, 40.1847610473633)
     69        txpa_extent = (-97.51611328125, 33.0583305358887,-79.4607315063477, 40.1847610473633)
     70        e1 = City.objects.extent(field_name='location__point')
     71        e2 = City.objects.exclude(name='Roswell').extent(field_name='location__point')
     72        e3 = aggs['location__point__extent']
     73
     74        for ref, e in [(all_extent, e1), (txpa_extent, e2), (all_extent, e3)]:
     75            for ref_val, e_val in zip(ref, e): self.assertAlmostEqual(ref_val, e_val)
     76
    8577        # The second union is for a query that has something in the WHERE clause.
    8678        ref_u1 = GEOSGeometry('MULTIPOINT(-104.528056 33.387222,-97.516111 33.058333,-79.460734 40.18476)', 4326)
    8779        ref_u2 = GEOSGeometry('MULTIPOINT(-97.516111 33.058333,-79.460734 40.18476)', 4326)
    8880        u1 = City.objects.unionagg(field_name='location__point')
    8981        u2 = City.objects.exclude(name='Roswell').unionagg(field_name='location__point')
     82        u3 = aggs['location__point__union']
     83
    9084        self.assertEqual(ref_u1, u1)
    9185        self.assertEqual(ref_u2, u2)
     86        self.assertEqual(ref_u1, u3)
    9287       
    9388    def test05_select_related_fk_to_subclass(self):
    9489        "Testing that calling select_related on a query over a model with an FK to a model subclass works"
     
    9691        l = list(DirectoryEntry.objects.all().select_related())
    9792
    9893    # TODO: Related tests for KML, GML, and distance lookups.
     94    def test6_f_expressions(self):
     95        "Testing F() expressions on Geometry fields."
     96        # Constructing a dummy parcel border and getting the City FK
     97        b1 = GEOSGeometry('POLYGON((-97.501205 33.052520,-97.501205 33.052576,-97.501150 33.052576,-97.501150 33.052520,-97.501205 33.052520))', srid=4326)
     98        pcity = City.objects.get(name='Aurora')
     99
     100        # First parcel has incorrect center point that is equal to the City;
     101        # it also has a second border that is different from the first as a
     102        # 100ft buffer around the City.
     103        c1 = pcity.location.point
     104        c2 = c1.transform(2276, clone=True)
     105        b2 = c2.buffer(100)
     106        p1 = Parcel.objects.create(name='P1', city=pcity, center1=c1, center2=c2, border1=b1, border2=b2)
     107
     108        # Now creating a second Parcel where the borders are the same
     109        # _border1_ but in different coordinate systems.  The borders are the
     110        # same here
     111        c1 = b1.centroid
     112        c2 = c1.transform(2276, clone=True)
     113        p2 = Parcel.objects.create(name='P2', city=pcity, center1=c1, center2=c2, border1=b1, border2=b1)
     114
     115        # Should return the second Parcel, which has the center within the
     116        # border.
     117        qs = Parcel.objects.filter(center1__within=F('border1'))
     118        self.assertEqual(1, len(qs))
     119        self.assertEqual('P2', qs[0].name)
    99120       
     121        if not mysql:
     122            # This time center2 is in a different coordinate system and needs
     123            # to be wrapped in transformation SQL.
     124            qs = Parcel.objects.filter(center2__within=F('border1'))
     125            self.assertEqual(1, len(qs))
     126            self.assertEqual('P2', qs[0].name)           
     127       
     128        # Should return the first Parcel, which has the center point equal
     129        # to the point in the City ForeignKey.
     130        qs = Parcel.objects.filter(center1=F('city__location__point'))
     131        self.assertEqual(1, len(qs))
     132        self.assertEqual('P1', qs[0].name)
     133
     134        if not mysql:
     135            # This time the city column should be wrapped in transformation SQL.
     136            qs = Parcel.objects.filter(border2__contains=F('city__location__point'))
     137            self.assertEqual(1, len(qs))
     138            self.assertEqual('P1', qs[0].name)
     139       
    100140def suite():
    101141    s = unittest.TestSuite()
    102142    s.addTest(unittest.makeSuite(RelatedGeoModelTest))
  • django/contrib/gis/tests/relatedapp/models.py

     
    2020    listing_text = models.CharField(max_length=50)
    2121    location = models.ForeignKey(AugmentedLocation)
    2222    objects = models.GeoManager()
     23
     24class Parcel(models.Model):
     25    name = models.CharField(max_length=30)
     26    city = models.ForeignKey(City)
     27    center1 = models.PointField()
     28    # Throwing a curveball w/`db_column` here.
     29    center2 = models.PointField(srid=2276, db_column='mycenter')
     30    border1 = models.PolygonField()
     31    border2 = models.PolygonField(srid=2276)
     32    objects = models.GeoManager()
     33    def __unicode__(self): return self.name
Back to Top