1 | import datetime
|
---|
2 | import warnings
|
---|
3 |
|
---|
4 | from django.db import models
|
---|
5 | from django.db.models.signals import post_delete, post_save, pre_delete
|
---|
6 |
|
---|
7 | try:
|
---|
8 | from django.utils.timezone import now
|
---|
9 | except ImportError:
|
---|
10 | now = datetime.datetime.now
|
---|
11 |
|
---|
12 | # define basestring for python 3
|
---|
13 | try:
|
---|
14 | basestring
|
---|
15 | except NameError:
|
---|
16 | basestring = (str, bytes)
|
---|
17 |
|
---|
18 |
|
---|
19 | class PositionField(models.IntegerField):
|
---|
20 | def __init__(self, verbose_name=None, name=None, default=-1, collection=None, parent_link=None, unique_for_field=None, unique_for_fields=None, *args, **kwargs):
|
---|
21 | if 'unique' in kwargs:
|
---|
22 | raise TypeError("%s can't have a unique constraint." % self.__class__.__name__)
|
---|
23 | super(PositionField, self).__init__(verbose_name, name, default=default, *args, **kwargs)
|
---|
24 |
|
---|
25 | # Backwards-compatibility mess begins here.
|
---|
26 | if collection is not None and unique_for_field is not None:
|
---|
27 | raise TypeError("'collection' and 'unique_for_field' are incompatible arguments.")
|
---|
28 |
|
---|
29 | if collection is not None and unique_for_fields is not None:
|
---|
30 | raise TypeError("'collection' and 'unique_for_fields' are incompatible arguments.")
|
---|
31 |
|
---|
32 | if unique_for_field is not None:
|
---|
33 | warnings.warn("The 'unique_for_field' argument is deprecated. Please use 'collection' instead.", DeprecationWarning)
|
---|
34 | if unique_for_fields is not None:
|
---|
35 | raise TypeError("'unique_for_field' and 'unique_for_fields' are incompatible arguments.")
|
---|
36 | collection = unique_for_field
|
---|
37 |
|
---|
38 | if unique_for_fields is not None:
|
---|
39 | warnings.warn("The 'unique_for_fields' argument is deprecated. Please use 'collection' instead.", DeprecationWarning)
|
---|
40 | collection = unique_for_fields
|
---|
41 | # Backwards-compatibility mess ends here.
|
---|
42 |
|
---|
43 | if isinstance(collection, basestring):
|
---|
44 | collection = (collection,)
|
---|
45 | self.collection = collection
|
---|
46 | self.parent_link = parent_link
|
---|
47 | self._collection_changed = None
|
---|
48 |
|
---|
49 | def get_cache_name(self):
|
---|
50 | return '_%s_cache' % self.name
|
---|
51 |
|
---|
52 | def contribute_to_class(self, cls, name):
|
---|
53 | super(PositionField, self).contribute_to_class(cls, name)
|
---|
54 | for constraint in cls._meta.unique_together:
|
---|
55 | if self.name in constraint:
|
---|
56 | raise TypeError("%s can't be part of a unique constraint." % self.__class__.__name__)
|
---|
57 | self.auto_now_fields = []
|
---|
58 | for field in cls._meta.fields:
|
---|
59 | if getattr(field, 'auto_now', False):
|
---|
60 | self.auto_now_fields.append(field)
|
---|
61 | setattr(cls, self.name, self)
|
---|
62 | pre_delete.connect(self.prepare_delete, sender=cls)
|
---|
63 | post_delete.connect(self.update_on_delete, sender=cls)
|
---|
64 | post_save.connect(self.update_on_save, sender=cls)
|
---|
65 |
|
---|
66 | def pre_save(self, model_instance, add):
|
---|
67 | # NOTE: check if the node has been moved to another collection; if it has, delete it from the old collection.
|
---|
68 | previous_instance = None
|
---|
69 | collection_changed = False
|
---|
70 | if not add and self.collection is not None:
|
---|
71 | try:
|
---|
72 | previous_instance = type(model_instance)._default_manager.get(pk=model_instance.pk)
|
---|
73 | for field_name in self.collection:
|
---|
74 | field = model_instance._meta.get_field(field_name)
|
---|
75 | current_field_value = getattr(model_instance, field.attname)
|
---|
76 | previous_field_value = getattr(previous_instance, field.attname)
|
---|
77 | if previous_field_value != current_field_value:
|
---|
78 | collection_changed = True
|
---|
79 | break
|
---|
80 | except models.ObjectDoesNotExist:
|
---|
81 | add = True
|
---|
82 | if not collection_changed:
|
---|
83 | previous_instance = None
|
---|
84 |
|
---|
85 | self._collection_changed = collection_changed
|
---|
86 | if collection_changed:
|
---|
87 | self.remove_from_collection(previous_instance)
|
---|
88 |
|
---|
89 | cache_name = self.get_cache_name()
|
---|
90 | current, updated = getattr(model_instance, cache_name)
|
---|
91 |
|
---|
92 | if collection_changed:
|
---|
93 | current = None
|
---|
94 |
|
---|
95 | if add:
|
---|
96 | if updated is None:
|
---|
97 | updated = current
|
---|
98 | current = None
|
---|
99 |
|
---|
100 | # existing instance, position not modified; no cleanup required
|
---|
101 | if current is not None and updated is None:
|
---|
102 | return current
|
---|
103 |
|
---|
104 | # if updated is still unknown set the object to the last position,
|
---|
105 | # either it is a new object or collection has been changed
|
---|
106 | if updated is None:
|
---|
107 | updated = -1
|
---|
108 |
|
---|
109 | collection_count = self.get_collection(model_instance).count()
|
---|
110 | if current is None:
|
---|
111 | max_position = collection_count
|
---|
112 | else:
|
---|
113 | max_position = collection_count - 1
|
---|
114 | min_position = 0
|
---|
115 |
|
---|
116 | # new instance; appended; no cleanup required on post_save
|
---|
117 | if add and (updated == -1 or updated >= max_position):
|
---|
118 | setattr(model_instance, cache_name, (max_position, None))
|
---|
119 | return max_position
|
---|
120 |
|
---|
121 | if max_position >= updated >= min_position:
|
---|
122 | # positive position; valid index
|
---|
123 | position = updated
|
---|
124 | elif updated > max_position:
|
---|
125 | # positive position; invalid index
|
---|
126 | position = max_position
|
---|
127 | elif abs(updated) <= (max_position + 1):
|
---|
128 | # negative position; valid index
|
---|
129 |
|
---|
130 | # Add 1 to max_position to make this behave like a negative list index.
|
---|
131 | # -1 means the last position, not the last position minus 1
|
---|
132 |
|
---|
133 | position = max_position + 1 + updated
|
---|
134 | else:
|
---|
135 | # negative position; invalid index
|
---|
136 | position = min_position
|
---|
137 |
|
---|
138 | # instance inserted; cleanup required on post_save
|
---|
139 | setattr(model_instance, cache_name, (current, position))
|
---|
140 | return position
|
---|
141 |
|
---|
142 | def __get__(self, instance, owner):
|
---|
143 | if instance is None:
|
---|
144 | raise AttributeError("%s must be accessed via instance." % self.name)
|
---|
145 | current, updated = getattr(instance, self.get_cache_name())
|
---|
146 | return current if updated is None else updated
|
---|
147 |
|
---|
148 | def __set__(self, instance, value):
|
---|
149 | if instance is None:
|
---|
150 | raise AttributeError("%s must be accessed via instance." % self.name)
|
---|
151 | if value is None:
|
---|
152 | value = self.default
|
---|
153 | cache_name = self.get_cache_name()
|
---|
154 | try:
|
---|
155 | current, updated = getattr(instance, cache_name)
|
---|
156 | except AttributeError:
|
---|
157 | current, updated = value, None
|
---|
158 | else:
|
---|
159 | updated = value
|
---|
160 |
|
---|
161 | instance.__dict__[self.name] = value # Django 1.10 fix for deferred fields
|
---|
162 | setattr(instance, cache_name, (current, updated))
|
---|
163 |
|
---|
164 | def get_collection(self, instance):
|
---|
165 | filters = {}
|
---|
166 | if self.collection is not None:
|
---|
167 | for field_name in self.collection:
|
---|
168 | field = instance._meta.get_field(field_name)
|
---|
169 | field_value = getattr(instance, field.attname)
|
---|
170 | if field.null and field_value is None:
|
---|
171 | filters['%s__isnull' % field.name] = True
|
---|
172 | else:
|
---|
173 | filters[field.name] = field_value
|
---|
174 | model = type(instance)
|
---|
175 | parent_link = self.parent_link
|
---|
176 | if parent_link is not None:
|
---|
177 | model = model._meta.get_field(parent_link).rel.to
|
---|
178 | return model._default_manager.filter(**filters)
|
---|
179 |
|
---|
180 | def get_next_sibling(self, instance):
|
---|
181 | """
|
---|
182 | Returns the next sibling of this instance.
|
---|
183 | """
|
---|
184 | try:
|
---|
185 | return self.get_collection(instance).filter(**{'%s__gt' % self.name: getattr(instance, self.get_cache_name())[0]})[0]
|
---|
186 | except:
|
---|
187 | return None
|
---|
188 |
|
---|
189 | def remove_from_collection(self, instance):
|
---|
190 | """
|
---|
191 | Removes a positioned item from the collection.
|
---|
192 | """
|
---|
193 | queryset = self.get_collection(instance)
|
---|
194 | current = getattr(instance, self.get_cache_name())[0]
|
---|
195 | updates = {self.name: models.F(self.name) - 1}
|
---|
196 | if self.auto_now_fields:
|
---|
197 | right_now = now()
|
---|
198 | for field in self.auto_now_fields:
|
---|
199 | updates[field.name] = right_now
|
---|
200 | queryset.filter(**{'%s__gt' % self.name: current}).update(**updates)
|
---|
201 |
|
---|
202 | def prepare_delete(self, sender, instance, **kwargs):
|
---|
203 | next_sibling = self.get_next_sibling(instance)
|
---|
204 | if next_sibling:
|
---|
205 | setattr(instance, '_next_sibling_pk', next_sibling.pk)
|
---|
206 | else:
|
---|
207 | setattr(instance, '_next_sibling_pk', None)
|
---|
208 |
|
---|
209 | def update_on_delete(self, sender, instance, **kwargs):
|
---|
210 | next_sibling_pk = getattr(instance, '_next_sibling_pk', None)
|
---|
211 | if next_sibling_pk:
|
---|
212 | try:
|
---|
213 | next_sibling = type(instance)._default_manager.get(pk=next_sibling_pk)
|
---|
214 | except:
|
---|
215 | next_sibling = None
|
---|
216 | if next_sibling:
|
---|
217 | queryset = self.get_collection(next_sibling)
|
---|
218 | current = getattr(instance, self.get_cache_name())[0]
|
---|
219 | updates = {self.name: models.F(self.name) - 1}
|
---|
220 | if self.auto_now_fields:
|
---|
221 | right_now = now()
|
---|
222 | for field in self.auto_now_fields:
|
---|
223 | updates[field.name] = right_now
|
---|
224 | queryset.filter(**{'%s__gt' % self.name: current}).update(**updates)
|
---|
225 | setattr(instance, '_next_sibling_pk', None)
|
---|
226 |
|
---|
227 | def update_on_save(self, sender, instance, created, **kwargs):
|
---|
228 | collection_changed = self._collection_changed
|
---|
229 | self._collection_changed = None
|
---|
230 |
|
---|
231 | current, updated = getattr(instance, self.get_cache_name())
|
---|
232 |
|
---|
233 | if current is None:
|
---|
234 | current = 0
|
---|
235 |
|
---|
236 | if updated is None and not collection_changed:
|
---|
237 | return None
|
---|
238 |
|
---|
239 | queryset = self.get_collection(instance).exclude(pk=instance.pk)
|
---|
240 |
|
---|
241 | updates = {}
|
---|
242 | if self.auto_now_fields:
|
---|
243 | right_now = now()
|
---|
244 | for field in self.auto_now_fields:
|
---|
245 | updates[field.name] = right_now
|
---|
246 |
|
---|
247 | if updated is None and created:
|
---|
248 | updated = -1
|
---|
249 |
|
---|
250 | if created or collection_changed:
|
---|
251 | # increment positions gte updated or node moved from another collection
|
---|
252 | queryset = queryset.filter(**{'%s__gte' % self.name: updated})
|
---|
253 | updates[self.name] = models.F(self.name) + 1
|
---|
254 | elif updated > current:
|
---|
255 | # decrement positions gt current and lte updated
|
---|
256 | queryset = queryset.filter(**{'%s__gt' % self.name: current, '%s__lte' % self.name: updated})
|
---|
257 | updates[self.name] = models.F(self.name) - 1
|
---|
258 | else:
|
---|
259 | # increment positions lt current and gte updated
|
---|
260 | queryset = queryset.filter(**{'%s__lt' % self.name: current, '%s__gte' % self.name: updated})
|
---|
261 | updates[self.name] = models.F(self.name) + 1
|
---|
262 |
|
---|
263 | queryset.update(**updates)
|
---|
264 | setattr(instance, self.get_cache_name(), (updated, None))
|
---|