Index: django/core/template/__init__.py
===================================================================
--- django/core/template/__init__.py	(revision 1785)
+++ django/core/template/__init__.py	(working copy)
@@ -775,7 +775,7 @@
             raise
         return self.encode_output(output)
 
-def generic_tag_compiler(params, defaults, name, node_class, parser, token):
+def generic_tag_compiler(params, defaults, name, node_class, parser, token, takes_context=False, takes_block=False):
     "Returns a template.Node subclass."
     bits = token.contents.split()[1:]
     bmax = len(params)
@@ -787,6 +787,12 @@
         else:
             message = "%s takes between %s and %s arguments" % (name, bmin, bmax)
         raise TemplateSyntaxError, message
+    if takes_context:
+        node_class = curry(node_class, takes_context=takes_context)
+    if takes_block:
+        nodelist = parser.parse(('end' + name,))
+        parser.delete_first_token()
+        node_class = curry(node_class, block_nodelist=nodelist)
     return node_class(bits)
 
 class Library(object):
@@ -842,18 +848,32 @@
         self.filters[func.__name__] = func
         return func
 
-    def simple_tag(self,func):
-        (params, xx, xxx, defaults) = getargspec(func)
+    def simple_tag(self, compile_function=None, takes_block=False, takes_context=False):
+        if compile_function == None:
+            return curry(self.simple_tag_function, takes_block=takes_block, takes_context=takes_context)
+        elif callable(compile_function):
+            return self.simple_tag_function(compile_function, takes_block=takes_block, takes_context=takes_context)
+        else:
+            raise InvalidTemplateLibrary, "Unsupported argument to Library.simple_tag: (%r)", (compile_function,)
 
+    def simple_tag_function(self, func, takes_block=False, takes_context=False):
         class SimpleNode(Node):
-            def __init__(self, vars_to_resolve):
+            def __init__(self, vars_to_resolve, takes_context=False, block_nodelist=None):
                 self.vars_to_resolve = vars_to_resolve
+                self.takes_context, self.block_nodelist = takes_context, block_nodelist
 
             def render(self, context):
                 resolved_vars = [resolve_variable(var, context) for var in self.vars_to_resolve]
-                return func(*resolved_vars)
+                if self.block_nodelist:
+                    resolved_vars.insert(0, self.block_nodelist)
+                if self.takes_context:
+                    resolved_vars.insert(0, context)
+                rendered = func(*resolved_vars)
+                return rendered or ''
 
-        compile_func = curry(generic_tag_compiler, params, defaults, func.__name__, SimpleNode)
+        (params, xx, xxx, defaults) = getargspec(func)
+        taken_args = sum([takes_block, takes_context])
+        compile_func = curry(generic_tag_compiler, params[taken_args:], defaults, func.__name__, SimpleNode, takes_block=takes_block, takes_context=takes_context)
         compile_func.__doc__ = func.__doc__
         self.tag(func.__name__, compile_func)
         return func
