Ticket #28668: on_conflict_ignore.patch

File on_conflict_ignore.patch, 43.0 KB (added by Дилян Палаузов, 7 years ago)

The on_conflict_ignore.patch adds on_conflict='ignore' to QuerySet.bulk_create and empowers bulk_create optionally to send post_save signals, when the ID of the inserted object is known to bulk_create. For postgresql bulk_create offers in addition to retrieve the IDs of the newly inserted objects, when using on_conflict='ignore' and, with a second query, to find the PKs pf the supplied objs that were already in the databse. When supported by the backend (so not Oracle), on_conflict='ignore' is added to bulk_create() in django.db.models.fields.create_forward_many_to_many_manager.ManyRelatedManager._add_items, making ManyRelatedManager.add() thread-safe, possibly resolving #19544.

  • django/db/backends/base/features.py

    diff --git a/django/db/backends/base/features.py b/django/db/backends/base/features.py
    a b class BaseDatabaseFeatures:  
    250250    # Does the backend support keyword parameters for cursor.callproc()?
    251251    supports_callproc_kwargs = False
    252252
     253    # Does the backend support ignoring constraint or uniqueness errors during
     254    # inserting?
     255    supports_on_conflict_ignore = True
     256
    253257    def __init__(self, connection):
    254258        self.connection = connection
    255259
  • django/db/backends/base/operations.py

    diff --git a/django/db/backends/base/operations.py b/django/db/backends/base/operations.py
    a b class BaseDatabaseOperations:  
    652652
    653653    def window_frame_range_start_end(self, start=None, end=None):
    654654        return self.window_frame_rows_start_end(start, end)
     655
     656    def insert_statement(self, on_conflict=None):
     657        return 'INSERT INTO'
     658
     659    def on_conflict_postfix(self, on_conflict=None):
     660        return ''
  • django/db/backends/mysql/operations.py

    diff --git a/django/db/backends/mysql/operations.py b/django/db/backends/mysql/operations.py
    a b class DatabaseOperations(BaseDatabaseOperations):  
    269269            ) % {'lhs': lhs_sql, 'rhs': rhs_sql}, lhs_params * 2 + rhs_params * 2
    270270        else:
    271271            return "TIMESTAMPDIFF(MICROSECOND, %s, %s)" % (rhs_sql, lhs_sql), rhs_params + lhs_params
     272
     273    def insert_statement(self, on_conflict=None):
     274        if on_conflict == 'ignore':
     275            return 'INSERT IGNORE INTO'
     276
     277        return super().insert_statement(on_conflict)
  • django/db/backends/oracle/features.py

    diff --git a/django/db/backends/oracle/features.py b/django/db/backends/oracle/features.py
    a b class DatabaseFeatures(BaseDatabaseFeatures):  
    5454    """
    5555    supports_callproc_kwargs = True
    5656    supports_over_clause = True
     57    supports_on_conflict_ignore = False
    5758    max_query_params = 2**16 - 1
  • new file django/db/backends/postgresql/compiler.py

    diff --git a/django/db/backends/postgresql/compiler.py b/django/db/backends/postgresql/compiler.py
    new file mode 100644
    - +  
     1from django.db.models.sql import compiler
     2
     3
     4SQLCompiler = compiler.SQLCompiler
     5SQLDeleteCompiler = compiler.SQLDeleteCompiler
     6SQLUpdateCompiler = compiler.SQLUpdateCompiler
     7SQLAggregateCompiler = compiler.SQLAggregateCompiler
     8
     9
     10class SQLInsertCompiler(compiler.SQLInsertCompiler):
     11    def as_sql(self):
     12        """
     13        Create queries that work like ``INSERT INTO .. ON CONFLICT DO NOTHING RETURNUNG *``
     14        but return the same amount of rows as in the input, setting ``NULL`` in place of the
     15        primary key on already existing rows.  The cited query does not return anything for
     16        rows that were already in the database.  The drawback is that the postgresql-sequence
     17        counter increases everytime with the numbers of rows in the input, irrespective of the
     18        amount of actually inserted rows.  Requires PostgreSQL >= 9.5.
     19
     20        This creates a query like:
     21
     22        .. code-block:: sql
     23
     24            WITH
     25                r AS (SELECT * FROM (VALUES (...), (...) -- the ellipses are substituted with the values to be inserted
     26                    ) AS g(...)), -- the ellipsis is substituted with the corresponding column names, excluding the PK
     27                s AS (INSERT INTO (table name) (...) -- the same field names
     28                    SELECT * FROM r ON CONFLICT DO NOTHING RETURNING *)
     29            SELECT s.pk FROM r LEFT JOIN s USING (...); -- again the same field names, but this time exlcuding
     30            -- the fields ignored for the comparison in all_ids
     31
     32        ``r`` is a table containing the values that are going to be inserted having the correct column names.  It does
     33        not contain primary keys.  ``s`` is a table containing the rows that could successfully be inserted without
     34        conflicts.  The rows in ``s`` have the primary key set.  The final ``SELECT`` matches all values that were
     35        supposed to be inserted with the values that were actually inserted.  It creates a table having as much
     36        elements as ``r``, but the primary key is set only on the rows that were inserted.  In the remaining rows,
     37        that existed before the query, the primary key is not set.
     38
     39        There seems to be no simpler way to be able to achieve what the first sentance of this docstring say.
     40        """
     41        fields = self.query.fields
     42        if (
     43                self.return_id and fields and self.connection.features.is_postgresql_9_5 and
     44                self.query.on_conflict == 'ignore'
     45        ):
     46            qn = self.quote_name_unless_alias
     47            opts = self.query.get_meta()
     48            if isinstance(self.return_id, list):
     49                compare_columns = [qn(field.column) for field in fields if field.column not in self.return_id]
     50            else:
     51                compare_columns = [qn(field.column) for field in fields]
     52            return [("WITH r AS (SELECT * FROM(VALUES (" + "),(".join(
     53                ",".join("%s" for f in fields) for obj in self.query.objs
     54            ) + ")) AS g(" + ",".join(qn(field.column) for field in fields) + "))," +
     55                " s AS (INSERT INTO " + qn(opts.db_table) + " (" + ", ".join(
     56                    qn(field.column) for field in fields) +
     57                ") SELECT * FROM r ON CONFLICT DO NOTHING RETURNING *) SELECT s." +
     58                qn(opts.pk.column) + " FROM r LEFT JOIN s USING (" + ", ".join(compare_columns) + ")",
     59                tuple(p for ps in self.assemble_as_sql(fields, [
     60                    [self.prepare_value(field, self.pre_save_val(
     61                        field, obj)) for field in fields] for obj in self.query.objs
     62                ])[1] for p in ps))]
     63        return super().as_sql()
  • django/db/backends/postgresql/features.py

    diff --git a/django/db/backends/postgresql/features.py b/django/db/backends/postgresql/features.py
    a b class DatabaseFeatures(BaseDatabaseFeatures):  
    5959    has_brin_index_support = is_postgresql_9_5
    6060    has_jsonb_agg = is_postgresql_9_5
    6161    has_gin_pending_list_limit = is_postgresql_9_5
     62    supports_on_conflict_ignore = is_postgresql_9_5
  • django/db/backends/postgresql/operations.py

    diff --git a/django/db/backends/postgresql/operations.py b/django/db/backends/postgresql/operations.py
    a b from django.db.backends.base.operations import BaseDatabaseOperations  
    77
    88class DatabaseOperations(BaseDatabaseOperations):
    99    cast_char_field_without_max_length = 'varchar'
     10    compiler_module = "django.db.backends.postgresql.compiler"
    1011
    1112    def unification_cast_sql(self, output_field):
    1213        internal_type = output_field.get_internal_type()
    class DatabaseOperations(BaseDatabaseOperations):  
    258259                'and FOLLOWING.'
    259260            )
    260261        return start_, end_
     262
     263    def on_conflict_postfix(self, on_conflict=None):
     264        if on_conflict == 'ignore':
     265            return 'ON CONFLICT DO NOTHING'
     266
     267        return super().on_conflict_postfix(on_conflict)
  • django/db/backends/sqlite3/operations.py

    diff --git a/django/db/backends/sqlite3/operations.py b/django/db/backends/sqlite3/operations.py
    a b class DatabaseOperations(BaseDatabaseOperations):  
    296296        if internal_type == 'TimeField':
    297297            return "django_time_diff(%s, %s)" % (lhs_sql, rhs_sql), lhs_params + rhs_params
    298298        return "django_timestamp_diff(%s, %s)" % (lhs_sql, rhs_sql), lhs_params + rhs_params
     299
     300    def insert_statement(self, on_conflict=None):
     301        if on_conflict == 'ignore':
     302            return 'INSERT OR IGNORE INTO'
     303
     304        return super().insert_statement(on_conflict)
  • django/db/models/fields/related_descriptors.py

    diff --git a/django/db/models/fields/related_descriptors.py b/django/db/models/fields/related_descriptors.py
    a b def create_forward_many_to_many_manager(superclass, rel, reverse):  
    10851085                            '%s_id' % target_field_name: obj_id,
    10861086                        })
    10871087                        for obj_id in new_ids
    1088                     ])
     1088                    ], on_conflict='ignore' if connections[db].features.supports_on_conflict_ignore else None)
    10891089
    10901090                    if self.reverse or source_field_name == self.source_field_name:
    10911091                        # Don't send the signal when we are inserting the
  • django/db/models/query.py

    diff --git a/django/db/models/query.py b/django/db/models/query.py
    a b  
    11"""
    22The main QuerySet implementation. This provides the public API for the ORM.
    33"""
    4 
    54import copy
     5import functools
    66import operator
    77import warnings
    88from collections import OrderedDict, namedtuple
    from functools import lru_cache  
    1010
    1111from django.conf import settings
    1212from django.core import exceptions
     13from django.contrib.postgres.fields import CIText
    1314from django.db import (
    1415    DJANGO_VERSION_PICKLE_KEY, IntegrityError, connections, router,
    1516    transaction,
    1617)
    17 from django.db.models import DateField, DateTimeField, sql
     18from django.db.models import DateField, DateTimeField, signals, sql
    1819from django.db.models.constants import LOOKUP_SEP
    1920from django.db.models.deletion import Collector
    2021from django.db.models.expressions import F
    from django.db.models.fields import AutoField  
    2223from django.db.models.functions import Trunc
    2324from django.db.models.query_utils import FilteredRelation, InvalidQuery, Q
    2425from django.db.models.sql.constants import CURSOR, GET_ITERATOR_CHUNK_SIZE
     26from django.db.utils import NotSupportedError
    2527from django.utils import timezone
    2628from django.utils.deprecation import RemovedInDjango30Warning
    2729from django.utils.functional import cached_property, partition
    class QuerySet:  
    417419            if obj.pk is None:
    418420                obj.pk = obj._meta.pk.get_pk_value_on_save(obj)
    419421
    420     def bulk_create(self, objs, batch_size=None):
     422    def bulk_create(self, objs, batch_size=None, on_conflict=None, send_post_save=False, all_ids=None):
    421423        """
    422424        Insert each of the instances into the database. Do *not* call
    423         save() on each of the instances, do not send any pre/post_save
     425        save() on each of the instances, do not send any pre_save
    424426        signals, and do not set the primary key attribute if it is an
    425         autoincrement field (except if features.can_return_ids_from_bulk_insert=True).
    426         Multi-table models are not supported.
     427        autoincrement field (except if features.can_return_ids_from_bulk_insert=True, or both all_ids is not None and
     428        postgresql >= 9.5 is used). Multi-table models are not supported.
    427429        """
    428430        # When you bulk insert you don't get the primary keys back (if it's an
    429431        # autoincrement, except if can_return_ids_from_bulk_insert=True), so
    class QuerySet:  
    447449                raise ValueError("Can't bulk create a multi-table inherited model")
    448450        if not objs:
    449451            return objs
    450         self._for_write = True
    451452        connection = connections[self.db]
     453        if on_conflict:
     454            on_conflict = on_conflict.lower()
     455            if on_conflict != 'ignore':
     456                raise ValueError("'%s' is an invalid value for on_conflict. Allowed values: 'ignore'" % on_conflict)
     457            if not connections[self.db].features.supports_on_conflict_ignore:
     458                raise NotSupportedError('This database backend does not support ON CONFLICT IGNORE')
     459            if all_ids is not None and not getattr(connection.features, 'is_postgresql_9_5', False):
     460                raise NotSupportedError('all_ids can be set only when Postgresql >= 9.5 is used')
     461        elif all_ids is not None:
     462            raise ValueError('all_ids can be used only with on_conflict')
     463
     464        self._for_write = True
    452465        fields = self.model._meta.concrete_fields
    453466        objs = list(objs)
    454467        self._populate_pk_values(objs)
    455468        with transaction.atomic(using=self.db, savepoint=False):
    456469            objs_with_pk, objs_without_pk = partition(lambda o: o.pk is None, objs)
    457470            if objs_with_pk:
    458                 self._batched_insert(objs_with_pk, fields, batch_size)
     471                self._batched_insert(objs_with_pk, fields, batch_size, on_conflict=on_conflict, return_id=False)
    459472            if objs_without_pk:
    460473                fields = [f for f in fields if not isinstance(f, AutoField)]
    461                 ids = self._batched_insert(objs_without_pk, fields, batch_size)
    462                 if connection.features.can_return_ids_from_bulk_insert:
    463                     assert len(ids) == len(objs_without_pk)
     474                return_id = (
     475                    connection.features.can_return_ids_from_bulk_insert and not on_conflict or
     476                    getattr(connection.features, 'is_postgresql_9_5', False) and all_ids is not None
     477                )
     478                if return_id and isinstance(all_ids, list):  # stores the fields, that shall be ...
     479                    return_id = all_ids   # ... ignored when comparing objects for equality
     480                ids = self._batched_insert(objs_without_pk, fields, batch_size, on_conflict=on_conflict,
     481                                           return_id=return_id)
     482                if (
     483                    connection.features.can_return_ids_from_bulk_insert and on_conflict != 'ignore' or
     484                    getattr(connection.features, 'is_postgresql_9_5', False) and (
     485                        all_ids is True or isinstance(all_ids, list))
     486                ):
     487                        assert len(ids) == len(objs_without_pk)
    464488                for obj_without_pk, pk in zip(objs_without_pk, ids):
    465489                    obj_without_pk.pk = pk
    466490                    obj_without_pk._state.adding = False
    467491                    obj_without_pk._state.db = self.db
    468492
     493            else:
     494                return_id = False
     495        if send_post_save or return_id:
     496            objs_with_pk, objs_without_pk = partition(lambda o: o.pk is None, objs_without_pk + objs_with_pk)
     497            if send_post_save:
     498                already_sent = set()  # In case objs contains the same element twice
     499                for obj in objs_with_pk:
     500                    if obj.pk not in already_sent:
     501                        signals.post_save.send(sender=obj.__class__, instance=obj,
     502                                               created=True, raw=False, using=self.db)
     503                    already_sent.add(obj.pk)
     504
     505            if all_ids and objs_without_pk and getattr(connection.features, 'is_postgresql_9_5', False):
     506                all_ids = [] if all_ids is True else all_ids
     507                # f.attname in obj.__dict__ and f!= obj._meta.pk means the field is not deferred and is not primary key
     508                obj0 = objs_without_pk[0]
     509                fields = [f.attname for f in obj0._meta.concrete_fields if f.attname
     510                          in obj0.__dict__ and f != obj0._meta.pk and f.attname not in all_ids]
     511                q = [Q(**{f.attname: getattr(obj, f.attname) for f in obj._meta.concrete_fields if f.attname in
     512                          obj.__dict__ and f != obj._meta.pk and f.attname not in all_ids}) for obj in objs_without_pk]
     513                if q:
     514                    output = self.filter(functools.reduce(Q.__or__, q)).values(*fields, obj0._meta.pk.attname)
     515                    for obj in objs_without_pk:
     516                        for o in output:
     517                            if all((getattr(obj, f).lower() == o[f].lower()) if isinstance(
     518                                    obj._meta.get_field(f), CIText) else (getattr(obj, f) == o[f]) for f in fields):
     519                                obj.pk = o[obj0._meta.pk.attname]
     520                                break
    469521        return objs
    470522
    471523    def get_or_create(self, defaults=None, **kwargs):
    class QuerySet:  
    11081160    # PRIVATE METHODS #
    11091161    ###################
    11101162
    1111     def _insert(self, objs, fields, return_id=False, raw=False, using=None):
     1163    def _insert(self, objs, fields, return_id=False, raw=False, using=None, on_conflict=None):
    11121164        """
    11131165        Insert a new record for the given model. This provides an interface to
    11141166        the InsertQuery class and is how Model.save() is implemented.
    class QuerySet:  
    11161168        self._for_write = True
    11171169        if using is None:
    11181170            using = self.db
    1119         query = sql.InsertQuery(self.model)
     1171        query = sql.InsertQuery(self.model, on_conflict=on_conflict)
    11201172        query.insert_values(fields, objs, raw=raw)
    11211173        return query.get_compiler(using=using).execute_sql(return_id)
    11221174    _insert.alters_data = True
    11231175    _insert.queryset_only = False
    11241176
    1125     def _batched_insert(self, objs, fields, batch_size):
     1177    def _batched_insert(self, objs, fields, batch_size, on_conflict, return_id):
    11261178        """
    11271179        Helper method for bulk_create() to insert objs one batch at a time.
    11281180        """
    1129         ops = connections[self.db].ops
    1130         batch_size = (batch_size or max(ops.bulk_batch_size(fields, objs), 1))
     1181        batch_size = batch_size or max(connections[self.db].ops.bulk_batch_size(fields, objs), 1)
    11311182        inserted_ids = []
    11321183        for item in [objs[i:i + batch_size] for i in range(0, len(objs), batch_size)]:
    1133             if connections[self.db].features.can_return_ids_from_bulk_insert:
    1134                 inserted_id = self._insert(item, fields=fields, using=self.db, return_id=True)
     1184            if return_id:
     1185                inserted_id = self._insert(item, fields=fields, using=self.db,
     1186                                           return_id=return_id, on_conflict=on_conflict)
    11351187                if isinstance(inserted_id, list):
    11361188                    inserted_ids.extend(inserted_id)
    11371189                else:
    11381190                    inserted_ids.append(inserted_id)
    11391191            else:
    1140                 self._insert(item, fields=fields, using=self.db)
     1192                self._insert(item, fields=fields, using=self.db, on_conflict=on_conflict)
    11411193        return inserted_ids
    11421194
    11431195    def _chain(self, **kwargs):
  • django/db/models/sql/compiler.py

    diff --git a/django/db/models/sql/compiler.py b/django/db/models/sql/compiler.py
    a b class SQLInsertCompiler(SQLCompiler):  
    12091209        # going to be column names (so we can avoid the extra overhead).
    12101210        qn = self.connection.ops.quote_name
    12111211        opts = self.query.get_meta()
    1212         result = ['INSERT INTO %s' % qn(opts.db_table)]
     1212        insert_statement = self.connection.ops.insert_statement(on_conflict=self.query.on_conflict)
     1213        result = ['%s %s' % (insert_statement, qn(opts.db_table))]
    12131214        fields = self.query.fields or [opts.pk]
    12141215        result.append('(%s)' % ', '.join(qn(f.column) for f in fields))
    12151216
    class SQLInsertCompiler(SQLCompiler):  
    12311232
    12321233        placeholder_rows, param_rows = self.assemble_as_sql(fields, value_rows)
    12331234
     1235        on_conflict_postfix = self.connection.ops.on_conflict_postfix(on_conflict=self.query.on_conflict)
     1236
    12341237        if self.return_id and self.connection.features.can_return_id_from_insert:
    12351238            if self.connection.features.can_return_ids_from_bulk_insert:
    12361239                result.append(self.connection.ops.bulk_insert_sql(fields, placeholder_rows))
    class SQLInsertCompiler(SQLCompiler):  
    12381241            else:
    12391242                result.append("VALUES (%s)" % ", ".join(placeholder_rows[0]))
    12401243                params = [param_rows[0]]
     1244            if on_conflict_postfix:
     1245                result.append(on_conflict_postfix)
    12411246            col = "%s.%s" % (qn(opts.db_table), qn(opts.pk.column))
    12421247            r_fmt, r_params = self.connection.ops.return_insert_id()
    12431248            # Skip empty r_fmt to allow subclasses to customize behavior for
    class SQLInsertCompiler(SQLCompiler):  
    12491254
    12501255        if can_bulk:
    12511256            result.append(self.connection.ops.bulk_insert_sql(fields, placeholder_rows))
     1257            if on_conflict_postfix:
     1258                result.append(on_conflict_postfix)
    12521259            return [(" ".join(result), tuple(p for ps in param_rows for p in ps))]
    12531260        else:
     1261            if on_conflict_postfix:
     1262                result.append(on_conflict_postfix)
    12541263            return [
    12551264                (" ".join(result + ["VALUES (%s)" % ", ".join(p)]), vals)
    12561265                for p, vals in zip(placeholder_rows, param_rows)
  • django/db/models/sql/subqueries.py

    diff --git a/django/db/models/sql/subqueries.py b/django/db/models/sql/subqueries.py
    a b class UpdateQuery(Query):  
    169169class InsertQuery(Query):
    170170    compiler = 'SQLInsertCompiler'
    171171
    172     def __init__(self, *args, **kwargs):
     172    def __init__(self, *args, on_conflict=None, **kwargs):
    173173        super().__init__(*args, **kwargs)
    174174        self.fields = []
    175175        self.objs = []
     176        self.on_conflict = on_conflict
    176177
    177178    def insert_values(self, fields, objs, raw=False):
    178179        self.fields = fields
  • docs/ref/models/querysets.txt

    diff --git a/docs/ref/models/querysets.txt b/docs/ref/models/querysets.txt
    a b exists in the database, an :exc:`~django.db.IntegrityError` is raised.  
    19951995``bulk_create()``
    19961996~~~~~~~~~~~~~~~~~
    19971997
    1998 .. method:: bulk_create(objs, batch_size=None)
     1998.. method:: bulk_create(objs, batch_size=None, on_conflict=None, send_post_save=False, all_ids=None)
    19991999
    20002000This method inserts the provided list of objects into the database in an
    20012001efficient manner (generally only 1 query, no matter how many objects there
    are)::  
    20092009This has a number of caveats though:
    20102010
    20112011* The model's ``save()`` method will not be called, and the ``pre_save`` and
    2012   ``post_save`` signals will not be sent.
     2012  signal will not be sent.
     2013* If send_post_save is True, ``post_save`` signal will be sent to the objs with ID, that
     2014  are known to Django.  The IDs are known to Django, if on_conflict=None is used for
     2015  backends having features.can_return_ids_from_bulk_insert=True, or if simultanteously on_conflict='ignore',
     2016  postgresql >= 9.5 and all_ids is not None.
    20132017* It does not work with child models in a multi-table inheritance scenario.
    20142018* If the model's primary key is an :class:`~django.db.models.AutoField` it
    20152019  does not retrieve and set the primary key attribute, as ``save()`` does,
    The ``batch_size`` parameter controls how many objects are created in a single  
    20352039query. The default is to create all objects in one batch, except for SQLite
    20362040where the default is such that at most 999 variables per query are used.
    20372041
     2042On database backends that support it, the ``on_conflict`` parameter controls
     2043how the database handles conflicting data, i.e containing duplicate UNIQUE
     2044values. Currently this parameter can be ``None`` in which case an
     2045:exc:`~django.db.IntegrityError` will be thrown, or ``ignore`` in which case
     2046the database will ignore any failing rows. Using this parameter will disable
     2047setting the primary key attribute in the returned list and on ``objs` on backends that
     2048support ``on_conflict=='ignore'``.
     2049
     2050If anything else than Postgresql >= 9.5 is used, the ``all_ids`` parameter is ignored.
     2051
     2052If Postgresql >= 9.5 and `on_conflict='ignore'` are used:
     2053
     2054* If `alls_ids` is not `None` the underlaying Postgresql sequence is incremented unnecessary
     2055  for each object what was already in the database and the newly created objects have their ``pk`` set.
     2056  If ``all_ids`` is `None`, the returned objs that were created do not have their `pk` set.
     2057* If `all_ids` is True, a second query is sent to the database which retrieves the IDs
     2058  of those objs, that existed prior to calling ``bulk_create()``.  The query matches all
     2059  provided fields of the supplied objs.
     2060* If `all_ids` is False, the second query is not sent.  The difference between False and
     2061  None for all_ids is, that False sets both the IDs of the inserted objs and possibly increments
     2062  unnecessary the postgresql sequence counter, where all_ids=None does neither.
     2063* If `all_ids` is a non-empty list, all fields mentioned in that list are
     2064  ignored in the latter query, when considering objects for equality.  This differs from
     2065  all_ids=True where all fields are compared::
     2066
     2067    >>> from django.db import models
     2068
     2069    >>> class T(models.Model):
     2070    ...    d = models.DateTimeField(default=django.utils.timezone.now)
     2071    ...    n = models.IntegerField(unique=True)
     2072
     2073    >>> T.objects.bulk_create([T(n=1), T(n=1)], on_conflict='ignore', all_ids=True)
     2074    # Now the database contains one object with n=1 and a timestamp when the first
     2075    # constructor was called.  The returned list has two objects, and the second object
     2076    # has no pk set.  The reason is that the second T(n=1) has d with a timestamp that
     2077    # is different from the timestamp of the first T(n=1), and querying the database
     2078    # for the second T object returned no results.  Even if the second object is not
     2079    # inserted into the database, but only the first one, the corresponding Postgresql
     2080    # sequence is increased by two
     2081
     2082    >>> T.objects.bulk_create([T(n=1), T(n=1)], on_conflict='ignore', all_ids=['d'])
     2083    # Now the database will check if there is an object with n=1 and ignore the d field.
     2084    # The pk field of each element in the list will be set. The Postgresql sequence is
     2085    # increased by two and it does not matter if T(n=1) was in the database before the call.
     2086
     2087.. versionchanged:: 2.1
     2088
     2089    The ``on_conflict``,  ``send_post_signal`` and ``all_ids`` parameters were added.
     2090
    20382091``count()``
    20392092~~~~~~~~~~~
    20402093
  • docs/releases/2.1.txt

    diff --git a/docs/releases/2.1.txt b/docs/releases/2.1.txt
    a b Models  
    221221* :meth:`.QuerySet.order_by` and :meth:`distinct(*fields) <.QuerySet.distinct>`
    222222  now support using field transforms.
    223223
     224* The new ``on_conflict``, ``send_post_save`` and ``all_ids`` parameters of
     225  :meth:`~django.db.models.query.QuerySet.bulk_create` controls how the database
     226  handles rows that fail constraint checking, whether to send post_save signals
     227  for the created objects and for Postgresql >= 9.5 whether and how to retrieve the
     228  IDs of the objects that failed the constraint checkings.
     229
     230* RelatedManager.add() is now thread-safe on sqlite, Postgresql >= 9.5 and MySQL
     231
    224232Requests and Responses
    225233~~~~~~~~~~~~~~~~~~~~~~
    226234
  • tests/bulk_create/models.py

    diff --git a/tests/bulk_create/models.py b/tests/bulk_create/models.py
    a b class Country(models.Model):  
    1717    description = models.TextField()
    1818
    1919
     20class CountryUnique(models.Model):
     21    name = models.CharField(max_length=255)
     22    iso_two_letter = models.CharField(max_length=2, unique=True)
     23    description = models.TextField()
     24
     25
    2026class ProxyCountry(Country):
    2127    class Meta:
    2228        proxy = True
  • tests/bulk_create/tests.py

    diff --git a/tests/bulk_create/tests.py b/tests/bulk_create/tests.py
    a b  
    11from operator import attrgetter
    22
    3 from django.db import connection
    4 from django.db.models import FileField, Value
     3from django.db import IntegrityError, connection
     4from django.db.models import FileField, Value, signals
     5from django.db.models.fields import AutoField
    56from django.db.models.functions import Lower
     7from django.db.utils import NotSupportedError
    68from django.test import (
    79    TestCase, override_settings, skipIfDBFeature, skipUnlessDBFeature,
    810)
    911
    1012from .models import (
    11     Country, NoFields, NullableFields, Pizzeria, ProxyCountry,
     13    Country, CountryUnique, NoFields, NullableFields, Pizzeria, ProxyCountry,
    1214    ProxyMultiCountry, ProxyMultiProxyCountry, ProxyProxyCountry, Restaurant,
    1315    State, TwoFields,
    1416)
    class BulkCreateTests(TestCase):  
    233235        self.assertEqual(len(countries), 1)
    234236        self.assertEqual(Country.objects.get(pk=countries[0].pk), countries[0])
    235237
     238    def test_insert_single_item_that_is_present(self):
     239        CountryUnique.objects.bulk_create([self.data[0]])
     240        with self.assertRaises(IntegrityError):
     241            CountryUnique.objects.bulk_create([self.data[0]])
     242
    236243    @skipUnlessDBFeature('can_return_ids_from_bulk_insert')
    237244    def test_set_pk_and_query_efficiency(self):
    238245        with self.assertNumQueries(1):
    class BulkCreateTests(TestCase):  
    252259        # Objects save via bulk_create() and save() should have equal state.
    253260        self.assertEqual(country_nl._state.adding, country_be._state.adding)
    254261        self.assertEqual(country_nl._state.db, country_be._state.db)
     262
     263    @skipIfDBFeature("supports_on_conflict_ignore")
     264    def test_on_conflict_value_error(self):
     265        message = 'This database backend does not support ON CONFLICT IGNORE'
     266        with self.assertRaises(NotSupportedError, message=message):
     267            TwoFields.objects.bulk_create(self.data, on_conflict='ignore')
     268
     269    @skipUnlessDBFeature("supports_on_conflict_ignore")
     270    def test_on_conflict_ignore(self):
     271        data = [
     272            TwoFields(f1=1, f2=1),
     273            TwoFields(f1=2, f2=2),
     274            TwoFields(f1=3, f2=3)
     275        ]
     276        TwoFields.objects.bulk_create(data)
     277        self.assertEqual(TwoFields.objects.count(), 3)
     278
     279        conflicting_objects = [
     280            TwoFields(f1=2, f2=2),
     281            TwoFields(f1=3, f2=3)
     282        ]
     283        TwoFields.objects.bulk_create([conflicting_objects[0]], on_conflict='ignore')
     284        TwoFields.objects.bulk_create(conflicting_objects, on_conflict='ignore')
     285        self.assertEqual(TwoFields.objects.count(), 3)
     286        self.assertIsNone(conflicting_objects[0].pk)
     287        self.assertIsNone(conflicting_objects[1].pk)
     288
     289        new_object = TwoFields(f1=4, f2=4)
     290        TwoFields.objects.bulk_create(conflicting_objects + [new_object], on_conflict='ignore')
     291        self.assertEqual(TwoFields.objects.count(), 4)
     292        self.assertIsNone(new_object.pk)
     293
     294        with self.assertRaises(IntegrityError):
     295            TwoFields.objects.bulk_create(conflicting_objects)
     296
     297    def test_on_conflict_invalid(self):
     298        message = "'test' is an invalid value for on_conflict. Allowed values: 'ignore'"
     299        with self.assertRaises(ValueError, message=message):
     300            Country.objects.bulk_create(self.data, on_conflict='test')
     301
     302    @skipUnlessDBFeature("supports_on_conflict_ignore")
     303    def test_on_conflict_case_insensitive(self):
     304        with self.assertNumQueries(1):
     305            Country.objects.bulk_create(self.data, on_conflict='IGNORE')
     306        self.assertEqual(Country.objects.count(), 4)
     307
     308    def test_on_conflict_unset_all_ids_set(self):
     309        with self.assertRaises(ValueError, message='all_ids can be used only with on_conflict'):
     310            Country.objects.bulk_create(self.data, all_ids=True)
     311
     312    @skipIfDBFeature('is_postgresql_9_5')
     313    def test_on_conflict_ignore_all_ids_invalid(self):
     314        message = 'all_ids can be set only when Postgresql >= 9.5 is used'
     315        for all_ids in (False, True, ['description']):
     316            with self.assertRaises(NotSupportedError, message=message):
     317                Country.objects.bulk_create(self.data, on_conflict='ignore', all_ids=all_ids)
     318
     319    @skipUnlessDBFeature('is_postgresql_9_5')
     320    def test_on_conflict_ignore_all_ids_false(self):
     321        data = [
     322            TwoFields(f1=1, f2=1),
     323            TwoFields(f1=2, f2=2),
     324            TwoFields(f1=3, f2=3)
     325        ]
     326        TwoFields.objects.bulk_create(data)
     327        self.assertEqual(TwoFields.objects.count(), 3)
     328
     329        new_object = TwoFields(f1=4, f2=4)
     330        with self.assertNumQueries(2):
     331            TwoFields.objects.bulk_create(data + [new_object], on_conflict='ignore', all_ids=False)
     332        self.assertEqual(TwoFields.objects.count(), 4)
     333        self.assertIsNotNone(new_object.pk)
     334        new_object_duplicate_1 = TwoFields(f1=5, f2=5)
     335        new_object_duplicate_2 = TwoFields(f1=5, f2=5)
     336        with self.assertNumQueries(2):
     337            TwoFields.objects.bulk_create(data + [new_object_duplicate_1, new_object_duplicate_2],
     338                                          on_conflict='ignore', all_ids=False)
     339        self.assertEqual(TwoFields.objects.count(), 5)
     340        self.assertEqual(new_object_duplicate_1.pk, new_object_duplicate_2.pk)
     341        self.assertIsNotNone(new_object_duplicate_1.pk)
     342
     343    @skipUnlessDBFeature('is_postgresql_9_5')
     344    def test_on_conflict_ingore_all_ids_none(self):
     345        """Verify that the IDs do not grow exorbitant when on_conflict='ignore', all_ids=None is used"""
     346        data = [CountryUnique(iso_two_letter='BG'), CountryUnique(iso_two_letter='DE')]
     347        x = CountryUnique.objects.bulk_create(data)
     348        CountryUnique.objects.bulk_create(data + [CountryUnique(iso_two_letter='UK')], on_conflict='ignore')
     349        self.assertEqual(x[1].id + 2, CountryUnique.objects.create(iso_two_letter='US').pk)
     350
     351    @skipUnlessDBFeature('is_postgresql_9_5')
     352    def test_on_conflict_ignore_all_ids_true(self):
     353        x = CountryUnique(iso_two_letter='BG')
     354        y = CountryUnique(iso_two_letter='GR')
     355        CountryUnique.objects.bulk_create([x, y])
     356        data = [
     357            CountryUnique(iso_two_letter='BG', description='Between Romania and Turkey'),
     358            y
     359        ]
     360        t_gr = y.pk
     361        y.pk = None
     362        with self.assertNumQueries(2):
     363            ret = CountryUnique.objects.bulk_create(data, on_conflict='ignore', all_ids=True)
     364        self.assertEqual([x.pk for x in ret], [None, t_gr])
     365
     366    @skipUnlessDBFeature('is_postgresql_9_5')
     367    def test_on_conflict_ignore_all_ids_list(self):
     368        x = CountryUnique(iso_two_letter='BG')
     369        y = CountryUnique(iso_two_letter='GR')
     370        CountryUnique.objects.bulk_create([x, y])
     371        t_bg, t_gr = x.pk, y.pk
     372        x.pk = y.pk = None
     373        data = [
     374            CountryUnique(iso_two_letter='BG', description='Between Romania and Turkey'),
     375            y
     376        ]
     377        with self.assertNumQueries(2):
     378            ret = CountryUnique.objects.bulk_create(data, on_conflict='ignore', all_ids=['description'])
     379        self.assertEqual([x.pk for x in ret], [t_bg, t_gr])
     380
     381    @skipUnlessDBFeature("supports_on_conflict_ignore")
     382    def test__batched_insert_on_conflict_ignore_return_id_false(self):
     383        fields = [f for f in CountryUnique._meta.concrete_fields if not isinstance(f, AutoField)]
     384        x = CountryUnique.objects.all()._batched_insert(
     385            [CountryUnique(iso_two_letter='BG'), CountryUnique(iso_two_letter='GR')],
     386            fields, batch_size=None, on_conflict='ignore', return_id=False)
     387        self.assertEqual(x, [])
     388        x = CountryUnique.objects.all()._batched_insert(
     389            [CountryUnique(iso_two_letter='BG'), CountryUnique(iso_two_letter='GR')],
     390            fields, batch_size=None, on_conflict='ignore', return_id=False)
     391        self.assertEqual(x, [])
     392
     393    @skipUnlessDBFeature('is_postgresql_9_5')
     394    def test__batched_insert_on_conflict_ignore_return_id_true(self):
     395        fields = [f for f in CountryUnique._meta.concrete_fields if not isinstance(f, AutoField)]
     396        x = CountryUnique.objects.all()._batched_insert(
     397            [CountryUnique(iso_two_letter='BG'), CountryUnique(iso_two_letter='GR')],
     398            fields, batch_size=None, on_conflict='ignore', return_id=True)
     399        self.assertTrue(x[0] is not None and x[1] is not None)
     400        x = CountryUnique.objects.all()._batched_insert(
     401            [CountryUnique(iso_two_letter='BG'), CountryUnique(iso_two_letter='GR')],
     402            fields, batch_size=None, on_conflict='ignore', return_id=True)
     403        self.assertTrue(x[0] is None and x[1] is None)
     404
     405    def test__batched_insert_on_conflict_none_return_id_false(self):
     406        fields = [f for f in CountryUnique._meta.concrete_fields if not isinstance(f, AutoField)]
     407        x = CountryUnique.objects.all()._batched_insert(
     408            [CountryUnique(iso_two_letter='BG'), CountryUnique(iso_two_letter='GR')],
     409            fields, batch_size=None, on_conflict=None, return_id=False)
     410        self.assertEqual(x, [])
     411
     412    @skipUnlessDBFeature('can_return_ids_from_bulk_insert')
     413    def test__batched_insert_on_conflict_none_return_id_true(self):
     414        fields = [f for f in CountryUnique._meta.concrete_fields if not isinstance(f, AutoField)]
     415        x = CountryUnique.objects.all()._batched_insert(
     416            [CountryUnique(iso_two_letter='BG'), CountryUnique(iso_two_letter='GR')],
     417            fields, batch_size=None, on_conflict=None, return_id=True)
     418        self.assertTrue(x[0] is not None and x[1] is not None)
     419
     420    @skipUnlessDBFeature('can_return_ids_from_bulk_insert')
     421    def test_insert_twice_the_same_item(self):
     422        with self.assertNumQueries(1), self.assertRaises(IntegrityError):
     423            CountryUnique.objects.bulk_create([CountryUnique(iso_two_letter='DE'), CountryUnique(iso_two_letter='DE')])
     424
     425
     426class BulkCreatePostSaveSignalTests(TestCase):
     427    """Tests bulk_create(objs, send_post_save_=True)"""
     428    objs = [
     429        CountryUnique(name="United States of America", iso_two_letter="US"),
     430        CountryUnique(name="The Netherlands", iso_two_letter="NL"),
     431        CountryUnique(name="Germany", iso_two_letter="DE"),
     432        CountryUnique(name="Czech Republic", iso_two_letter="CZ")
     433    ]
     434
     435    def setUp(self):
     436        # Save up the number of connected signals so that we can check at the
     437        # end that all the signals we register get properly unregistered (#9989)
     438        self.received_signals = []
     439        signals.post_save.connect(self.post_save_handler, weak=False)
     440        self.pre_signals = len(signals.post_save.receivers)
     441
     442    def tearDown(self):
     443        # All our signals got disconnected properly.
     444        post_signals = len(signals.post_save.receivers)
     445        signals.post_save.disconnect(self.post_save_handler)
     446        self.assertEqual(self.pre_signals, post_signals)
     447
     448    def post_save_handler(self, **kwargs):
     449        self.received_signals.append('post_save')
     450
     451    @skipUnlessDBFeature('can_return_ids_from_bulk_insert')
     452    def test_bulk_create_post_save_signal_objs_without_pk(self):
     453        CountryUnique.objects.bulk_create(self.objs, send_post_save=True)
     454        self.assertEqual(len(self.received_signals), 4)
     455
     456    @skipUnlessDBFeature('can_return_ids_from_bulk_insert')
     457    def test_bulk_create_post_save_signals_objs_with_pk_and_without_pk(self):
     458        objs = [
     459            CountryUnique(id=10, name="United States of America", iso_two_letter="US"),
     460            CountryUnique(name="The Netherlands", iso_two_letter="NL"),
     461            CountryUnique(id=13, name="Germany", iso_two_letter="DE"),
     462            CountryUnique(name="Czech Republic", iso_two_letter="CZ")
     463        ]
     464        x = CountryUnique.objects.bulk_create(objs, send_post_save=True)
     465        self.assertTrue(x[0].pk == 10 and x[2].id == 13 and len(x) == 4)
     466        self.assertEqual(len(self.received_signals), 4)
     467
     468    def test_bulk_create_post_save_signals_objs_with_pk(self):
     469        objs = [
     470            CountryUnique(id=1, name="United States of America", iso_two_letter="US"),
     471            CountryUnique(id=2, name="The Netherlands", iso_two_letter="NL"),
     472            CountryUnique(id=3, name="Germany", iso_two_letter="DE"),
     473            CountryUnique(id=4, name="Czech Republic", iso_two_letter="CZ")
     474        ]
     475
     476        x = CountryUnique.objects.bulk_create(objs, send_post_save=True)
     477        self.assertEqual([y.id for y in x], [1, 2, 3, 4])
     478        self.assertEqual(len(self.received_signals), 4)
     479
     480    #  From now on tests bulk_create(objs, send_post_save_=True, on_conflict='ignore')
     481    @skipUnlessDBFeature('supports_on_conflict_ignore')
     482    def test_bulk_create_post_save_signals_ignore(self):
     483        CountryUnique.objects.bulk_create(self.objs)
     484        self.received_signals = []
     485        data = [
     486            CountryUnique(name="Greece", iso_two_letter="GR"),
     487            CountryUnique(name="The Netherlands", iso_two_letter="NL"),
     488            CountryUnique(name="Germany", iso_two_letter="DE")
     489        ]
     490        x = CountryUnique.objects.bulk_create(data, send_post_save=True, on_conflict='ignore')
     491        self.assertEqual([y.id for y in x], [None, None, None])
     492        self.assertEqual(len(self.received_signals), 0)
     493
     494    @skipUnlessDBFeature('is_postgresql_9_5')
     495    def test_bulk_create_post_save_signals_ignore_2(self):
     496        CountryUnique.objects.bulk_create(self.objs)
     497        self.received_signals = []
     498        data = [
     499            CountryUnique(name="Greece", iso_two_letter="GR"),
     500            CountryUnique(name="The Netherlands", iso_two_letter="NL"),
     501            CountryUnique(name="Germany", iso_two_letter="DE")
     502        ]
     503        x = CountryUnique.objects.bulk_create(data, send_post_save=True, on_conflict='ignore', all_ids=False)
     504        self.assertEqual([bool(y.id) for y in x], [True, False, False])
     505        self.assertEqual(len(self.received_signals), 1)
     506
     507    @skipUnlessDBFeature('is_postgresql_9_5')
     508    def test_bulk_create_post_save_signals_ignore_3(self):
     509        """Tests when bulk_create gets the same object twice"""
     510        CountryUnique.objects.bulk_create(self.objs)
     511        self.received_signals = []
     512        data = [
     513            CountryUnique(name="Greece", iso_two_letter="GR"),
     514            CountryUnique(name="Greece", iso_two_letter="GR"),
     515            CountryUnique(name="The Netherlands", iso_two_letter="NL"),
     516            CountryUnique(name="Germany", iso_two_letter="DE")
     517        ]
     518        x = CountryUnique.objects.bulk_create(data, send_post_save=True, on_conflict='ignore', all_ids=False)
     519        self.assertEqual([bool(y.id) for y in x], [True, True, False, False])
     520        self.assertEqual(len(self.received_signals), 1)
     521
     522    @skipUnlessDBFeature('is_postgresql_9_5')
     523    def test_bulk_create_post_save_signals_ignore_4(self):
     524        """Tests when bulk_create on_conflict='ignore' with all_ids being a list"""
     525        CountryUnique.objects.bulk_create(self.objs)
     526        self.received_signals = []
     527        data = [
     528            CountryUnique(name="Greece", iso_two_letter="GR", description="Contains Acropolis"),
     529            CountryUnique(name="Greece", iso_two_letter="GR", description="Contains Athen"),
     530            CountryUnique(name="The Netherlands", iso_two_letter="NL"),
     531            CountryUnique(name="Germany", iso_two_letter="DE")
     532        ]
     533
     534        x = CountryUnique.objects.bulk_create(data, send_post_save=True, on_conflict='ignore',
     535                                              all_ids=['description'])
     536        self.assertTrue(all(y.id for y in x) and len(x) == 4)
     537        self.assertEqual(len(self.received_signals), 1)
     538
     539    @skipUnlessDBFeature('is_postgresql_9_5')
     540    def test_bulk_create_post_save_signals_ignore_5(self):
     541        """Tests bulk_create on_conflict='ignore' and all_ids=True"""
     542        CountryUnique.objects.bulk_create(self.objs)
     543        self.received_signals = []
     544        data = [
     545            CountryUnique(name="Greece", iso_two_letter="GR", description="Contains Acropolis"),
     546            CountryUnique(name="The Netherlands", iso_two_letter="NL"),
     547            CountryUnique(name="Germany", iso_two_letter="DE")
     548        ]
     549
     550        x = CountryUnique.objects.bulk_create(data, send_post_save=True, on_conflict='ignore', all_ids=True)
     551        self.assertTrue(all(y.id for y in x) and len(x) == 3)
     552        self.assertEqual(len(self.received_signals), 1)
Back to Top