1 | import re
|
---|
2 | import unittest
|
---|
3 | from urlparse import urlsplit, urlunsplit
|
---|
4 |
|
---|
5 | from django.http import QueryDict
|
---|
6 | from django.db import transaction
|
---|
7 | from django.core import mail
|
---|
8 | from django.core.management import call_command
|
---|
9 | from django.test import _doctest as doctest
|
---|
10 | from django.test.client import Client
|
---|
11 |
|
---|
12 | normalize_long_ints = lambda s: re.sub(r'(?<![\w])(\d+)L(?![\w])', '\\1', s)
|
---|
13 |
|
---|
14 | def to_list(value):
|
---|
15 | """
|
---|
16 | Puts value into a list if it's not already one.
|
---|
17 | Returns an empty list if value is None.
|
---|
18 | """
|
---|
19 | if value is None:
|
---|
20 | value = []
|
---|
21 | elif not isinstance(value, list):
|
---|
22 | value = [value]
|
---|
23 | return value
|
---|
24 |
|
---|
25 |
|
---|
26 | class OutputChecker(doctest.OutputChecker):
|
---|
27 | def check_output(self, want, got, optionflags):
|
---|
28 | ok = doctest.OutputChecker.check_output(self, want, got, optionflags)
|
---|
29 |
|
---|
30 | # Doctest does an exact string comparison of output, which means long
|
---|
31 | # integers aren't equal to normal integers ("22L" vs. "22"). The
|
---|
32 | # following code normalizes long integers so that they equal normal
|
---|
33 | # integers.
|
---|
34 | if not ok:
|
---|
35 | return normalize_long_ints(want) == normalize_long_ints(got)
|
---|
36 | return ok
|
---|
37 |
|
---|
38 | class DocTestRunner(doctest.DocTestRunner):
|
---|
39 | def __init__(self, *args, **kwargs):
|
---|
40 | doctest.DocTestRunner.__init__(self, *args, **kwargs)
|
---|
41 | self.optionflags = doctest.ELLIPSIS
|
---|
42 |
|
---|
43 | def report_unexpected_exception(self, out, test, example, exc_info):
|
---|
44 | doctest.DocTestRunner.report_unexpected_exception(self, out, test,
|
---|
45 | example, exc_info)
|
---|
46 | # Rollback, in case of database errors. Otherwise they'd have
|
---|
47 | # side effects on other tests.
|
---|
48 | transaction.rollback_unless_managed()
|
---|
49 |
|
---|
50 | class TestCase(unittest.TestCase):
|
---|
51 | def _pre_setup(self):
|
---|
52 | """Performs any pre-test setup. This includes:
|
---|
53 |
|
---|
54 | * If the Test Case class has a 'fixtures' member, clearing the
|
---|
55 | database and installing the named fixtures at the start of each
|
---|
56 | test.
|
---|
57 | * Clearing the mail test outbox.
|
---|
58 | """
|
---|
59 | call_command('flush', verbosity=0, interactive=False)
|
---|
60 | if hasattr(self, 'fixtures'):
|
---|
61 | # We have to use this slightly awkward syntax due to the fact
|
---|
62 | # that we're using *args and **kwargs together.
|
---|
63 | call_command('loaddata', *self.fixtures, **{'verbosity': 0})
|
---|
64 | mail.outbox = []
|
---|
65 |
|
---|
66 | def __call__(self, result=None):
|
---|
67 | """
|
---|
68 | Wrapper around default __call__ method to perform common Django test
|
---|
69 | set up. This means that user-defined Test Cases aren't required to
|
---|
70 | include a call to super().setUp().
|
---|
71 | """
|
---|
72 | self.client = Client()
|
---|
73 | self._pre_setup()
|
---|
74 | super(TestCase, self).__call__(result)
|
---|
75 |
|
---|
76 | def assertRedirects(self, response, expected_url, status_code=302,
|
---|
77 | target_status_code=200):
|
---|
78 | """Asserts that a response redirected to a specific URL, and that the
|
---|
79 | redirect URL can be loaded.
|
---|
80 |
|
---|
81 | Note that assertRedirects won't work for external links since it uses
|
---|
82 | TestClient to do a request.
|
---|
83 | """
|
---|
84 | self.assertEqual(response.status_code, status_code,
|
---|
85 | ("Response didn't redirect as expected: Response code was %d"
|
---|
86 | " (expected %d)" % (response.status_code, status_code)))
|
---|
87 | url = response['Location']
|
---|
88 | scheme, netloc, path, query, fragment = urlsplit(url)
|
---|
89 |
|
---|
90 | # reconstruct url using only parts defined for expected_url
|
---|
91 | url = urlunsplit( tuple( [ part if expected_part else ''
|
---|
92 | for part, expected_part in zip(urlsplit(url), urlsplit(expected_url)) ] ) )
|
---|
93 |
|
---|
94 | self.assertEqual(url, expected_url,
|
---|
95 | "Response redirected to '%s', expected '%s'" % (url, expected_url))
|
---|
96 |
|
---|
97 | # Get the redirection page, using the same client that was used
|
---|
98 | # to obtain the original response.
|
---|
99 | redirect_response = response.client.get(path, QueryDict(query))
|
---|
100 | self.assertEqual(redirect_response.status_code, target_status_code,
|
---|
101 | ("Couldn't retrieve redirection page '%s': response code was %d"
|
---|
102 | " (expected %d)") %
|
---|
103 | (path, redirect_response.status_code, target_status_code))
|
---|
104 |
|
---|
105 | def assertContains(self, response, text, count=None, status_code=200):
|
---|
106 | """
|
---|
107 | Asserts that a response indicates that a page was retreived
|
---|
108 | successfully, (i.e., the HTTP status code was as expected), and that
|
---|
109 | ``text`` occurs ``count`` times in the content of the response.
|
---|
110 | If ``count`` is None, the count doesn't matter - the assertion is true
|
---|
111 | if the text occurs at least once in the response.
|
---|
112 | """
|
---|
113 | self.assertEqual(response.status_code, status_code,
|
---|
114 | "Couldn't retrieve page: Response code was %d (expected %d)'" %
|
---|
115 | (response.status_code, status_code))
|
---|
116 | real_count = response.content.count(text)
|
---|
117 | if count is not None:
|
---|
118 | self.assertEqual(real_count, count,
|
---|
119 | "Found %d instances of '%s' in response (expected %d)" %
|
---|
120 | (real_count, text, count))
|
---|
121 | else:
|
---|
122 | self.failUnless(real_count != 0,
|
---|
123 | "Couldn't find '%s' in response" % text)
|
---|
124 |
|
---|
125 | def assertFormError(self, response, form, field, errors):
|
---|
126 | """
|
---|
127 | Asserts that a form used to render the response has a specific field
|
---|
128 | error.
|
---|
129 | """
|
---|
130 | # Put context(s) into a list to simplify processing.
|
---|
131 | contexts = to_list(response.context)
|
---|
132 | if not contexts:
|
---|
133 | self.fail('Response did not use any contexts to render the'
|
---|
134 | ' response')
|
---|
135 |
|
---|
136 | # Put error(s) into a list to simplify processing.
|
---|
137 | errors = to_list(errors)
|
---|
138 |
|
---|
139 | # Search all contexts for the error.
|
---|
140 | found_form = False
|
---|
141 | for i,context in enumerate(contexts):
|
---|
142 | if form not in context:
|
---|
143 | continue
|
---|
144 | found_form = True
|
---|
145 | for err in errors:
|
---|
146 | if field:
|
---|
147 | if field in context[form].errors:
|
---|
148 | field_errors = context[form].errors[field]
|
---|
149 | self.failUnless(err in field_errors,
|
---|
150 | "The field '%s' on form '%s' in"
|
---|
151 | " context %d does not contain the"
|
---|
152 | " error '%s' (actual errors: %s)" %
|
---|
153 | (field, form, i, err,
|
---|
154 | list(field_errors)))
|
---|
155 | elif field in context[form].fields:
|
---|
156 | self.fail("The field '%s' on form '%s' in context %d"
|
---|
157 | " contains no errors" % (field, form, i))
|
---|
158 | else:
|
---|
159 | self.fail("The form '%s' in context %d does not"
|
---|
160 | " contain the field '%s'" %
|
---|
161 | (form, i, field))
|
---|
162 | else:
|
---|
163 | non_field_errors = context[form].non_field_errors()
|
---|
164 | self.failUnless(err in non_field_errors,
|
---|
165 | "The form '%s' in context %d does not contain the"
|
---|
166 | " non-field error '%s' (actual errors: %s)" %
|
---|
167 | (form, i, err, non_field_errors))
|
---|
168 | if not found_form:
|
---|
169 | self.fail("The form '%s' was not used to render the response" %
|
---|
170 | form)
|
---|
171 |
|
---|
172 | def assertTemplateUsed(self, response, template_name):
|
---|
173 | """
|
---|
174 | Asserts that the template with the provided name was used in rendering
|
---|
175 | the response.
|
---|
176 | """
|
---|
177 | template_names = [t.name for t in to_list(response.template)]
|
---|
178 | if not template_names:
|
---|
179 | self.fail('No templates used to render the response')
|
---|
180 | self.failUnless(template_name in template_names,
|
---|
181 | (u"Template '%s' was not a template used to render the response."
|
---|
182 | u" Actual template(s) used: %s") % (template_name,
|
---|
183 | u', '.join(template_names)))
|
---|
184 |
|
---|
185 | def assertTemplateNotUsed(self, response, template_name):
|
---|
186 | """
|
---|
187 | Asserts that the template with the provided name was NOT used in
|
---|
188 | rendering the response.
|
---|
189 | """
|
---|
190 | template_names = [t.name for t in to_list(response.template)]
|
---|
191 | self.failIf(template_name in template_names,
|
---|
192 | (u"Template '%s' was used unexpectedly in rendering the"
|
---|
193 | u" response") % template_name)
|
---|