Ticket #2650: xml_serializer.py

File xml_serializer.py, 8.0 KB (added by ben.khoo@…, 13 years ago)

Modification to the XML serializer to correctly encode the tag

Line 
1"""
2XML serializer.
3"""
4
5from django.conf import settings
6from django.core.serializers import base
7from django.db import models
8from django.utils.xmlutils import SimplerXMLGenerator
9from xml.dom import pulldom
10
11class Serializer(base.Serializer):
12    """
13    Serializes a QuerySet to XML.
14    """
15
16    def start_serialization(self):
17        """
18        Start serialization -- open the XML document and the root element.
19        """
20        self.xml = SimplerXMLGenerator(self.stream, self.options.get("encoding", settings.DEFAULT_CHARSET))
21        self.xml.startDocument()
22        self.xml.startElement("django-objects", {"version" : "1.0"})
23
24    def end_serialization(self):
25        """
26        End serialization -- end the document.
27        """
28        self.xml.endElement("django-objects")
29        self.xml.endDocument()
30
31    def start_object(self, obj):
32        """
33        Called as each object is handled.
34        """
35        if not hasattr(obj, "_meta"):
36            raise base.SerializationError("Non-model object (%s) encountered during serialization" % type(obj))
37
38        self.xml.startElement("object", {
39            "pk"    : str(obj._get_pk_val()),
40            "model" : str(obj._meta),
41        })
42
43    def end_object(self, obj):
44        """
45        Called after handling all fields for an object.
46        """
47        self.xml.endElement("object")
48
49    def handle_field(self, obj, field):
50        """
51        Called to handle each field on an object (except for ForeignKeys and
52        ManyToManyFields)
53        """
54        self.xml.startElement("field", {
55            "name" : field.name,
56            "type" : field.get_internal_type()
57        })
58
59        # Get a "string version" of the object's data (this is handled by the
60        # serializer base class).  None is handled specially.
61        value = self.get_string_value(obj, field)
62        if value is not None:
63            self.xml.characters(str(value))
64        else:
65            self.xml.addQuickElement("None")
66
67        self.xml.endElement("field")
68
69    def handle_fk_field(self, obj, field):
70        """
71        Called to handle a ForeignKey (we need to treat them slightly
72        differently from regular fields).
73        """
74        self._start_relational_field(field)
75        related = getattr(obj, field.name)
76        if related is not None:
77            self.xml.characters(str(related._get_pk_val()))
78        else:
79            self.xml.addQuickElement("None")
80        self.xml.endElement("field")
81
82    def handle_m2m_field(self, obj, field):
83        """
84        Called to handle a ManyToManyField. Related objects are only
85        serialized as references to the object's PK (i.e. the related *data*
86        is not dumped, just the relation).
87        """
88        self._start_relational_field(field)
89        for relobj in getattr(obj, field.name).iterator():
90            self.xml.addQuickElement("object", attrs={"pk" : str(relobj._get_pk_val())})
91        self.xml.endElement("field")
92
93    def _start_relational_field(self, field):
94        """
95        Helper to output the <field> element for relational fields
96        """
97        self.xml.startElement("field", {
98            "name" : field.name,
99            "rel"  : field.rel.__class__.__name__,
100            "to"   : str(field.rel.to._meta),
101        })
102
103class Deserializer(base.Deserializer):
104    """
105    Deserialize XML.
106    """
107
108    def __init__(self, stream_or_string, **options):
109        super(Deserializer, self).__init__(stream_or_string, **options)
110        self.encoding = self.options.get("encoding", settings.DEFAULT_CHARSET)
111        self.event_stream = pulldom.parse(self.stream)
112
113    def next(self):
114        for event, node in self.event_stream:
115            if event == "START_ELEMENT" and node.nodeName == "object":
116                self.event_stream.expandNode(node)
117                return self._handle_object(node)
118        raise StopIteration
119
120    def _handle_object(self, node):
121        """
122        Convert an <object> node to a DeserializedObject.
123        """
124        # Look up the model using the model loading mechanism. If this fails, bail.
125        Model = self._get_model_from_node(node, "model")
126
127        # Start building a data dictionary from the object.  If the node is
128        # missing the pk attribute, bail.
129        pk = node.getAttribute("pk")
130        if not pk:
131            raise base.DeserializationError("<object> node is missing the 'pk' attribute")
132        data = {Model._meta.pk.name : pk}
133
134        # Also start building a dict of m2m data (this is saved as
135        # {m2m_accessor_attribute : [list_of_related_objects]})
136        m2m_data = {}
137
138        # Deseralize each field.
139        for field_node in node.getElementsByTagName("field"):
140            # If the field is missing the name attribute, bail (are you
141            # sensing a pattern here?)
142            field_name = field_node.getAttribute("name")
143            if not field_name:
144                raise base.DeserializationError("<field> node is missing the 'name' attribute")
145
146            # Get the field from the Model. This will raise a
147            # FieldDoesNotExist if, well, the field doesn't exist, which will
148            # be propagated correctly.
149            field = Model._meta.get_field(field_name)
150
151            # As is usually the case, relation fields get the special treatment.
152            if field.rel and isinstance(field.rel, models.ManyToManyRel):
153                m2m_data[field.name] = self._handle_m2m_field_node(field_node)
154            elif field.rel and isinstance(field.rel, models.ManyToOneRel):
155                data[field.name] = self._handle_fk_field_node(field_node)
156            else:
157                value = field.to_python(getInnerText(field_node).strip().encode(self.encoding))
158                data[field.name] = value
159
160        # Return a DeserializedObject so that the m2m data has a place to live.
161        return base.DeserializedObject(Model(**data), m2m_data)
162
163    def _handle_fk_field_node(self, node):
164        """
165        Handle a <field> node for a ForeignKey
166        """
167        # Try to set the foreign key by looking up the foreign related object.
168        # If it doesn't exist, set the field to None (which might trigger
169        # validation error, but that's expected).
170        RelatedModel = self._get_model_from_node(node, "to")
171        return RelatedModel.objects.get(pk=getInnerText(node).strip().encode(self.encoding))
172
173    def _handle_m2m_field_node(self, node):
174        """
175        Handle a <field> node for a ManyToManyField
176        """
177        # Load the related model
178        RelatedModel = self._get_model_from_node(node, "to")
179
180        # Look up all the related objects. Using the in_bulk() lookup ensures
181        # that missing related objects don't cause an exception
182        related_ids = [c.getAttribute("pk").encode(self.encoding) for c in node.getElementsByTagName("object")]
183        return RelatedModel._default_manager.in_bulk(related_ids).values()
184
185    def _get_model_from_node(self, node, attr):
186        """
187        Helper to look up a model from a <object model=...> or a <field
188        rel=... to=...> node.
189        """
190        model_identifier = node.getAttribute(attr)
191        if not model_identifier:
192            raise base.DeserializationError(
193                "<%s> node is missing the required '%s' attribute" \
194                    % (node.nodeName, attr))
195        try:
196            Model = models.get_model(*model_identifier.split("."))
197        except TypeError:
198            Model = None
199        if Model is None:
200            raise base.DeserializationError(
201                "<%s> node has invalid model identifier: '%s'" % \
202                    (node.nodeName, model_identifier))
203        return Model
204
205
206def getInnerText(node):
207    """
208    Get all the inner text of a DOM node (recursively).
209    """
210    # inspired by http://mail.python.org/pipermail/xml-sig/2005-March/011022.html
211    inner_text = []
212    for child in node.childNodes:
213        if child.nodeType == child.TEXT_NODE or child.nodeType == child.CDATA_SECTION_NODE:
214            inner_text.append(child.data)
215        elif child.nodeType == child.ELEMENT_NODE:
216            inner_text.extend(getInnerText(child))
217        else:
218           pass
219    return "".join(inner_text)
Back to Top