Ticket #17788: batch_bulk_insert.diff

File batch_bulk_insert.diff, 9.2 KB (added by akaariai, 3 years ago)
  • django/db/backends/__init__.py

    diff --git a/django/db/backends/__init__.py b/django/db/backends/__init__.py
    index 7674f5c..48ad30c 100644
    a b class BaseDatabaseOperations(object): 
    858858        conn = ' %s ' % connector
    859859        return conn.join(sub_expressions)
    860860
     861    def get_batch_size(self, fields, objs):
     862        """
     863        Returns the maximum allowed batch size for the backend. The fields
     864        will be repeated for each object in the batch. You can use the fields
     865        together with the objs to approximate the size of the query.
     866
     867        Use an arbitrarily picked value of 100000 for default batch size.
     868        Above this amount backends will have problems with either query string
     869        size, or they will use excessive amount of memory.
     870        """
     871        return 100000
     872
    861873class BaseDatabaseIntrospection(object):
    862874    """
    863875    This class encapsulates all backend-specific introspection utilities
  • django/db/backends/sqlite3/base.py

    diff --git a/django/db/backends/sqlite3/base.py b/django/db/backends/sqlite3/base.py
    index 0b19442..0fa1a27 100644
    a b class DatabaseOperations(BaseDatabaseOperations): 
    211211        res.extend(["UNION SELECT %s" % ", ".join(["%s"] * len(fields))] * (num_values - 1))
    212212        return " ".join(res)
    213213
     214    def get_batch_size(self, fields, objs):
     215        """
     216        SQLite has a limit of 1000 variables per query. Intentionally
     217        leave a couple of unused variables.
     218        """
     219        return max(900 / len(fields), 1)
     220
    214221class DatabaseWrapper(BaseDatabaseWrapper):
    215222    vendor = 'sqlite'
    216223    # SQLite requires LIKE statements to include an ESCAPE clause if the value
  • django/db/models/query.py

    diff --git a/django/db/models/query.py b/django/db/models/query.py
    index 08fda78..c567b18 100644
    a b class QuerySet(object): 
    386386        """
    387387        # So this case is fun. When you bulk insert you don't get the primary
    388388        # keys back (if it's an autoincrement), so you can't insert into the
    389         # child tables which references this. There are two workarounds, 1)
     389        # child tables which references this. There are three workarounds, 1)
    390390        # this could be implemented if you didn't have an autoincrement pk,
    391391        # and 2) you could do it by doing O(n) normal inserts into the parent
    392392        # tables to get the primary keys back, and then doing a single bulk
    393         # insert into the childmost table. We're punting on these for now
    394         # because they are relatively rare cases.
     393        # insert into the childmost table. 3) you could do it if the database
     394        # implements RETURNING id queries (Oracle, PostgreSQL).  We're punting
     395        # on these for now because they are relatively rare cases. Still,
     396        # fixing this would allow "generic" use of bulk create.
    395397        if self.model._meta.parents:
    396398            raise ValueError("Can't bulk create an inherited model")
    397399        if not objs:
    class QuerySet(object): 
    399401        self._for_write = True
    400402        connection = connections[self.db]
    401403        fields = self.model._meta.local_fields
     404        def batched_insert(objs, fields, using):
     405            if not objs:
     406                return
     407            batch_size = connection.ops.get_batch_size(fields, objs)
     408            self.model._base_manager._insert(objs[0:batch_size], fields=fields, using=self.db)
     409            batched_insert(objs[batch_size:], fields, using)
     410
    402411        if not transaction.is_managed(using=self.db):
    403412            transaction.enter_transaction_management(using=self.db)
    404413            forced_managed = True
    class QuerySet(object): 
    407416        try:
    408417            if (connection.features.can_combine_inserts_with_and_without_auto_increment_pk
    409418                and self.model._meta.has_auto_field):
    410                 self.model._base_manager._insert(objs, fields=fields, using=self.db)
     419                batched_insert(objs, fields, self.db)
    411420            else:
    412421                objs_with_pk, objs_without_pk = partition(lambda o: o.pk is None, objs)
    413422                if objs_with_pk:
    414                     self.model._base_manager._insert(objs_with_pk, fields=fields, using=self.db)
     423                    batched_insert(objs_with_pk, fields, self.db)
    415424                if objs_without_pk:
    416                     self.model._base_manager._insert(objs_without_pk, fields=[f for f in fields if not isinstance(f, AutoField)], using=self.db)
     425                    batched_insert(objs_without_pk, [f for f in fields if not isinstance(f, AutoField)],
     426                                   self.db)
    417427            if forced_managed:
    418428                transaction.commit(using=self.db)
    419429            else:
  • tests/regressiontests/bulk_create/models.py

    diff --git a/tests/regressiontests/bulk_create/models.py b/tests/regressiontests/bulk_create/models.py
    index a4c611d..db72b8b 100644
    a b class Pizzeria(Restaurant): 
    1818    pass
    1919
    2020class State(models.Model):
    21     two_letter_code = models.CharField(max_length=2, primary_key=True)
    22  No newline at end of file
     21    two_letter_code = models.CharField(max_length=2, primary_key=True)
     22
     23class TwoFields(models.Model):
     24    f1 = models.CharField(max_length=100)
     25    f2 = models.CharField(max_length=100)
  • tests/regressiontests/bulk_create/tests.py

    diff --git a/tests/regressiontests/bulk_create/tests.py b/tests/regressiontests/bulk_create/tests.py
    index 0fa142b..5faaed8 100644
    a b from __future__ import with_statement, absolute_import 
    22
    33from operator import attrgetter
    44
    5 from django.test import TestCase, skipUnlessDBFeature
     5from django.db import connection
     6from django.test import TestCase, skipUnlessDBFeature, skipIfDBFeature
     7from django.test.utils import override_settings
    68
    7 from .models import Country, Restaurant, Pizzeria, State
     9from .models import Country, Restaurant, Pizzeria, State, TwoFields
    810
    911
    1012class BulkCreateTests(TestCase):
    class BulkCreateTests(TestCase): 
    5658            ])
    5759        self.assertQuerysetEqual(State.objects.order_by("two_letter_code"), [
    5860            "CA", "IL", "ME", "NY",
    59         ], attrgetter("two_letter_code"))
    60  No newline at end of file
     61        ], attrgetter("two_letter_code"))
     62
     63    def test_large_batch(self):
     64        with override_settings(DEBUG=True):
     65            Restaurant.objects.bulk_create([
     66                   Restaurant(name=str(i)) for i in range(0, 1001)
     67                ])
     68            self.assertTrue(len(connection.queries) < 10)
     69        self.assertEqual(Restaurant.objects.count(), 1001)
     70        self.assertQuerysetEqual(
     71            Restaurant.objects.filter(name__in=['999', '1000', '1001']).order_by('name'),
     72            ['1000', '999'], attrgetter('name'))
     73        # We happen to know SQLite is going to do the cut at 900 objs. Test that boundary.
     74        self.assertQuerysetEqual(
     75            Restaurant.objects.filter(name__in=['899', '900', '901']).order_by('name'),
     76            ['899', '900', '901'], attrgetter('name'))
     77
     78    def test_large_batch_multifield(self):
     79        with override_settings(DEBUG=True):
     80            TwoFields.objects.bulk_create([
     81                   TwoFields(f1=str(i), f2=str(i)) for i in range(0, 1001)
     82                ])
     83            self.assertTrue(len(connection.queries) < 10)
     84        self.assertEqual(TwoFields.objects.count(), 1001)
     85        # Test boundaries
     86        self.assertQuerysetEqual(
     87            TwoFields.objects.filter(f1__in=['0', '1']).order_by('f1'),
     88            ['0', '1'], attrgetter('f1'))
     89        self.assertQuerysetEqual(
     90            TwoFields.objects.filter(f1__in=['500', '501']).order_by('f1'),
     91            ['500', '501'], attrgetter('f1'))
     92        self.assertQuerysetEqual(
     93            TwoFields.objects.filter(f1__in=['999', '1000', '1001']).order_by('f1'),
     94            ['1000', '999'], attrgetter('f1'))
     95
     96    @skipIfDBFeature('can_combine_inserts_with_and_without_auto_increment_pk')
     97    def test_lage_batch_mixed(self):
     98        # mixed pk + non-pk.
     99        with override_settings(DEBUG=True):
     100            TwoFields.objects.bulk_create([
     101                   TwoFields(id=i if i % 2 == 0 else None, f1=str(i), f2=str(i)) for i in range(1000, 2001)
     102                ])
     103            self.assertTrue(len(connection.queries) < 10)
     104        self.assertEqual(TwoFields.objects.count(), 1001)
     105        self.assertQuerysetEqual(TwoFields.objects.filter(id__in=[1998, 1999, 2000]).order_by('id'), [
     106            1998, 2000
     107        ], attrgetter('id'))
  • tests/regressiontests/queries/tests.py

    diff --git a/tests/regressiontests/queries/tests.py b/tests/regressiontests/queries/tests.py
    index ded3e8f..ed71be8 100644
    a b class ConditionalTests(BaseQuerysetTest): 
    18071807        # Test that the "in" lookup works with lists of 1000 items or more.
    18081808        Number.objects.all().delete()
    18091809        numbers = range(2500)
    1810         for num in numbers:
    1811             _ = Number.objects.create(num=num)
     1810        Number.objects.bulk_create(Number(num=num) for num in numbers)
    18121811        self.assertEqual(
    18131812            Number.objects.filter(num__in=numbers[:1000]).count(),
    18141813            1000
Back to Top