from django.contrib.auth.models import User
from django.db.models.query import Prefetch
from django.test import TestCase

from fastfood import models

class PrefetchTestCase(TestCase):
    def test_prefetch(self):
        user = User.objects.create(username='test')
        chain1, chain2, chain3, chain4 = [
            models.FastfoodChain.objects.create(name='chain1'),
            models.FastfoodChain.objects.create(name='chain2'),
            models.FastfoodChain.objects.create(name='chain3'),
            models.FastfoodChain.objects.create(name='chain4'),
        ]
        restaurant11, restaurant12, restaurant21, restaurant31, restaurant32, restaurant41 = [
            models.Restaurant.objects.create(chain=chain1, name='restaurant11'),
            models.Restaurant.objects.create(chain=chain1, name='restaurant12'),
            models.Restaurant.objects.create(chain=chain2, name='restaurant21'),
            models.Restaurant.objects.create(chain=chain3, name='restaurant31'),
            models.Restaurant.objects.create(chain=chain3, name='restaurant32'),
            models.Restaurant.objects.create(chain=chain4, name='restaurant41'),
        ]
        p1, p2, p3, p4 = [
            models.RestaurantProfile.objects.create(restaurant=restaurant11, user=user, is_favorite=True),
            models.RestaurantProfile.objects.create(restaurant=restaurant12, user=user, is_favorite=False),
            models.RestaurantProfile.objects.create(restaurant=restaurant21, user=user, is_favorite=True),
            models.RestaurantProfile.objects.create(restaurant=restaurant31, user=user, is_favorite=True),
        ]
        with self.assertNumQueries(3):
            chains = models.FastfoodChain.objects.filter(
                name__in=['chain1', 'chain2', 'chain4']
            ).prefetch_related(
                Prefetch(
                    'restaurant_set',
                    models.Restaurant.objects.prefetch_related(
                        Prefetch(
                            'restaurantprofile_set',
                            models.RestaurantProfile.objects.filter(
                                user=user,
                                is_favorite=True,
                            ),
                            to_attr='profiles',
                        )
                    ),
                    to_attr='restaurants',
                ),
            )
            result_chain_1, result_chain_2, result_chain_4 = list(chains)
            self.assertListEqual(result_chain_1.restaurants, [restaurant11, restaurant12])
            self.assertListEqual(result_chain_2.restaurants, [restaurant21])
            self.assertListEqual(result_chain_4.restaurants, [restaurant41])

            self.assertListEqual(result_chain_1.restaurants[0].profiles, [p1])
            self.assertListEqual(result_chain_1.restaurants[1].profiles, [])
            self.assertListEqual(result_chain_2.restaurants[0].profiles, [p3])
            self.assertListEqual(result_chain_4.restaurants[0].profiles, [])
