diff --git a/django/test/testcases.py b/django/test/testcases.py
index a79a304..b91decc 100644
a
|
b
|
from xml.dom.minidom import parseString, Node
|
6 | 6 | from django.conf import settings |
7 | 7 | from django.core import mail |
8 | 8 | from django.core.management import call_command |
| 9 | from django.core.signals import request_started |
9 | 10 | from django.core.urlresolvers import clear_url_caches |
10 | | from django.db import transaction, connection, connections, DEFAULT_DB_ALIAS |
| 11 | from django.db import (transaction, connection, connections, DEFAULT_DB_ALIAS, |
| 12 | reset_queries) |
11 | 13 | from django.http import QueryDict |
12 | 14 | from django.test import _doctest as doctest |
13 | 15 | from django.test.client import Client |
… |
… |
class _AssertNumQueriesContext(object):
|
215 | 217 | self.test_case = test_case |
216 | 218 | self.num = num |
217 | 219 | self.connection = connection |
| 220 | self.executed = 0 |
| 221 | |
| 222 | def reset_queries_wrapper(self, **kwargs): |
| 223 | self.executed += len(self.connection.queries) - self.starting_queries |
| 224 | self.starting_queries = 0 |
| 225 | reset_queries(**kwargs) |
218 | 226 | |
219 | 227 | def __enter__(self): |
220 | 228 | self.old_debug_cursor = self.connection.use_debug_cursor |
221 | 229 | self.connection.use_debug_cursor = True |
222 | 230 | self.starting_queries = len(self.connection.queries) |
| 231 | request_started.disconnect(reset_queries) |
| 232 | request_started.connect(self.reset_queries_wrapper, weak=False) |
223 | 233 | return self |
224 | 234 | |
225 | 235 | def __exit__(self, exc_type, exc_value, traceback): |
226 | 236 | self.connection.use_debug_cursor = self.old_debug_cursor |
| 237 | request_started.connect(reset_queries) |
| 238 | request_started.disconnect(self.reset_queries_wrapper) |
227 | 239 | if exc_type is not None: |
228 | 240 | return |
229 | 241 | |
230 | 242 | final_queries = len(self.connection.queries) |
231 | | executed = final_queries - self.starting_queries |
232 | | |
| 243 | executed = final_queries - self.starting_queries + self.executed |
| 244 | |
233 | 245 | self.test_case.assertEqual( |
234 | 246 | executed, self.num, "%d queries executed, %d expected" % ( |
235 | 247 | executed, self.num |
diff --git a/tests/regressiontests/test_utils/tests.py b/tests/regressiontests/test_utils/tests.py
index 2a9c826..6c5161e 100644
a
|
b
|
import sys
|
2 | 2 | |
3 | 3 | from django.test import TestCase, skipUnlessDBFeature, skipIfDBFeature |
4 | 4 | |
| 5 | from models import Person |
5 | 6 | |
6 | 7 | if sys.version_info >= (2, 5): |
7 | 8 | from tests_25 import AssertNumQueriesTests |
… |
… |
class SkippingTestCase(TestCase):
|
15 | 16 | self.assertRaises(ValueError, |
16 | 17 | self.assertNumQueries, 2, test_func |
17 | 18 | ) |
| 19 | |
| 20 | def test_assert_num_queries_with_client(self): |
| 21 | person = Person.objects.create(name='test') |
| 22 | |
| 23 | self.assertNumQueries( |
| 24 | 1, |
| 25 | self.client.get, |
| 26 | '/test_utils/get_person/%s/' % person.pk |
| 27 | ) |
| 28 | |
| 29 | self.assertNumQueries( |
| 30 | 1, |
| 31 | self.client.get, |
| 32 | '/test_utils/get_person/%s/' % person.pk |
| 33 | ) |
| 34 | |
| 35 | def test_func(): |
| 36 | self.client.get('/test_utils/get_person/%s/' % person.pk) |
| 37 | self.client.get('/test_utils/get_person/%s/' % person.pk) |
| 38 | self.assertNumQueries(2, test_func) |
| 39 | |
18 | 40 | |
19 | 41 | def test_skip_unless_db_feature(self): |
20 | 42 | "A test that might be skipped is actually called." |
diff --git a/tests/regressiontests/test_utils/tests_25.py b/tests/regressiontests/test_utils/tests_25.py
index 4adea6c..43f3312 100644
a
|
b
|
class AssertNumQueriesTests(TestCase):
|
26 | 26 | with self.assertRaises(TypeError): |
27 | 27 | with self.assertNumQueries(4000): |
28 | 28 | raise TypeError |
| 29 | |
| 30 | def test_with_client(self): |
| 31 | person = Person.objects.create(name='test') |
| 32 | |
| 33 | with self.assertNumQueries(1): |
| 34 | self.client.get('/test_utils/get_person/%s/' % person.pk) |
| 35 | |
| 36 | with self.assertNumQueries(1): |
| 37 | self.client.get('/test_utils/get_person/%s/' % person.pk) |
| 38 | |
| 39 | with self.assertNumQueries(2): |
| 40 | self.client.get('/test_utils/get_person/%s/' % person.pk) |
| 41 | self.client.get('/test_utils/get_person/%s/' % person.pk) |
| 42 | No newline at end of file |
diff --git a/tests/regressiontests/test_utils/urls.py b/tests/regressiontests/test_utils/urls.py
new file mode 100644
index 0000000..1109b82
-
|
+
|
|
| 1 | from django.conf.urls.defaults import * |
| 2 | import views |
| 3 | |
| 4 | urlpatterns = patterns('', |
| 5 | (r'^get_person/(\d+)/$', views.get_person), |
| 6 | ) |
| 7 | No newline at end of file |
diff --git a/tests/regressiontests/test_utils/views.py b/tests/regressiontests/test_utils/views.py
new file mode 100644
index 0000000..62af0d9
-
|
+
|
|
| 1 | from django.http import HttpResponse |
| 2 | from django.shortcuts import get_object_or_404 |
| 3 | from models import Person |
| 4 | |
| 5 | def get_person(request, pk): |
| 6 | person = get_object_or_404(Person, pk=pk) |
| 7 | return HttpResponse(person.name) |
| 8 | No newline at end of file |
diff --git a/tests/urls.py b/tests/urls.py
index 01d6408..5fd4e7d 100644
a
|
b
|
urlpatterns = patterns('',
|
41 | 41 | |
42 | 42 | # special headers views |
43 | 43 | (r'special_headers/', include('regressiontests.special_headers.urls')), |
| 44 | |
| 45 | # test util views |
| 46 | (r'test_utils/', include('regressiontests.test_utils.urls')), |
44 | 47 | ) |