Ticket #13790: 13790_CurrentSiteManager_patch.diff
File 13790_CurrentSiteManager_patch.diff, 6.0 KB (added by , 14 years ago) |
---|
-
django/contrib/sites/managers.py
4 4 5 5 class CurrentSiteManager(models.Manager): 6 6 "Use this to limit objects to those associated with the current site." 7 def __init__(self, field_name= 'site'):7 def __init__(self, field_name=None): 8 8 super(CurrentSiteManager, self).__init__() 9 9 self.__field_name = field_name 10 10 self.__is_validated = False 11 11 12 def _validate_field_name(self): 13 field_names = self.model._meta.get_all_field_names() 14 15 # If a custom name is provided, make sure the field exists on the model 16 if self.__field_name is not None and self.__field_name not in field_names: 17 raise ValueError("%s couldn't find a field named %s in %s." % \ 18 (self.__class__.__name__, self.__field_name, self.model._meta.object_name)) 19 20 # Otherwise, see if there is a field called either 'site' or 'sites' 21 else: 22 for potential_name in ['site', 'sites']: 23 if potential_name in field_names: 24 self.__field_name = potential_name 25 self.__is_validated = True 26 break 27 28 # Now do a type check on the field (FK or M2M only) 29 try: 30 field = self.model._meta.get_field(self.__field_name) 31 if not isinstance(field, (models.ForeignKey, models.ManyToManyField)): 32 raise TypeError("%s must be a ForeignKey or ManyToManyField." %self.__field_name) 33 except FieldDoesNotExist: 34 raise ValueError("%s couldn't find a field named %s in %s." % \ 35 (self.__class__.__name__, self.__field_name, self.model._meta.object_name)) 36 self.__is_validated = True 37 12 38 def get_query_set(self): 13 39 if not self.__is_validated: 14 try: 15 self.model._meta.get_field(self.__field_name) 16 except FieldDoesNotExist: 17 raise ValueError("%s couldn't find a field named %s in %s." % \ 18 (self.__class__.__name__, self.__field_name, self.model._meta.object_name)) 19 self.__is_validated = True 40 self._validate_field_name() 20 41 return super(CurrentSiteManager, self).get_query_set().filter(**{self.__field_name + '__id__exact': settings.SITE_ID}) -
tests/regressiontests/sites_framework/models.py
1 from django.contrib.sites.managers import CurrentSiteManager 2 from django.contrib.sites.models import Site 3 from django.db import models 4 5 class AbstractArticle(models.Model): 6 title = models.CharField(max_length=50) 7 8 objects = models.Manager() 9 on_site = CurrentSiteManager() 10 11 class Meta: 12 abstract = True 13 14 def __unicode__(self): 15 return self.title 16 17 class SyndicatedArticle(AbstractArticle): 18 sites = models.ManyToManyField(Site) 19 20 class ExclusiveArticle(AbstractArticle): 21 site = models.ForeignKey(Site) 22 23 class CustomArticle(AbstractArticle): 24 places_this_article_should_appear = models.ForeignKey(Site) 25 26 objects = models.Manager() 27 on_site = CurrentSiteManager("places_this_article_should_appear") 28 29 class InvalidArticle(AbstractArticle): 30 site = models.ForeignKey(Site) 31 32 objects = models.Manager() 33 on_site = CurrentSiteManager("places_this_article_should_appear") 34 35 class ConfusedArticle(AbstractArticle): 36 site = models.IntegerField() -
tests/regressiontests/sites_framework/tests.py
1 from django.conf import settings 2 from django.contrib.sites.models import Site 3 from django.test import TestCase 4 5 from models import SyndicatedArticle, ExclusiveArticle, CustomArticle, InvalidArticle, ConfusedArticle 6 7 class SitesFrameworkTestCase(TestCase): 8 def setUp(self): 9 Site.objects.get_or_create(id=settings.SITE_ID, domain="example.com", name="example.com") 10 Site.objects.create(id=settings.SITE_ID+1, domain="example2.com", name="example2.com") 11 12 def test_site_fk(self): 13 article = ExclusiveArticle.objects.create(title="Breaking News!", site_id=settings.SITE_ID) 14 self.assertEqual(ExclusiveArticle.on_site.all().get(), article) 15 16 def test_sites_m2m(self): 17 article = SyndicatedArticle.objects.create(title="Fresh News!") 18 article.sites.add(Site.objects.get(id=settings.SITE_ID)) 19 article.sites.add(Site.objects.get(id=settings.SITE_ID+1)) 20 article2 = SyndicatedArticle.objects.create(title="More News!") 21 article2.sites.add(Site.objects.get(id=settings.SITE_ID+1)) 22 self.assertEqual(SyndicatedArticle.on_site.all().get(), article) 23 24 def test_custom_named_field(self): 25 article = CustomArticle.objects.create(title="Tantalizing News!", places_this_article_should_appear_id=settings.SITE_ID) 26 self.assertEqual(CustomArticle.on_site.all().get(), article) 27 28 def test_invalid_name(self): 29 article = InvalidArticle.objects.create(title="Bad News!", site_id=settings.SITE_ID) 30 self.assertRaises(ValueError, InvalidArticle.on_site.all) 31 32 def test_invalid_field_type(self): 33 article = ConfusedArticle.objects.create(title="More Bad News!", site=settings.SITE_ID) 34 self.assertRaises(TypeError, ConfusedArticle.on_site.all)