| 1 |
import re |
|---|
| 2 |
import unittest |
|---|
| 3 |
from urlparse import urlsplit, urlunsplit |
|---|
| 4 |
|
|---|
| 5 |
from django.http import QueryDict |
|---|
| 6 |
from django.db import transaction |
|---|
| 7 |
from django.conf import settings |
|---|
| 8 |
from django.core import mail |
|---|
| 9 |
from django.core.management import call_command |
|---|
| 10 |
from django.test import _doctest as doctest |
|---|
| 11 |
from django.test.client import Client |
|---|
| 12 |
from django.core.urlresolvers import clear_url_caches |
|---|
| 13 |
|
|---|
| 14 |
normalize_long_ints = lambda s: re.sub(r'(?<![\w])(\d+)L(?![\w])', '\\1', s) |
|---|
| 15 |
|
|---|
| 16 |
def to_list(value): |
|---|
| 17 |
""" |
|---|
| 18 |
Puts value into a list if it's not already one. |
|---|
| 19 |
Returns an empty list if value is None. |
|---|
| 20 |
""" |
|---|
| 21 |
if value is None: |
|---|
| 22 |
value = [] |
|---|
| 23 |
elif not isinstance(value, list): |
|---|
| 24 |
value = [value] |
|---|
| 25 |
return value |
|---|
| 26 |
|
|---|
| 27 |
|
|---|
| 28 |
class OutputChecker(doctest.OutputChecker): |
|---|
| 29 |
def check_output(self, want, got, optionflags): |
|---|
| 30 |
ok = doctest.OutputChecker.check_output(self, want, got, optionflags) |
|---|
| 31 |
|
|---|
| 32 |
# Doctest does an exact string comparison of output, which means long |
|---|
| 33 |
# integers aren't equal to normal integers ("22L" vs. "22"). The |
|---|
| 34 |
# following code normalizes long integers so that they equal normal |
|---|
| 35 |
# integers. |
|---|
| 36 |
if not ok: |
|---|
| 37 |
return normalize_long_ints(want) == normalize_long_ints(got) |
|---|
| 38 |
return ok |
|---|
| 39 |
|
|---|
| 40 |
class DocTestRunner(doctest.DocTestRunner): |
|---|
| 41 |
def __init__(self, *args, **kwargs): |
|---|
| 42 |
doctest.DocTestRunner.__init__(self, *args, **kwargs) |
|---|
| 43 |
self.optionflags = doctest.ELLIPSIS |
|---|
| 44 |
|
|---|
| 45 |
def report_unexpected_exception(self, out, test, example, exc_info): |
|---|
| 46 |
doctest.DocTestRunner.report_unexpected_exception(self, out, test, |
|---|
| 47 |
example, exc_info) |
|---|
| 48 |
# Rollback, in case of database errors. Otherwise they'd have |
|---|
| 49 |
# side effects on other tests. |
|---|
| 50 |
transaction.rollback_unless_managed() |
|---|
| 51 |
|
|---|
| 52 |
class TestCase(unittest.TestCase): |
|---|
| 53 |
def _pre_setup(self): |
|---|
| 54 |
"""Performs any pre-test setup. This includes: |
|---|
| 55 |
|
|---|
| 56 |
* Flushing the database. |
|---|
| 57 |
* If the Test Case class has a 'fixtures' member, installing the |
|---|
| 58 |
named fixtures. |
|---|
| 59 |
* If the Test Case class has a 'urls' member, replace the |
|---|
| 60 |
ROOT_URLCONF with it. |
|---|
| 61 |
* Clearing the mail test outbox. |
|---|
| 62 |
""" |
|---|
| 63 |
call_command('flush', verbosity=0, interactive=False) |
|---|
| 64 |
if hasattr(self, 'fixtures'): |
|---|
| 65 |
# We have to use this slightly awkward syntax due to the fact |
|---|
| 66 |
# that we're using *args and **kwargs together. |
|---|
| 67 |
call_command('loaddata', *self.fixtures, **{'verbosity': 0}) |
|---|
| 68 |
if hasattr(self, 'urls'): |
|---|
| 69 |
self._old_root_urlconf = settings.ROOT_URLCONF |
|---|
| 70 |
settings.ROOT_URLCONF = self.urls |
|---|
| 71 |
clear_url_caches() |
|---|
| 72 |
mail.outbox = [] |
|---|
| 73 |
|
|---|
| 74 |
def __call__(self, result=None): |
|---|
| 75 |
""" |
|---|
| 76 |
Wrapper around default __call__ method to perform common Django test |
|---|
| 77 |
set up. This means that user-defined Test Cases aren't required to |
|---|
| 78 |
include a call to super().setUp(). |
|---|
| 79 |
""" |
|---|
| 80 |
self.client = Client() |
|---|
| 81 |
try: |
|---|
| 82 |
self._pre_setup() |
|---|
| 83 |
except (KeyboardInterrupt, SystemExit): |
|---|
| 84 |
raise |
|---|
| 85 |
except Exception: |
|---|
| 86 |
import sys |
|---|
| 87 |
result.addError(self, sys.exc_info()) |
|---|
| 88 |
return |
|---|
| 89 |
super(TestCase, self).__call__(result) |
|---|
| 90 |
try: |
|---|
| 91 |
self._post_teardown() |
|---|
| 92 |
except (KeyboardInterrupt, SystemExit): |
|---|
| 93 |
raise |
|---|
| 94 |
except Exception: |
|---|
| 95 |
import sys |
|---|
| 96 |
result.addError(self, sys.exc_info()) |
|---|
| 97 |
return |
|---|
| 98 |
|
|---|
| 99 |
def _post_teardown(self): |
|---|
| 100 |
""" Performs any post-test things. This includes: |
|---|
| 101 |
|
|---|
| 102 |
* Putting back the original ROOT_URLCONF if it was changed. |
|---|
| 103 |
""" |
|---|
| 104 |
if hasattr(self, '_old_root_urlconf'): |
|---|
| 105 |
settings.ROOT_URLCONF = self._old_root_urlconf |
|---|
| 106 |
clear_url_caches() |
|---|
| 107 |
|
|---|
| 108 |
def assertRedirects(self, response, expected_url, status_code=302, |
|---|
| 109 |
target_status_code=200, host=None): |
|---|
| 110 |
"""Asserts that a response redirected to a specific URL, and that the |
|---|
| 111 |
redirect URL can be loaded. |
|---|
| 112 |
|
|---|
| 113 |
Note that assertRedirects won't work for external links since it uses |
|---|
| 114 |
TestClient to do a request. |
|---|
| 115 |
""" |
|---|
| 116 |
self.assertEqual(response.status_code, status_code, |
|---|
| 117 |
("Response didn't redirect as expected: Response code was %d" |
|---|
| 118 |
" (expected %d)" % (response.status_code, status_code))) |
|---|
| 119 |
url = response['Location'] |
|---|
| 120 |
scheme, netloc, path, query, fragment = urlsplit(url) |
|---|
| 121 |
e_scheme, e_netloc, e_path, e_query, e_fragment = urlsplit(expected_url) |
|---|
| 122 |
if not (e_scheme or e_netloc): |
|---|
| 123 |
expected_url = urlunsplit(('http', host or 'testserver', e_path, |
|---|
| 124 |
e_query, e_fragment)) |
|---|
| 125 |
self.assertEqual(url, expected_url, |
|---|
| 126 |
"Response redirected to '%s', expected '%s'" % (url, expected_url)) |
|---|
| 127 |
|
|---|
| 128 |
# Get the redirection page, using the same client that was used |
|---|
| 129 |
# to obtain the original response. |
|---|
| 130 |
redirect_response = response.client.get(path, QueryDict(query)) |
|---|
| 131 |
self.assertEqual(redirect_response.status_code, target_status_code, |
|---|
| 132 |
("Couldn't retrieve redirection page '%s': response code was %d" |
|---|
| 133 |
" (expected %d)") % |
|---|
| 134 |
(path, redirect_response.status_code, target_status_code)) |
|---|
| 135 |
|
|---|
| 136 |
def assertContains(self, response, text, count=None, status_code=200): |
|---|
| 137 |
""" |
|---|
| 138 |
Asserts that a response indicates that a page was retrieved |
|---|
| 139 |
successfully, (i.e., the HTTP status code was as expected), and that |
|---|
| 140 |
``text`` occurs ``count`` times in the content of the response. |
|---|
| 141 |
If ``count`` is None, the count doesn't matter - the assertion is true |
|---|
| 142 |
if the text occurs at least once in the response. |
|---|
| 143 |
""" |
|---|
| 144 |
self.assertEqual(response.status_code, status_code, |
|---|
| 145 |
"Couldn't retrieve page: Response code was %d (expected %d)'" % |
|---|
| 146 |
(response.status_code, status_code)) |
|---|
| 147 |
real_count = response.content.count(text) |
|---|
| 148 |
if count is not None: |
|---|
| 149 |
self.assertEqual(real_count, count, |
|---|
| 150 |
"Found %d instances of '%s' in response (expected %d)" % |
|---|
| 151 |
(real_count, text, count)) |
|---|
| 152 |
else: |
|---|
| 153 |
self.failUnless(real_count != 0, |
|---|
| 154 |
"Couldn't find '%s' in response" % text) |
|---|
| 155 |
|
|---|
| 156 |
def assertNotContains(self, response, text, status_code=200): |
|---|
| 157 |
""" |
|---|
| 158 |
Asserts that a response indicates that a page was retrieved |
|---|
| 159 |
successfully, (i.e., the HTTP status code was as expected), and that |
|---|
| 160 |
``text`` doesn't occurs in the content of the response. |
|---|
| 161 |
""" |
|---|
| 162 |
self.assertEqual(response.status_code, status_code, |
|---|
| 163 |
"Couldn't retrieve page: Response code was %d (expected %d)'" % |
|---|
| 164 |
(response.status_code, status_code)) |
|---|
| 165 |
self.assertEqual(response.content.count(text), 0, |
|---|
| 166 |
"Response should not contain '%s'" % text) |
|---|
| 167 |
|
|---|
| 168 |
def assertFormError(self, response, form, field, errors): |
|---|
| 169 |
""" |
|---|
| 170 |
Asserts that a form used to render the response has a specific field |
|---|
| 171 |
error. |
|---|
| 172 |
""" |
|---|
| 173 |
# Put context(s) into a list to simplify processing. |
|---|
| 174 |
contexts = to_list(response.context) |
|---|
| 175 |
if not contexts: |
|---|
| 176 |
self.fail('Response did not use any contexts to render the' |
|---|
| 177 |
' response') |
|---|
| 178 |
|
|---|
| 179 |
# Put error(s) into a list to simplify processing. |
|---|
| 180 |
errors = to_list(errors) |
|---|
| 181 |
|
|---|
| 182 |
# Search all contexts for the error. |
|---|
| 183 |
found_form = False |
|---|
| 184 |
for i,context in enumerate(contexts): |
|---|
| 185 |
if form not in context: |
|---|
| 186 |
continue |
|---|
| 187 |
found_form = True |
|---|
| 188 |
for err in errors: |
|---|
| 189 |
if field: |
|---|
| 190 |
if field in context[form].errors: |
|---|
| 191 |
field_errors = context[form].errors[field] |
|---|
| 192 |
self.failUnless(err in field_errors, |
|---|
| 193 |
"The field '%s' on form '%s' in" |
|---|
| 194 |
" context %d does not contain the" |
|---|
| 195 |
" error '%s' (actual errors: %s)" % |
|---|
| 196 |
(field, form, i, err, |
|---|
| 197 |
repr(field_errors))) |
|---|
| 198 |
elif field in context[form].fields: |
|---|
| 199 |
self.fail("The field '%s' on form '%s' in context %d" |
|---|
| 200 |
" contains no errors" % (field, form, i)) |
|---|
| 201 |
else: |
|---|
| 202 |
self.fail("The form '%s' in context %d does not" |
|---|
| 203 |
" contain the field '%s'" % |
|---|
| 204 |
(form, i, field)) |
|---|
| 205 |
else: |
|---|
| 206 |
non_field_errors = context[form].non_field_errors() |
|---|
| 207 |
self.failUnless(err in non_field_errors, |
|---|
| 208 |
"The form '%s' in context %d does not contain the" |
|---|
| 209 |
" non-field error '%s' (actual errors: %s)" % |
|---|
| 210 |
(form, i, err, non_field_errors)) |
|---|
| 211 |
if not found_form: |
|---|
| 212 |
self.fail("The form '%s' was not used to render the response" % |
|---|
| 213 |
form) |
|---|
| 214 |
|
|---|
| 215 |
def assertTemplateUsed(self, response, template_name): |
|---|
| 216 |
""" |
|---|
| 217 |
Asserts that the template with the provided name was used in rendering |
|---|
| 218 |
the response. |
|---|
| 219 |
""" |
|---|
| 220 |
template_names = [t.name for t in to_list(response.template)] |
|---|
| 221 |
if not template_names: |
|---|
| 222 |
self.fail('No templates used to render the response') |
|---|
| 223 |
self.failUnless(template_name in template_names, |
|---|
| 224 |
(u"Template '%s' was not a template used to render the response." |
|---|
| 225 |
u" Actual template(s) used: %s") % (template_name, |
|---|
| 226 |
u', '.join(template_names))) |
|---|
| 227 |
|
|---|
| 228 |
def assertTemplateNotUsed(self, response, template_name): |
|---|
| 229 |
""" |
|---|
| 230 |
Asserts that the template with the provided name was NOT used in |
|---|
| 231 |
rendering the response. |
|---|
| 232 |
""" |
|---|
| 233 |
template_names = [t.name for t in to_list(response.template)] |
|---|
| 234 |
self.failIf(template_name in template_names, |
|---|
| 235 |
(u"Template '%s' was used unexpectedly in rendering the" |
|---|
| 236 |
u" response") % template_name) |
|---|