Ticket #16174: cbv-formpreview0.diff

File cbv-formpreview0.diff, 10.2 KB (added by Ryan Kaskel, 13 years ago)
  • django/contrib/formtools/preview.py

    diff --git a/django/contrib/formtools/preview.py b/django/contrib/formtools/preview.py
    index b4cdeba..b32d7d0 100644
    a b  
    11"""
    22Formtools Preview application.
    33"""
    4 
    5 try:
    6     import cPickle as pickle
    7 except ImportError:
    8     import pickle
    9 
    104from django.conf import settings
    11 from django.http import Http404
    125from django.shortcuts import render_to_response
    13 from django.template.context import RequestContext
    146from django.utils.crypto import constant_time_compare
    157from django.contrib.formtools.utils import form_hmac
     8from django.views.generic import FormView
    169
    1710AUTO_ID = 'formtools_%s' # Each form here uses this as its auto_id parameter.
    1811
    19 class FormPreview(object):
     12class FormPreview(FormView):
    2013    preview_template = 'formtools/preview.html'
    2114    form_template = 'formtools/form.html'
    2215
    2316    # METHODS SUBCLASSES SHOULDN'T OVERRIDE ###################################
    2417
    25     def __init__(self, form):
     18    def __init__(self, form_class, *args, **kwargs):
     19        super(FormPreview, self).__init__(*args, **kwargs)
    2620        # form should be a Form class, not an instance.
    27         self.form, self.state = form, {}
     21        self.form_class = form_class
     22        self.state = {}
    2823
    2924    def __call__(self, request, *args, **kwargs):
    30         stage = {'1': 'preview', '2': 'post'}.get(request.POST.get(self.unused_name('stage')), 'preview')
     25        return self.dispatch(request, *args, **kwargs)
     26
     27    def dispatch(self, request, *args, **kwargs):
     28        self.preview_stage = 'preview'
     29        self.post_stage = 'post'
     30        stages = {'1': self.preview_stage, '2': self.post_stage}
     31
     32        posted_stage = request.POST.get(self.unused_name('stage'))
     33        self.stage = stages.get(posted_stage, 'preview')
     34
     35        # For backwards compatiblity
    3136        self.parse_params(*args, **kwargs)
    32         try:
    33             method = getattr(self, stage + '_' + request.method.lower())
    34         except AttributeError:
    35             raise Http404
    36         return method(request)
     37
     38        return super(FormPreview, self).dispatch(request, *args, **kwargs)
    3739
    3840    def unused_name(self, name):
    3941        """
    class FormPreview(object):  
    4547        """
    4648        while 1:
    4749            try:
    48                 f = self.form.base_fields[name]
     50                self.form_class.base_fields[name]
    4951            except KeyError:
    5052                break # This field name isn't being used by the form.
    5153            name += '_'
    5254        return name
    5355
    54     def preview_get(self, request):
     56    def _get_context_data(self, form):
     57        """ For backwards compatiblity. """
     58        context = self.get_context_data()
     59        context.update(self.get_context(self.request, form))
     60        return context
     61
     62    def get(self, request, *args, **kwargs):
    5563        "Displays the form"
    56         f = self.form(auto_id=self.get_auto_id(), initial=self.get_initial(request))
    57         return render_to_response(self.form_template,
    58             self.get_context(request, f),
    59             context_instance=RequestContext(request))
     64        form_class = self.get_form_class()
     65        form = self.get_form(form_class)
     66        context = self._get_context_data(form)
     67        self.template_name = self.form_template
     68        return self.render_to_response(context)
     69
     70    def _check_security_hash(self, token, form):
     71        expected = self.security_hash(self.request, form)
     72        return constant_time_compare(token, expected)
     73
     74    def post(self, request, *args, **kwargs):
     75        """ Validates the POST data. If valid, displays the preview
     76        page or calls done, depending on the stage. Else, redisplays
     77        form. """
     78        form_class = self.get_form_class()
     79        form = self.get_form(form_class)
     80        if form.is_valid():
     81            return self.form_valid(form)
     82        else:
     83            return self.form_invalid(form)
    6084
    6185    def preview_post(self, request):
    62         "Validates the POST data. If valid, displays the preview page. Else, redisplays form."
    63         f = self.form(request.POST, auto_id=self.get_auto_id())
    64         context = self.get_context(request, f)
    65         if f.is_valid():
    66             self.process_preview(request, f, context)
     86        """ For backwards compatibility. failed_hash calls this method by
     87        default. """
     88        self.stage = self.preview_stage
     89        return self.post(request)
     90
     91    def form_valid(self, form):
     92        context = self._get_context_data(form)
     93        if self.stage == self.preview_stage:
     94            self.process_preview(self.request, form, context)
    6795            context['hash_field'] = self.unused_name('hash')
    68             context['hash_value'] = self.security_hash(request, f)
    69             return render_to_response(self.preview_template, context, context_instance=RequestContext(request))
     96            context['hash_value'] = self.security_hash(self.request, form)
     97            self.template_name = self.preview_template
     98            return self.render_to_response(context)
    7099        else:
    71             return render_to_response(self.form_template, context, context_instance=RequestContext(request))
     100            form_hash = self.request.POST.get(self.unused_name('hash'), '')
     101            if not self._check_security_hash(form_hash, form):
     102                return self.failed_hash(self.request) # Security hash failed.
     103            return self.done(self.request, form.cleaned_data)
    72104
    73     def _check_security_hash(self, token, request, form):
    74         expected = self.security_hash(request, form)
    75         return constant_time_compare(token, expected)
    76 
    77     def post_post(self, request):
    78         "Validates the POST data. If valid, calls done(). Else, redisplays form."
    79         f = self.form(request.POST, auto_id=self.get_auto_id())
    80         if f.is_valid():
    81             if not self._check_security_hash(request.POST.get(self.unused_name('hash'), ''),
    82                                              request, f):
    83                 return self.failed_hash(request) # Security hash failed.
    84             return self.done(request, f.cleaned_data)
    85         else:
    86             return render_to_response(self.form_template,
    87                 self.get_context(request, f),
    88                 context_instance=RequestContext(request))
     105    def form_invalid(self, form):
     106        context = self._get_context_data(form)
     107        self.template_name = self.form_template
     108        return render_to_response(context)
    89109
    90110    # METHODS SUBCLASSES MIGHT OVERRIDE IF APPROPRIATE ########################
    91111
    class FormPreview(object):  
    96116        """
    97117        return AUTO_ID
    98118
    99     def get_initial(self, request):
     119    def get_initial(self, request=None):
    100120        """
    101121        Takes a request argument and returns a dictionary to pass to the form's
    102122        ``initial`` kwarg when the form is being created from an HTTP get.
    103123        """
    104         return {}
     124        return self.initial
    105125
    106126    def get_context(self, request, form):
    107127        "Context for template rendering."
    108         return {'form': form, 'stage_field': self.unused_name('stage'), 'state': self.state}
    109 
     128        context = {
     129            'form': form,
     130            'stage_field': self.unused_name('stage'),
     131            'state': self.state
     132        }
     133        return context
     134
     135    def get_form_kwargs(self):
     136        """ This is overriden to maintain backward compatibility and pass
     137        the request to to get_initial. """
     138        kwargs = {
     139            'initial': self.get_initial(self.request),
     140            'auto_id': self.get_auto_id()
     141        }
     142        if self.request.method in ('POST', 'PUT'):
     143            kwargs.update({
     144                'data': self.request.POST,
     145                'files': self.request.FILES,
     146            })
     147        return kwargs
    110148
    111149    def parse_params(self, *args, **kwargs):
    112150        """
  • django/contrib/formtools/tests/__init__.py

    diff --git a/django/contrib/formtools/tests/__init__.py b/django/contrib/formtools/tests/__init__.py
    index 7084386..8c5fbb5 100644
    a b warnings.filterwarnings('ignore', category=PendingDeprecationWarning,  
    1717
    1818success_string = "Done was called!"
    1919
     20
    2021class TestFormPreview(preview.FormPreview):
    2122    def get_context(self, request, form):
    2223        context = super(TestFormPreview, self).get_context(request, form)
    2324        context.update({'custom_context': True})
    2425        return context
    2526
     27    def get_context_data(self, **kwargs):
     28        context = super(TestFormPreview, self).get_context_data(**kwargs)
     29        context['more_custom_context'] = True
     30        return context
     31
    2632    def get_initial(self, request):
    2733        return {'field1': 'Works!'}
    2834
    class PreviewTests(TestCase):  
    6773        stage = self.input % 1
    6874        self.assertContains(response, stage, 1)
    6975        self.assertEqual(response.context['custom_context'], True)
     76        self.assertEqual(response.context['more_custom_context'], True)
    7077        self.assertEqual(response.context['form'].initial, {'field1': 'Works!'})
    7178
    7279    def test_form_preview(self):
    class PreviewTests(TestCase):  
    8693        stage = self.input % 2
    8794        self.assertContains(response, stage, 1)
    8895
     96        # Check that the correct context was passed to the template
     97        self.assertEqual(response.context['custom_context'], True)
     98        self.assertEqual(response.context['more_custom_context'], True)
     99
    89100    def test_form_submit(self):
    90101        """
    91102        Test contrib.formtools.preview form submittal.
    class PreviewTests(TestCase):  
    140151        response = self.client.post('/preview/', self.test_data)
    141152        self.assertEqual(response.content, success_string)
    142153
    143 
    144154    def test_form_submit_bad_hash(self):
    145155        """
    146156        Test contrib.formtools.preview form submittal does not proceed
    class PreviewTests(TestCase):  
    154164        self.assertNotEqual(response.content, success_string)
    155165        hash = utils.form_hmac(TestForm(self.test_data)) + "bad"
    156166        self.test_data.update({'hash': hash})
    157         response = self.client.post('/previewpreview/', self.test_data)
     167        response = self.client.post('/preview/', self.test_data)
     168        self.assertTemplateUsed(response, 'formtools/preview.html')
    158169        self.assertNotEqual(response.content, success_string)
    159170
    160171
Back to Top