Ticket #19527: django_patch_bulk_create_sets_pk-22198e6c.patch

File django_patch_bulk_create_sets_pk-22198e6c.patch, 8.5 KB (added by acrefoot@…, 10 years ago)

A simple implementation of a bulk_create that sets primary keys on inserted objects. Includes a test and slightly modified docs. Only tested on postgres_psycopg2, so I need someone to test it on the Oracle backend. Applies to 22198e6c (current master)

  • django/db/backends/__init__.py

    commit 660a401179d860f15d99f115ddd349038e7ade4c (HEAD, refs/heads/acrefoot-django-patch-1.6)
    Author: acrefoot <acrefoot@zulip.com>
    Date:   Mon Jan 13 11:31:05 2014 -0500
    
        bulk_create sets primary keys on created objects
        
        Addresses #19527
    
    diff --git a/django/db/backends/__init__.py b/django/db/backends/__init__.py
    index 2dbb8b3..86ee001 100644
    a b class BaseDatabaseOperations(object):  
    835835        """
    836836        return cursor.fetchone()[0]
    837837
     838    def fetch_returned_insert_ids(self, cursor):
     839        """
     840        Given a cursor object that has just performed an INSERT...RETURNING
     841        statement into a table that has an auto-incrementing ID, returns the
     842        list of newly created IDs.
     843        """
     844        return [item[0] for item in cursor.fetchall()]
     845
    838846    def field_cast_sql(self, db_type, internal_type):
    839847        """
    840848        Given a column type (e.g. 'BLOB', 'VARCHAR'), and an internal type
  • django/db/models/query.py

    diff --git a/django/db/models/query.py b/django/db/models/query.py
    index 48d295c..c70de14 100644
    a b class QuerySet(object):  
    400400                    self._batched_insert(objs_with_pk, fields, batch_size)
    401401                if objs_without_pk:
    402402                    fields = [f for f in fields if not isinstance(f, AutoField)]
    403                     self._batched_insert(objs_without_pk, fields, batch_size)
     403                    ids = self._batched_insert(objs_without_pk, fields, batch_size)
     404                    for (obj_without_pk, pk) in zip(objs_without_pk, ids):
     405                        obj_without_pk.pk = pk
    404406
    405407        return objs
    406408
    class QuerySet(object):  
    928930            return
    929931        ops = connections[self.db].ops
    930932        batch_size = (batch_size or max(ops.bulk_batch_size(fields, objs), 1))
     933        ret = []
    931934        for batch in [objs[i:i + batch_size]
    932935                      for i in range(0, len(objs), batch_size)]:
    933             self.model._base_manager._insert(batch, fields=fields,
    934                                              using=self.db)
     936            if connections[self.db].features.can_return_id_from_insert:
     937                if len(objs) > 1:
     938                    ret.extend(self.model._base_manager._insert(batch, fields=fields,
     939                                                                using=self.db, return_id=True))
     940                else:
     941                    ret.append(self.model._base_manager._insert(batch, fields=fields,
     942                                                                using=self.db, return_id=True))
     943            else:
     944                self.model._base_manager._insert(batch, fields=fields,
     945                                                 using=self.db)
     946        return ret
    935947
    936948    def _clone(self, klass=None, setup=False, **kwargs):
    937949        if klass is None:
  • django/db/models/sql/compiler.py

    diff --git a/django/db/models/sql/compiler.py b/django/db/models/sql/compiler.py
    index 41bba93..203804e 100644
    a b class SQLInsertCompiler(SQLCompiler):  
    841841            values = [[self.connection.ops.pk_default_value()] for obj in self.query.objs]
    842842            params = [[]]
    843843            fields = [None]
    844         can_bulk = (not any(hasattr(field, "get_placeholder") for field in fields) and
    845             not self.return_id and self.connection.features.has_bulk_insert)
    846 
    847         if can_bulk:
     844        can_bulk_and_not_return_ids = (not any(hasattr(field, "get_placeholder") for field in fields) and
     845                                      not self.return_id and self.connection.features.has_bulk_insert)
     846        # If not all of these conditions are met, fall back to doing a bunch of single inserts
     847        can_bulk_and_return_ids = self.return_id and self.connection.features.can_return_id_from_insert \
     848                                  and (not any(hasattr(field, "get_placeholder") for field in fields)) \
     849                                  and self.connection.features.has_bulk_insert
     850        if can_bulk_and_not_return_ids or can_bulk_and_return_ids:
    848851            placeholders = [["%s"] * len(fields)]
    849852        else:
    850853            placeholders = [
    class SQLInsertCompiler(SQLCompiler):  
    853856            ]
    854857            # Oracle Spatial needs to remove some values due to #10888
    855858            params = self.connection.ops.modify_insert_params(placeholders, params)
    856         if self.return_id and self.connection.features.can_return_id_from_insert:
     859        if can_bulk_and_return_ids:
     860            result.append(self.connection.ops.bulk_insert_sql(fields, len(values)))
     861            r_fmt, r_params = self.connection.ops.return_insert_id()
     862            if r_fmt:
     863                col = "%s.%s" % (qn(opts.db_table), qn(opts.pk.column))
     864                result.append(r_fmt % col)
     865                params += r_params
     866            return [(" ".join(result), tuple([v for val in values for v in val]))]
     867        elif self.return_id and self.connection.features.can_return_id_from_insert:
    857868            params = params[0]
    858869            col = "%s.%s" % (qn(opts.db_table), qn(opts.pk.column))
    859870            result.append("VALUES (%s)" % ", ".join(placeholders[0]))
    class SQLInsertCompiler(SQLCompiler):  
    864875                result.append(r_fmt % col)
    865876                params += r_params
    866877            return [(" ".join(result), tuple(params))]
    867         if can_bulk:
     878        elif can_bulk_and_not_return_ids:
    868879            result.append(self.connection.ops.bulk_insert_sql(fields, len(values)))
    869880            return [(" ".join(result), tuple(v for val in values for v in val))]
    870881        else:
    class SQLInsertCompiler(SQLCompiler):  
    874885            ]
    875886
    876887    def execute_sql(self, return_id=False):
    877         assert not (return_id and len(self.query.objs) != 1)
     888        assert not (return_id and len(self.query.objs) != 1 and
     889                    not self.connection.features.can_return_id_from_insert)
    878890        self.return_id = return_id
    879891        cursor = self.connection.cursor()
    880892        for sql, params in self.as_sql():
    881893            cursor.execute(sql, params)
    882894        if not (return_id and cursor):
    883895            return
    884         if self.connection.features.can_return_id_from_insert:
     896        if self.connection.features.can_return_id_from_insert and len(self.query.objs) > 1:
     897            return self.connection.ops.fetch_returned_insert_ids(cursor)
     898        if self.connection.features.can_return_id_from_insert and len(self.query.objs) == 1:
    885899            return self.connection.ops.fetch_returned_insert_id(cursor)
    886900        return self.connection.ops.last_insert_id(cursor,
    887901                self.query.get_meta().db_table, self.query.get_meta().pk.column)
  • docs/ref/models/querysets.txt

    diff --git a/docs/ref/models/querysets.txt b/docs/ref/models/querysets.txt
    index e15e55d..aee4ddf 100644
    a b This has a number of caveats though:  
    16241624  ``post_save`` signals will not be sent.
    16251625* It does not work with child models in a multi-table inheritance scenario.
    16261626* If the model's primary key is an :class:`~django.db.models.AutoField` it
    1627   does not retrieve and set the primary key attribute, as ``save()`` does.
     1627  does not retrieve and set the primary key attribute, as ``save()`` does,
     1628  unless the db backend also has the feature ``can_return_id_from_insert``
     1629  (currently ``postgresql_psycopg2``, ``oracle``).
    16281630
    16291631The ``batch_size`` parameter controls how many objects are created in single
    16301632query. The default is to create all objects in one batch, except for SQLite
  • tests/bulk_create/tests.py

    diff --git a/tests/bulk_create/tests.py b/tests/bulk_create/tests.py
    index 367cbde..f9efa2e 100644
    a b class BulkCreateTests(TestCase):  
    165165        TwoFields.objects.all().delete()
    166166        with self.assertNumQueries(1):
    167167            TwoFields.objects.bulk_create(objs, len(objs))
     168
     169    @skipUnlessDBFeature('can_return_id_from_insert')
     170    def test_set_pk_and_efficiency(self):
     171        countries = []
     172        with self.assertNumQueries(1):
     173            countries = Country.objects.bulk_create(self.data)
     174        self.assertEqual(len(countries), 4)
     175        self.assertEqual(Country.objects.get(pk=countries[0].pk), countries[0])
     176        self.assertEqual(Country.objects.get(pk=countries[1].pk), countries[1])
     177        self.assertEqual(Country.objects.get(pk=countries[2].pk), countries[2])
     178        self.assertEqual(Country.objects.get(pk=countries[3].pk), countries[3])
Back to Top