Django

Code

root/django/trunk/django/test/testcases.py

Revision 8003, 14.4 kB (checked in by russellm, 4 months ago)

Fixed #7441 -- Removed some of the shortcuts in the doctest output comparators, and added a wrapper to allow comparison of xml fragments. Thanks to Leo Soto for the report and fix.

  • Property svn:eol-style set to native
Line 
1 import re
2 import unittest
3 from urlparse import urlsplit, urlunsplit
4 from xml.dom.minidom import parseString, Node
5
6 from django.conf import settings
7 from django.core import mail
8 from django.core.management import call_command
9 from django.core.urlresolvers import clear_url_caches
10 from django.db import transaction
11 from django.http import QueryDict
12 from django.test import _doctest as doctest
13 from django.test.client import Client
14 from django.utils import simplejson
15
16 normalize_long_ints = lambda s: re.sub(r'(?<![\w])(\d+)L(?![\w])', '\\1', s)
17
18 def to_list(value):
19     """
20     Puts value into a list if it's not already one.
21     Returns an empty list if value is None.
22     """
23     if value is None:
24         value = []
25     elif not isinstance(value, list):
26         value = [value]
27     return value
28
29
30 class OutputChecker(doctest.OutputChecker):
31     def check_output(self, want, got, optionflags):
32         "The entry method for doctest output checking. Defers to a sequence of child checkers"
33         checks = (self.check_output_default,
34                   self.check_output_long,
35                   self.check_output_xml,
36                   self.check_output_json)
37         for check in checks:
38             if check(want, got, optionflags):
39                 return True
40         return False
41
42     def check_output_default(self, want, got, optionflags):
43         "The default comparator provided by doctest - not perfect, but good for most purposes"
44         return doctest.OutputChecker.check_output(self, want, got, optionflags)
45
46     def check_output_long(self, want, got, optionflags):
47         """Doctest does an exact string comparison of output, which means long
48         integers aren't equal to normal integers ("22L" vs. "22"). The
49         following code normalizes long integers so that they equal normal
50         integers.
51         """
52         return normalize_long_ints(want) == normalize_long_ints(got)
53
54     def check_output_xml(self, want, got, optionsflags):
55         """Tries to do a 'xml-comparision' of want and got.  Plain string
56         comparision doesn't always work because, for example, attribute
57         ordering should not be important.
58         
59         Based on http://codespeak.net/svn/lxml/trunk/src/lxml/doctestcompare.py
60         """
61         _norm_whitespace_re = re.compile(r'[ \t\n][ \t\n]+')
62         def norm_whitespace(v):
63             return _norm_whitespace_re.sub(' ', v)
64
65         def child_text(element):
66             return ''.join([c.data for c in element.childNodes
67                             if c.nodeType == Node.TEXT_NODE])
68
69         def children(element):
70             return [c for c in element.childNodes
71                     if c.nodeType == Node.ELEMENT_NODE]
72
73         def norm_child_text(element):
74             return norm_whitespace(child_text(element))
75
76         def attrs_dict(element):
77             return dict(element.attributes.items())
78
79         def check_element(want_element, got_element):
80             if want_element.tagName != got_element.tagName:
81                 return False
82             if norm_child_text(want_element) != norm_child_text(got_element):
83                 return False
84             if attrs_dict(want_element) != attrs_dict(got_element):
85                 return False
86             want_children = children(want_element)
87             got_children = children(got_element)
88             if len(want_children) != len(got_children):
89                 return False
90             for want, got in zip(want_children, got_children):
91                 if not check_element(want, got):
92                     return False
93             return True
94
95         want, got = self._strip_quotes(want, got)
96         want = want.replace('\\n','\n')
97         got = got.replace('\\n','\n')
98
99         # If the string is not a complete xml document, we may need to add a
100         # root element. This allow us to compare fragments, like "<foo/><bar/>"
101         if not want.startswith('<?xml'):
102             wrapper = '<root>%s</root>'
103             want = wrapper % want
104             got = wrapper % got
105            
106         # Parse the want and got strings, and compare the parsings.
107         try:
108             want_root = parseString(want).firstChild
109             got_root = parseString(got).firstChild
110         except:
111             return False
112         return check_element(want_root, got_root)
113
114     def check_output_json(self, want, got, optionsflags):
115         "Tries to compare want and got as if they were JSON-encoded data"
116         want, got = self._strip_quotes(want, got)
117         try:
118             want_json = simplejson.loads(want)
119             got_json = simplejson.loads(got)
120         except:
121             return False
122         return want_json == got_json
123
124     def _strip_quotes(self, want, got):
125         """
126         Strip quotes of doctests output values:
127
128         >>> o = OutputChecker()
129         >>> o._strip_quotes("'foo'")
130         "foo"
131         >>> o._strip_quotes('"foo"')
132         "foo"
133         >>> o._strip_quotes("u'foo'")
134         "foo"
135         >>> o._strip_quotes('u"foo"')
136         "foo"
137         """
138         def is_quoted_string(s):
139             s = s.strip()
140             return (len(s) >= 2
141                     and s[0] == s[-1]
142                     and s[0] in ('"', "'"))
143
144         def is_quoted_unicode(s):
145             s = s.strip()
146             return (len(s) >= 3
147                     and s[0] == 'u'
148                     and s[1] == s[-1]
149                     and s[1] in ('"', "'"))
150
151         if is_quoted_string(want) and is_quoted_string(got):
152             want = want.strip()[1:-1]
153             got = got.strip()[1:-1]
154         elif is_quoted_unicode(want) and is_quoted_unicode(got):
155             want = want.strip()[2:-1]
156             got = got.strip()[2:-1]
157         return want, got
158
159
160 class DocTestRunner(doctest.DocTestRunner):
161     def __init__(self, *args, **kwargs):
162         doctest.DocTestRunner.__init__(self, *args, **kwargs)
163         self.optionflags = doctest.ELLIPSIS
164
165     def report_unexpected_exception(self, out, test, example, exc_info):
166         doctest.DocTestRunner.report_unexpected_exception(self, out, test,
167                                                           example, exc_info)
168         # Rollback, in case of database errors. Otherwise they'd have
169         # side effects on other tests.
170         transaction.rollback_unless_managed()
171
172 class TestCase(unittest.TestCase):
173     def _pre_setup(self):
174         """Performs any pre-test setup. This includes:
175
176             * Flushing the database.
177             * If the Test Case class has a 'fixtures' member, installing the
178               named fixtures.
179             * If the Test Case class has a 'urls' member, replace the
180               ROOT_URLCONF with it.
181             * Clearing the mail test outbox.
182         """
183         call_command('flush', verbosity=0, interactive=False)
184         if hasattr(self, 'fixtures'):
185             # We have to use this slightly awkward syntax due to the fact
186             # that we're using *args and **kwargs together.
187             call_command('loaddata', *self.fixtures, **{'verbosity': 0})
188         if hasattr(self, 'urls'):
189             self._old_root_urlconf = settings.ROOT_URLCONF
190             settings.ROOT_URLCONF = self.urls
191             clear_url_caches()
192         mail.outbox = []
193
194     def __call__(self, result=None):
195         """
196         Wrapper around default __call__ method to perform common Django test
197         set up. This means that user-defined Test Cases aren't required to
198         include a call to super().setUp().
199         """
200         self.client = Client()
201         try:
202             self._pre_setup()
203         except (KeyboardInterrupt, SystemExit):
204             raise
205         except Exception:
206             import sys
207             result.addError(self, sys.exc_info())
208             return
209         super(TestCase, self).__call__(result)
210         try:
211             self._post_teardown()
212         except (KeyboardInterrupt, SystemExit):
213             raise
214         except Exception:
215             import sys
216             result.addError(self, sys.exc_info())
217             return
218
219     def _post_teardown(self):
220         """ Performs any post-test things. This includes:
221
222             * Putting back the original ROOT_URLCONF if it was changed.
223         """
224         if hasattr(self, '_old_root_urlconf'):
225             settings.ROOT_URLCONF = self._old_root_urlconf
226             clear_url_caches()
227
228     def assertRedirects(self, response, expected_url, status_code=302,
229                         target_status_code=200, host=None):
230         """Asserts that a response redirected to a specific URL, and that the
231         redirect URL can be loaded.
232
233         Note that assertRedirects won't work for external links since it uses
234         TestClient to do a request.
235         """
236         self.assertEqual(response.status_code, status_code,
237             ("Response didn't redirect as expected: Response code was %d"
238              " (expected %d)" % (response.status_code, status_code)))
239         url = response['Location']
240         scheme, netloc, path, query, fragment = urlsplit(url)
241         e_scheme, e_netloc, e_path, e_query, e_fragment = urlsplit(expected_url)
242         if not (e_scheme or e_netloc):
243             expected_url = urlunsplit(('http', host or 'testserver', e_path,
244                     e_query, e_fragment))
245         self.assertEqual(url, expected_url,
246             "Response redirected to '%s', expected '%s'" % (url, expected_url))
247
248         # Get the redirection page, using the same client that was used
249         # to obtain the original response.
250         redirect_response = response.client.get(path, QueryDict(query))
251         self.assertEqual(redirect_response.status_code, target_status_code,
252             ("Couldn't retrieve redirection page '%s': response code was %d"
253              " (expected %d)") %
254                  (path, redirect_response.status_code, target_status_code))
255
256     def assertContains(self, response, text, count=None, status_code=200):
257         """
258         Asserts that a response indicates that a page was retrieved
259         successfully, (i.e., the HTTP status code was as expected), and that
260         ``text`` occurs ``count`` times in the content of the response.
261         If ``count`` is None, the count doesn't matter - the assertion is true
262         if the text occurs at least once in the response.
263         """
264         self.assertEqual(response.status_code, status_code,
265             "Couldn't retrieve page: Response code was %d (expected %d)'" %
266                 (response.status_code, status_code))
267         real_count = response.content.count(text)
268         if count is not None:
269             self.assertEqual(real_count, count,
270                 "Found %d instances of '%s' in response (expected %d)" %
271                     (real_count, text, count))
272         else:
273             self.failUnless(real_count != 0,
274                             "Couldn't find '%s' in response" % text)
275
276     def assertNotContains(self, response, text, status_code=200):
277         """
278         Asserts that a response indicates that a page was retrieved
279         successfully, (i.e., the HTTP status code was as expected), and that
280         ``text`` doesn't occurs in the content of the response.
281         """
282         self.assertEqual(response.status_code, status_code,
283             "Couldn't retrieve page: Response code was %d (expected %d)'" %
284                 (response.status_code, status_code))
285         self.assertEqual(response.content.count(text), 0,
286                          "Response should not contain '%s'" % text)
287
288     def assertFormError(self, response, form, field, errors):
289         """
290         Asserts that a form used to render the response has a specific field
291         error.
292         """
293         # Put context(s) into a list to simplify processing.
294         contexts = to_list(response.context)
295         if not contexts:
296             self.fail('Response did not use any contexts to render the'
297                       ' response')
298
299         # Put error(s) into a list to simplify processing.
300         errors = to_list(errors)
301
302         # Search all contexts for the error.
303         found_form = False
304         for i,context in enumerate(contexts):
305             if form not in context:
306                 continue
307             found_form = True
308             for err in errors:
309                 if field:
310                     if field in context[form].errors:
311                         field_errors = context[form].errors[field]
312                         self.failUnless(err in field_errors,
313                                         "The field '%s' on form '%s' in"
314                                         " context %d does not contain the"
315                                         " error '%s' (actual errors: %s)" %
316                                             (field, form, i, err,
317                                              repr(field_errors)))
318                     elif field in context[form].fields:
319                         self.fail("The field '%s' on form '%s' in context %d"
320                                   " contains no errors" % (field, form, i))
321                     else:
322                         self.fail("The form '%s' in context %d does not"
323                                   " contain the field '%s'" %
324                                       (form, i, field))
325                 else:
326                     non_field_errors = context[form].non_field_errors()
327                     self.failUnless(err in non_field_errors,
328                         "The form '%s' in context %d does not contain the"
329                         " non-field error '%s' (actual errors: %s)" %
330                             (form, i, err, non_field_errors))
331         if not found_form:
332             self.fail("The form '%s' was not used to render the response" %
333                           form)
334
335     def assertTemplateUsed(self, response, template_name):
336         """
337         Asserts that the template with the provided name was used in rendering
338         the response.
339         """
340         template_names = [t.name for t in to_list(response.template)]
341         if not template_names:
342             self.fail('No templates used to render the response')
343         self.failUnless(template_name in template_names,
344             (u"Template '%s' was not a template used to render the response."
345              u" Actual template(s) used: %s") % (template_name,
346                                                  u', '.join(template_names)))
347
348     def assertTemplateNotUsed(self, response, template_name):
349         """
350         Asserts that the template with the provided name was NOT used in
351         rendering the response.
352         """
353         template_names = [t.name for t in to_list(response.template)]
354         self.failIf(template_name in template_names,
355             (u"Template '%s' was used unexpectedly in rendering the"
356              u" response") % template_name)
Note: See TracBrowser for help on using the browser.