Ticket #13790: 13790_CurrentSiteManager_patch.diff

File 13790_CurrentSiteManager_patch.diff, 6.0 KB (added by Gabriel Hurley, 14 years ago)
  • django/contrib/sites/managers.py

     
    44
    55class CurrentSiteManager(models.Manager):
    66    "Use this to limit objects to those associated with the current site."
    7     def __init__(self, field_name='site'):
     7    def __init__(self, field_name=None):
    88        super(CurrentSiteManager, self).__init__()
    99        self.__field_name = field_name
    1010        self.__is_validated = False
    11 
     11       
     12    def _validate_field_name(self):
     13        field_names = self.model._meta.get_all_field_names()
     14       
     15        # If a custom name is provided, make sure the field exists on the model
     16        if self.__field_name is not None and self.__field_name not in field_names:
     17            raise ValueError("%s couldn't find a field named %s in %s." % \
     18                (self.__class__.__name__, self.__field_name, self.model._meta.object_name))
     19       
     20        # Otherwise, see if there is a field called either 'site' or 'sites'
     21        else:
     22            for potential_name in ['site', 'sites']:
     23                if potential_name in field_names:
     24                    self.__field_name = potential_name
     25                    self.__is_validated = True
     26                    break
     27       
     28        # Now do a type check on the field (FK or M2M only)
     29        try:
     30            field = self.model._meta.get_field(self.__field_name)
     31            if not isinstance(field, (models.ForeignKey, models.ManyToManyField)):
     32                raise TypeError("%s must be a ForeignKey or ManyToManyField." %self.__field_name)
     33        except FieldDoesNotExist:
     34            raise ValueError("%s couldn't find a field named %s in %s." % \
     35                    (self.__class__.__name__, self.__field_name, self.model._meta.object_name))
     36        self.__is_validated = True
     37   
    1238    def get_query_set(self):
    1339        if not self.__is_validated:
    14             try:
    15                 self.model._meta.get_field(self.__field_name)
    16             except FieldDoesNotExist:
    17                 raise ValueError("%s couldn't find a field named %s in %s." % \
    18                     (self.__class__.__name__, self.__field_name, self.model._meta.object_name))
    19             self.__is_validated = True
     40            self._validate_field_name()
    2041        return super(CurrentSiteManager, self).get_query_set().filter(**{self.__field_name + '__id__exact': settings.SITE_ID})
  • tests/regressiontests/sites_framework/models.py

     
     1from django.contrib.sites.managers import CurrentSiteManager
     2from django.contrib.sites.models import Site
     3from django.db import models
     4
     5class AbstractArticle(models.Model):
     6    title = models.CharField(max_length=50)
     7   
     8    objects = models.Manager()
     9    on_site = CurrentSiteManager()
     10   
     11    class Meta:
     12        abstract = True
     13   
     14    def __unicode__(self):
     15        return self.title
     16
     17class SyndicatedArticle(AbstractArticle):
     18    sites = models.ManyToManyField(Site)
     19
     20class ExclusiveArticle(AbstractArticle):
     21    site = models.ForeignKey(Site)
     22   
     23class CustomArticle(AbstractArticle):
     24    places_this_article_should_appear = models.ForeignKey(Site)
     25   
     26    objects = models.Manager()
     27    on_site = CurrentSiteManager("places_this_article_should_appear")
     28
     29class InvalidArticle(AbstractArticle):
     30    site = models.ForeignKey(Site)
     31   
     32    objects = models.Manager()
     33    on_site = CurrentSiteManager("places_this_article_should_appear")
     34
     35class ConfusedArticle(AbstractArticle):
     36    site = models.IntegerField()
  • tests/regressiontests/sites_framework/tests.py

     
     1from django.conf import settings
     2from django.contrib.sites.models import Site
     3from django.test import TestCase
     4
     5from models import SyndicatedArticle, ExclusiveArticle, CustomArticle, InvalidArticle, ConfusedArticle
     6
     7class SitesFrameworkTestCase(TestCase):
     8    def setUp(self):
     9        Site.objects.get_or_create(id=settings.SITE_ID, domain="example.com", name="example.com")
     10        Site.objects.create(id=settings.SITE_ID+1, domain="example2.com", name="example2.com")
     11       
     12    def test_site_fk(self):
     13        article = ExclusiveArticle.objects.create(title="Breaking News!", site_id=settings.SITE_ID)
     14        self.assertEqual(ExclusiveArticle.on_site.all().get(), article)
     15   
     16    def test_sites_m2m(self):
     17        article = SyndicatedArticle.objects.create(title="Fresh News!")
     18        article.sites.add(Site.objects.get(id=settings.SITE_ID))
     19        article.sites.add(Site.objects.get(id=settings.SITE_ID+1))
     20        article2 = SyndicatedArticle.objects.create(title="More News!")
     21        article2.sites.add(Site.objects.get(id=settings.SITE_ID+1))
     22        self.assertEqual(SyndicatedArticle.on_site.all().get(), article)
     23       
     24    def test_custom_named_field(self):
     25        article = CustomArticle.objects.create(title="Tantalizing News!", places_this_article_should_appear_id=settings.SITE_ID)
     26        self.assertEqual(CustomArticle.on_site.all().get(), article)
     27   
     28    def test_invalid_name(self):
     29        article = InvalidArticle.objects.create(title="Bad News!", site_id=settings.SITE_ID)
     30        self.assertRaises(ValueError, InvalidArticle.on_site.all)
     31       
     32    def test_invalid_field_type(self):
     33        article = ConfusedArticle.objects.create(title="More Bad News!", site=settings.SITE_ID)
     34        self.assertRaises(TypeError, ConfusedArticle.on_site.all)
Back to Top