commit 660a401179d860f15d99f115ddd349038e7ade4c (HEAD, refs/heads/acrefoot-django-patch-1.6)
Author: acrefoot <acrefoot@zulip.com>
Date: Mon Jan 13 11:31:05 2014 -0500
bulk_create sets primary keys on created objects
Addresses #19527
diff --git a/django/db/backends/__init__.py b/django/db/backends/__init__.py
index 2dbb8b3..86ee001 100644
|
a
|
b
|
class BaseDatabaseOperations(object):
|
| 835 | 835 | """ |
| 836 | 836 | return cursor.fetchone()[0] |
| 837 | 837 | |
| | 838 | def fetch_returned_insert_ids(self, cursor): |
| | 839 | """ |
| | 840 | Given a cursor object that has just performed an INSERT...RETURNING |
| | 841 | statement into a table that has an auto-incrementing ID, returns the |
| | 842 | list of newly created IDs. |
| | 843 | """ |
| | 844 | return [item[0] for item in cursor.fetchall()] |
| | 845 | |
| 838 | 846 | def field_cast_sql(self, db_type, internal_type): |
| 839 | 847 | """ |
| 840 | 848 | Given a column type (e.g. 'BLOB', 'VARCHAR'), and an internal type |
diff --git a/django/db/models/query.py b/django/db/models/query.py
index 48d295c..c70de14 100644
|
a
|
b
|
class QuerySet(object):
|
| 400 | 400 | self._batched_insert(objs_with_pk, fields, batch_size) |
| 401 | 401 | if objs_without_pk: |
| 402 | 402 | fields = [f for f in fields if not isinstance(f, AutoField)] |
| 403 | | self._batched_insert(objs_without_pk, fields, batch_size) |
| | 403 | ids = self._batched_insert(objs_without_pk, fields, batch_size) |
| | 404 | for (obj_without_pk, pk) in zip(objs_without_pk, ids): |
| | 405 | obj_without_pk.pk = pk |
| 404 | 406 | |
| 405 | 407 | return objs |
| 406 | 408 | |
| … |
… |
class QuerySet(object):
|
| 928 | 930 | return |
| 929 | 931 | ops = connections[self.db].ops |
| 930 | 932 | batch_size = (batch_size or max(ops.bulk_batch_size(fields, objs), 1)) |
| | 933 | ret = [] |
| 931 | 934 | for batch in [objs[i:i + batch_size] |
| 932 | 935 | for i in range(0, len(objs), batch_size)]: |
| 933 | | self.model._base_manager._insert(batch, fields=fields, |
| 934 | | using=self.db) |
| | 936 | if connections[self.db].features.can_return_id_from_insert: |
| | 937 | if len(objs) > 1: |
| | 938 | ret.extend(self.model._base_manager._insert(batch, fields=fields, |
| | 939 | using=self.db, return_id=True)) |
| | 940 | else: |
| | 941 | ret.append(self.model._base_manager._insert(batch, fields=fields, |
| | 942 | using=self.db, return_id=True)) |
| | 943 | else: |
| | 944 | self.model._base_manager._insert(batch, fields=fields, |
| | 945 | using=self.db) |
| | 946 | return ret |
| 935 | 947 | |
| 936 | 948 | def _clone(self, klass=None, setup=False, **kwargs): |
| 937 | 949 | if klass is None: |
diff --git a/django/db/models/sql/compiler.py b/django/db/models/sql/compiler.py
index 41bba93..203804e 100644
|
a
|
b
|
class SQLInsertCompiler(SQLCompiler):
|
| 841 | 841 | values = [[self.connection.ops.pk_default_value()] for obj in self.query.objs] |
| 842 | 842 | params = [[]] |
| 843 | 843 | fields = [None] |
| 844 | | can_bulk = (not any(hasattr(field, "get_placeholder") for field in fields) and |
| 845 | | not self.return_id and self.connection.features.has_bulk_insert) |
| 846 | | |
| 847 | | if can_bulk: |
| | 844 | can_bulk_and_not_return_ids = (not any(hasattr(field, "get_placeholder") for field in fields) and |
| | 845 | not self.return_id and self.connection.features.has_bulk_insert) |
| | 846 | # If not all of these conditions are met, fall back to doing a bunch of single inserts |
| | 847 | can_bulk_and_return_ids = self.return_id and self.connection.features.can_return_id_from_insert \ |
| | 848 | and (not any(hasattr(field, "get_placeholder") for field in fields)) \ |
| | 849 | and self.connection.features.has_bulk_insert |
| | 850 | if can_bulk_and_not_return_ids or can_bulk_and_return_ids: |
| 848 | 851 | placeholders = [["%s"] * len(fields)] |
| 849 | 852 | else: |
| 850 | 853 | placeholders = [ |
| … |
… |
class SQLInsertCompiler(SQLCompiler):
|
| 853 | 856 | ] |
| 854 | 857 | # Oracle Spatial needs to remove some values due to #10888 |
| 855 | 858 | params = self.connection.ops.modify_insert_params(placeholders, params) |
| 856 | | if self.return_id and self.connection.features.can_return_id_from_insert: |
| | 859 | if can_bulk_and_return_ids: |
| | 860 | result.append(self.connection.ops.bulk_insert_sql(fields, len(values))) |
| | 861 | r_fmt, r_params = self.connection.ops.return_insert_id() |
| | 862 | if r_fmt: |
| | 863 | col = "%s.%s" % (qn(opts.db_table), qn(opts.pk.column)) |
| | 864 | result.append(r_fmt % col) |
| | 865 | params += r_params |
| | 866 | return [(" ".join(result), tuple([v for val in values for v in val]))] |
| | 867 | elif self.return_id and self.connection.features.can_return_id_from_insert: |
| 857 | 868 | params = params[0] |
| 858 | 869 | col = "%s.%s" % (qn(opts.db_table), qn(opts.pk.column)) |
| 859 | 870 | result.append("VALUES (%s)" % ", ".join(placeholders[0])) |
| … |
… |
class SQLInsertCompiler(SQLCompiler):
|
| 864 | 875 | result.append(r_fmt % col) |
| 865 | 876 | params += r_params |
| 866 | 877 | return [(" ".join(result), tuple(params))] |
| 867 | | if can_bulk: |
| | 878 | elif can_bulk_and_not_return_ids: |
| 868 | 879 | result.append(self.connection.ops.bulk_insert_sql(fields, len(values))) |
| 869 | 880 | return [(" ".join(result), tuple(v for val in values for v in val))] |
| 870 | 881 | else: |
| … |
… |
class SQLInsertCompiler(SQLCompiler):
|
| 874 | 885 | ] |
| 875 | 886 | |
| 876 | 887 | def execute_sql(self, return_id=False): |
| 877 | | assert not (return_id and len(self.query.objs) != 1) |
| | 888 | assert not (return_id and len(self.query.objs) != 1 and |
| | 889 | not self.connection.features.can_return_id_from_insert) |
| 878 | 890 | self.return_id = return_id |
| 879 | 891 | cursor = self.connection.cursor() |
| 880 | 892 | for sql, params in self.as_sql(): |
| 881 | 893 | cursor.execute(sql, params) |
| 882 | 894 | if not (return_id and cursor): |
| 883 | 895 | return |
| 884 | | if self.connection.features.can_return_id_from_insert: |
| | 896 | if self.connection.features.can_return_id_from_insert and len(self.query.objs) > 1: |
| | 897 | return self.connection.ops.fetch_returned_insert_ids(cursor) |
| | 898 | if self.connection.features.can_return_id_from_insert and len(self.query.objs) == 1: |
| 885 | 899 | return self.connection.ops.fetch_returned_insert_id(cursor) |
| 886 | 900 | return self.connection.ops.last_insert_id(cursor, |
| 887 | 901 | self.query.get_meta().db_table, self.query.get_meta().pk.column) |
diff --git a/docs/ref/models/querysets.txt b/docs/ref/models/querysets.txt
index e15e55d..aee4ddf 100644
|
a
|
b
|
This has a number of caveats though:
|
| 1624 | 1624 | ``post_save`` signals will not be sent. |
| 1625 | 1625 | * It does not work with child models in a multi-table inheritance scenario. |
| 1626 | 1626 | * If the model's primary key is an :class:`~django.db.models.AutoField` it |
| 1627 | | does not retrieve and set the primary key attribute, as ``save()`` does. |
| | 1627 | does not retrieve and set the primary key attribute, as ``save()`` does, |
| | 1628 | unless the db backend also has the feature ``can_return_id_from_insert`` |
| | 1629 | (currently ``postgresql_psycopg2``, ``oracle``). |
| 1628 | 1630 | |
| 1629 | 1631 | The ``batch_size`` parameter controls how many objects are created in single |
| 1630 | 1632 | query. The default is to create all objects in one batch, except for SQLite |
diff --git a/tests/bulk_create/tests.py b/tests/bulk_create/tests.py
index 367cbde..f9efa2e 100644
|
a
|
b
|
class BulkCreateTests(TestCase):
|
| 165 | 165 | TwoFields.objects.all().delete() |
| 166 | 166 | with self.assertNumQueries(1): |
| 167 | 167 | TwoFields.objects.bulk_create(objs, len(objs)) |
| | 168 | |
| | 169 | @skipUnlessDBFeature('can_return_id_from_insert') |
| | 170 | def test_set_pk_and_efficiency(self): |
| | 171 | countries = [] |
| | 172 | with self.assertNumQueries(1): |
| | 173 | countries = Country.objects.bulk_create(self.data) |
| | 174 | self.assertEqual(len(countries), 4) |
| | 175 | self.assertEqual(Country.objects.get(pk=countries[0].pk), countries[0]) |
| | 176 | self.assertEqual(Country.objects.get(pk=countries[1].pk), countries[1]) |
| | 177 | self.assertEqual(Country.objects.get(pk=countries[2].pk), countries[2]) |
| | 178 | self.assertEqual(Country.objects.get(pk=countries[3].pk), countries[3]) |