commit 660a401179d860f15d99f115ddd349038e7ade4c (HEAD, refs/heads/acrefoot-django-patch-1.6)
Author: acrefoot <acrefoot@zulip.com>
Date: Mon Jan 13 11:31:05 2014 -0500
bulk_create sets primary keys on created objects
Addresses #19527
diff --git a/django/db/backends/__init__.py b/django/db/backends/__init__.py
index 2dbb8b3..86ee001 100644
a
|
b
|
class BaseDatabaseOperations(object):
|
835 | 835 | """ |
836 | 836 | return cursor.fetchone()[0] |
837 | 837 | |
| 838 | def fetch_returned_insert_ids(self, cursor): |
| 839 | """ |
| 840 | Given a cursor object that has just performed an INSERT...RETURNING |
| 841 | statement into a table that has an auto-incrementing ID, returns the |
| 842 | list of newly created IDs. |
| 843 | """ |
| 844 | return [item[0] for item in cursor.fetchall()] |
| 845 | |
838 | 846 | def field_cast_sql(self, db_type, internal_type): |
839 | 847 | """ |
840 | 848 | Given a column type (e.g. 'BLOB', 'VARCHAR'), and an internal type |
diff --git a/django/db/models/query.py b/django/db/models/query.py
index 48d295c..c70de14 100644
a
|
b
|
class QuerySet(object):
|
400 | 400 | self._batched_insert(objs_with_pk, fields, batch_size) |
401 | 401 | if objs_without_pk: |
402 | 402 | fields = [f for f in fields if not isinstance(f, AutoField)] |
403 | | self._batched_insert(objs_without_pk, fields, batch_size) |
| 403 | ids = self._batched_insert(objs_without_pk, fields, batch_size) |
| 404 | for (obj_without_pk, pk) in zip(objs_without_pk, ids): |
| 405 | obj_without_pk.pk = pk |
404 | 406 | |
405 | 407 | return objs |
406 | 408 | |
… |
… |
class QuerySet(object):
|
928 | 930 | return |
929 | 931 | ops = connections[self.db].ops |
930 | 932 | batch_size = (batch_size or max(ops.bulk_batch_size(fields, objs), 1)) |
| 933 | ret = [] |
931 | 934 | for batch in [objs[i:i + batch_size] |
932 | 935 | for i in range(0, len(objs), batch_size)]: |
933 | | self.model._base_manager._insert(batch, fields=fields, |
934 | | using=self.db) |
| 936 | if connections[self.db].features.can_return_id_from_insert: |
| 937 | if len(objs) > 1: |
| 938 | ret.extend(self.model._base_manager._insert(batch, fields=fields, |
| 939 | using=self.db, return_id=True)) |
| 940 | else: |
| 941 | ret.append(self.model._base_manager._insert(batch, fields=fields, |
| 942 | using=self.db, return_id=True)) |
| 943 | else: |
| 944 | self.model._base_manager._insert(batch, fields=fields, |
| 945 | using=self.db) |
| 946 | return ret |
935 | 947 | |
936 | 948 | def _clone(self, klass=None, setup=False, **kwargs): |
937 | 949 | if klass is None: |
diff --git a/django/db/models/sql/compiler.py b/django/db/models/sql/compiler.py
index 41bba93..203804e 100644
a
|
b
|
class SQLInsertCompiler(SQLCompiler):
|
841 | 841 | values = [[self.connection.ops.pk_default_value()] for obj in self.query.objs] |
842 | 842 | params = [[]] |
843 | 843 | fields = [None] |
844 | | can_bulk = (not any(hasattr(field, "get_placeholder") for field in fields) and |
845 | | not self.return_id and self.connection.features.has_bulk_insert) |
846 | | |
847 | | if can_bulk: |
| 844 | can_bulk_and_not_return_ids = (not any(hasattr(field, "get_placeholder") for field in fields) and |
| 845 | not self.return_id and self.connection.features.has_bulk_insert) |
| 846 | # If not all of these conditions are met, fall back to doing a bunch of single inserts |
| 847 | can_bulk_and_return_ids = self.return_id and self.connection.features.can_return_id_from_insert \ |
| 848 | and (not any(hasattr(field, "get_placeholder") for field in fields)) \ |
| 849 | and self.connection.features.has_bulk_insert |
| 850 | if can_bulk_and_not_return_ids or can_bulk_and_return_ids: |
848 | 851 | placeholders = [["%s"] * len(fields)] |
849 | 852 | else: |
850 | 853 | placeholders = [ |
… |
… |
class SQLInsertCompiler(SQLCompiler):
|
853 | 856 | ] |
854 | 857 | # Oracle Spatial needs to remove some values due to #10888 |
855 | 858 | params = self.connection.ops.modify_insert_params(placeholders, params) |
856 | | if self.return_id and self.connection.features.can_return_id_from_insert: |
| 859 | if can_bulk_and_return_ids: |
| 860 | result.append(self.connection.ops.bulk_insert_sql(fields, len(values))) |
| 861 | r_fmt, r_params = self.connection.ops.return_insert_id() |
| 862 | if r_fmt: |
| 863 | col = "%s.%s" % (qn(opts.db_table), qn(opts.pk.column)) |
| 864 | result.append(r_fmt % col) |
| 865 | params += r_params |
| 866 | return [(" ".join(result), tuple([v for val in values for v in val]))] |
| 867 | elif self.return_id and self.connection.features.can_return_id_from_insert: |
857 | 868 | params = params[0] |
858 | 869 | col = "%s.%s" % (qn(opts.db_table), qn(opts.pk.column)) |
859 | 870 | result.append("VALUES (%s)" % ", ".join(placeholders[0])) |
… |
… |
class SQLInsertCompiler(SQLCompiler):
|
864 | 875 | result.append(r_fmt % col) |
865 | 876 | params += r_params |
866 | 877 | return [(" ".join(result), tuple(params))] |
867 | | if can_bulk: |
| 878 | elif can_bulk_and_not_return_ids: |
868 | 879 | result.append(self.connection.ops.bulk_insert_sql(fields, len(values))) |
869 | 880 | return [(" ".join(result), tuple(v for val in values for v in val))] |
870 | 881 | else: |
… |
… |
class SQLInsertCompiler(SQLCompiler):
|
874 | 885 | ] |
875 | 886 | |
876 | 887 | def execute_sql(self, return_id=False): |
877 | | assert not (return_id and len(self.query.objs) != 1) |
| 888 | assert not (return_id and len(self.query.objs) != 1 and |
| 889 | not self.connection.features.can_return_id_from_insert) |
878 | 890 | self.return_id = return_id |
879 | 891 | cursor = self.connection.cursor() |
880 | 892 | for sql, params in self.as_sql(): |
881 | 893 | cursor.execute(sql, params) |
882 | 894 | if not (return_id and cursor): |
883 | 895 | return |
884 | | if self.connection.features.can_return_id_from_insert: |
| 896 | if self.connection.features.can_return_id_from_insert and len(self.query.objs) > 1: |
| 897 | return self.connection.ops.fetch_returned_insert_ids(cursor) |
| 898 | if self.connection.features.can_return_id_from_insert and len(self.query.objs) == 1: |
885 | 899 | return self.connection.ops.fetch_returned_insert_id(cursor) |
886 | 900 | return self.connection.ops.last_insert_id(cursor, |
887 | 901 | self.query.get_meta().db_table, self.query.get_meta().pk.column) |
diff --git a/docs/ref/models/querysets.txt b/docs/ref/models/querysets.txt
index e15e55d..aee4ddf 100644
a
|
b
|
This has a number of caveats though:
|
1624 | 1624 | ``post_save`` signals will not be sent. |
1625 | 1625 | * It does not work with child models in a multi-table inheritance scenario. |
1626 | 1626 | * If the model's primary key is an :class:`~django.db.models.AutoField` it |
1627 | | does not retrieve and set the primary key attribute, as ``save()`` does. |
| 1627 | does not retrieve and set the primary key attribute, as ``save()`` does, |
| 1628 | unless the db backend also has the feature ``can_return_id_from_insert`` |
| 1629 | (currently ``postgresql_psycopg2``, ``oracle``). |
1628 | 1630 | |
1629 | 1631 | The ``batch_size`` parameter controls how many objects are created in single |
1630 | 1632 | query. The default is to create all objects in one batch, except for SQLite |
diff --git a/tests/bulk_create/tests.py b/tests/bulk_create/tests.py
index 367cbde..f9efa2e 100644
a
|
b
|
class BulkCreateTests(TestCase):
|
165 | 165 | TwoFields.objects.all().delete() |
166 | 166 | with self.assertNumQueries(1): |
167 | 167 | TwoFields.objects.bulk_create(objs, len(objs)) |
| 168 | |
| 169 | @skipUnlessDBFeature('can_return_id_from_insert') |
| 170 | def test_set_pk_and_efficiency(self): |
| 171 | countries = [] |
| 172 | with self.assertNumQueries(1): |
| 173 | countries = Country.objects.bulk_create(self.data) |
| 174 | self.assertEqual(len(countries), 4) |
| 175 | self.assertEqual(Country.objects.get(pk=countries[0].pk), countries[0]) |
| 176 | self.assertEqual(Country.objects.get(pk=countries[1].pk), countries[1]) |
| 177 | self.assertEqual(Country.objects.get(pk=countries[2].pk), countries[2]) |
| 178 | self.assertEqual(Country.objects.get(pk=countries[3].pk), countries[3]) |