Ticket #5538: testcases.py

File testcases.py, 8.5 KB (added by pat.m.boyd@…, 10 years ago)

patch

Line 
1import re
2import unittest
3from urlparse import urlsplit, urlunsplit
4
5from django.http import QueryDict
6from django.db import transaction
7from django.core import mail
8from django.core.management import call_command
9from django.test import _doctest as doctest
10from django.test.client import Client
11
12normalize_long_ints = lambda s: re.sub(r'(?<![\w])(\d+)L(?![\w])', '\\1', s)
13
14def to_list(value):
15    """
16    Puts value into a list if it's not already one.
17    Returns an empty list if value is None.
18    """
19    if value is None:
20        value = []
21    elif not isinstance(value, list):
22        value = [value]
23    return value
24
25
26class OutputChecker(doctest.OutputChecker):
27    def check_output(self, want, got, optionflags):
28        ok = doctest.OutputChecker.check_output(self, want, got, optionflags)
29
30        # Doctest does an exact string comparison of output, which means long
31        # integers aren't equal to normal integers ("22L" vs. "22"). The
32        # following code normalizes long integers so that they equal normal
33        # integers.
34        if not ok:
35            return normalize_long_ints(want) == normalize_long_ints(got)
36        return ok
37
38class DocTestRunner(doctest.DocTestRunner):
39    def __init__(self, *args, **kwargs):
40        doctest.DocTestRunner.__init__(self, *args, **kwargs)
41        self.optionflags = doctest.ELLIPSIS
42
43    def report_unexpected_exception(self, out, test, example, exc_info):
44        doctest.DocTestRunner.report_unexpected_exception(self, out, test,
45                                                          example, exc_info)
46        # Rollback, in case of database errors. Otherwise they'd have
47        # side effects on other tests.
48        transaction.rollback_unless_managed()
49
50class TestCase(unittest.TestCase):
51    def _pre_setup(self):
52        """Performs any pre-test setup. This includes:
53
54            * If the Test Case class has a 'fixtures' member, clearing the
55              database and installing the named fixtures at the start of each
56              test.
57            * Clearing the mail test outbox.
58        """
59        call_command('flush', verbosity=0, interactive=False)
60        if hasattr(self, 'fixtures'):
61            # We have to use this slightly awkward syntax due to the fact
62            # that we're using *args and **kwargs together.
63            call_command('loaddata', *self.fixtures, **{'verbosity': 0})
64        mail.outbox = []
65
66    def __call__(self, result=None):
67        """
68        Wrapper around default __call__ method to perform common Django test
69        set up. This means that user-defined Test Cases aren't required to
70        include a call to super().setUp().
71        """
72        self.client = Client()
73        self._pre_setup()
74        super(TestCase, self).__call__(result)
75
76    def assertRedirects(self, response, expected_url, status_code=302,
77                        target_status_code=200):
78        """Asserts that a response redirected to a specific URL, and that the
79        redirect URL can be loaded.
80
81        Note that assertRedirects won't work for external links since it uses
82        TestClient to do a request.
83        """
84        self.assertEqual(response.status_code, status_code,
85            ("Response didn't redirect as expected: Response code was %d"
86             " (expected %d)" % (response.status_code, status_code)))
87        url = response['Location']
88        scheme, netloc, path, query, fragment = urlsplit(url)
89
90        # reconstruct url using only parts defined for expected_url
91        url = urlunsplit( tuple( [ part if expected_part else ''
92            for part, expected_part in zip(urlsplit(url), urlsplit(expected_url)) ] ) )
93
94        self.assertEqual(url, expected_url,
95            "Response redirected to '%s', expected '%s'" % (url, expected_url))
96
97        # Get the redirection page, using the same client that was used
98        # to obtain the original response.
99        redirect_response = response.client.get(path, QueryDict(query))
100        self.assertEqual(redirect_response.status_code, target_status_code,
101            ("Couldn't retrieve redirection page '%s': response code was %d"
102             " (expected %d)") %
103                 (path, redirect_response.status_code, target_status_code))
104
105    def assertContains(self, response, text, count=None, status_code=200):
106        """
107        Asserts that a response indicates that a page was retreived
108        successfully, (i.e., the HTTP status code was as expected), and that
109        ``text`` occurs ``count`` times in the content of the response.
110        If ``count`` is None, the count doesn't matter - the assertion is true
111        if the text occurs at least once in the response.
112        """
113        self.assertEqual(response.status_code, status_code,
114            "Couldn't retrieve page: Response code was %d (expected %d)'" %
115                (response.status_code, status_code))
116        real_count = response.content.count(text)
117        if count is not None:
118            self.assertEqual(real_count, count,
119                "Found %d instances of '%s' in response (expected %d)" %
120                    (real_count, text, count))
121        else:
122            self.failUnless(real_count != 0,
123                            "Couldn't find '%s' in response" % text)
124
125    def assertFormError(self, response, form, field, errors):
126        """
127        Asserts that a form used to render the response has a specific field
128        error.
129        """
130        # Put context(s) into a list to simplify processing.
131        contexts = to_list(response.context)
132        if not contexts:
133            self.fail('Response did not use any contexts to render the'
134                      ' response')
135
136        # Put error(s) into a list to simplify processing.
137        errors = to_list(errors)
138
139        # Search all contexts for the error.
140        found_form = False
141        for i,context in enumerate(contexts):
142            if form not in context:
143                continue
144            found_form = True
145            for err in errors:
146                if field:
147                    if field in context[form].errors:
148                        field_errors = context[form].errors[field]
149                        self.failUnless(err in field_errors,
150                                        "The field '%s' on form '%s' in"
151                                        " context %d does not contain the"
152                                        " error '%s' (actual errors: %s)" %
153                                            (field, form, i, err,
154                                             list(field_errors)))
155                    elif field in context[form].fields:
156                        self.fail("The field '%s' on form '%s' in context %d"
157                                  " contains no errors" % (field, form, i))
158                    else:
159                        self.fail("The form '%s' in context %d does not"
160                                  " contain the field '%s'" %
161                                      (form, i, field))
162                else:
163                    non_field_errors = context[form].non_field_errors()
164                    self.failUnless(err in non_field_errors,
165                        "The form '%s' in context %d does not contain the"
166                        " non-field error '%s' (actual errors: %s)" %
167                            (form, i, err, non_field_errors))
168        if not found_form:
169            self.fail("The form '%s' was not used to render the response" %
170                          form)
171
172    def assertTemplateUsed(self, response, template_name):
173        """
174        Asserts that the template with the provided name was used in rendering
175        the response.
176        """
177        template_names = [t.name for t in to_list(response.template)]
178        if not template_names:
179            self.fail('No templates used to render the response')
180        self.failUnless(template_name in template_names,
181            (u"Template '%s' was not a template used to render the response."
182             u" Actual template(s) used: %s") % (template_name,
183                                                 u', '.join(template_names)))
184
185    def assertTemplateNotUsed(self, response, template_name):
186        """
187        Asserts that the template with the provided name was NOT used in
188        rendering the response.
189        """
190        template_names = [t.name for t in to_list(response.template)]
191        self.failIf(template_name in template_names,
192            (u"Template '%s' was used unexpectedly in rendering the"
193             u" response") % template_name)
Back to Top