Ticket #10720: patch_for_django_rev-10131.diff

File patch_for_django_rev-10131.diff, 25.3 KB (added by Aryeh Leib Taurog <vim@…>, 16 years ago)

Provides OrderedManyToManyField with convenient API

Line 
1Ordered Many-to-Many Relation Fields for Django
2
3Version: 1.0
4Date: Tue Mar 31 21:47:05 JDT 2009
5
6For more information, see http://www.aryehleib.com/MutableLists.html
7
8 Todo
9 ----
10
11 1. List-interface API to ordered ForeignKey relations
12 2. Doubly-ordered many-to-many relation fields
13 3. Symmetric, ordered self-referential many-to-many relations
14 4. Allow multiple occurences, i.e. (to, from, order) is unique,
15 not just (to, from)
16
17Index: django/db/models/fields/ordered.py
18===================================================================
19--- django/db/models/fields/ordered.py (revision 0)
20+++ django/db/models/fields/ordered.py (revision 0)
21@@ -0,0 +1,292 @@
22+from django.db import connection, transaction
23+from django.db.models.fields import related as relfields
24+from django.utils.functional import curry
25+from django.utils.mutable_list import ListMixin
26+from django.core.exceptions import ObjectDoesNotExist
27+"""
28+Provides Singly Ordered Many To Many Relation Field
29+"""
30+
31+class OrderedManyIndexError(ObjectDoesNotExist):
32+ "Out of range index on ordered many-to-many relation list"
33+
34+def create_ordered_many_related_manager(manager_base, through=None):
35+ """
36+ Creates a manager that subclasses manager_base
37+ to provide a set-style interface for the non-ordered side
38+ for partially ordered many-to-many related objects
39+ """
40+ ManyRelatedManager = relfields.create_many_related_manager(manager_base, through)
41+
42+ class OrderedManyRelatedManager(ManyRelatedManager):
43+ def __init__(self, model, *args, **kwargs):
44+ self.order_col_name = kwargs.pop('order_col_name')
45+ super(OrderedManyRelatedManager, self).__init__(model, *args, **kwargs)
46+
47+ def _add_one_item(self, target_pk, cursor=None):
48+ if cursor is None:
49+ cursor = connection.cursor()
50+
51+ # get the ordering for object to add and append this rel
52+ cursor.execute("SELECT MAX(%s) FROM %s WHERE %s = %%s" % \
53+ (self.order_col_name, self.join_table, self.target_col_name),
54+ [ target_pk ])
55+ result = cursor.fetchone()[0]
56+ if result is None:
57+ next_order = 0
58+ else:
59+ next_order = result + 1
60+
61+ # Add the ones that aren't there already
62+ cursor.execute("INSERT INTO %s (%s, %s, %s) VALUES (%%s, %%s, %%s)" % \
63+ (self.join_table, self.source_col_name, self.target_col_name,
64+ self.order_col_name),
65+ [self._pk_val, target_pk, next_order])
66+
67+ def _add_items(self, source_col_name, target_col_name, *objs):
68+ # join_table: name of the m2m link table
69+ # source_col_name: the PK colname in join_table for the source object
70+ # target_col_name: the PK colname in join_table for the target object
71+ # *objs - objects to add. Either object instances, or primary keys of object instances.
72+ assert source_col_name == self.source_col_name, 'source is "%s", should be "%s"' % \
73+ (source_col_name, self.source_col_name)
74+ assert target_col_name == self.target_col_name, 'target is "%s", should be "%s"' % \
75+ (target_col_name, self.target_col_name)
76+
77+ # If there aren't any objects, there is nothing to do.
78+ if objs:
79+ # Check that all the objects are of the right type
80+ new_ids = set()
81+ for obj in objs:
82+ if isinstance(obj, self.model):
83+ new_ids.add(obj._get_pk_val())
84+ else:
85+ new_ids.add(obj)
86+ # Add the newly created or already existing objects to the join table.
87+ # First find out which items are already added, to avoid adding them twice
88+ cursor = connection.cursor()
89+ cursor.execute("SELECT %s FROM %s WHERE %s = %%s AND %s IN (%s)" % \
90+ (target_col_name, self.join_table, source_col_name,
91+ target_col_name, ",".join(['%s'] * len(new_ids))),
92+ [self._pk_val] + list(new_ids))
93+ existing_ids = set([row[0] for row in cursor.fetchall()])
94+
95+ # Add the ones that aren't there already
96+ for obj_id in (new_ids - existing_ids):
97+ self._add_one_item(obj_id, cursor)
98+ transaction.commit_unless_managed()
99+
100+ return OrderedManyRelatedManager
101+
102+def create_reverse_ordered_many_related_manager(manager_base, through=None):
103+ """
104+ Creates a manager that subclasses manager_base
105+ as well as ListMixin to provide list-style interface
106+ to ordered many-to-many related objects
107+ """
108+ ManyRelatedManager = relfields.create_many_related_manager(manager_base, through)
109+
110+ class ReverseOrderedManyRelatedManager(ManyRelatedManager, ListMixin):
111+ _IndexError = OrderedManyIndexError
112+
113+ def __init__(self, model, *args, **kwargs):
114+ self.order_col_name = kwargs.pop('order_col_name')
115+ super(ReverseOrderedManyRelatedManager, self).__init__(model, *args, **kwargs)
116+ # TODO: dynamically get the pk type for related field,
117+ # instead of assuming it's an int/long
118+ self._allowed = (model, int, long)
119+
120+ def get_query_set(self):
121+ extras = { 'order_by' : ['%s.%s' % (self.join_table, self.order_col_name)], }
122+ return manager_base.get_query_set(self)._next_is_sticky().filter(**(self.core_filters)).extra(**extras)
123+
124+ def __len__(self):
125+ return self.count()
126+
127+ def __getitem__(self, index):
128+ self._begin()
129+ result = super(ReverseOrderedManyRelatedManager, self).__getitem__(index)
130+ self._end(commit=False)
131+ return result
132+
133+ def __setitem__(self, index, val):
134+ self._begin()
135+ super(ReverseOrderedManyRelatedManager, self).__setitem__(index, val)
136+ self._end(commit=True)
137+ __setitem__.alters_data = True
138+
139+ def __delitem__(self, index):
140+ self._begin()
141+ super(ReverseOrderedManyRelatedManager, self).__delitem__(index)
142+ self._end(commit=True)
143+ __delitem__.alters_data = True
144+
145+ def _begin(self):
146+ query = 'SELECT "id", %s FROM %s WHERE %s = %%s ORDER BY %s' % \
147+ (self.source_col_name, self.join_table,
148+ self.target_col_name, self.order_col_name)
149+ self._cursor = connection.cursor()
150+ self._cursor.execute(query, [self._pk_val])
151+ self._cache = self._cursor.fetchall()
152+
153+ def _end(self, commit=False):
154+ del self._cursor
155+ del self._cache
156+ if commit:
157+ transaction.commit_unless_managed()
158+
159+ def add(self, *objs):
160+ self.extend(objs)
161+ add.alters_data = True
162+
163+ def remove(self, *objs):
164+ for o in objs:
165+ ListMixin.remove(self, o)
166+ remove.alters_data = True
167+
168+ def clear(self):
169+ self[:] = []
170+ clear.alters_data = True
171+
172+ def _get_single_external(self, i):
173+ target_pk = self._cache[i][1]
174+ return super(ReverseOrderedManyRelatedManager, self).get(pk=target_pk)
175+
176+ def _set_single(self, i, obj):
177+ join_id = self._cache[i][0]
178+ query = 'UPDATE %s SET %s = %%s WHERE "id" = %%s' % \
179+ (self.join_table, self.source_col_name)
180+ self._cursor.execute(query, [obj.pk, join_id])
181+
182+ # need to use int as pk of related object for serialization
183+ # so we use a special class to distinguish indices into the cache
184+ class _PK(object):
185+ def __init__(self, pk):
186+ self.pk = pk
187+
188+ def _get_single_internal(self, i):
189+ return self._PK(i)
190+
191+ def _set_list(self, length, items):
192+ # Optimize this by re-using existing records?
193+ query = 'DELETE FROM %s WHERE %s = %%s' % \
194+ (self.join_table, self.target_col_name)
195+ self._cursor.execute(query, [self._pk_val])
196+
197+ vals_list = []
198+ my_pk = self._pk_val
199+ for i, item in enumerate(items):
200+ if isinstance(item, self._PK):
201+ vals_list.append((my_pk, self._cache[item.pk][1], i))
202+ elif isinstance(item, (int, long)):
203+ vals_list.append((my_pk, item, i))
204+ else:
205+ vals_list.append((my_pk, item.pk, i))
206+
207+ query = 'INSERT INTO %s (%s, %s, %s) VALUES (%%s, %%s, %%s)' % \
208+ (self.join_table, self.target_col_name,
209+ self.source_col_name, self.order_col_name)
210+ self._cursor.executemany(query, vals_list)
211+
212+ return ReverseOrderedManyRelatedManager
213+
214+class OrderedManyRelatedObjectsDescriptor(relfields.ManyRelatedObjectsDescriptor):
215+ def __get__(self, instance, instance_type=None):
216+ if instance is None:
217+ return self
218+
219+ # Dynamically create a class that subclasses the related
220+ # model's default manager.
221+ rel_model = self.related.model
222+ superclass = rel_model._default_manager.__class__
223+ OrderedRelatedManager = create_ordered_many_related_manager(superclass, self.related.field.rel.through)
224+
225+ qn = connection.ops.quote_name
226+ manager = OrderedRelatedManager(
227+ model=rel_model,
228+ core_filters={'%s__pk' % self.related.field.name: instance._get_pk_val()},
229+ instance=instance,
230+ symmetrical=False,
231+ join_table=qn(self.related.field.m2m_db_table()),
232+ source_col_name=qn(self.related.field.m2m_reverse_name()),
233+ target_col_name=qn(self.related.field.m2m_column_name()),
234+ order_col_name=qn(self.related.field.m2m_order_name()),
235+ )
236+
237+ return manager
238+
239+class ReverseOrderedManyRelatedObjectsDescriptor(relfields.ReverseManyRelatedObjectsDescriptor):
240+ def __get__(self, instance, instance_type=None):
241+ # probably should not allow this
242+ if instance is None:
243+ return self
244+
245+ # Dynamically create a class that subclasses the related
246+ # model's default manager.
247+ rel_model = self.field.rel.to
248+ superclass = rel_model._default_manager.__class__
249+ ReverseOrderedRelatedManager = create_reverse_ordered_many_related_manager(superclass, self.field.rel.through)
250+
251+ qn = connection.ops.quote_name
252+ manager = ReverseOrderedRelatedManager(
253+ model=rel_model,
254+ core_filters={'%s__pk' % self.field.related_query_name(): instance._get_pk_val()},
255+ instance=instance,
256+ symmetrical=False,
257+ join_table=qn(self.field.m2m_db_table()),
258+ source_col_name=qn(self.field.m2m_reverse_name()),
259+ target_col_name=qn(self.field.m2m_column_name()),
260+ order_col_name=qn(self.field.m2m_order_name()),
261+ )
262+
263+ return manager
264+
265+ def __set__(self, instance, val):
266+ if instance is None:
267+ raise AttributeError, "Manager must be accessed via instance"
268+ manager = self.__get__(instance)
269+ manager[:] = list(val)
270+
271+
272+class OrderedManyToManyField(relfields.ManyToManyField):
273+ ordered = 1
274+ def __init__(self, *args, **kwargs):
275+ kwargs['symmetrical'] = False
276+ if 'through' in kwargs:
277+ raise ValueError, "Use of 'through' model not allowed for OrderedManyToManyField"
278+
279+ super(OrderedManyToManyField, self).__init__(*args, **kwargs)
280+
281+ def contribute_to_class(self, cls, name):
282+ super(OrderedManyToManyField, self).contribute_to_class(cls, name)
283+ # Add the descriptor for the m2m relation
284+ setattr(cls, self.name, ReverseOrderedManyRelatedObjectsDescriptor(self))
285+
286+ # Set up the accessor for the m2m table name for the relation
287+ self.m2m_db_table = curry(self._get_m2m_db_table, cls._meta)
288+
289+ if isinstance(self.rel.to, basestring):
290+ target = self.rel.to
291+ else:
292+ target = self.rel.to._meta.db_table
293+ cls._meta.duplicate_targets[self.column] = (target, "m2m")
294+
295+ def contribute_to_related_class(self, cls, related):
296+ # Set up the accessors for the column names on the m2m table
297+ self.m2m_column_name = curry(self._get_m2m_column_name, related)
298+ self.m2m_reverse_name = curry(self._get_m2m_reverse_name, related)
299+ self.m2m_order_name = curry(self._get_m2m_order_name, related)
300+
301+ setattr(cls, related.get_accessor_name(), OrderedManyRelatedObjectsDescriptor(related))
302+
303+ def _get_m2m_order_name(self, related):
304+ "Function that can be curried to provide the related column name for the m2m table"
305+ try:
306+ return self._m2m_order_name_cache
307+ except:
308+ self._m2m_order_name_cache = related.parent_model._meta.object_name.lower() + '_order'
309+ return self._m2m_order_name_cache
310+
311
312Property changes on: django/db/models/fields/ordered.py
313___________________________________________________________________
314Name: svn:executable
315 + *
316
317Index: django/db/models/__init__.py
318===================================================================
319--- django/db/models/__init__.py (revision 10131)
320+++ django/db/models/__init__.py (working copy)
321@@ -11,6 +11,7 @@
322 from django.db.models.fields.subclassing import SubfieldBase
323 from django.db.models.fields.files import FileField, ImageField
324 from django.db.models.fields.related import ForeignKey, OneToOneField, ManyToManyField, ManyToOneRel, ManyToManyRel, OneToOneRel
325+from django.db.models.fields.ordered import OrderedManyToManyField
326 from django.db.models import signals
327
328 # Admin stages.
329Index: django/db/backends/creation.py
330===================================================================
331--- django/db/backends/creation.py (revision 10131)
332+++ django/db/backends/creation.py (working copy)
333@@ -223,6 +223,13 @@
334 style.SQL_FIELD(qn(field.rel.to._meta.pk.column)),
335 self.connection.ops.deferrable_sql())
336 ]
337+ if hasattr(field, 'ordered'):
338+ table_output.append(
339+ ' %s %s %s,' %
340+ (style.SQL_FIELD(qn(field.m2m_order_name())),
341+ style.SQL_COLTYPE(models.IntegerField().db_type()),
342+ style.SQL_KEYWORD('NOT NULL'),)
343+ )
344 deferred = []
345
346 return table_output, deferred
347Index: django/utils/mutable_list.py
348===================================================================
349--- django/utils/mutable_list.py (revision 0)
350+++ django/utils/mutable_list.py (revision 0)
351@@ -0,0 +1,309 @@
352+# Copyright (c) 2008-2009 Aryeh Leib Taurog, all rights reserved.
353+# Released under the New BSD license.
354+"""
355+This module contains a base type which provides list-style mutations
356+without specific data storage methods.
357+
358+See also http://www.aryehleib.com/MutableLists.html
359+
360+Author: Aryeh Leib Taurog.
361+"""
362+class ListMixin(object):
363+ """
364+ A base class which provides complete list interface.
365+ Derived classes must call ListMixin's __init__() function
366+ and implement the following:
367+
368+ function _get_single_external(self, i):
369+ Return single item with index i for general use.
370+ The index i will always satisfy 0 <= i < len(self).
371+
372+ function _get_single_internal(self, i):
373+ Same as above, but for use within the class [Optional]
374+ Note that if _get_single_internal and _get_single_internal return
375+ different types of objects, _set_list must distinguish
376+ between the two and handle each appropriately.
377+
378+ function _set_list(self, length, items):
379+ Recreate the entire object.
380+
381+ NOTE: items may be a generator which calls _get_single_internal.
382+ Therefore, it is necessary to cache the values in a temporary:
383+ temp = list(items)
384+ before clobbering the original storage.
385+
386+ function _set_single(self, i, value):
387+ Set the single item at index i to value [Optional]
388+ If left undefined, all mutations will result in rebuilding
389+ the object using _set_list.
390+
391+ function __len__(self):
392+ Return the length
393+
394+ int _minlength:
395+ The minimum legal length [Optional]
396+
397+ int _maxlength:
398+ The maximum legal length [Optional]
399+
400+ type or tuple _allowed:
401+ A type or tuple of allowed item types [Optional]
402+
403+ class _IndexError:
404+ The type of exception to be raise on invalid index [Optional]
405+ """
406+
407+ _minlength = 0
408+ _maxlength = None
409+ _IndexError = IndexError
410+
411+ ### Python initialization and special list interface methods ###
412+
413+ def __init__(self, *args, **kwargs):
414+ if not hasattr(self, '_get_single_internal'):
415+ self._get_single_internal = self._get_single_external
416+
417+ if not hasattr(self, '_set_single'):
418+ self._set_single = self._set_single_rebuild
419+ self._assign_extended_slice = self._assign_extended_slice_rebuild
420+
421+ super(ListMixin, self).__init__(*args, **kwargs)
422+
423+ def __getitem__(self, index):
424+ "Get the item(s) at the specified index/slice."
425+ if isinstance(index, slice):
426+ return [self._get_single_external(i) for i in xrange(*index.indices(len(self)))]
427+ else:
428+ index = self._checkindex(index)
429+ return self._get_single_external(index)
430+
431+ def __delitem__(self, index):
432+ "Delete the item(s) at the specified index/slice."
433+ if not isinstance(index, (int, long, slice)):
434+ raise TypeError("%s is not a legal index" % index)
435+
436+ # calculate new length and dimensions
437+ origLen = len(self)
438+ if isinstance(index, (int, long)):
439+ index = self._checkindex(index)
440+ indexRange = [index]
441+ else:
442+ indexRange = range(*index.indices(origLen))
443+
444+ newLen = origLen - len(indexRange)
445+ newItems = ( self._get_single_internal(i)
446+ for i in xrange(origLen)
447+ if i not in indexRange )
448+
449+ self._rebuild(newLen, newItems)
450+
451+ def __setitem__(self, index, val):
452+ "Set the item(s) at the specified index/slice."
453+ if isinstance(index, slice):
454+ self._set_slice(index, val)
455+ else:
456+ index = self._checkindex(index)
457+ self._check_allowed((val,))
458+ self._set_single(index, val)
459+
460+ def __iter__(self):
461+ "Iterate over the items in the list"
462+ for i in xrange(len(self)):
463+ yield self[i]
464+
465+ ### Special methods for arithmetic operations ###
466+ def __add__(self, other):
467+ 'add another list-like object'
468+ return self.__class__(list(self) + list(other))
469+
470+ def __radd__(self, other):
471+ 'add to another list-like object'
472+ return other.__class__(list(other) + list(self))
473+
474+ def __iadd__(self, other):
475+ 'add another list-like object to self'
476+ self.extend(list(other))
477+ return self
478+
479+ def __mul__(self, n):
480+ 'multiply'
481+ return self.__class__(list(self) * n)
482+
483+ def __rmul__(self, n):
484+ 'multiply'
485+ return self.__class__(list(self) * n)
486+
487+ def __imul__(self, n):
488+ 'multiply'
489+ if n <= 0:
490+ del self[:]
491+ else:
492+ cache = list(self)
493+ for i in range(n-1):
494+ self.extend(cache)
495+ return self
496+
497+ def __cmp__(self, other):
498+ 'cmp'
499+ slen = len(self)
500+ for i in range(slen):
501+ try:
502+ c = cmp(self[i], other[i])
503+ except IndexError:
504+ # must be other is shorter
505+ return 1
506+ else:
507+ # elements not equal
508+ if c: return c
509+
510+ return cmp(slen, len(other))
511+
512+ ### Public list interface Methods ###
513+ ## Non-mutating ##
514+ def count(self, val):
515+ "Standard list count method"
516+ count = 0
517+ for i in self:
518+ if val == i: count += 1
519+ return count
520+
521+ def index(self, val):
522+ "Standard list index method"
523+ for i in xrange(0, len(self)):
524+ if self[i] == val: return i
525+ raise ValueError('%s not found in object' % str(val))
526+
527+ ## Mutating ##
528+ def append(self, val):
529+ "Standard list append method"
530+ self[len(self):] = [val]
531+
532+ def extend(self, vals):
533+ "Standard list extend method"
534+ self[len(self):] = vals
535+
536+ def insert(self, index, val):
537+ "Standard list insert method"
538+ if not isinstance(index, (int, long)):
539+ raise TypeError("%s is not a legal index" % index)
540+ self[index:index] = [val]
541+
542+ def pop(self, index=-1):
543+ "Standard list pop method"
544+ result = self[index]
545+ del self[index]
546+ return result
547+
548+ def remove(self, val):
549+ "Standard list remove method"
550+ del self[self.index(val)]
551+
552+ def reverse(self):
553+ "Standard list reverse method"
554+ self[:] = self[-1::-1]
555+
556+ def sort(self, cmp=cmp, key=None, reverse=False):
557+ "Standard list sort method"
558+ if key:
559+ temp = [(key(v),v) for v in self]
560+ temp.sort(cmp=cmp, key=lambda x: x[0], reverse=reverse)
561+ self[:] = [v[1] for v in temp]
562+ else:
563+ temp = list(self)
564+ temp.sort(cmp=cmp, reverse=reverse)
565+ self[:] = temp
566+
567+ ### Private routines ###
568+ def _rebuild(self, newLen, newItems):
569+ if newLen < self._minlength:
570+ raise ValueError('Must have at least %d items' % self._minlength)
571+ if self._maxlength is not None and newLen > self._maxlength:
572+ raise ValueError('Cannot have more than %d items' % self._maxlength)
573+
574+ self._set_list(newLen, newItems)
575+
576+ def _set_single_rebuild(self, index, value):
577+ self._set_slice(slice(index, index + 1, 1), [value])
578+
579+ def _checkindex(self, index, correct=True):
580+ length = len(self)
581+ if 0 <= index < length:
582+ return index
583+ if correct and -length <= index < 0:
584+ return index + length
585+ raise self._IndexError('invalid index: %s' % str(index))
586+
587+ def _check_allowed(self, items):
588+ if hasattr(self, '_allowed'):
589+ if False in [isinstance(val, self._allowed) for val in items]:
590+ raise TypeError('Invalid type encountered in the arguments.')
591+
592+ def _set_slice(self, index, values):
593+ "Assign values to a slice of the object"
594+ try:
595+ iter(values)
596+ except TypeError:
597+ raise TypeError('can only assign an iterable to a slice')
598+
599+ self._check_allowed(values)
600+
601+ origLen = len(self)
602+ valueList = list(values)
603+ start, stop, step = index.indices(origLen)
604+
605+ # CAREFUL: index.step and step are not the same!
606+ # step will never be None
607+ if index.step is None:
608+ self._assign_simple_slice(start, stop, valueList)
609+ else:
610+ self._assign_extended_slice(start, stop, step, valueList)
611+
612+ def _assign_extended_slice_rebuild(self, start, stop, step, valueList):
613+ 'Assign an extended slice by rebuilding entire list'
614+ indexList = range(start, stop, step)
615+ # extended slice, only allow assigning slice of same size
616+ if len(valueList) != len(indexList):
617+ raise ValueError('attempt to assign sequence of size %d '
618+ 'to extended slice of size %d'
619+ % (len(valueList), len(indexList)))
620+
621+ # we're not changing the length of the sequence
622+ newLen = len(self)
623+ newVals = dict(zip(indexList, valueList))
624+ def newItems():
625+ for i in xrange(newLen):
626+ if i in newVals:
627+ yield newVals[i]
628+ else:
629+ yield self._get_single_internal(i)
630+
631+ self._rebuild(newLen, newItems())
632+
633+ def _assign_extended_slice(self, start, stop, step, valueList):
634+ 'Assign an extended slice by re-assigning individual items'
635+ indexList = range(start, stop, step)
636+ # extended slice, only allow assigning slice of same size
637+ if len(valueList) != len(indexList):
638+ raise ValueError('attempt to assign sequence of size %d '
639+ 'to extended slice of size %d'
640+ % (len(valueList), len(indexList)))
641+
642+ for i, val in zip(indexList, valueList):
643+ self._set_single(i, val)
644+
645+ def _assign_simple_slice(self, start, stop, valueList):
646+ 'Assign a simple slice; Can assign slice of any length'
647+ origLen = len(self)
648+ stop = max(start, stop)
649+ newLen = origLen - stop + start + len(valueList)
650+ def newItems():
651+ for i in xrange(origLen + 1):
652+ if i == start:
653+ for val in valueList:
654+ yield val
655+
656+ if i < origLen:
657+ if i < start or i >= stop:
658+ yield self._get_single_internal(i)
659+
660+ self._rebuild(newLen, newItems())
661
662Property changes on: django/utils/mutable_list.py
663___________________________________________________________________
664Name: svn:executable
665 + *
666
Back to Top