Code

Ticket #8138: 8138alternate-tzcommit.diff

File 8138alternate-tzcommit.diff, 20.8 KB (added by kmtracey, 6 years ago)
Line 
1Index: django/test/client.py
2===================================================================
3--- django/test/client.py       (revision 9726)
4+++ django/test/client.py       (working copy)
5@@ -19,6 +19,7 @@
6 from django.utils.encoding import smart_str
7 from django.utils.http import urlencode
8 from django.utils.itercompat import is_iterable
9+from django.db import transaction, close_connection
10 
11 BOUNDARY = 'BoUnDaRyStRiNg'
12 MULTIPART_CONTENT = 'multipart/form-data; boundary=%s' % BOUNDARY
13@@ -69,7 +70,9 @@
14                 response = middleware_method(request, response)
15             response = self.apply_response_fixes(request, response)
16         finally:
17+            signals.request_finished.disconnect(close_connection)           
18             signals.request_finished.send(sender=self.__class__)
19+            signals.request_finished.connect(close_connection)
20 
21         return response
22 
23Index: django/test/testcases.py
24===================================================================
25--- django/test/testcases.py    (revision 9726)
26+++ django/test/testcases.py    (working copy)
27@@ -7,7 +7,7 @@
28 from django.core import mail
29 from django.core.management import call_command
30 from django.core.urlresolvers import clear_url_caches
31-from django.db import transaction
32+from django.db import transaction, connection
33 from django.http import QueryDict
34 from django.test import _doctest as doctest
35 from django.test.client import Client
36@@ -26,7 +26,32 @@
37         value = [value]
38     return value
39 
40+real_commit = transaction.commit
41+real_rollback = transaction.rollback
42+real_enter_transaction_management = transaction.enter_transaction_management
43+real_leave_transaction_management = transaction.leave_transaction_management
44+real_savepoint_commit = transaction.savepoint_commit
45+real_savepoint_rollback = transaction.savepoint_rollback
46 
47+def nop(x=None):
48+    return
49+
50+def disable_transaction_methods():
51+    transaction.commit = nop
52+    transaction.rollback = nop
53+    transaction.savepoint_commit = nop
54+    transaction.savepoint_rollback = nop
55+    transaction.enter_transaction_management = nop
56+    transaction.leave_transaction_management = nop       
57+
58+def restore_transaction_methods():
59+    transaction.commit = real_commit
60+    transaction.rollback = real_rollback
61+    transaction.savepoint_commit = real_savepoint_commit
62+    transaction.savepoint_rollback = real_savepoint_rollback
63+    transaction.enter_transaction_management = real_enter_transaction_management
64+    transaction.leave_transaction_management = real_leave_transaction_management
65+
66 class OutputChecker(doctest.OutputChecker):
67     def check_output(self, want, got, optionflags):
68         "The entry method for doctest output checking. Defers to a sequence of child checkers"
69@@ -168,8 +193,19 @@
70         # Rollback, in case of database errors. Otherwise they'd have
71         # side effects on other tests.
72         transaction.rollback_unless_managed()
73+       
74+    def run(self, test, compileflags=None, out=None, clear_globs=True):
75+        """
76+        Wraps the parent run() and encloses it in a transaction.
77+        """
78+        transaction.enter_transaction_management()
79+        transaction.managed(True)
80+        result = doctest.DocTestRunner.run(self, test, compileflags, out, clear_globs)
81+        transaction.rollback()
82+        transaction.leave_transaction_management()
83+        return result
84 
85-class TestCase(unittest.TestCase):
86+class TransactionTestCase(unittest.TestCase):
87     def _pre_setup(self):
88         """Performs any pre-test setup. This includes:
89 
90@@ -180,16 +216,22 @@
91               ROOT_URLCONF with it.
92             * Clearing the mail test outbox.
93         """
94+        self._fixture_setup()
95+        self._urlconf_setup()
96+        mail.outbox = []
97+
98+    def _fixture_setup(self):
99         call_command('flush', verbosity=0, interactive=False)
100         if hasattr(self, 'fixtures'):
101             # We have to use this slightly awkward syntax due to the fact
102             # that we're using *args and **kwargs together.
103             call_command('loaddata', *self.fixtures, **{'verbosity': 0})
104+
105+    def _urlconf_setup(self):
106         if hasattr(self, 'urls'):
107             self._old_root_urlconf = settings.ROOT_URLCONF
108             settings.ROOT_URLCONF = self.urls
109             clear_url_caches()
110-        mail.outbox = []
111 
112     def __call__(self, result=None):
113         """
114@@ -206,7 +248,7 @@
115             import sys
116             result.addError(self, sys.exc_info())
117             return
118-        super(TestCase, self).__call__(result)
119+        super(TransactionTestCase, self).__call__(result)       
120         try:
121             self._post_teardown()
122         except (KeyboardInterrupt, SystemExit):
123@@ -221,6 +263,13 @@
124 
125             * Putting back the original ROOT_URLCONF if it was changed.
126         """
127+        self._fixture_teardown()
128+        self._urlconf_teardown()
129+
130+    def _fixture_teardown(self):
131+        pass
132+
133+    def _urlconf_teardown(self):       
134         if hasattr(self, '_old_root_urlconf'):
135             settings.ROOT_URLCONF = self._old_root_urlconf
136             clear_url_caches()
137@@ -354,3 +403,36 @@
138         self.failIf(template_name in template_names,
139             (u"Template '%s' was used unexpectedly in rendering the"
140              u" response") % template_name)
141+
142+class TestCase(TransactionTestCase):
143+    """
144+    Does basically the same as TransactionTestCase, but surrounds every test
145+    with a transaction, monkey-patches the real transaction management routines to
146+    do nothing, and rollsback the test transaction at the end of the test. You have
147+    to use TransactionTestCase, if you need transaction management inside a test.
148+    """
149+
150+    def _fixture_setup(self):
151+        if not settings.DATABASE_SUPPORTS_TRANSACTIONS:
152+            return super(TestCase, self)._fixture_setup()
153+       
154+        transaction.enter_transaction_management()
155+        transaction.managed(True)
156+        disable_transaction_methods()
157+
158+        from django.contrib.sites.models import Site
159+        Site.objects.clear_cache()
160+
161+        if hasattr(self, 'fixtures'):
162+            call_command('loaddata', *self.fixtures, **{
163+                                                        'verbosity': 0,
164+                                                        'commit': False
165+                                                        })
166+
167+    def _fixture_teardown(self):
168+        if not settings.DATABASE_SUPPORTS_TRANSACTIONS:
169+            return super(TestCase, self)._fixture_teardown()
170+               
171+        restore_transaction_methods()
172+        transaction.rollback()
173+        transaction.leave_transaction_management()
174\ No newline at end of file
175Index: django/test/__init__.py
176===================================================================
177--- django/test/__init__.py     (revision 9726)
178+++ django/test/__init__.py     (working copy)
179@@ -3,4 +3,4 @@
180 """
181 
182 from django.test.client import Client
183-from django.test.testcases import TestCase
184+from django.test.testcases import TestCase, TransactionTestCase
185Index: django/db/backends/postgresql/base.py
186===================================================================
187--- django/db/backends/postgresql/base.py       (revision 9726)
188+++ django/db/backends/postgresql/base.py       (working copy)
189@@ -116,6 +116,7 @@
190         cursor = self.connection.cursor()
191         if set_tz:
192             cursor.execute("SET TIME ZONE %s", [settings.TIME_ZONE])
193+            self.connection.commit()
194             if not hasattr(self, '_version'):
195                 self.__class__._version = get_version(cursor)
196             if self._version < (8, 0):
197Index: django/db/backends/postgresql_psycopg2/base.py
198===================================================================
199--- django/db/backends/postgresql_psycopg2/base.py      (revision 9726)
200+++ django/db/backends/postgresql_psycopg2/base.py      (working copy)
201@@ -88,6 +88,7 @@
202         cursor.tzinfo_factory = None
203         if set_tz:
204             cursor.execute("SET TIME ZONE %s", [settings.TIME_ZONE])
205+            self.connection.commit()
206             if not hasattr(self, '_version'):
207                 self.__class__._version = get_version(cursor)
208             if self._version < (8, 0):
209Index: django/db/backends/creation.py
210===================================================================
211--- django/db/backends/creation.py      (revision 9726)
212+++ django/db/backends/creation.py      (working copy)
213@@ -311,7 +311,8 @@
214 
215         self.connection.close()
216         settings.DATABASE_NAME = test_database_name
217-
218+        settings.DATABASE_SUPPORTS_TRANSACTIONS = self._rollback_works()
219+       
220         call_command('syncdb', verbosity=verbosity, interactive=False)
221 
222         if settings.CACHE_BACKEND.startswith('db://'):
223@@ -362,7 +363,19 @@
224                 sys.exit(1)
225 
226         return test_database_name
227-
228+   
229+    def _rollback_works(self):
230+        cursor = self.connection.cursor()
231+        cursor.execute('CREATE TABLE ROLLBACK_TEST (X INT)')
232+        self.connection._commit()
233+        cursor.execute('INSERT INTO ROLLBACK_TEST (X) VALUES (8)')
234+        self.connection._rollback()
235+        cursor.execute('SELECT COUNT(X) FROM ROLLBACK_TEST')
236+        count, = cursor.fetchone()
237+        cursor.execute('DROP TABLE ROLLBACK_TEST')
238+        self.connection._commit()
239+        return count == 0
240+       
241     def destroy_test_db(self, old_database_name, verbosity=1):
242         """
243         Destroy a test database, prompting the user for confirmation if the
244Index: tests/regressiontests/generic_inline_admin/tests.py
245===================================================================
246--- tests/regressiontests/generic_inline_admin/tests.py (revision 9726)
247+++ tests/regressiontests/generic_inline_admin/tests.py (working copy)
248@@ -21,8 +21,10 @@
249         # relies on content type IDs, which will vary depending on what
250         # other tests have been run), thus we do it here.
251         e = Episode.objects.create(name='This Week in Django')
252+        self.episode_pk = e.pk
253         m = Media(content_object=e, url='http://example.com/podcast.mp3')
254         m.save()
255+        self.media_pk = m.pk
256     
257     def tearDown(self):
258         self.client.logout()
259@@ -39,7 +41,7 @@
260         """
261         A smoke test to ensure GET on the change_view works.
262         """
263-        response = self.client.get('/generic_inline_admin/admin/generic_inline_admin/episode/1/')
264+        response = self.client.get('/generic_inline_admin/admin/generic_inline_admin/episode/%d/' % self.episode_pk)
265         self.failUnlessEqual(response.status_code, 200)
266     
267     def testBasicAddPost(self):
268@@ -64,10 +66,11 @@
269             # inline data
270             "generic_inline_admin-media-content_type-object_id-TOTAL_FORMS": u"2",
271             "generic_inline_admin-media-content_type-object_id-INITIAL_FORMS": u"1",
272-            "generic_inline_admin-media-content_type-object_id-0-id": u"1",
273+            "generic_inline_admin-media-content_type-object_id-0-id": u"%d" % self.media_pk,
274             "generic_inline_admin-media-content_type-object_id-0-url": u"http://example.com/podcast.mp3",
275             "generic_inline_admin-media-content_type-object_id-1-id": u"",
276             "generic_inline_admin-media-content_type-object_id-1-url": u"",
277         }
278-        response = self.client.post('/generic_inline_admin/admin/generic_inline_admin/episode/1/', post_data)
279+        url = '/generic_inline_admin/admin/generic_inline_admin/episode/%d/' % self.episode_pk
280+        response = self.client.post(url, post_data)
281         self.failUnlessEqual(response.status_code, 302) # redirect somewhere
282Index: tests/regressiontests/comment_tests/tests/moderation_view_tests.py
283===================================================================
284--- tests/regressiontests/comment_tests/tests/moderation_view_tests.py  (revision 9726)
285+++ tests/regressiontests/comment_tests/tests/moderation_view_tests.py  (working copy)
286@@ -8,39 +8,43 @@
287 
288     def testFlagGet(self):
289         """GET the flag view: render a confirmation page."""
290-        self.createSomeComments()
291+        comments = self.createSomeComments()
292+        pk = comments[0].pk
293         self.client.login(username="normaluser", password="normaluser")
294-        response = self.client.get("/flag/1/")
295+        response = self.client.get("/flag/%d/" % pk)
296         self.assertTemplateUsed(response, "comments/flag.html")
297 
298     def testFlagPost(self):
299         """POST the flag view: actually flag the view (nice for XHR)"""
300-        self.createSomeComments()
301+        comments = self.createSomeComments()
302+        pk = comments[0].pk
303         self.client.login(username="normaluser", password="normaluser")
304-        response = self.client.post("/flag/1/")
305-        self.assertEqual(response["Location"], "http://testserver/flagged/?c=1")
306-        c = Comment.objects.get(pk=1)
307+        response = self.client.post("/flag/%d/" % pk)
308+        self.assertEqual(response["Location"], "http://testserver/flagged/?c=%d" % pk)
309+        c = Comment.objects.get(pk=pk)
310         self.assertEqual(c.flags.filter(flag=CommentFlag.SUGGEST_REMOVAL).count(), 1)
311         return c
312 
313     def testFlagPostTwice(self):
314         """Users don't get to flag comments more than once."""
315         c = self.testFlagPost()
316-        self.client.post("/flag/1/")
317-        self.client.post("/flag/1/")
318+        self.client.post("/flag/%d/" % c.pk)
319+        self.client.post("/flag/%d/" % c.pk)
320         self.assertEqual(c.flags.filter(flag=CommentFlag.SUGGEST_REMOVAL).count(), 1)
321 
322     def testFlagAnon(self):
323         """GET/POST the flag view while not logged in: redirect to log in."""
324-        self.createSomeComments()
325-        response = self.client.get("/flag/1/")
326-        self.assertEqual(response["Location"], "http://testserver/accounts/login/?next=/flag/1/")
327-        response = self.client.post("/flag/1/")
328-        self.assertEqual(response["Location"], "http://testserver/accounts/login/?next=/flag/1/")
329+        comments = self.createSomeComments()
330+        pk = comments[0].pk       
331+        response = self.client.get("/flag/%d/" % pk)
332+        self.assertEqual(response["Location"], "http://testserver/accounts/login/?next=/flag/%d/" % pk)
333+        response = self.client.post("/flag/%d/" % pk)
334+        self.assertEqual(response["Location"], "http://testserver/accounts/login/?next=/flag/%d/" % pk)
335 
336     def testFlaggedView(self):
337-        self.createSomeComments()
338-        response = self.client.get("/flagged/", data={"c":1})
339+        comments = self.createSomeComments()
340+        pk = comments[0].pk       
341+        response = self.client.get("/flagged/", data={"c":pk})
342         self.assertTemplateUsed(response, "comments/flagged.html")
343 
344     def testFlagSignals(self):
345@@ -70,23 +74,25 @@
346 
347     def testDeletePermissions(self):
348         """The delete view should only be accessible to 'moderators'"""
349-        self.createSomeComments()
350+        comments = self.createSomeComments()
351+        pk = comments[0].pk       
352         self.client.login(username="normaluser", password="normaluser")
353-        response = self.client.get("/delete/1/")
354-        self.assertEqual(response["Location"], "http://testserver/accounts/login/?next=/delete/1/")
355+        response = self.client.get("/delete/%d/" % pk)
356+        self.assertEqual(response["Location"], "http://testserver/accounts/login/?next=/delete/%d/" % pk)
357 
358         makeModerator("normaluser")
359-        response = self.client.get("/delete/1/")
360+        response = self.client.get("/delete/%d/" % pk)
361         self.assertEqual(response.status_code, 200)
362 
363     def testDeletePost(self):
364         """POSTing the delete view should mark the comment as removed"""
365-        self.createSomeComments()
366+        comments = self.createSomeComments()
367+        pk = comments[0].pk
368         makeModerator("normaluser")
369         self.client.login(username="normaluser", password="normaluser")
370-        response = self.client.post("/delete/1/")
371-        self.assertEqual(response["Location"], "http://testserver/deleted/?c=1")
372-        c = Comment.objects.get(pk=1)
373+        response = self.client.post("/delete/%d/" % pk)
374+        self.assertEqual(response["Location"], "http://testserver/deleted/?c=%d" % pk)
375+        c = Comment.objects.get(pk=pk)
376         self.failUnless(c.is_removed)
377         self.assertEqual(c.flags.filter(flag=CommentFlag.MODERATOR_DELETION, user__username="normaluser").count(), 1)
378 
379@@ -103,21 +109,23 @@
380         self.assertEqual(received_signals, [signals.comment_was_flagged])
381 
382     def testDeletedView(self):
383-        self.createSomeComments()
384-        response = self.client.get("/deleted/", data={"c":1})
385+        comments = self.createSomeComments()
386+        pk = comments[0].pk       
387+        response = self.client.get("/deleted/", data={"c":pk})
388         self.assertTemplateUsed(response, "comments/deleted.html")
389 
390 class ApproveViewTests(CommentTestCase):
391 
392     def testApprovePermissions(self):
393         """The delete view should only be accessible to 'moderators'"""
394-        self.createSomeComments()
395+        comments = self.createSomeComments()
396+        pk = comments[0].pk       
397         self.client.login(username="normaluser", password="normaluser")
398-        response = self.client.get("/approve/1/")
399-        self.assertEqual(response["Location"], "http://testserver/accounts/login/?next=/approve/1/")
400+        response = self.client.get("/approve/%d/" % pk)
401+        self.assertEqual(response["Location"], "http://testserver/accounts/login/?next=/approve/%d/" % pk)
402 
403         makeModerator("normaluser")
404-        response = self.client.get("/approve/1/")
405+        response = self.client.get("/approve/%d/" % pk)
406         self.assertEqual(response.status_code, 200)
407 
408     def testApprovePost(self):
409@@ -127,9 +135,9 @@
410 
411         makeModerator("normaluser")
412         self.client.login(username="normaluser", password="normaluser")
413-        response = self.client.post("/approve/1/")
414-        self.assertEqual(response["Location"], "http://testserver/approved/?c=1")
415-        c = Comment.objects.get(pk=1)
416+        response = self.client.post("/approve/%d/" % c1.pk)
417+        self.assertEqual(response["Location"], "http://testserver/approved/?c=%d" % c1.pk)
418+        c = Comment.objects.get(pk=c1.pk)
419         self.failUnless(c.is_public)
420         self.assertEqual(c.flags.filter(flag=CommentFlag.MODERATOR_APPROVAL, user__username="normaluser").count(), 1)
421 
422@@ -146,8 +154,9 @@
423         self.assertEqual(received_signals, [signals.comment_was_flagged])
424 
425     def testApprovedView(self):
426-        self.createSomeComments()
427-        response = self.client.get("/approved/", data={"c":1})
428+        comments = self.createSomeComments()
429+        pk = comments[0].pk       
430+        response = self.client.get("/approved/", data={"c":pk})
431         self.assertTemplateUsed(response, "comments/approved.html")
432 
433 
434Index: tests/regressiontests/comment_tests/tests/comment_view_tests.py
435===================================================================
436--- tests/regressiontests/comment_tests/tests/comment_view_tests.py     (revision 9726)
437+++ tests/regressiontests/comment_tests/tests/comment_view_tests.py     (working copy)
438@@ -1,3 +1,4 @@
439+import re
440 from django.conf import settings
441 from django.contrib.auth.models import User
442 from django.contrib.comments import signals
443@@ -5,6 +6,8 @@
444 from regressiontests.comment_tests.models import Article
445 from regressiontests.comment_tests.tests import CommentTestCase
446 
447+post_redirect_re = re.compile(r'^http://testserver/posted/\?c=(?P<pk>\d+$)')
448+
449 class CommentViewTests(CommentTestCase):
450 
451     def testPostCommentHTTPMethods(self):
452@@ -181,18 +184,26 @@
453         a = Article.objects.get(pk=1)
454         data = self.getValidData(a)
455         response = self.client.post("/post/", data)
456-        self.assertEqual(response["Location"], "http://testserver/posted/?c=1")
457-
458+        location = response["Location"]
459+        match = post_redirect_re.match(location)
460+        self.failUnless(match != None, "Unexpected redirect location: %s" % location)
461+       
462         data["next"] = "/somewhere/else/"
463         data["comment"] = "This is another comment"
464         response = self.client.post("/post/", data)
465-        self.assertEqual(response["Location"], "http://testserver/somewhere/else/?c=2")
466+        location = response["Location"]       
467+        match = re.search(r"^http://testserver/somewhere/else/\?c=\d+$", location)
468+        self.failUnless(match != None, "Unexpected redirect location: %s" % location)
469 
470     def testCommentDoneView(self):
471         a = Article.objects.get(pk=1)
472         data = self.getValidData(a)
473         response = self.client.post("/post/", data)
474-        response = self.client.get("/posted/", {'c':1})
475+        location = response["Location"]       
476+        match = post_redirect_re.match(location)
477+        self.failUnless(match != None, "Unexpected redirect location: %s" % location)
478+        pk = int(match.group('pk'))
479+        response = self.client.get(location)
480         self.assertTemplateUsed(response, "comments/posted.html")
481-        self.assertEqual(response.context[0]["comment"], Comment.objects.get(pk=1))
482+        self.assertEqual(response.context[0]["comment"], Comment.objects.get(pk=pk))
483