| 1 | from django.contrib.contenttypes import generic
|
|---|
| 2 | from django.db.models import signals
|
|---|
| 3 |
|
|---|
| 4 | class ImprovedGenericForeignKey(generic.GenericForeignKey):
|
|---|
| 5 | """
|
|---|
| 6 | Corrects the behaviour of GenericForeignKey so even if you firstly
|
|---|
| 7 | assign an object to this field and then save this object - its PK
|
|---|
| 8 | still gets saved in the fk_field.
|
|---|
| 9 |
|
|---|
| 10 | If you assign a not yet saved object to this field an exception is
|
|---|
| 11 | thrown upon saving the model.
|
|---|
| 12 | """
|
|---|
| 13 |
|
|---|
| 14 | class IncompleteData(Exception):
|
|---|
| 15 | message = 'Object assigned to field "%s" doesn\'t have a PK (save it first)!'
|
|---|
| 16 |
|
|---|
| 17 | def __init__(self, field_name):
|
|---|
| 18 | self.field_name = field_name
|
|---|
| 19 |
|
|---|
| 20 | def __str__(self):
|
|---|
| 21 | return self.message % self.field_name
|
|---|
| 22 |
|
|---|
| 23 | def contribute_to_class(self, cls, name):
|
|---|
| 24 | signals.pre_save.connect(self.instance_pre_save, sender=cls, weak=False)
|
|---|
| 25 | super(ImprovedGenericForeignKey, self).contribute_to_class(cls, name)
|
|---|
| 26 |
|
|---|
| 27 | def instance_pre_save(self, sender, instance, **kwargs):
|
|---|
| 28 | """
|
|---|
| 29 | Ensures that if GenericForeignKey has an object assigned
|
|---|
| 30 | that the fk_field stores the object's PK.
|
|---|
| 31 | """
|
|---|
| 32 |
|
|---|
| 33 | """ If we already have pk set don't do anything... """
|
|---|
| 34 | if getattr(instance, self.fk_field) is not None: return
|
|---|
| 35 |
|
|---|
| 36 | value = getattr(instance, self.name)
|
|---|
| 37 |
|
|---|
| 38 | """
|
|---|
| 39 | If no objects is assigned then we leave it as it is. If null constraints
|
|---|
| 40 | are present they should take care of this, if not, well, it's not my fault;)
|
|---|
| 41 | """
|
|---|
| 42 | if value is not None:
|
|---|
| 43 | fk = value._get_pk_val()
|
|---|
| 44 |
|
|---|
| 45 | if fk is None:
|
|---|
| 46 | raise self.IncompleteData(self.name)
|
|---|
| 47 |
|
|---|
| 48 | setattr(instance, self.fk_field, fk)
|
|---|