Code

Ticket #16385: 2885579c6f8a2f41422788b0abdee3da816144f8.diff

File 2885579c6f8a2f41422788b0abdee3da816144f8.diff, 6.6 KB (added by jonash, 3 years ago)

Adds a possibility to disable exists() checks in save().

Line 
1diff --git a/django/db/backends/__init__.py b/django/db/backends/__init__.py
2index 01144eb..2e24f73 100644
3--- a/django/db/backends/__init__.py
4+++ b/django/db/backends/__init__.py
5@@ -281,6 +281,7 @@ class BaseDatabaseFeatures(object):
6     allow_sliced_subqueries = True
7 
8     supports_joins = True
9+    distinguishes_insert_from_update = True
10     supports_select_related = True
11 
12     # Does the default test database allow multiple connections?
13diff --git a/django/db/models/base.py b/django/db/models/base.py
14index 286f9b0..c0238b6 100644
15--- a/django/db/models/base.py
16+++ b/django/db/models/base.py
17@@ -273,6 +273,7 @@ class Model(object):
18     _deferred = False
19 
20     def __init__(self, *args, **kwargs):
21+        self._entity_exists = kwargs.pop('__entity_exists', False)
22         signals.pre_init.send(sender=self.__class__, args=args, kwargs=kwargs)
23 
24         # Set up the storage for instance state
25@@ -362,6 +363,7 @@ class Model(object):
26                     pass
27             if kwargs:
28                 raise TypeError("'%s' is an invalid keyword argument for this function" % kwargs.keys()[0])
29+        self._original_pk = self.pk if self._meta.pk is not None else None
30         super(Model, self).__init__()
31         signals.post_init.send(sender=self.__class__, instance=self)
32 
33@@ -470,6 +472,7 @@ class Model(object):
34         ('raw', 'cls', and 'origin').
35         """
36         using = using or router.db_for_write(self.__class__, instance=self)
37+        entity_exists = bool(self._entity_exists and self._original_pk == self.pk)
38         connection = connections[using]
39         assert not (force_insert and force_update)
40         if cls is None:
41@@ -516,7 +519,19 @@ class Model(object):
42             pk_set = pk_val is not None
43             record_exists = True
44             manager = cls._base_manager
45-            if pk_set:
46+            # TODO/NONREL: Some backends could emulate force_insert/_update
47+            # with an optimistic transaction, but since it's costly we should
48+            # only do it when the user explicitly wants it.
49+            # By adding support for an optimistic locking transaction
50+            # in Django (SQL: SELECT ... FOR UPDATE) we could even make that
51+            # part fully reusable on all backends (the current .exists()
52+            # check below isn't really safe if you have lots of concurrent
53+            # requests. BTW, and neither is QuerySet.get_or_create).
54+            try_update = connection.features.distinguishes_insert_from_update
55+            if not try_update:
56+                record_exists = False
57+
58+            if try_update and pk_set:
59                 # Determine whether a record with the primary key already exists.
60                 if (force_update or (not force_insert and
61                         manager.using(using).filter(pk=pk_val).exists())):
62@@ -536,13 +551,18 @@ class Model(object):
63                     order_value = manager.using(using).filter(**{field.name: getattr(self, field.attname)}).count()
64                     self._order = order_value
65 
66+                if connection.features.distinguishes_insert_from_update:
67+                    add = True
68+                else:
69+                    add = not entity_exists
70+
71                 if not pk_set:
72                     if force_update:
73                         raise ValueError("Cannot force an update in save() with no primary key.")
74-                    values = [(f, f.get_db_prep_save(raw and getattr(self, f.attname) or f.pre_save(self, True), connection=connection))
75+                    values = [(f, f.get_db_prep_save(raw and getattr(self, f.attname) or f.pre_save(self, add), connection=connection))
76                         for f in meta.local_fields if not isinstance(f, AutoField)]
77                 else:
78-                    values = [(f, f.get_db_prep_save(raw and getattr(self, f.attname) or f.pre_save(self, True), connection=connection))
79+                    values = [(f, f.get_db_prep_save(raw and getattr(self, f.attname) or f.pre_save(self, add), connection=connection))
80                         for f in meta.local_fields]
81 
82                 record_exists = False
83@@ -566,8 +586,15 @@ class Model(object):
84 
85         # Signal that the save is complete
86         if origin and not meta.auto_created:
87+            if connection.features.distinguishes_insert_from_update:
88+                created = not record_exists
89+            else:
90+                created = not entity_exists
91             signals.post_save.send(sender=origin, instance=self,
92-                created=(not record_exists), raw=raw, using=using)
93+                created=created, raw=raw, using=using)
94+
95+        self._entity_exists = True
96+        self._original_pk = self.pk
97 
98 
99     save_base.alters_data = True
100@@ -580,6 +607,9 @@ class Model(object):
101         collector.collect([self])
102         collector.delete()
103 
104+        self._entity_exists = False
105+        self._original_pk = None
106+
107     delete.alters_data = True
108 
109     def _get_FIELD_display(self, field):
110diff --git a/django/db/models/query.py b/django/db/models/query.py
111index 26402ba..7a9bb45 100644
112--- a/django/db/models/query.py
113+++ b/django/db/models/query.py
114@@ -282,10 +282,10 @@ class QuerySet(object):
115                 if skip:
116                     row_data = row[index_start:aggregate_start]
117                     pk_val = row_data[pk_idx]
118-                    obj = model_cls(**dict(zip(init_list, row_data)))
119+                    obj = model_cls(**dict(zip(init_list, row_data), __entity_exists=True))
120                 else:
121                     # Omit aggregates in object creation.
122-                    obj = model(*row[index_start:aggregate_start])
123+                    obj = model(*row[index_start:aggregate_start], **{'__entity_exists': True})
124 
125                 # Store the source database of the object
126                 obj._state.db = db
127@@ -1197,9 +1197,9 @@ def get_cached_row(klass, row, index_start, using, max_depth=0, cur_depth=0,
128             obj = None
129         elif skip:
130             klass = deferred_class_factory(klass, skip)
131-            obj = klass(**dict(zip(init_list, fields)))
132+            obj = klass(__entity_exists=True, **dict(zip(init_list, fields)))
133         else:
134-            obj = klass(*fields)
135+            obj = klass(*fields, **{'__entity_exists': True})
136 
137     else:
138         # Load all fields on klass
139@@ -1215,7 +1215,7 @@ def get_cached_row(klass, row, index_start, using, max_depth=0, cur_depth=0,
140         if fields == (None,) * field_count:
141             obj = None
142         else:
143-            obj = klass(**dict(zip(field_names, fields)))
144+            obj = klass(__entity_exists=True, **dict(zip(field_names, fields)))
145 
146     # If an object was retrieved, set the database state.
147     if obj: