Django

Code

Ticket #2650: xml_serializer.py

File xml_serializer.py, 8.0 kB (added by ben.khoo@calytrix.com, 2 years ago)

Modification to the XML serializer to correctly encode the tag

Line 
1 """
2 XML serializer.
3 """
4
5 from django.conf import settings
6 from django.core.serializers import base
7 from django.db import models
8 from django.utils.xmlutils import SimplerXMLGenerator
9 from xml.dom import pulldom
10
11 class Serializer(base.Serializer):
12     """
13     Serializes a QuerySet to XML.
14     """
15
16     def start_serialization(self):
17         """
18         Start serialization -- open the XML document and the root element.
19         """
20         self.xml = SimplerXMLGenerator(self.stream, self.options.get("encoding", settings.DEFAULT_CHARSET))
21         self.xml.startDocument()
22         self.xml.startElement("django-objects", {"version" : "1.0"})
23
24     def end_serialization(self):
25         """
26         End serialization -- end the document.
27         """
28         self.xml.endElement("django-objects")
29         self.xml.endDocument()
30
31     def start_object(self, obj):
32         """
33         Called as each object is handled.
34         """
35         if not hasattr(obj, "_meta"):
36             raise base.SerializationError("Non-model object (%s) encountered during serialization" % type(obj))
37
38         self.xml.startElement("object", {
39             "pk"    : str(obj._get_pk_val()),
40             "model" : str(obj._meta),
41         })
42
43     def end_object(self, obj):
44         """
45         Called after handling all fields for an object.
46         """
47         self.xml.endElement("object")
48
49     def handle_field(self, obj, field):
50         """
51         Called to handle each field on an object (except for ForeignKeys and
52         ManyToManyFields)
53         """
54         self.xml.startElement("field", {
55             "name" : field.name,
56             "type" : field.get_internal_type()
57         })
58
59         # Get a "string version" of the object's data (this is handled by the
60         # serializer base class).  None is handled specially.
61         value = self.get_string_value(obj, field)
62         if value is not None:
63             self.xml.characters(str(value))
64         else:
65             self.xml.addQuickElement("None")
66
67         self.xml.endElement("field")
68
69     def handle_fk_field(self, obj, field):
70         """
71         Called to handle a ForeignKey (we need to treat them slightly
72         differently from regular fields).
73         """
74         self._start_relational_field(field)
75         related = getattr(obj, field.name)
76         if related is not None:
77             self.xml.characters(str(related._get_pk_val()))
78         else:
79             self.xml.addQuickElement("None")
80         self.xml.endElement("field")
81
82     def handle_m2m_field(self, obj, field):
83         """
84         Called to handle a ManyToManyField. Related objects are only
85         serialized as references to the object's PK (i.e. the related *data*
86         is not dumped, just the relation).
87         """
88         self._start_relational_field(field)
89         for relobj in getattr(obj, field.name).iterator():
90             self.xml.addQuickElement("object", attrs={"pk" : str(relobj._get_pk_val())})
91         self.xml.endElement("field")
92
93     def _start_relational_field(self, field):
94         """
95         Helper to output the <field> element for relational fields
96         """
97         self.xml.startElement("field", {
98             "name" : field.name,
99             "rel"  : field.rel.__class__.__name__,
100             "to"   : str(field.rel.to._meta),
101         })
102
103 class Deserializer(base.Deserializer):
104     """
105     Deserialize XML.
106     """
107
108     def __init__(self, stream_or_string, **options):
109         super(Deserializer, self).__init__(stream_or_string, **options)
110         self.encoding = self.options.get("encoding", settings.DEFAULT_CHARSET)
111         self.event_stream = pulldom.parse(self.stream)
112
113     def next(self):
114         for event, node in self.event_stream:
115             if event == "START_ELEMENT" and node.nodeName == "object":
116                 self.event_stream.expandNode(node)
117                 return self._handle_object(node)
118         raise StopIteration
119
120     def _handle_object(self, node):
121         """
122         Convert an <object> node to a DeserializedObject.
123         """
124         # Look up the model using the model loading mechanism. If this fails, bail.
125         Model = self._get_model_from_node(node, "model")
126
127         # Start building a data dictionary from the object.  If the node is
128         # missing the pk attribute, bail.
129         pk = node.getAttribute("pk")
130         if not pk:
131             raise base.DeserializationError("<object> node is missing the 'pk' attribute")
132         data = {Model._meta.pk.name : pk}
133
134         # Also start building a dict of m2m data (this is saved as
135         # {m2m_accessor_attribute : [list_of_related_objects]})
136         m2m_data = {}
137
138         # Deseralize each field.
139         for field_node in node.getElementsByTagName("field"):
140             # If the field is missing the name attribute, bail (are you
141             # sensing a pattern here?)
142             field_name = field_node.getAttribute("name")
143             if not field_name:
144                 raise base.DeserializationError("<field> node is missing the 'name' attribute")
145
146             # Get the field from the Model. This will raise a
147             # FieldDoesNotExist if, well, the field doesn't exist, which will
148             # be propagated correctly.
149             field = Model._meta.get_field(field_name)
150
151             # As is usually the case, relation fields get the special treatment.
152             if field.rel and isinstance(field.rel, models.ManyToManyRel):
153                 m2m_data[field.name] = self._handle_m2m_field_node(field_node)
154             elif field.rel and isinstance(field.rel, models.ManyToOneRel):
155                 data[field.name] = self._handle_fk_field_node(field_node)
156             else:
157                 value = field.to_python(getInnerText(field_node).strip().encode(self.encoding))
158                 data[field.name] = value
159
160         # Return a DeserializedObject so that the m2m data has a place to live.
161         return base.DeserializedObject(Model(**data), m2m_data)
162
163     def _handle_fk_field_node(self, node):
164         """
165         Handle a <field> node for a ForeignKey
166         """
167         # Try to set the foreign key by looking up the foreign related object.
168         # If it doesn't exist, set the field to None (which might trigger
169         # validation error, but that's expected).
170         RelatedModel = self._get_model_from_node(node, "to")
171         return RelatedModel.objects.get(pk=getInnerText(node).strip().encode(self.encoding))
172
173     def _handle_m2m_field_node(self, node):
174         """
175         Handle a <field> node for a ManyToManyField
176         """
177         # Load the related model
178         RelatedModel = self._get_model_from_node(node, "to")
179
180         # Look up all the related objects. Using the in_bulk() lookup ensures
181         # that missing related objects don't cause an exception
182         related_ids = [c.getAttribute("pk").encode(self.encoding) for c in node.getElementsByTagName("object")]
183         return RelatedModel._default_manager.in_bulk(related_ids).values()
184
185     def _get_model_from_node(self, node, attr):
186         """
187         Helper to look up a model from a <object model=...> or a <field
188         rel=... to=...> node.
189         """
190         model_identifier = node.getAttribute(attr)
191         if not model_identifier:
192             raise base.DeserializationError(
193                 "<%s> node is missing the required '%s' attribute" \
194                     % (node.nodeName, attr))
195         try:
196             Model = models.get_model(*model_identifier.split("."))
197         except TypeError:
198             Model = None
199         if Model is None:
200             raise base.DeserializationError(
201                 "<%s> node has invalid model identifier: '%s'" % \
202                     (node.nodeName, model_identifier))
203         return Model
204
205
206 def getInnerText(node):
207     """
208     Get all the inner text of a DOM node (recursively).
209     """
210     # inspired by http://mail.python.org/pipermail/xml-sig/2005-March/011022.html
211     inner_text = []
212     for child in node.childNodes:
213         if child.nodeType == child.TEXT_NODE or child.nodeType == child.CDATA_SECTION_NODE:
214             inner_text.append(child.data)
215         elif child.nodeType == child.ELEMENT_NODE:
216             inner_text.extend(getInnerText(child))
217         else:
218            pass
219     return "".join(inner_text)