Ticket #19527: django_patch_bulk_create_sets_pk-1.5.1.patch

File django_patch_bulk_create_sets_pk-1.5.1.patch, 6.5 KB (added by acrefoot@…, 12 years ago)

A simple implementation of a bulk_create that sets primary keys on inserted objects. Only tested with psychopg2. Applies to git tag "1.5.1".

  • django/db/backends/__init__.py

    diff --git a/django/db/backends/__init__.py b/django/db/backends/__init__.py
    index 1decce0..3626ec0 100644
    a b class BaseDatabaseOperations(object):  
    576576        """
    577577        return cursor.fetchone()[0]
    578578
     579    def fetch_returned_insert_ids(self, cursor):
     580        """
     581        Given a cursor object that has just performed an INSERT...RETURNING
     582        statement into a table that has an auto-incrementing ID, returns the
     583        list of newly created IDs.
     584        """
     585        return [item[0] for item in cursor.fetchall()]
     586
    579587    def field_cast_sql(self, db_type):
    580588        """
    581589        Given a column type (e.g. 'BLOB', 'VARCHAR'), returns the SQL necessary
  • django/db/models/query.py

    diff --git a/django/db/models/query.py b/django/db/models/query.py
    index 952739e..0119170 100644
    a b class QuerySet(object):  
    441441                    self._batched_insert(objs_with_pk, fields, batch_size)
    442442                if objs_without_pk:
    443443                    fields= [f for f in fields if not isinstance(f, AutoField)]
    444                     self._batched_insert(objs_without_pk, fields, batch_size)
     444                    ids = self._batched_insert(objs_without_pk, fields, batch_size)
     445                    for i in range(len(ids)):
     446                        objs_without_pk[i].pk = ids[i]
    445447            if forced_managed:
    446448                transaction.commit(using=self.db)
    447449            else:
    class QuerySet(object):  
    896898            return
    897899        ops = connections[self.db].ops
    898900        batch_size = (batch_size or max(ops.bulk_batch_size(fields, objs), 1))
     901        ret = []
    899902        for batch in [objs[i:i+batch_size]
    900903                      for i in range(0, len(objs), batch_size)]:
    901             self.model._base_manager._insert(batch, fields=fields,
    902                                              using=self.db)
     904            if connections[self.db].features.can_return_id_from_insert:
     905                if len(objs) > 1:
     906                    ret.extend(self.model._base_manager._insert(batch, fields=fields,
     907                                                                using=self.db, return_id=True))
     908                else:
     909                    assert(len(objs) is 1)
     910                    ret.append(self.model._base_manager._insert(batch, fields=fields,
     911                                                                using=self.db, return_id=True))
     912            else:
     913                self.model._base_manager._insert(batch, fields=fields,
     914                                                 using=self.db)
     915        return ret
    903916
    904917    def _clone(self, klass=None, setup=False, **kwargs):
    905918        if klass is None:
  • django/db/models/sql/compiler.py

    diff --git a/django/db/models/sql/compiler.py b/django/db/models/sql/compiler.py
    index 7ea4cd7..61dd816 100644
    a b class SQLInsertCompiler(SQLCompiler):  
    897897            values = [[self.connection.ops.pk_default_value()] for obj in self.query.objs]
    898898            params = [[]]
    899899            fields = [None]
    900         can_bulk = (not any(hasattr(field, "get_placeholder") for field in fields) and
    901             not self.return_id and self.connection.features.has_bulk_insert)
    902 
    903         if can_bulk:
     900        can_bulk_and_not_return_ids = (not any(hasattr(field, "get_placeholder") for field in fields) and
     901                                      not self.return_id and self.connection.features.has_bulk_insert)
     902        # If not all of these conditions are met, fall back to doing a bunch of single inserts
     903        can_bulk_and_return_ids = self.return_id and self.connection.features.can_return_id_from_insert \
     904                                  and (not any(hasattr(field, "get_placeholder") for field in fields)) \
     905                                  and self.connection.features.has_bulk_insert
     906        if can_bulk_and_not_return_ids or can_bulk_and_return_ids:
    904907            placeholders = [["%s"] * len(fields)]
    905908        else:
    906909            placeholders = [
    class SQLInsertCompiler(SQLCompiler):  
    909912            ]
    910913            # Oracle Spatial needs to remove some values due to #10888
    911914            params = self.connection.ops.modify_insert_params(placeholders, params)
    912         if self.return_id and self.connection.features.can_return_id_from_insert:
     915        if can_bulk_and_return_ids:
     916            result.append(self.connection.ops.bulk_insert_sql(fields, len(values)))
     917            r_fmt, r_params = self.connection.ops.return_insert_id()
     918            if r_fmt:
     919                col = "%s.%s" % (qn(opts.db_table), qn(opts.pk.column))
     920                result.append(r_fmt % col)
     921                params += r_params
     922            return [(" ".join(result), tuple([v for val in values for v in val]))]
     923        elif self.return_id and self.connection.features.can_return_id_from_insert:
    913924            params = params[0]
    914925            col = "%s.%s" % (qn(opts.db_table), qn(opts.pk.column))
    915926            result.append("VALUES (%s)" % ", ".join(placeholders[0]))
    class SQLInsertCompiler(SQLCompiler):  
    920931                result.append(r_fmt % col)
    921932                params += r_params
    922933            return [(" ".join(result), tuple(params))]
    923         if can_bulk:
     934        elif can_bulk_and_not_return_ids:
    924935            result.append(self.connection.ops.bulk_insert_sql(fields, len(values)))
    925936            return [(" ".join(result), tuple([v for val in values for v in val]))]
    926937        else:
    class SQLInsertCompiler(SQLCompiler):  
    930941            ]
    931942
    932943    def execute_sql(self, return_id=False):
    933         assert not (return_id and len(self.query.objs) != 1)
     944        assert not (return_id and len(self.query.objs) != 1 and
     945                    not self.connection.features.can_return_id_from_insert)
    934946        self.return_id = return_id
    935947        cursor = self.connection.cursor()
    936948        for sql, params in self.as_sql():
    937949            cursor.execute(sql, params)
    938950        if not (return_id and cursor):
    939951            return
    940         if self.connection.features.can_return_id_from_insert:
     952        if self.connection.features.can_return_id_from_insert and len(self.query.objs) > 1:
     953            return self.connection.ops.fetch_returned_insert_ids(cursor)
     954        if self.connection.features.can_return_id_from_insert and len(self.query.objs) is 1:
    941955            return self.connection.ops.fetch_returned_insert_id(cursor)
    942956        return self.connection.ops.last_insert_id(cursor,
    943957                self.query.model._meta.db_table, self.query.model._meta.pk.column)
Back to Top