Code

Ticket #6510: 6510.2.diff

File 6510.2.diff, 5.3 KB (added by SmileyChris, 4 years ago)
Line 
1diff --git a/django/template/__init__.py b/django/template/__init__.py
2index 7fb01f0..33ea0aa 100644
3--- a/django/template/__init__.py
4+++ b/django/template/__init__.py
5@@ -770,6 +770,7 @@ class Node(object):
6     # Set this to True for nodes that must be first in the template (although
7     # they can be preceded by text nodes.
8     must_be_first = False
9+    child_nodelists = ('nodelist',)
10 
11     def render(self, context):
12         "Return the node rendered as a string"
13@@ -783,8 +784,10 @@ class Node(object):
14         nodes = []
15         if isinstance(self, nodetype):
16             nodes.append(self)
17-        if hasattr(self, 'nodelist'):
18-            nodes.extend(self.nodelist.get_nodes_by_type(nodetype))
19+        for attr in self.child_nodelists:
20+            nodelist = getattr(self, attr, None)
21+            if nodelist:
22+                nodes.extend(nodelist.get_nodes_by_type(nodetype))
23         return nodes
24 
25 class NodeList(list):
26diff --git a/django/template/defaulttags.py b/django/template/defaulttags.py
27index 69afd84..4b0a4d1 100644
28--- a/django/template/defaulttags.py
29+++ b/django/template/defaulttags.py
30@@ -97,6 +97,8 @@ class FirstOfNode(Node):
31         return u''
32 
33 class ForNode(Node):
34+    child_nodelists = ('nodelist_loop', 'nodelist_empty')
35+
36     def __init__(self, loopvars, sequence, is_reversed, nodelist_loop, nodelist_empty=None):
37         self.loopvars, self.sequence = loopvars, sequence
38         self.is_reversed = is_reversed
39@@ -118,14 +120,6 @@ class ForNode(Node):
40         for node in self.nodelist_empty:
41             yield node
42 
43-    def get_nodes_by_type(self, nodetype):
44-        nodes = []
45-        if isinstance(self, nodetype):
46-            nodes.append(self)
47-        nodes.extend(self.nodelist_loop.get_nodes_by_type(nodetype))
48-        nodes.extend(self.nodelist_empty.get_nodes_by_type(nodetype))
49-        return nodes
50-
51     def render(self, context):
52         if 'forloop' in context:
53             parentloop = context['forloop']
54@@ -181,6 +175,8 @@ class ForNode(Node):
55         return nodelist.render(context)
56 
57 class IfChangedNode(Node):
58+    child_nodelists = ('nodelist_true', 'nodelist_false')
59+
60     def __init__(self, nodelist_true, nodelist_false, *varlist):
61         self.nodelist_true, self.nodelist_false = nodelist_true, nodelist_false
62         self._last_seen = None
63@@ -211,6 +207,8 @@ class IfChangedNode(Node):
64         return ''
65 
66 class IfEqualNode(Node):
67+    child_nodelists = ('nodelist_true', 'nodelist_false')
68+
69     def __init__(self, var1, var2, nodelist_true, nodelist_false, negate):
70         self.var1, self.var2 = var1, var2
71         self.nodelist_true, self.nodelist_false = nodelist_true, nodelist_false
72@@ -227,6 +225,8 @@ class IfEqualNode(Node):
73         return self.nodelist_false.render(context)
74 
75 class IfNode(Node):
76+    child_nodelists = ('nodelist_true', 'nodelist_false')
77+
78     def __init__(self, var, nodelist_true, nodelist_false=None):
79         self.nodelist_true, self.nodelist_false = nodelist_true, nodelist_false
80         self.var = var
81@@ -240,14 +240,6 @@ class IfNode(Node):
82         for node in self.nodelist_false:
83             yield node
84 
85-    def get_nodes_by_type(self, nodetype):
86-        nodes = []
87-        if isinstance(self, nodetype):
88-            nodes.append(self)
89-        nodes.extend(self.nodelist_true.get_nodes_by_type(nodetype))
90-        nodes.extend(self.nodelist_false.get_nodes_by_type(nodetype))
91-        return nodes
92-
93     def render(self, context):
94         if self.var.eval(context):
95             return self.nodelist_true.render(context)
96diff --git a/tests/regressiontests/templates/nodelist.py b/tests/regressiontests/templates/nodelist.py
97new file mode 100644
98index 0000000..89fac97
99--- /dev/null
100+++ b/tests/regressiontests/templates/nodelist.py
101@@ -0,0 +1,30 @@
102+from unittest import TestCase
103+from django.template.loader import get_template_from_string
104+from django.template import VariableNode
105+
106+
107+class NodelistTest(TestCase):
108+
109+    def test_for(self):
110+        source = '{% for i in 1 %}{{ a }}{% endfor %}'
111+        template = get_template_from_string(source)
112+        vars = template.nodelist.get_nodes_by_type(VariableNode)
113+        self.assertEqual(len(vars), 1)
114+
115+    def test_if(self):
116+        source = '{% if x %}{{ a }}{% endif %}'
117+        template = get_template_from_string(source)
118+        vars = template.nodelist.get_nodes_by_type(VariableNode)
119+        self.assertEqual(len(vars), 1)
120+
121+    def test_ifequal(self):
122+        source = '{% ifequal x y %}{{ a }}{% endifequal %}'
123+        template = get_template_from_string(source)
124+        vars = template.nodelist.get_nodes_by_type(VariableNode)
125+        self.assertEqual(len(vars), 1)
126+
127+    def test_ifchanged(self):
128+        source = '{% ifchanged x %}{{ a }}{% endifchanged %}'
129+        template = get_template_from_string(source)
130+        vars = template.nodelist.get_nodes_by_type(VariableNode)
131+        self.assertEqual(len(vars), 1)
132diff --git a/tests/regressiontests/templates/tests.py b/tests/regressiontests/templates/tests.py
133index 31c9e24..3e7b7fb 100644
134--- a/tests/regressiontests/templates/tests.py
135+++ b/tests/regressiontests/templates/tests.py
136@@ -24,6 +24,7 @@ from context import context_tests
137 from custom import custom_filters
138 from parser import token_parsing, filter_parsing, variable_parsing
139 from unicode import unicode_tests
140+from nodelist import NodelistTest
141 from smartif import *
142 
143 try: