Ticket #5253: csvserializer.diff

File csvserializer.diff, 13.4 KB (added by Adam Schmideg, 8 years ago)

A working implementation with minor limitations

  • tests.py

     
     1
     2from cStringIO import StringIO
     3
     4from django.core.serializers.base import DeserializationError
     5from django.db import models
     6from django.test import TestCase
     7
     8from serializers import csvserializer
     9
     10class Author(models.Model):
     11    first_name = models.TextField(null=True, default=None)
     12    last_name = models.TextField()
     13
     14    def __unicode__(self):
     15        if self.first_name:
     16            return '%s %s' % (self.first_name, self.last_name)
     17        else:
     18            return self.last_name
     19
     20    class Meta:
     21        unique_together = (('first_name','last_name'),)
     22
     23class Book(models.Model):
     24    author = models.ForeignKey(Author)
     25    title = models.TextField()
     26    copies = models.IntegerField(default=0)
     27
     28    def __unicode__(self):
     29        return '%s: %s' % (self.author, self.title)
     30   
     31    class Meta:
     32        unique_together = (('author', 'title'),)
     33
     34class Loaner(models.Model):
     35    name = models.TextField(unique=True)
     36
     37class Card(models.Model):
     38    number = models.IntegerField()
     39    books = models.ManyToManyField(Book)
     40    owner = models.ForeignKey(Loaner)
     41
     42    def __unicode__(self):
     43        return '%i' % self.number
     44
     45class HalfEmpty(models.Model):
     46    nullable = models.CharField(maxlength=10, null=True)
     47    not_nullable = models.CharField(maxlength=10)
     48
     49
     50class CsvTest(TestCase):
     51
     52    def setUp(self):
     53        #csvserializer.nice_foreign_keys = False
     54        def save_list(list):
     55            for o in list:
     56                o.save()
     57
     58        self.authors = (
     59          Author(last_name='Dante'),
     60          Author(last_name='Milton', first_name='John'),
     61          Author(last_name='Milton', first_name='Different'),
     62        )
     63        save_list(self.authors)
     64        self.books = (
     65          Book(author=self.authors[0], title='Paradise'),
     66          Book(author=self.authors[0], title='Hell'),
     67          Book(author=self.authors[1], title='Paradise'),
     68        )
     69        save_list(self.books)
     70        self.loaners = (Loaner(name='Tom'), Loaner(name='Liz'))
     71        save_list(self.loaners)
     72        self.cards = (
     73          Card(number=555,
     74            owner=self.loaners[0]),
     75          Card(number=666,
     76            owner=self.loaners[1]),
     77        )
     78        save_list(self.cards)
     79
     80        self.cards[0].books = [self.books[0], self.books[1]]
     81        self.cards[1].books = [self.books[1]]
     82
     83    def tearDown(self):
     84        for mod in (Author, Book, Loaner, Card):
     85            mod.objects.all().delete()
     86
     87    def test_resolve_related(self):
     88        field = Book._meta.get_field('author')
     89        dante = self.authors[0]
     90        def assert_found(obj, field, value):
     91            self.assertEquals(obj,
     92              csvserializer.resolve_related(field, value))
     93        def assert_not_found(field, value):
     94            try:
     95                csvserializer.resolve_related(field, value)
     96                self.fail('Expected DeserializationError')
     97            except DeserializationError:
     98                pass
     99
     100        assert_found(dante, field, 1)
     101        assert_found(dante, field, {'last_name': 'Dante'})
     102        assert_found(dante, field, {'last_name': 'Dante', 'first_name': None})
     103        assert_not_found(field, 999)
     104        assert_not_found(field, {'last_name': 'Milton'})
     105
     106    def test_write_read(self):
     107
     108        def check(sequence):
     109            ser = csvserializer.Serializer()
     110            csv_result = ser.serialize(sequence)
     111
     112            ser = csvserializer.Deserializer(StringIO(csv_result))
     113            deserialized = [o for o in ser]
     114
     115            self.assertEquals(len(sequence), len(deserialized))
     116            for i in range(len(sequence)):
     117                self.assertEquals(sequence[i], deserialized[i].object)
     118
     119            for o in sequence:
     120                o.delete()
     121            for o in deserialized:
     122                o.save()
     123
     124        check(self.cards)
     125
     126
     127class CsvEmptyFieldTest(TestCase):
     128
     129    def test_it(self):
     130        half_empty = HalfEmpty(not_nullable='')
     131        half_empty.save()
     132        ser = csvserializer.Serializer()
     133        csv_result = ser.serialize([half_empty])
     134        lines = csv_result.splitlines()
     135        #self.assertEquals('1,,', lines[1])
     136
     137        ser = csvserializer.Deserializer(StringIO(csv_result))
     138        deserialized = [o for o in ser]
     139        new_half_empty = deserialized[0].object
     140        self.assertEquals(None, new_half_empty.nullable)
     141        self.assertEquals('', new_half_empty.not_nullable)
  • csvserializer.py

     
     1
     2"""
     3Does serialization to and from a single csv file using the standard csv
     4module.
     5
     6Features:
     7- Many-to-many data is serialized as a list of ids or dicts.
     8- Both None, and empty string is serialized as an empty string.  When
     9  deserializing, the empty string becomes None for nullable fields,
     10  empty string for not nullable fields.
     11- If nice_foreign_keys is True, foreign keys get serialized as possibly
     12  nested dictionaries of unique fields.  For example if we have
     13 
     14    class Article(Model):
     15        owner = ForeignKey(User)
     16        text = TextField()
     17
     18  It will be serialized as
     19 
     20    --- article: id,owner,text
     21    1,{'username': u'john'},Bla bla
     22
     23Limitations:
     24- Unicode data is not supported (yet).
     25- The nice_foreign_keys feature doesn't work yet with unique foreign
     26  keys.
     27- Comments are not supported in the csv file -- well, this is rather a
     28  limitation of the csv module.
     29
     30Configuration:
     31- To enable --format=csv for manage dumpdata/loaddata, add the following
     32  lines to your settings.py:
     33
     34    SERIALIZATION_MODULES = {'csv': 'path.to.csvserializer'}
     35
     36- To configure it, add the following lines, too:
     37
     38    CSV_SERIALIZER_OPTIONS = {
     39        'dialect': 'excel',
     40        'nice_foreign_keys': True,
     41    }
     42
     43    The given dialect should be registered to the csv module (defaults
     44    to the pre-defined 'excel').
     45"""
     46import csv
     47
     48from django.conf import settings
     49from django.core.serializers import base
     50from django.db import models
     51from django.db.models.fields import IntegerField, FloatField, Field
     52from django.utils.encoding import smart_unicode
     53
     54
     55#__all__ = ('Serializer', 'Deserializer')
     56
     57# options
     58dialect = 'excel'
     59nice_foreign_keys = True
     60header_start = '--- '
     61globals().update(getattr(settings, 'CSV_SERIALIZER_OPTIONS', {}))
     62
     63def resolve_related(field, value):
     64    """Return a model object for `field` represented as `value` where
     65    the representation was produced by Serializer.represent()."""
     66    if value is None:
     67        return None
     68    else:
     69        model = field.rel.to
     70        if isinstance(value, int):
     71            query = {'pk': value}
     72        elif isinstance(value, dict):
     73            query = {}
     74            # value dict may be recursive
     75            for k, v in value.iteritems():
     76                if isinstance(v, dict):
     77                    query[k] = resolve_related(model._meta.get_field(k), v)
     78                elif v is None:
     79                    query['%s__isnull' % k] = True
     80                else:
     81                    query[k] = v
     82        else:
     83            raise base.DeserializationError('Expected int or dict, got %s of %s' % \
     84              (value, type(value)))
     85        objects = model.objects.complex_filter(query)
     86        if len(objects) == 1:
     87            return objects[0]
     88        else:
     89            raise base.DeserializationError('%i instance(s) of %s for %s' % \
     90              (len(objects), model, value))
     91
     92def pk_or_none(obj):
     93    if obj:
     94        return obj._get_pk_val()
     95    else:
     96        return None
     97       
     98
     99class Serializer(base.Serializer):
     100
     101    def start_serialization(self):
     102        self.last_model = None
     103        self.output = csv.writer(self.stream, dialect=dialect)
     104
     105    def start_object(self, obj):
     106        if not hasattr(obj, "_meta"):
     107            raise base.SerializationError("Non-model object (%s) encountered during serialization" % type(obj))
     108        if self.last_model != obj._meta:
     109            self.last_model = obj._meta
     110            header = []
     111            for field in obj._meta.fields:
     112                header.append(field.name)
     113            for field in obj._meta.many_to_many:
     114                header.append(field.name)
     115            header[0] = '%s%s:%s' % (header_start, obj._meta, header[0])
     116            self.writerow(header)
     117        self.row = [self.tostring(obj._get_pk_val())]
     118
     119    def end_object(self, obj):
     120        self.writerow(self.row)
     121
     122    def handle_field(self, obj, field):
     123        if getattr(obj, field.name) is not None:
     124            value = self.get_string_value(obj, field)
     125            self.row.append(self.tostring(value))
     126        else:
     127            self.row.append('')
     128
     129    def handle_fk_field(self, obj, field):
     130        related = self.represent(getattr(obj, field.name))
     131        self.row.append(self.tostring(related))
     132
     133    def handle_m2m_field(self, obj, field):
     134        """A tuple of m2m representations as dicts or ids"""
     135        related = [self.represent(related) for related in getattr(obj, field.name).iterator()]
     136        self.row.append(self.tostring(tuple(related)))
     137
     138    def tostring(self, value):
     139        s = smart_unicode(value)
     140        return s or ''
     141
     142    def represent(self, related):
     143        """Represent a model object either as its pk as int, or as a
     144        dict of unique key-values pairs recursevily."""
     145        if related is None:
     146            return None
     147        elif nice_foreign_keys:
     148            # Find a compound key
     149            if related._meta.unique_together:
     150                dict = {}
     151                for field_name in related._meta.unique_together[0]:
     152                    field = related._meta.get_field(field_name)
     153                    value = getattr(related, field.name)
     154                    if isinstance(value, models.Model):
     155                        dict[field_name] = self.represent(value)
     156                    else:
     157                        dict[field_name] = value
     158                return dict
     159            # Find a unique key
     160            else:
     161                for field in related._meta.fields:
     162                    if field.unique:
     163                        return {field.name: getattr(related, field.name)}
     164        return related._get_pk_val()
     165
     166    def writerow(self, row):
     167        self.output.writerow(row)
     168
     169
     170class Deserializer(base.Deserializer):
     171
     172    def __iter__(self):
     173        for values in csv.reader(self.stream, dialect=dialect):
     174            if values:
     175                if values[0].startswith(header_start):
     176                    # Model
     177                    model, first_field = values[0].split(':', 2)
     178                    model = model[len(header_start):]
     179                    try:
     180                        self.model = models.get_model(*model.split("."))
     181                    except TypeError:
     182                        raise base.DeserializationError("No model %s in db" % model)
     183                    # Field names
     184                    self.field_names = [first_field] + values[1:]
     185                else:
     186                    # An object
     187                    meta = self.model._meta
     188                    data = {meta.pk.attname: meta.pk.to_python(values[0])}
     189                    m2m_data = {}
     190                    for i in range(1, len(values)):
     191                        name = self.field_names[i]
     192                        value = values[i]
     193                        if value == '':
     194                            value = None
     195                        field = meta.get_field(name)
     196                        if field.rel and isinstance(field.rel, models.ManyToManyRel):
     197                            if value:
     198                                m2m_data[field.name] = \
     199                                  [pk_or_none(resolve_related(field, v)) for v in eval(value)]
     200                            else:
     201                                m2m_data[field.name] = []
     202                        elif field.rel and isinstance(field.rel, models.ManyToOneRel):
     203                            if value:
     204                                data[field.attname] = pk_or_none(resolve_related(field, eval(value)))
     205                        else:
     206                            if value == '""':
     207                                value = ''
     208                            value = self.to_python(field, value)
     209                            data[field.attname] = value
     210                    yield base.DeserializedObject(self.model(**data), m2m_data)
     211       
     212    def to_python(self, field, value):
     213        """
     214        The to_python method of some fields are not implemented, so this
     215        is a workaround for them.
     216        """
     217        if value is None and not field.null:
     218            value = ''
     219        value = field.to_python(value)
     220        if value is None:
     221            return None
     222        # XXX isinstance(field, IntegerField) doesn't seem to work
     223        elif field.__class__.__name__.endswith('IntegerField'):
     224            return int(value)
     225        elif field.__class__.__name__.endswith('FloatField'):
     226            return float(value)
     227        else:
     228            return value
     229
Back to Top