Code

Ticket #5253: csvserializer.diff

File csvserializer.diff, 13.4 KB (added by Adam Schmideg, 7 years ago)

A working implementation with minor limitations

Line 
1Index: tests.py
2===================================================================
3--- tests.py    (revision 0)
4+++ tests.py    (revision 0)
5@@ -0,0 +1,141 @@
6+
7+from cStringIO import StringIO
8+
9+from django.core.serializers.base import DeserializationError
10+from django.db import models
11+from django.test import TestCase
12+
13+from serializers import csvserializer
14+
15+class Author(models.Model):
16+    first_name = models.TextField(null=True, default=None)
17+    last_name = models.TextField()
18+
19+    def __unicode__(self):
20+        if self.first_name:
21+            return '%s %s' % (self.first_name, self.last_name)
22+        else:
23+            return self.last_name
24+
25+    class Meta:
26+        unique_together = (('first_name','last_name'),)
27+
28+class Book(models.Model):
29+    author = models.ForeignKey(Author)
30+    title = models.TextField()
31+    copies = models.IntegerField(default=0)
32+
33+    def __unicode__(self):
34+        return '%s: %s' % (self.author, self.title)
35+   
36+    class Meta:
37+        unique_together = (('author', 'title'),)
38+
39+class Loaner(models.Model):
40+    name = models.TextField(unique=True)
41+
42+class Card(models.Model):
43+    number = models.IntegerField()
44+    books = models.ManyToManyField(Book)
45+    owner = models.ForeignKey(Loaner)
46+
47+    def __unicode__(self):
48+        return '%i' % self.number
49+
50+class HalfEmpty(models.Model):
51+    nullable = models.CharField(maxlength=10, null=True)
52+    not_nullable = models.CharField(maxlength=10)
53+
54+
55+class CsvTest(TestCase):
56+
57+    def setUp(self):
58+        #csvserializer.nice_foreign_keys = False
59+        def save_list(list):
60+            for o in list:
61+                o.save()
62+
63+        self.authors = (
64+          Author(last_name='Dante'),
65+          Author(last_name='Milton', first_name='John'),
66+          Author(last_name='Milton', first_name='Different'),
67+        )
68+        save_list(self.authors)
69+        self.books = (
70+          Book(author=self.authors[0], title='Paradise'),
71+          Book(author=self.authors[0], title='Hell'),
72+          Book(author=self.authors[1], title='Paradise'),
73+        )
74+        save_list(self.books)
75+        self.loaners = (Loaner(name='Tom'), Loaner(name='Liz'))
76+        save_list(self.loaners)
77+        self.cards = (
78+          Card(number=555,
79+            owner=self.loaners[0]),
80+          Card(number=666,
81+            owner=self.loaners[1]),
82+        )
83+        save_list(self.cards)
84+
85+        self.cards[0].books = [self.books[0], self.books[1]]
86+        self.cards[1].books = [self.books[1]]
87+
88+    def tearDown(self):
89+        for mod in (Author, Book, Loaner, Card):
90+            mod.objects.all().delete()
91+
92+    def test_resolve_related(self):
93+        field = Book._meta.get_field('author')
94+        dante = self.authors[0]
95+        def assert_found(obj, field, value):
96+            self.assertEquals(obj,
97+              csvserializer.resolve_related(field, value))
98+        def assert_not_found(field, value):
99+            try:
100+                csvserializer.resolve_related(field, value)
101+                self.fail('Expected DeserializationError')
102+            except DeserializationError:
103+                pass
104+
105+        assert_found(dante, field, 1)
106+        assert_found(dante, field, {'last_name': 'Dante'})
107+        assert_found(dante, field, {'last_name': 'Dante', 'first_name': None})
108+        assert_not_found(field, 999)
109+        assert_not_found(field, {'last_name': 'Milton'})
110+
111+    def test_write_read(self):
112+
113+        def check(sequence):
114+            ser = csvserializer.Serializer()
115+            csv_result = ser.serialize(sequence)
116+
117+            ser = csvserializer.Deserializer(StringIO(csv_result))
118+            deserialized = [o for o in ser]
119+
120+            self.assertEquals(len(sequence), len(deserialized))
121+            for i in range(len(sequence)):
122+                self.assertEquals(sequence[i], deserialized[i].object)
123+
124+            for o in sequence:
125+                o.delete()
126+            for o in deserialized:
127+                o.save()
128+
129+        check(self.cards)
130+
131+
132+class CsvEmptyFieldTest(TestCase):
133+
134+    def test_it(self):
135+        half_empty = HalfEmpty(not_nullable='')
136+        half_empty.save()
137+        ser = csvserializer.Serializer()
138+        csv_result = ser.serialize([half_empty])
139+        lines = csv_result.splitlines()
140+        #self.assertEquals('1,,', lines[1])
141+
142+        ser = csvserializer.Deserializer(StringIO(csv_result))
143+        deserialized = [o for o in ser]
144+        new_half_empty = deserialized[0].object
145+        self.assertEquals(None, new_half_empty.nullable)
146+        self.assertEquals('', new_half_empty.not_nullable)
147Index: csvserializer.py
148===================================================================
149--- csvserializer.py    (revision 0)
150+++ csvserializer.py    (revision 0)
151@@ -0,0 +1,229 @@
152+
153+"""
154+Does serialization to and from a single csv file using the standard csv
155+module.
156+
157+Features:
158+- Many-to-many data is serialized as a list of ids or dicts.
159+- Both None, and empty string is serialized as an empty string.  When
160+  deserializing, the empty string becomes None for nullable fields,
161+  empty string for not nullable fields.
162+- If nice_foreign_keys is True, foreign keys get serialized as possibly
163+  nested dictionaries of unique fields.  For example if we have
164
165+    class Article(Model):
166+        owner = ForeignKey(User)
167+        text = TextField()
168+
169+  It will be serialized as
170+
171+    --- article: id,owner,text
172+    1,{'username': u'john'},Bla bla
173+
174+Limitations:
175+- Unicode data is not supported (yet).
176+- The nice_foreign_keys feature doesn't work yet with unique foreign
177+  keys.
178+- Comments are not supported in the csv file -- well, this is rather a
179+  limitation of the csv module.
180+
181+Configuration:
182+- To enable --format=csv for manage dumpdata/loaddata, add the following
183+  lines to your settings.py:
184+
185+    SERIALIZATION_MODULES = {'csv': 'path.to.csvserializer'}
186+
187+- To configure it, add the following lines, too:
188+
189+    CSV_SERIALIZER_OPTIONS = {
190+        'dialect': 'excel',
191+        'nice_foreign_keys': True,
192+    }
193+
194+    The given dialect should be registered to the csv module (defaults
195+    to the pre-defined 'excel').
196+"""
197+import csv
198+
199+from django.conf import settings
200+from django.core.serializers import base
201+from django.db import models
202+from django.db.models.fields import IntegerField, FloatField, Field
203+from django.utils.encoding import smart_unicode
204+
205+
206+#__all__ = ('Serializer', 'Deserializer')
207+
208+# options
209+dialect = 'excel'
210+nice_foreign_keys = True
211+header_start = '--- '
212+globals().update(getattr(settings, 'CSV_SERIALIZER_OPTIONS', {}))
213+
214+def resolve_related(field, value):
215+    """Return a model object for `field` represented as `value` where
216+    the representation was produced by Serializer.represent()."""
217+    if value is None:
218+        return None
219+    else:
220+        model = field.rel.to
221+        if isinstance(value, int):
222+            query = {'pk': value}
223+        elif isinstance(value, dict):
224+            query = {}
225+            # value dict may be recursive
226+            for k, v in value.iteritems():
227+                if isinstance(v, dict):
228+                    query[k] = resolve_related(model._meta.get_field(k), v)
229+                elif v is None:
230+                    query['%s__isnull' % k] = True
231+                else:
232+                    query[k] = v
233+        else:
234+            raise base.DeserializationError('Expected int or dict, got %s of %s' % \
235+              (value, type(value)))
236+        objects = model.objects.complex_filter(query)
237+        if len(objects) == 1:
238+            return objects[0]
239+        else:
240+            raise base.DeserializationError('%i instance(s) of %s for %s' % \
241+              (len(objects), model, value))
242+
243+def pk_or_none(obj):
244+    if obj:
245+        return obj._get_pk_val()
246+    else:
247+        return None
248+       
249+
250+class Serializer(base.Serializer):
251+
252+    def start_serialization(self):
253+        self.last_model = None
254+        self.output = csv.writer(self.stream, dialect=dialect)
255+
256+    def start_object(self, obj):
257+        if not hasattr(obj, "_meta"):
258+            raise base.SerializationError("Non-model object (%s) encountered during serialization" % type(obj))
259+        if self.last_model != obj._meta:
260+            self.last_model = obj._meta
261+            header = []
262+            for field in obj._meta.fields:
263+                header.append(field.name)
264+            for field in obj._meta.many_to_many:
265+                header.append(field.name)
266+            header[0] = '%s%s:%s' % (header_start, obj._meta, header[0])
267+            self.writerow(header)
268+        self.row = [self.tostring(obj._get_pk_val())]
269+
270+    def end_object(self, obj):
271+        self.writerow(self.row)
272+
273+    def handle_field(self, obj, field):
274+        if getattr(obj, field.name) is not None:
275+            value = self.get_string_value(obj, field)
276+            self.row.append(self.tostring(value))
277+        else:
278+            self.row.append('')
279+
280+    def handle_fk_field(self, obj, field):
281+        related = self.represent(getattr(obj, field.name))
282+        self.row.append(self.tostring(related))
283+
284+    def handle_m2m_field(self, obj, field):
285+        """A tuple of m2m representations as dicts or ids"""
286+        related = [self.represent(related) for related in getattr(obj, field.name).iterator()]
287+        self.row.append(self.tostring(tuple(related)))
288+
289+    def tostring(self, value):
290+        s = smart_unicode(value)
291+        return s or ''
292+
293+    def represent(self, related):
294+        """Represent a model object either as its pk as int, or as a
295+        dict of unique key-values pairs recursevily."""
296+        if related is None:
297+            return None
298+        elif nice_foreign_keys:
299+            # Find a compound key
300+            if related._meta.unique_together:
301+                dict = {}
302+                for field_name in related._meta.unique_together[0]:
303+                    field = related._meta.get_field(field_name)
304+                    value = getattr(related, field.name)
305+                    if isinstance(value, models.Model):
306+                        dict[field_name] = self.represent(value)
307+                    else:
308+                        dict[field_name] = value
309+                return dict
310+            # Find a unique key
311+            else:
312+                for field in related._meta.fields:
313+                    if field.unique:
314+                        return {field.name: getattr(related, field.name)}
315+        return related._get_pk_val()
316+
317+    def writerow(self, row):
318+        self.output.writerow(row)
319+
320+
321+class Deserializer(base.Deserializer):
322+
323+    def __iter__(self):
324+        for values in csv.reader(self.stream, dialect=dialect):
325+            if values:
326+                if values[0].startswith(header_start):
327+                    # Model
328+                    model, first_field = values[0].split(':', 2)
329+                    model = model[len(header_start):]
330+                    try:
331+                        self.model = models.get_model(*model.split("."))
332+                    except TypeError:
333+                        raise base.DeserializationError("No model %s in db" % model)
334+                    # Field names
335+                    self.field_names = [first_field] + values[1:]
336+                else:
337+                    # An object
338+                    meta = self.model._meta
339+                    data = {meta.pk.attname: meta.pk.to_python(values[0])}
340+                    m2m_data = {}
341+                    for i in range(1, len(values)):
342+                        name = self.field_names[i]
343+                        value = values[i]
344+                        if value == '':
345+                            value = None
346+                        field = meta.get_field(name)
347+                        if field.rel and isinstance(field.rel, models.ManyToManyRel):
348+                            if value:
349+                                m2m_data[field.name] = \
350+                                  [pk_or_none(resolve_related(field, v)) for v in eval(value)]
351+                            else:
352+                                m2m_data[field.name] = []
353+                        elif field.rel and isinstance(field.rel, models.ManyToOneRel):
354+                            if value:
355+                                data[field.attname] = pk_or_none(resolve_related(field, eval(value)))
356+                        else:
357+                            if value == '""':
358+                                value = ''
359+                            value = self.to_python(field, value)
360+                            data[field.attname] = value
361+                    yield base.DeserializedObject(self.model(**data), m2m_data)
362+       
363+    def to_python(self, field, value):
364+        """
365+        The to_python method of some fields are not implemented, so this
366+        is a workaround for them.
367+        """
368+        if value is None and not field.null:
369+            value = ''
370+        value = field.to_python(value)
371+        if value is None:
372+            return None
373+        # XXX isinstance(field, IntegerField) doesn't seem to work
374+        elif field.__class__.__name__.endswith('IntegerField'):
375+            return int(value)
376+        elif field.__class__.__name__.endswith('FloatField'):
377+            return float(value)
378+        else:
379+            return value
380+