Ticket #12268: 12268_regex.diff

File 12268_regex.diff, 3.4 KB (added by Anssi Kääriäinen, 12 years ago)
  • django/db/models/sql/query.py

    diff --git a/django/db/models/sql/query.py b/django/db/models/sql/query.py
    index 4afe288..30e551d 100644
    a b all about the internals of models in order to get the information it needs.  
    88"""
    99
    1010import copy
     11import re
    1112
    1213from django.utils.datastructures import SortedDict
    1314from django.utils.encoding import force_unicode
    class Query(object):  
    17231724        self.related_select_cols = []
    17241725        self.related_select_fields = []
    17251726
     1727
     1728    # This pattern is used to find non-escaped %s patterns in a string.
     1729    # The idea is to match %s, not %%s, match %%%s and so on. First match
     1730    # %, then consume double %%, and check if there is still a % left. If
     1731    # there is, we have even number of % before the s, and thus the s is
     1732    # not a param, otherwise there is odd number of % and we have a param.
     1733    params_pattern = re.compile(r'%(%%)*(%)?s', flags=re.MULTILINE)
     1734
    17261735    def add_extra(self, select, select_params, where, params, tables, order_by):
    17271736        """
    17281737        Adds data to the various extra_* attributes for user-created additions
    class Query(object):  
    17411750            for name, entry in select.items():
    17421751                entry = force_unicode(entry)
    17431752                entry_params = []
    1744                 pos = entry.find("%s")
    1745                 while pos != -1:
    1746                     entry_params.append(param_iter.next())
    1747                     pos = entry.find("%s", pos + 2)
     1753                # Find all non-escaped %s strings in the entry - see comment
     1754                # of params_pattern.
     1755                for group in self.params_pattern.findall(entry):
     1756                    if not group[1]:
     1757                        entry_params.append(param_iter.next())
    17481758                select_pairs[name] = (entry, entry_params)
    17491759            # This is order preserving, since self.extra_select is a SortedDict.
    17501760            self.extra.update(select_pairs)
  • tests/regressiontests/extra_regress/tests.py

    diff --git a/tests/regressiontests/extra_regress/tests.py b/tests/regressiontests/extra_regress/tests.py
    index 67efb42..9e1674c 100644
    a b class ExtraRegressTests(TestCase):  
    313313               TestObject.objects.extra(where=["id > %s"], params=[obj.pk]),
    314314            ['<TestObject: TestObject: first,second,third>']
    315315        )
     316
     317    def test_regression_12268(self):
     318        """
     319        Test that % escaping works correctly in .extra select. %s needs a
     320        select_param, %%s not, %%%s needs a param and so on.
     321        """
     322        sql = str(TestObject.objects.extra(
     323            select={'foo': '%s'}, select_params=['_mark']).query)
     324        self.assertTrue('_mark' in sql)
     325        sql = str(TestObject.objects.extra(
     326            select={'foo': '%%s'}, select_params=['_mark']).query)
     327        self.assertTrue('_mark' not in sql)
     328        sql = str(TestObject.objects.extra(
     329            select={'foo': '%%%s'}, select_params=['_mark']).query)
     330        self.assertTrue('_mark' in sql)
     331        sql = str(TestObject.objects.extra(
     332            select={'foo': '%%%s%%s'}, select_params=['_mark1', '_mark2']).query)
     333        self.assertTrue('_mark1' in sql)
     334        self.assertTrue('_mark2' not in sql)
     335        sql = str(TestObject.objects.extra(
     336            select={'foo': 'asdf%%%sasdf%s'}, select_params=['_mark1', '_mark2']).query)
     337        self.assertTrue('_mark1' in sql)
     338        self.assertTrue('_mark2' in sql)
Back to Top