Code

Ticket #2705: 2705-for_update-r15021.diff

File 2705-for_update-r15021.diff, 18.7 KB (added by danfairs, 4 years ago)

Update to prevent spurious failure on MySQL MyISAM

Line 
1diff --git a/AUTHORS b/AUTHORS
2--- a/AUTHORS
3+++ b/AUTHORS
4@@ -164,6 +164,7 @@
5     eriks@win.tue.nl
6     Tomáš Ehrlich <tomas.ehrlich@gmail.com>
7     Dirk Eschler <dirk.eschler@gmx.net>
8+    Dan Fairs <dan@fezconsulting.com>
9     Marc Fargas <telenieko@telenieko.com>
10     Szilveszter Farkas <szilveszter.farkas@gmail.com>
11     Grigory Fateyev <greg@dial.com.ru>
12diff --git a/django/db/backends/__init__.py b/django/db/backends/__init__.py
13--- a/django/db/backends/__init__.py
14+++ b/django/db/backends/__init__.py
15@@ -103,6 +103,8 @@
16     # integer primary keys.
17     related_fields_match_type = False
18     allow_sliced_subqueries = True
19+    has_select_for_update = False
20+    has_select_for_update_nowait = False
21 
22     # Does the default test database allow multiple connections?
23     # Usually an indication that the test database is in-memory
24@@ -282,6 +284,15 @@
25         """
26         return []
27 
28+    def for_update_sql(self, nowait=False):
29+        """
30+        Returns the FOR UPDATE SQL clause to lock rows for an update operation.
31+        """
32+        if nowait:
33+            return 'FOR UPDATE NOWAIT'
34+        else:
35+            return 'FOR UPDATE'
36+
37     def fulltext_search_sql(self, field_name):
38         """
39         Returns the SQL WHERE clause to use in order to perform a full-text
40diff --git a/django/db/backends/mysql/base.py b/django/db/backends/mysql/base.py
41--- a/django/db/backends/mysql/base.py
42+++ b/django/db/backends/mysql/base.py
43@@ -124,6 +124,8 @@
44     allows_group_by_pk = True
45     related_fields_match_type = True
46     allow_sliced_subqueries = False
47+    has_select_for_update = True
48+    has_select_for_update_nowait = False
49     supports_forward_references = False
50     supports_long_model_names = False
51     supports_microsecond_precision = False
52diff --git a/django/db/backends/oracle/base.py b/django/db/backends/oracle/base.py
53--- a/django/db/backends/oracle/base.py
54+++ b/django/db/backends/oracle/base.py
55@@ -70,6 +70,8 @@
56     needs_datetime_string_cast = False
57     interprets_empty_strings_as_nulls = True
58     uses_savepoints = True
59+    has_select_for_update = True
60+    has_select_for_update_nowait = True
61     can_return_id_from_insert = True
62     allow_sliced_subqueries = False
63     supports_subqueries_in_group_by = False
64diff --git a/django/db/backends/postgresql_psycopg2/base.py b/django/db/backends/postgresql_psycopg2/base.py
65--- a/django/db/backends/postgresql_psycopg2/base.py
66+++ b/django/db/backends/postgresql_psycopg2/base.py
67@@ -70,6 +70,9 @@
68     requires_rollback_on_dirty_transaction = True
69     has_real_datatype = True
70     can_defer_constraint_checks = True
71+    has_select_for_update = True
72+    has_select_for_update_nowait = True
73+   
74 
75 class DatabaseOperations(PostgresqlDatabaseOperations):
76     def last_executed_query(self, cursor, sql, params):
77diff --git a/django/db/models/manager.py b/django/db/models/manager.py
78--- a/django/db/models/manager.py
79+++ b/django/db/models/manager.py
80@@ -164,6 +164,9 @@
81     def order_by(self, *args, **kwargs):
82         return self.get_query_set().order_by(*args, **kwargs)
83 
84+    def select_for_update(self, *args, **kwargs):
85+        return self.get_query_set().select_for_update(*args, **kwargs)
86+
87     def select_related(self, *args, **kwargs):
88         return self.get_query_set().select_related(*args, **kwargs)
89 
90diff --git a/django/db/models/query.py b/django/db/models/query.py
91--- a/django/db/models/query.py
92+++ b/django/db/models/query.py
93@@ -432,6 +432,7 @@
94         del_query._for_write = True
95 
96         # Disable non-supported fields.
97+        del_query.query.select_for_update = False
98         del_query.query.select_related = False
99         del_query.query.clear_ordering()
100 
101@@ -580,6 +581,18 @@
102         else:
103             return self._filter_or_exclude(None, **filter_obj)
104 
105+    def select_for_update(self, **kwargs):
106+        """
107+        Returns a new QuerySet instance that will select objects with a
108+        FOR UPDATE lock.
109+        """
110+        # Default to false for nowait
111+        nowait = kwargs.pop('nowait', False)
112+        obj = self._clone()
113+        obj.query.select_for_update = True
114+        obj.query.select_for_update_nowait = nowait
115+        return obj
116+
117     def select_related(self, *fields, **kwargs):
118         """
119         Returns a new QuerySet instance that will select related objects.
120diff --git a/django/db/models/sql/compiler.py b/django/db/models/sql/compiler.py
121--- a/django/db/models/sql/compiler.py
122+++ b/django/db/models/sql/compiler.py
123@@ -1,5 +1,6 @@
124 from django.core.exceptions import FieldError
125 from django.db import connections
126+from django.db import transaction
127 from django.db.backends.util import truncate_name
128 from django.db.models.sql.constants import *
129 from django.db.models.sql.datastructures import EmptyResultSet
130@@ -117,6 +118,10 @@
131                         result.append('LIMIT %d' % val)
132                 result.append('OFFSET %d' % self.query.low_mark)
133 
134+        if self.query.select_for_update and self.connection.features.has_select_for_update:
135+            nowait = self.query.select_for_update_nowait and self.connection.features.has_select_for_update
136+            result.append(self.connection.ops.for_update_sql(nowait=nowait))
137+
138         return ' '.join(result), tuple(params)
139 
140     def as_nested_sql(self):
141@@ -677,6 +682,11 @@
142         resolve_columns = hasattr(self, 'resolve_columns')
143         fields = None
144         has_aggregate_select = bool(self.query.aggregate_select)
145+        # Set transaction dirty if we're using SELECT FOR UPDATE to ensure
146+        # a subsequent commit/rollback is executed, so any database locks
147+        # are released.
148+        if self.query.select_for_update and transaction.is_managed(self.using):
149+            transaction.set_dirty(self.using)
150         for rows in self.execute_sql(MULTI):
151             for row in rows:
152                 if resolve_columns:
153diff --git a/django/db/models/sql/query.py b/django/db/models/sql/query.py
154--- a/django/db/models/sql/query.py
155+++ b/django/db/models/sql/query.py
156@@ -131,6 +131,8 @@
157         self.order_by = []
158         self.low_mark, self.high_mark = 0, None  # Used for offset/limit
159         self.distinct = False
160+        self.select_for_update = False
161+        self.select_for_update_nowait = False
162         self.select_related = False
163         self.related_select_cols = []
164 
165@@ -260,6 +262,8 @@
166         obj.order_by = self.order_by[:]
167         obj.low_mark, obj.high_mark = self.low_mark, self.high_mark
168         obj.distinct = self.distinct
169+        obj.select_for_update = self.select_for_update
170+        obj.select_for_update_nowait = self.select_for_update_nowait
171         obj.select_related = self.select_related
172         obj.related_select_cols = []
173         obj.aggregates = deepcopy(self.aggregates, memo=memo)
174@@ -366,6 +370,7 @@
175 
176         query.clear_ordering(True)
177         query.clear_limits()
178+        query.select_for_update = False
179         query.select_related = False
180         query.related_select_cols = []
181         query.related_select_fields = []
182diff --git a/docs/ref/databases.txt b/docs/ref/databases.txt
183--- a/docs/ref/databases.txt
184+++ b/docs/ref/databases.txt
185@@ -364,6 +364,15 @@
186 column types have a maximum length restriction of 255 characters, regardless
187 of whether ``unique=True`` is specified or not.
188 
189+Row locking with ``QuerySet.select_for_update()``
190+-------------------------------------------------
191+
192+MySQL does not support the ``NOWAIT`` option to the ``SELECT ... FOR UPDATE``
193+statement. However, you may call the ``select_for_update()`` method of a
194+queryset with ``nowait=True``. In that case, the argument will be silently
195+discarded and the generated query will block until the requested lock can be
196+acquired.
197+
198 .. _sqlite-notes:
199 
200 SQLite notes
201diff --git a/docs/ref/models/querysets.txt b/docs/ref/models/querysets.txt
202--- a/docs/ref/models/querysets.txt
203+++ b/docs/ref/models/querysets.txt
204@@ -975,6 +975,40 @@
205     # queries the database with the 'backup' alias
206     >>> Entry.objects.using('backup')
207 
208+select_for_update
209+~~~~~~~~~~~~~~~~~
210+
211+.. method:: select_for_update(nowait=False)
212+
213+.. versionadded:: 1.3
214+
215+Returns a queryset that will lock rows until the end of the transaction,
216+generating a ``SELECT ... FOR UPDATE`` SQL statement on supported databases.
217+
218+For example::
219+
220+    entries = Entry.objects.select_for_update().filter(author=request.user)
221+
222+All matched entries will be locked until the end of the transaction block,
223+meaning that other transactions will be prevented from changing or acquiring
224+locks on them.
225+
226+Usually, if another transaction has already acquired a lock on one of the
227+selected rows, the query will block until the lock is released. If this is
228+not the behaviour you want, call ``select_for_update(nowait=True)``. This will
229+make the call non-blocking. If a conflicting lock is already acquired by
230+another transaction, ``django.db.utils.DatabaseError`` will be raised when
231+the queryset is evaluated.
232+
233+Note that using ``select_related`` will cause the current transaction to be set
234+dirty, if under transaction management. This is to ensure that Django issues a
235+``COMMIT`` or ``ROLLBACK``, releasing any locks held by the ``SELECT FOR
236+UPDATE``.
237+
238+Currently the ``postgresql_psycopg2``, ``oracle``, and ``mysql``
239+database backends support ``select_for_update()`` but MySQL has no
240+support for the ``nowait`` argument. Other backends will simply
241+generate queries as if ``select_for_update()`` had not been used.
242 
243 Methods that do not return QuerySets
244 ------------------------------------
245@@ -1253,7 +1287,7 @@
246 the only restriction on the :class:`QuerySet` that is updated is that it can
247 only update columns in the model's main table. Filtering based on related
248 fields is still possible. You cannot call ``update()`` on a
249-:class:`QuerySet` that has had a slice taken or can otherwise no longer be
250+:class:`QuerySet` that has had a slice taken or can otherwise no longer be
251 filtered.
252 
253 For example, if you wanted to update all the entries in a particular blog
254diff --git a/tests/modeltests/select_for_update/__init__.py b/tests/modeltests/select_for_update/__init__.py
255new file mode 100644
256--- /dev/null
257+++ b/tests/modeltests/select_for_update/__init__.py
258@@ -0,0 +1,1 @@
259+#
260diff --git a/tests/modeltests/select_for_update/models.py b/tests/modeltests/select_for_update/models.py
261new file mode 100644
262--- /dev/null
263+++ b/tests/modeltests/select_for_update/models.py
264@@ -0,0 +1,4 @@
265+from django.db import models
266+
267+class Person(models.Model):
268+    name = models.CharField(max_length=30)
269diff --git a/tests/modeltests/select_for_update/tests.py b/tests/modeltests/select_for_update/tests.py
270new file mode 100644
271--- /dev/null
272+++ b/tests/modeltests/select_for_update/tests.py
273@@ -0,0 +1,218 @@
274+import time
275+from django.conf import settings
276+from django.db import transaction, connection
277+from django.db.utils import ConnectionHandler, DEFAULT_DB_ALIAS, DatabaseError
278+from django.test import TransactionTestCase, skipUnlessDBFeature
279+from django.utils.functional import wraps
280+from django.utils import unittest
281+
282+from models import Person
283+
284+try:
285+    import threading
286+    def requires_threading(func):
287+        return func
288+except ImportError:
289+    # Note we can't use dummy_threading here, as our tests will actually
290+    # block. We just have to skip the test completely.
291+    def requires_threading(func):
292+        @wraps(func)
293+        def wrapped(*args, **kw):
294+            raise unittest.SkipTest('threading required')
295+
296+class SelectForUpdateTests(TransactionTestCase):
297+
298+    def setUp(self):
299+        connection._rollback()
300+        connection._enter_transaction_management(True)
301+        self.new_connections = ConnectionHandler(settings.DATABASES)
302+        self.person = Person.objects.create(name='Reinhardt')
303+
304+        # We need to set settings.DEBUG to True so we can capture
305+        # the output SQL to examine.
306+        self._old_debug = settings.DEBUG
307+        settings.DEBUG = True
308+
309+    def tearDown(self):
310+        connection._leave_transaction_management(True)
311+        settings.DEBUG = self._old_debug
312+        try:
313+            self.end_blocking_transaction()
314+        except (DatabaseError, AttributeError):
315+            pass
316+
317+    def start_blocking_transaction(self):
318+        # Start a blocking transaction. At some point,
319+        # end_blocking_transaction() should be called.
320+        self.new_connection = self.new_connections[DEFAULT_DB_ALIAS]
321+        self.new_connection._enter_transaction_management(True)
322+        self.cursor = self.new_connection.cursor()
323+        sql = 'SELECT * FROM %(db_table)s %(for_update)s;' % {
324+            'db_table': Person._meta.db_table,
325+            'for_update': self.new_connection.ops.for_update_sql(),
326+            }
327+        self.cursor.execute(sql, ())
328+        result = self.cursor.fetchone()
329+
330+    def end_blocking_transaction(self):
331+        # Roll back the blocking transaction.
332+        self.new_connection._rollback()
333+        self.new_connection.close()
334+        self.new_connection._leave_transaction_management(True)
335+
336+    def has_for_update_sql(self, tested_connection, nowait=False):
337+        # Examine the SQL that was executed to determine whether it
338+        # contains the 'SELECT..FOR UPDATE' stanza.
339+        for_update_sql = tested_connection.ops.for_update_sql(nowait)
340+        sql = tested_connection.queries[-1]['sql']
341+        return bool(sql.find(for_update_sql) > -1)
342+
343+    def check_exc(self, exc):
344+        self.failUnless(isinstance(exc, DatabaseError))
345+
346+    @skipUnlessDBFeature('has_select_for_update')
347+    def test_for_update_sql_generated(self):
348+        """
349+        Test that the backend's FOR UPDATE variant appears in
350+        generated SQL when select_for_update is invoked.
351+        """
352+        list(Person.objects.all().select_for_update())
353+        self.assertTrue(self.has_for_update_sql(connection))
354+
355+    @skipUnlessDBFeature('has_select_for_update_nowait')
356+    def test_for_update_sql_generated_nowait(self):
357+        """
358+        Test that the backend's FOR UPDATE NOWAIT variant appears in
359+        generated SQL when select_for_update is invoked.
360+        """
361+        list(Person.objects.all().select_for_update(nowait=True))
362+        self.assertTrue(self.has_for_update_sql(connection, nowait=True))
363+
364+    @requires_threading
365+    @skipUnlessDBFeature('has_select_for_update_nowait')
366+    def test_nowait_raises_error_on_block(self):
367+        """
368+        If nowait is specified, we expect an error to be raised rather
369+        than blocking.
370+        """
371+        self.start_blocking_transaction()
372+        status = []
373+        thread = threading.Thread(
374+            target=self.run_select_for_update,
375+            args=(status,),
376+            kwargs={'nowait': True},
377+        )
378+
379+        thread.start()
380+        time.sleep(1)
381+        thread.join()
382+        self.end_blocking_transaction()
383+        self.check_exc(status[-1])
384+
385+    def run_select_for_update(self, status, nowait=False):
386+        status.append('started')
387+        try:
388+            connection._rollback()
389+            people = list(Person.objects.all().select_for_update(nowait=nowait))
390+            people[0].name = 'Fred'
391+            people[0].save()
392+            connection._commit()
393+        except DatabaseError, e:
394+            status.append(e)
395+        except Exception, e:
396+            raise
397+
398+    @requires_threading
399+    @skipUnlessDBFeature('has_select_for_update')
400+    @skipUnlessDBFeature('supports_transactions')
401+    def test_block(self):
402+        """
403+        Check that a thread running a select_for_update that
404+        accesses rows being touched by a similar operation
405+        on another connection blocks correctly.
406+        """
407+        # First, let's start the transaction in our thread.
408+        self.start_blocking_transaction()
409+
410+        # Now, try it again using the ORM's select_for_update
411+        # facility. Do this in a separate thread.
412+        status = []
413+        thread = threading.Thread(target=self.run_select_for_update, args=(status,))
414+
415+        # The thread should immediately block, but we'll sleep
416+        # for a bit to make sure
417+        thread.start()
418+        sanity_count = 0
419+        while len(status) != 1 and sanity_count < 10:
420+            sanity_count += 1
421+            time.sleep(1)
422+        if sanity_count >= 10:
423+            raise ValueError, 'Thread did not run and block'
424+
425+        # Check the person hasn't been updated. Since this isn't
426+        # using FOR UPDATE, it won't block.
427+        p = Person.objects.get(pk=self.person.pk)
428+        self.assertEqual('Reinhardt', p.name)
429+
430+        # When we end our blocking transaction, our thread should
431+        # be able to continue.
432+        self.end_blocking_transaction()
433+        thread.join(5.0)
434+
435+        # Check the thread has finished. Assuming it has, we should
436+        # find that it has updated the person's name.
437+        self.failIf(thread.isAlive())
438+        p = Person.objects.get(pk=self.person.pk)
439+        self.assertEqual('Fred', p.name)
440+
441+    @requires_threading
442+    @skipUnlessDBFeature('has_select_for_update')
443+    def test_raw_lock_not_available(self):
444+        """
445+        Check that running a raw query which can't obtain a FOR UPDATE lock
446+        raises the correct exception
447+        """
448+        self.start_blocking_transaction()
449+        def raw(status):
450+            try:
451+                list(
452+                    Person.objects.raw(
453+                        'SELECT * FROM %s %s' % (
454+                            Person._meta.db_table,
455+                            connection.ops.for_update_sql(nowait=True)
456+                        )
457+                    )
458+                )
459+            except DatabaseError, e:
460+                status.append(e)
461+        status = []
462+        thread = threading.Thread(target=raw, kwargs={'status': status})
463+        thread.start()
464+        time.sleep(1)
465+        thread.join()
466+        self.end_blocking_transaction()
467+        self.check_exc(status[-1])
468+
469+    @skipUnlessDBFeature('has_select_for_update')
470+    def test_transaction_dirty_managed(self):
471+        """ Check that a select_for_update sets the transaction to be
472+        dirty when executed under txn management. Setting the txn dirty
473+        means that it will be either committed or rolled back by Django,
474+        which will release any locks held by the SELECT FOR UPDATE.
475+        """
476+        transaction.enter_transaction_management(True)
477+        transaction.managed(True)
478+        try:
479+            people = list(Person.objects.select_for_update())
480+            self.assertTrue(transaction.is_dirty())
481+        finally:
482+            transaction.rollback()
483+            transaction.leave_transaction_management()
484+
485+    @skipUnlessDBFeature('has_select_for_update')
486+    def test_transaction_not_dirty_unmanaged(self):
487+        """ If we're not under txn management, the txn will never be
488+        marked as dirty.
489+        """
490+        people = list(Person.objects.select_for_update())
491+        self.assertFalse(transaction.is_dirty())