Ticket #17258: 17258.thread-local-connections.3.diff

File 17258.thread-local-connections.3.diff, 10.4 KB (added by Anssi Kääriäinen, 8 years ago)

POC: check sharing of connections between threads

  • django/db/__init__.py

    diff --git a/django/db/__init__.py b/django/db/__init__.py
    index 8395468..8d68e83 100644
    a b  
     1from threading import local
    12from django.conf import settings
    23from django.core import signals
    34from django.core.exceptions import ImproperlyConfigured
    4 from django.db.utils import (ConnectionHandler, ConnectionRouter,
    5     load_backend, DEFAULT_DB_ALIAS, DatabaseError, IntegrityError)
     5from django.db.utils import (ConnectionHandler, ConnectionRouter, load_backend,
     6    DEFAULT_DB_ALIAS, DatabaseError, IntegrityError, SharedConnectionError)
    67
    78__all__ = ('backend', 'connection', 'connections', 'router', 'DatabaseError',
    89    'IntegrityError', 'DEFAULT_DB_ALIAS')
    router = ConnectionRouter(settings.DATABASE_ROUTERS) 
    2223# we manually create the dictionary from the settings, passing only the
    2324# settings that the database backends care about. Note that TIME_ZONE is used
    2425# by the PostgreSQL backends.
    25 # we load all these up for backwards compatibility, you should use
     26# We load all these up for backwards compatibility, you should use
    2627# connections['default'] instead.
    27 connection = connections[DEFAULT_DB_ALIAS]
     28class DefaultConnectionProxy(object):
     29    """
     30    Proxy for the thread-local default connection.
     31    """
     32    def __getattr__(self, item):
     33        return getattr(connections[DEFAULT_DB_ALIAS], item)
     34
     35    def __setattr__(self, name, value):
     36        return setattr(connections[DEFAULT_DB_ALIAS], name, value)
     37
     38connection = DefaultConnectionProxy()
    2839backend = load_backend(connection.settings_dict['ENGINE'])
    2940
    3041# Register an event that closes the database connection
  • django/db/backends/__init__.py

    diff --git a/django/db/backends/__init__.py b/django/db/backends/__init__.py
    index f2bde84..315b99b 100644
    a b try: 
    22    import thread
    33except ImportError:
    44    import dummy_thread as thread
    5 from threading import local
    65from contextlib import contextmanager
    76
    87from django.conf import settings
    9 from django.db import DEFAULT_DB_ALIAS
     8from django.db import DEFAULT_DB_ALIAS, SharedConnectionError
    109from django.db.backends import util
    1110from django.db.transaction import TransactionManagementError
    1211from django.utils.importlib import import_module
    1312from django.utils.timezone import is_aware
    1413
    1514
    16 class BaseDatabaseWrapper(local):
     15class BaseDatabaseWrapper(object):
    1716    """
    1817    Represents a database connection.
    1918    """
    class BaseDatabaseWrapper(local): 
    3433        self.transaction_state = []
    3534        self.savepoint_state = 0
    3635        self._dirty = None
     36        self.allow_thread_sharing = False
     37        self._thread_ident = None
    3738
    3839    def __eq__(self, other):
    3940        return self.alias == other.alias
    class BaseDatabaseWrapper(local): 
    4142    def __ne__(self, other):
    4243        return not self == other
    4344
     45    def _check_thread_sharing(self):
     46        if self.allow_thread_sharing:
     47            return
     48        if self._thread_ident is None:
     49            self._thread_ident = thread.get_ident()
     50            return
     51        if self._thread_ident <> thread.get_ident():
     52            raise SharedConnectionError(
     53                "This connection seems to be shared between threads. This "
     54                "is not allowed.")
     55
    4456    def _commit(self):
    4557        if self.connection is not None:
    4658            return self.connection.commit()
  • django/db/backends/postgresql_psycopg2/base.py

    diff --git a/django/db/backends/postgresql_psycopg2/base.py b/django/db/backends/postgresql_psycopg2/base.py
    index 2f020f4..6a58b41 100644
    a b class DatabaseWrapper(BaseDatabaseWrapper): 
    153153    pg_version = property(_get_pg_version)
    154154
    155155    def _cursor(self):
     156        self._check_thread_sharing()
    156157        settings_dict = self.settings_dict
    157158        if self.connection is None:
    158159            if settings_dict['NAME'] == '':
  • django/db/backends/sqlite3/base.py

    diff --git a/django/db/backends/sqlite3/base.py b/django/db/backends/sqlite3/base.py
    index a610606..5e0a076 100644
    a b class DatabaseWrapper(BaseDatabaseWrapper): 
    231231        self.validation = BaseDatabaseValidation(self)
    232232
    233233    def _cursor(self):
     234        self._check_thread_sharing()
    234235        if self.connection is None:
    235236            settings_dict = self.settings_dict
    236237            if not settings_dict['NAME']:
  • django/db/utils.py

    diff --git a/django/db/utils.py b/django/db/utils.py
    index f0c13e3..7d820d1 100644
    a b  
    11import os
     2from threading import local
    23
    34from django.conf import settings
    45from django.core.exceptions import ImproperlyConfigured
    def load_backend(backend_name): 
    4647class ConnectionDoesNotExist(Exception):
    4748    pass
    4849
     50class SharedConnectionError(Exception):
     51    pass
    4952
    5053class ConnectionHandler(object):
    5154    def __init__(self, databases):
    5255        self.databases = databases
    53         self._connections = {}
     56        self._connections = local()
    5457
    5558    def ensure_defaults(self, alias):
    5659        """
    class ConnectionHandler(object): 
    7376            conn.setdefault(setting, None)
    7477
    7578    def __getitem__(self, alias):
    76         if alias in self._connections:
    77             return self._connections[alias]
     79        if hasattr(self._connections, alias):
     80            return getattr(self._connections, alias)
    7881
    7982        self.ensure_defaults(alias)
    8083        db = self.databases[alias]
    8184        backend = load_backend(db['ENGINE'])
    8285        conn = backend.DatabaseWrapper(db, alias)
    83         self._connections[alias] = conn
     86        setattr(self._connections, alias, conn)
    8487        return conn
    8588
     89    def __setitem__(self, key, value):
     90        setattr(self._connections, key, value)
     91
    8692    def __iter__(self):
    8793        return iter(self.databases)
    8894
  • tests/regressiontests/backends/tests.py

    diff --git a/tests/regressiontests/backends/tests.py b/tests/regressiontests/backends/tests.py
    index 936f010..9028e96 100644
    a b  
    33from __future__ import with_statement, absolute_import
    44
    55import datetime
     6import threading
    67
    78from django.conf import settings
    89from django.core.management.color import no_style
    910from django.db import (backend, connection, connections, DEFAULT_DB_ALIAS,
    10     IntegrityError, transaction)
     11    IntegrityError, transaction, SharedConnectionError)
    1112from django.db.backends.signals import connection_created
    1213from django.db.backends.postgresql_psycopg2 import version as pg_version
    1314from django.db.utils import ConnectionHandler, DatabaseError
    class ConnectionCreatedSignalTest(TestCase): 
    283284        connection_created.connect(receiver)
    284285        connection.close()
    285286        cursor = connection.cursor()
    286         self.assertTrue(data["connection"] is connection)
     287        self.assertTrue(data["connection"].connection is connection.connection)
    287288
    288289        connection_created.disconnect(receiver)
    289290        data.clear()
    class FkConstraintsTests(TransactionTestCase): 
    446447                        connection.check_constraints()
    447448            finally:
    448449                transaction.rollback()
     450
     451class ThreadTests(TestCase):
     452
     453    @unittest.skipIf(connection.vendor == 'sqlite',
     454                     "SQLite doesn't allow connection sharing between threads")
     455    def test_default_connection_thread_local(self):
     456        """
     457        Ensure that the default connection (i.e. django.db.connection) is
     458        different for each thread.
     459        Refs #17258.
     460        """
     461        connections_set = set()
     462        connection.cursor()
     463        connections_set.add(connection.connection)
     464        def runner():
     465            from django.db import connection
     466            connection.cursor()
     467            connections_set.add(connection.connection)
     468        for x in xrange(2):
     469            t = threading.Thread(target=runner)
     470            t.start()
     471            t.join()
     472        self.assertEquals(len(connections_set), 3)
     473        # Finish by closing the connections opened by the other threads (the
     474        # connection opened in the main thread will automatically be closed on
     475        # teardown).
     476        for conn in connections_set:
     477            if conn != connection.connection:
     478                conn.close()
     479
     480    @unittest.skipIf(connection.vendor == 'sqlite',
     481                     "SQLite doesn't allow connection sharing between threads")
     482    def test_connections_thread_local(self):
     483        """
     484        Ensure that the connections are different for each thread.
     485        Refs #17258.
     486        """
     487        connections_set = set()
     488        for conn in connections.all():
     489            connections_set.add(conn)
     490        def runner():
     491            from django.db import connections
     492            for conn in connections.all():
     493                connections_set.add(conn)
     494        for x in xrange(2):
     495            t = threading.Thread(target=runner)
     496            t.start()
     497            t.join()
     498        self.assertEquals(len(connections_set), 6)
     499        # Finish by closing the connections opened by the other threads (the
     500        # connection opened in the main thread will automatically be closed on
     501        # teardown).
     502        for conn in connections_set:
     503            if conn != connection:
     504                conn.close()
     505
     506    @unittest.skipIf(connection.vendor == 'sqlite',
     507                     "SQLite doesn't allow connection sharing between threads")
     508    def test_connection_passing(self):
     509        """
     510        Using the same connection in two separate threads is not allowed
     511        unless explicitly requested by the user. Refs #17258.
     512        """
     513        errors = []
     514        def runner(connection):
     515            # Usage is defined by taking a cursor.
     516            try:
     517                connection.cursor()
     518                connection.close()
     519            except SharedConnectionError, e:
     520                errors.append(str(e))
     521
     522        connection = connections[DEFAULT_DB_ALIAS]
     523        # Make sure the connection is claimed by this thread.
     524        connection.cursor()
     525        for x in xrange(2):
     526            t = threading.Thread(target=runner, args=(connection,))
     527            t.start()
     528            t.join()
     529        self.assertEquals(len(errors), 2)
     530
     531        errors = []
     532        connection.allow_thread_sharing = True
     533        for x in xrange(2):
     534            t = threading.Thread(target=runner, args=(connection,))
     535            t.start()
     536            t.join()
     537        self.assertEquals(len(errors), 0)
Back to Top