Code

Ticket #2259: 2259_r13817.diff

File 2259_r13817.diff, 11.4 KB (added by carljm, 4 years ago)
Line 
1diff --git a/django/contrib/admin/options.py b/django/contrib/admin/options.py
2index 3b6e2b7..f1b89ba 100644
3--- a/django/contrib/admin/options.py
4+++ b/django/contrib/admin/options.py
5@@ -600,7 +600,7 @@ class ModelAdmin(BaseModelAdmin):
6         """
7         Given a model instance save it to the database.
8         """
9-        obj.save()
10+        obj.save(original_pk=form.initial.get(self.model._meta.pk.name))
11 
12     def save_formset(self, request, form, formset, change):
13         """
14@@ -680,18 +680,19 @@ class ModelAdmin(BaseModelAdmin):
15         """
16         opts = obj._meta
17         pk_value = obj._get_pk_val()
18+        same_obj_url = "../%s/" % pk_value
19 
20         msg = _('The %(name)s "%(obj)s" was changed successfully.') % {'name': force_unicode(opts.verbose_name), 'obj': force_unicode(obj)}
21         if request.POST.has_key("_continue"):
22             self.message_user(request, msg + ' ' + _("You may edit it again below."))
23             if request.REQUEST.has_key('_popup'):
24-                return HttpResponseRedirect(request.path + "?_popup=1")
25+                return HttpResponseRedirect(same_obj_url + "?_popup=1")
26             else:
27-                return HttpResponseRedirect(request.path)
28+                return HttpResponseRedirect(same_obj_url)
29         elif request.POST.has_key("_saveasnew"):
30             msg = _('The %(name)s "%(obj)s" was added successfully. You may edit it again below.') % {'name': force_unicode(opts.verbose_name), 'obj': obj}
31             self.message_user(request, msg)
32-            return HttpResponseRedirect("../%s/" % pk_value)
33+            return HttpResponseRedirect(same_obj_url)
34         elif request.POST.has_key("_addanother"):
35             self.message_user(request, msg + ' ' + (_("You may add another %s below.") % force_unicode(opts.verbose_name)))
36             return HttpResponseRedirect("../add/")
37@@ -875,6 +876,7 @@ class ModelAdmin(BaseModelAdmin):
38             return self.add_view(request, form_url='../add/')
39 
40         ModelForm = self.get_form(request, obj)
41+
42         formsets = []
43         if request.method == 'POST':
44             form = ModelForm(request.POST, request.FILES, instance=obj)
45diff --git a/django/db/models/base.py b/django/db/models/base.py
46index b3deda1..1b13a88 100644
47--- a/django/db/models/base.py
48+++ b/django/db/models/base.py
49@@ -420,7 +420,7 @@ class Model(object):
50             return getattr(self, field_name)
51         return getattr(self, field.attname)
52 
53-    def save(self, force_insert=False, force_update=False, using=None):
54+    def save(self, force_insert=False, force_update=False, using=None, original_pk=None):
55         """
56         Saves the current instance. Override this in a subclass if you want to
57         control the saving process.
58@@ -428,15 +428,22 @@ class Model(object):
59         The 'force_insert' and 'force_update' parameters can be used to insist
60         that the "save" must be an SQL insert or update (or equivalent for
61         non-SQL backends), respectively. Normally, they should not be set.
62+
63+        The 'original_pk' argument allows updating the primary key of an
64+        existing row. Assuming the actual instance pk attribute has already
65+        been updated to the new value, passing in the original pk will update
66+        that row.
67         """
68         if force_insert and force_update:
69             raise ValueError("Cannot force both insert and updating in model saving.")
70-        self.save_base(using=using, force_insert=force_insert, force_update=force_update)
71+        if force_insert and original_pk is not None:
72+            raise ValueError("Cannot force insert and provide an original primary key.")
73+        self.save_base(using=using, force_insert=force_insert, force_update=force_update, original_pk=original_pk)
74 
75     save.alters_data = True
76 
77     def save_base(self, raw=False, cls=None, origin=None, force_insert=False,
78-            force_update=False, using=None):
79+            force_update=False, using=None, original_pk=None):
80         """
81         Does the heavy-lifting involved in saving. Subclasses shouldn't need to
82         override this method. It's separate from save() in order to hide the
83@@ -483,11 +490,19 @@ class Model(object):
84                 return
85 
86         if not meta.proxy:
87-            non_pks = [f for f in meta.local_fields if not f.primary_key]
88+            fields_to_update = meta.local_fields
89 
90             # First, try an UPDATE. If that doesn't update anything, do an INSERT.
91             pk_val = self._get_pk_val(meta)
92             pk_set = pk_val is not None
93+
94+            if original_pk is not None:
95+                if not pk_set:
96+                    raise ValueError("Cannot update existing row primary key to None")
97+                pk_val = original_pk
98+            else:
99+                fields_to_update = [f for f in fields_to_update if not f.primary_key]
100+
101             record_exists = True
102             manager = cls._base_manager
103             if pk_set:
104@@ -495,8 +510,11 @@ class Model(object):
105                 if (force_update or (not force_insert and
106                         manager.using(using).filter(pk=pk_val).exists())):
107                     # It does already exist, so do an UPDATE.
108-                    if force_update or non_pks:
109-                        values = [(f, None, (raw and getattr(self, f.attname) or f.pre_save(self, False))) for f in non_pks]
110+                    if force_update or fields_to_update:
111+                        values = [
112+                            (f, None, (raw and getattr(self, f.attname) or f.pre_save(self, False)))
113+                            for f in fields_to_update
114+                        ]
115                         rows = manager.using(using).filter(pk=pk_val)._update(values)
116                         if force_update and not rows:
117                             raise DatabaseError("Forced update did not affect any rows.")
118diff --git a/tests/regressiontests/admin_views/tests.py b/tests/regressiontests/admin_views/tests.py
119index 725369a..6de4b52 100644
120--- a/tests/regressiontests/admin_views/tests.py
121+++ b/tests/regressiontests/admin_views/tests.py
122@@ -326,6 +326,53 @@ class SaveAsTests(TestCase):
123         response = self.client.post('/test_admin/admin/admin_views/person/1/', post_data)
124         self.assertEqual(response.context['form_url'], '../add/')
125 
126+class UpdatePrimaryKeyTests(TestCase):
127+    fixtures = ['admin-views-users.xml', 'string-primary-key.xml']
128+    model_class = ModelWithStringPrimaryKey
129+    model_pk = """abcdefghijklmnopqrstuvwxyz ABCDEFGHIJKLMNOPQRSTUVWXYZ 1234567890 -_.!~*'() ;/?:@&=+$, <>#%" {}|\^[]`"""
130+
131+    def setUp(self):
132+        self.client.login(username='super', password='secret')
133+        self.instance = self._get_by_pk(self.model_pk)
134+
135+    def tearDown(self):
136+        self.client.logout()
137+
138+    def _get_by_pk(self, pk_val):
139+        return self._manager.get(pk=pk_val)
140+
141+    @property
142+    def _manager(self):
143+        return self.model_class._default_manager
144+
145+    def _pk_name(self):
146+        return self.model_class._meta.pk.attname
147+
148+    def _change_url(self, pk=None):
149+        pk = pk or self.instance.pk
150+        return ('/test_admin/admin/admin_views/%s/%s/'
151+                % (self.model_class._meta.module_name,
152+                   iri_to_uri(quote(pk))))
153+
154+    def test_update_pk(self):
155+        post_data = {
156+            self._pk_name(): 'new pk value',
157+            }
158+        count = self._manager.count()
159+        response = self.client.post(self._change_url(), data=post_data)
160+        self.assertEqual(response.status_code, 302)
161+        self.assertEqual(self._manager.filter(pk='new pk value').count(), 1)
162+        self.assertEqual(self._manager.count(), count)
163+        self.assertFalse(self._manager.filter(pk=self.model_pk).exists())
164+
165+    def test_update_pk_continue_redirect(self):
166+        post_data = {
167+            self._pk_name(): 'new pk value',
168+            '_continue': 'Save and continue editing',
169+            }
170+        response = self.client.post(self._change_url(), data=post_data)
171+        self.assertRedirects(response, self._change_url('new pk value'))
172+
173 class CustomModelAdminTest(AdminViewBasicTest):
174     urlbit = "admin2"
175 
176diff --git a/tests/regressiontests/model_save/__init__.py b/tests/regressiontests/model_save/__init__.py
177new file mode 100644
178index 0000000..e69de29
179diff --git a/tests/regressiontests/model_save/models.py b/tests/regressiontests/model_save/models.py
180new file mode 100644
181index 0000000..c9848d0
182--- /dev/null
183+++ b/tests/regressiontests/model_save/models.py
184@@ -0,0 +1,12 @@
185+from django.db import models
186+
187+class Beverage(models.Model):
188+    id = models.IntegerField(primary_key=True)
189+    name = models.CharField(max_length=100)
190+
191+class Train(models.Model):
192+    name = models.CharField(max_length=100, primary_key=True)
193+
194+class TrainCar(models.Model):
195+    train = models.ForeignKey(Train)
196+    number = models.IntegerField()
197diff --git a/tests/regressiontests/model_save/tests.py b/tests/regressiontests/model_save/tests.py
198new file mode 100644
199index 0000000..2e6a026
200--- /dev/null
201+++ b/tests/regressiontests/model_save/tests.py
202@@ -0,0 +1,70 @@
203+from django.db import DatabaseError
204+from django.test import TestCase, TransactionTestCase
205+
206+from models import Beverage, Train, TrainCar
207+
208+class UpdatePKTests(TestCase):
209+    def setUp(self):
210+        self.b = Beverage.objects.create(id=1, name='Water')
211+
212+    def test_update_pk(self):
213+        self.b.id = 2
214+        self.b.save(original_pk=1)
215+        self.assertEqual(Beverage.objects.count(), 1)
216+        self.assertEqual(Beverage.objects.filter(id=2).count(), 1)
217+        self.assertEqual(Beverage.objects.filter(id=1).count(), 0)
218+
219+    def test_update_pk_to_none(self):
220+        """
221+        Trying to update a row to a PK of None is totally not legit.
222+
223+        """
224+        self.b.id = None
225+        self.assertRaises(ValueError, self.b.save, original_pk=1)
226+
227+    def test_bogus_original_pk(self):
228+        """
229+        Asking to update a row with a nonexistent original_pk results in a
230+        DatabaseError (same as combining force_update with a nonexistent pk).
231+
232+        """
233+        self.assertRaises(DatabaseError, self.b.save, original_pk=7)
234+
235+    def test_force_insert_and_original_pk(self):
236+        """
237+        Asking to update a row with an original_pk and force_insert is
238+        nonsensical and results in a ValueError.
239+
240+        """
241+        self.b.id = 2
242+        self.assertRaises(ValueError, self.b.save, original_pk=1, force_insert=True)
243+
244+    def test_force_update_and_original_pk(self):
245+        """
246+        Combining force_update and original_pk is just redundant, not
247+        problematic.
248+
249+        """
250+        self.b.id = 2
251+        self.b.save(original_pk=1, force_update=True)
252+        self.assertEqual(Beverage.objects.count(), 1)
253+        self.assertEqual(Beverage.objects.filter(id=2).count(), 1)
254+        self.assertEqual(Beverage.objects.filter(id=1).count(), 0)
255+
256+    def test_update_pk_cascade(self):
257+        """
258+        Updating the primary key of an object that is the referent of
259+        ForeignKeys needs to cascade the update to referring row, or
260+        referential integrity is broken.
261+
262+        XXX This test currently fails because we have neither ON UPDATE CASCADE
263+        support nor any emulation of it.
264+
265+        """
266+        train = Train.objects.create(name='California Zephyr')
267+        car = TrainCar.objects.create(train=train, number=12)
268+
269+        train.name = 'Empire Builder'
270+        train.save(original_pk='California Zephyr')
271+        car = TrainCar.objects.get(pk=car.pk)
272+        self.assertEqual(car.train.name, 'Empire Builder')