Source code for sortedm2m.fields

# -*- coding: utf-8 -*-
import django
from django.conf import settings
from django.db import router
from django.db import transaction
from django.db import models
from django.db.models import signals
from django.db.models.fields.related import ManyToManyField as _ManyToManyField
from django.db.models.fields.related import RECURSIVE_RELATIONSHIP_CONSTANT
from django.utils import six
from django.utils.functional import cached_property, curry

from .compat import create_forward_many_to_many_manager
from .compat import get_foreignkey_field_kwargs
from .compat import get_model_name, get_rel, get_rel_to, add_lazy_relation
from .forms import SortedMultipleChoiceField

SORT_VALUE_FIELD_NAME = 'sort_value'


if hasattr(transaction, 'atomic'):
    atomic = transaction.atomic
# Django 1.5 support
# We mock the atomic context manager.
else:
    class atomic(object):
        def __init__(self, *args, **kwargs):
            pass

        def __enter__(self):
            pass

        def __exit__(self, exc_type, exc_val, exc_tb):
            pass


def create_sorted_many_related_manager(superclass, rel, *args, **kwargs):
    RelatedManager = create_forward_many_to_many_manager(
        superclass, rel, *args, **kwargs)

    class SortedRelatedManager(RelatedManager):
        def get_queryset(self):
            # We use ``extra`` method here because we have no other access to
            # the extra sorting field of the intermediary model. The fields
            # are hidden for joins because we set ``auto_created`` on the
            # intermediary's meta options.
            try:
                return self.instance._prefetched_objects_cache[self.prefetch_cache_name]
            except (AttributeError, KeyError):
                # Django 1.5 support.
                if not hasattr(RelatedManager, 'get_queryset'):
                    queryset = super(SortedRelatedManager, self).get_query_set()
                else:
                    queryset = super(SortedRelatedManager, self).get_queryset()
                return queryset.extra(order_by=['%s.%s' % (
                    rel.through._meta.db_table,
                    rel.through._sort_field_name,
                )])

        get_query_set = get_queryset

        if not hasattr(RelatedManager, '_get_fk_val'):
            @property
            def _fk_val(self):
                # Django 1.5 support.
                if not hasattr(self, 'related_val'):
                    return self._pk_val
                return self.related_val[0]

        def get_prefetch_queryset(self, instances, queryset=None):
            # Django 1.5 support. The method name changed since.
            if django.VERSION < (1, 6):
                result = super(SortedRelatedManager, self).get_prefetch_query_set(instances)
            # Django 1.6 support. The queryset parameter was not supported.
            elif django.VERSION < (1, 7):
                result = super(SortedRelatedManager, self).get_prefetch_queryset(instances)
            else:
                result = super(SortedRelatedManager, self).get_prefetch_queryset(instances, queryset)
            queryset = result[0]
            queryset.query.extra_order_by = [
                '%s.%s' % (
                    rel.through._meta.db_table,
                    rel.through._sort_field_name,
                )]
            return (queryset,) + result[1:]

        get_prefetch_query_set = get_prefetch_queryset

        def set(self, objs, **kwargs):
            # Choosing to clear first will ensure the order is maintained.
            kwargs['clear'] = True
            super(SortedRelatedManager, self).set(objs, **kwargs)
        set.alters_data = True

        def _add_items(self, source_field_name, target_field_name, *objs):
            # source_field_name: the PK fieldname in join table for the source object
            # target_field_name: the PK fieldname in join table for the target object
            # *objs - objects to add. Either object instances, or primary keys of object instances.

            # If there aren't any objects, there is nothing to do.
            from django.db.models import Max, Model
            if objs:
                # Django uses a set here, we need to use a list to keep the
                # correct ordering.
                new_ids = []
                for obj in objs:
                    if isinstance(obj, self.model):
                        if not router.allow_relation(obj, self.instance):
                            raise ValueError(
                                'Cannot add "%r": instance is on database "%s", value is on database "%s"' %
                                (obj, self.instance._state.db, obj._state.db)
                            )
                        if hasattr(self, '_get_fk_val'):  # Django>=1.5
                            fk_val = self._get_fk_val(obj, target_field_name)
                            if fk_val is None:
                                raise ValueError('Cannot add "%r": the value for field "%s" is None' %
                                                 (obj, target_field_name))
                            new_ids.append(self._get_fk_val(obj, target_field_name))
                        else:  # Django<1.5
                            new_ids.append(obj.pk)
                    elif isinstance(obj, Model):
                        raise TypeError(
                            "'%s' instance expected, got %r" %
                            (self.model._meta.object_name, obj)
                        )
                    else:
                        new_ids.append(obj)

                db = router.db_for_write(self.through, instance=self.instance)
                manager = self.through._default_manager.using(db)
                vals = (manager
                        .values_list(target_field_name, flat=True)
                        .filter(**{
                            source_field_name: self._fk_val,
                            '%s__in' % target_field_name: new_ids,
                        }))
                for val in vals:
                    if val in new_ids:
                        new_ids.remove(val)
                _new_ids = []
                for pk in new_ids:
                    if pk not in _new_ids:
                        _new_ids.append(pk)
                new_ids = _new_ids
                new_ids_set = set(new_ids)

                if self.reverse or source_field_name == self.source_field_name:
                    # Don't send the signal when we are inserting the
                    # duplicate data row for symmetrical reverse entries.
                    signals.m2m_changed.send(sender=rel.through, action='pre_add',
                        instance=self.instance, reverse=self.reverse,
                        model=self.model, pk_set=new_ids_set, using=db)

                # Add the ones that aren't there already
                with atomic(using=db):
                    fk_val = self._fk_val
                    source_queryset = manager.filter(**{'%s_id' % source_field_name: fk_val})
                    sort_field_name = self.through._sort_field_name
                    sort_value_max = source_queryset.aggregate(max=Max(sort_field_name))['max'] or 0

                    manager.bulk_create([
                        self.through(**{
                            '%s_id' % source_field_name: fk_val,
                            '%s_id' % target_field_name: pk,
                            sort_field_name: sort_value_max + i + 1,
                        })
                        for i, pk in enumerate(new_ids)
                    ])

                if self.reverse or source_field_name == self.source_field_name:
                    # Don't send the signal when we are inserting the
                    # duplicate data row for symmetrical reverse entries.
                    signals.m2m_changed.send(sender=rel.through, action='post_add',
                        instance=self.instance, reverse=self.reverse,
                        model=self.model, pk_set=new_ids_set, using=db)

    return SortedRelatedManager


try:
    from django.db.models.fields.related_descriptors import ManyToManyDescriptor

    class SortedManyToManyDescriptor(ManyToManyDescriptor):
        def __init__(self, field):
            super(SortedManyToManyDescriptor, self).__init__(field.remote_field)

        @cached_property
        def related_manager_cls(self):
            model = self.rel.model
            return create_sorted_many_related_manager(
                model._default_manager.__class__,
                self.rel,
                # This is the new `reverse` argument (which ironically should
                # be False)
                reverse=False,
            )
except ImportError:
    # Django 1.8 support
    from django.db.models.fields.related import ReverseManyRelatedObjectsDescriptor

    class SortedManyToManyDescriptor(ReverseManyRelatedObjectsDescriptor):
        @cached_property
        def related_manager_cls(self):
            return create_sorted_many_related_manager(
                get_rel_to(self.field)._default_manager.__class__,
                get_rel(self.field))


class SortedManyToManyField(_ManyToManyField):
    '''
    Providing a many to many relation that remembers the order of related
    objects.

    Accept a boolean ``sorted`` attribute which specifies if relation is
    ordered or not. Default is set to ``True``. If ``sorted`` is set to
    ``False`` the field will behave exactly like django's ``ManyToManyField``.

    Accept a class ``base_class`` attribute which specifies the base class of
    the intermediate model. It allows to customize the intermediate model.
    '''
    def __init__(self, to, sorted=True, base_class=None, **kwargs):
        self.sorted = sorted
        self.sort_value_field_name = kwargs.pop(
            'sort_value_field_name',
            SORT_VALUE_FIELD_NAME)

        # Base class of through model
        self.base_class = base_class

        super(SortedManyToManyField, self).__init__(to, **kwargs)
        if self.sorted:
            self.help_text = kwargs.get('help_text', None)

    def deconstruct(self):
        # We have to persist custom added options in the ``kwargs``
        # dictionary. For readability only non-default values are stored.
        name, path, args, kwargs = super(SortedManyToManyField, self).deconstruct()
        if self.sort_value_field_name is not SORT_VALUE_FIELD_NAME:
            kwargs['sort_value_field_name'] = self.sort_value_field_name
        if self.sorted is not True:
            kwargs['sorted'] = self.sorted
        return name, path, args, kwargs

    def contribute_to_class(self, cls, name, **kwargs):
        if not self.sorted:
            return super(SortedManyToManyField, self).contribute_to_class(cls, name, **kwargs)

        # To support multiple relations to self, it's useful to have a non-None
        # related name on symmetrical relations for internal reasons. The
        # concept doesn't make a lot of sense externally ("you want me to
        # specify *what* on my non-reversible relation?!"), so we set it up
        # automatically. The funky name reduces the chance of an accidental
        # clash.
        rel = get_rel(self)
        rel_to = get_rel_to(self)
        if rel.symmetrical and (rel_to == "self" or rel_to == cls._meta.object_name):
            rel.related_name = "%s_rel_+" % name

        super(_ManyToManyField, self).contribute_to_class(cls, name, **kwargs)

        # The intermediate m2m model is not auto created if:
        #  1) There is a manually specified intermediate, or
        #  2) The class owning the m2m field is abstract.
        if not rel.through and not cls._meta.abstract:
            rel.through = self.create_intermediate_model(cls)

        # Add the descriptor for the m2m relation
        setattr(cls, self.name, SortedManyToManyDescriptor(self))

        # Set up the accessor for the m2m table name for the relation
        self.m2m_db_table = curry(self._get_m2m_db_table, cls._meta)

        # Populate some necessary rel arguments so that cross-app relations
        # work correctly.
        if isinstance(rel.through, six.string_types):
            def resolve_through_model(field, model, cls):
                get_rel(field).through = model
            add_lazy_relation(cls, self, rel.through, resolve_through_model)

        if hasattr(cls._meta, 'duplicate_targets'):  # Django<1.5
            if isinstance(rel_to, six.string_types):
                target = rel_to
            else:
                target = rel_to._meta.db_table
            cls._meta.duplicate_targets[self.column] = (target, "m2m")

    def get_internal_type(self):
        return 'ManyToManyField'

    def formfield(self, **kwargs):
        defaults = {}
        if self.sorted:
            defaults['form_class'] = SortedMultipleChoiceField
        defaults.update(kwargs)
        return super(SortedManyToManyField, self).formfield(**defaults)

    ###############################
    # Intermediate Model Creation #
    ###############################

    def get_intermediate_model_name(self, klass):
        return '%s_%s' % (klass._meta.object_name, self.name)

    def get_intermediate_model_meta_class(self, klass, from_field_name,
                                          to_field_name,
                                          sort_value_field_name):
        managed = True
        to_model = get_rel_to(self)
        if isinstance(to_model, six.string_types):
            if to_model != RECURSIVE_RELATIONSHIP_CONSTANT:
                def set_managed(field, model, cls):
                    get_rel(field).through._meta.managed = model._meta.managed or cls._meta.managed
                add_lazy_relation(klass, self, to_model, set_managed)
            else:
                managed = klass._meta.managed
        else:
            managed = klass._meta.managed or to_model._meta.managed

        options = {
            'db_table': self._get_m2m_db_table(klass._meta),
            'managed': managed,
            'auto_created': klass,
            'app_label': klass._meta.app_label,
            'db_tablespace': klass._meta.db_tablespace,
            'unique_together': ((from_field_name, to_field_name),),
            'ordering': (sort_value_field_name,),
            'verbose_name': '%(from)s-%(to)s relationship' % {'from': from_field_name, 'to': to_field_name},
            'verbose_name_plural': '%(from)s-%(to)s relationships' % {'from': from_field_name, 'to': to_field_name},
        }
        # Django 1.6 support.
        if hasattr(self.model._meta, 'apps'):
            options.update({
                'apps': self.model._meta.apps,
            })
        return type(str('Meta'), (object,), options)

    def get_rel_to_model_and_object_name(self, klass):
        rel_to = get_rel_to(self)
        if isinstance(rel_to, six.string_types):
            if rel_to != RECURSIVE_RELATIONSHIP_CONSTANT:
                to_model = rel_to
                to_object_name = to_model.split('.')[-1]
            else:
                to_model = klass
                to_object_name = to_model._meta.object_name
        else:
            to_model = rel_to
            to_object_name = to_model._meta.object_name
        return to_model, to_object_name

    def get_intermediate_model_sort_value_field(self, klass):
        field_name = self.sort_value_field_name
        field = models.IntegerField(default=0)
        return field_name, field

    def get_intermediate_model_from_field(self, klass):
        name = self.get_intermediate_model_name(klass)

        to_model, to_object_name = self.get_rel_to_model_and_object_name(klass)

        if get_rel_to(self) == RECURSIVE_RELATIONSHIP_CONSTANT or to_object_name == klass._meta.object_name:
            field_name = 'from_%s' % to_object_name.lower()
        else:
            field_name = get_model_name(klass)

        field = models.ForeignKey(klass, related_name='%s+' % name,
                                  **get_foreignkey_field_kwargs(self))
        return field_name, field

    def get_intermediate_model_to_field(self, klass):
        name = self.get_intermediate_model_name(klass)

        to_model, to_object_name = self.get_rel_to_model_and_object_name(klass)

        if get_rel_to(self) == RECURSIVE_RELATIONSHIP_CONSTANT or to_object_name == klass._meta.object_name:
            field_name = 'to_%s' % to_object_name.lower()
        else:
            field_name = to_object_name.lower()

        field = models.ForeignKey(to_model, related_name='%s+' % name,
                                  **get_foreignkey_field_kwargs(self))
        return field_name, field

    def create_intermediate_model_from_attrs(self, klass, attrs):
        name = self.get_intermediate_model_name(klass)
        base_classes = (self.base_class, models.Model) if self.base_class else (models.Model,)

        return type(str(name), base_classes, attrs)

    def create_intermediate_model(self, klass):
        # Construct and return the new class.
        from_field_name, from_field = self.get_intermediate_model_from_field(klass)
        to_field_name, to_field = self.get_intermediate_model_to_field(klass)
        sort_value_field_name, sort_value_field = self.get_intermediate_model_sort_value_field(klass)
        meta = self.get_intermediate_model_meta_class(klass,
                                                      from_field_name,
                                                      to_field_name,
                                                      sort_value_field_name)

        attrs = {
            'Meta': meta,
            '__module__': klass.__module__,
            from_field_name: from_field,
            to_field_name: to_field,
            sort_value_field_name: sort_value_field,
            '_sort_field_name': sort_value_field_name,
            '_from_field_name': from_field_name,
            '_to_field_name': to_field_name,
        }
        return self.create_intermediate_model_from_attrs(klass, attrs)


# Add introspection rules for South database migrations
# See http://south.aeracode.org/docs/customfields.html
try:
    import south
except ImportError:
    south = None


if south is not None and 'south' in settings.INSTALLED_APPS:
    from south.modelsinspector import add_introspection_rules
    add_introspection_rules(
        [(
            (SortedManyToManyField,),
            [],
            {"sorted": ["sorted", {"default": True}]},
        )],
        [r'^sortedm2m\.fields\.SortedManyToManyField']
    )

    # Monkeypatch South M2M actions to create the sorted through model.
    # FIXME: This doesn't detect if you changed the sorted argument to the field.
    import south.creator.actions
    from south.creator.freezer import model_key

    class AddM2M(south.creator.actions.AddM2M):
        SORTED_FORWARDS_TEMPLATE = '''
        # Adding SortedM2M table for field %(field_name)s on '%(model_name)s'
        db.create_table(%(table_name)r, (
            ('id', models.AutoField(verbose_name='ID', primary_key=True, auto_created=True)),
            (%(left_field)r, models.ForeignKey(orm[%(left_model_key)r], null=False)),
            (%(right_field)r, models.ForeignKey(orm[%(right_model_key)r], null=False)),
            (%(sort_field)r, models.IntegerField())
        ))
        db.create_unique(%(table_name)r, [%(left_column)r, %(right_column)r])'''

        def console_line(self):
            if isinstance(self.field, SortedManyToManyField) and self.field.sorted:
                return " + Added SortedM2M table for %s on %s.%s" % (
                    self.field.name,
                    self.model._meta.app_label,
                    self.model._meta.object_name,
                )
            else:
                return super(AddM2M, self).console_line()

        def forwards_code(self):
            if isinstance(self.field, SortedManyToManyField) and self.field.sorted:
                return self.SORTED_FORWARDS_TEMPLATE % {
                    "model_name": self.model._meta.object_name,
                    "field_name": self.field.name,
                    "table_name": self.field.m2m_db_table(),
                    "left_field": self.field.m2m_column_name()[:-3], # Remove the _id part
                    "left_column": self.field.m2m_column_name(),
                    "left_model_key": model_key(self.model),
                    "right_field": self.field.m2m_reverse_name()[:-3], # Remove the _id part
                    "right_column": self.field.m2m_reverse_name(),
                    "right_model_key": model_key(get_rel_to(self.field)),
                    "sort_field": self.field.sort_value_field_name,
                }
            else:
                return super(AddM2M, self).forwards_code()

    class DeleteM2M(AddM2M):
        def console_line(self):
            return " - Deleted M2M table for %s on %s.%s" % (
                self.field.name,
                self.model._meta.app_label,
                self.model._meta.object_name,
            )

        def forwards_code(self):
            return AddM2M.backwards_code(self)

        def backwards_code(self):
            return AddM2M.forwards_code(self)

    south.creator.actions.AddM2M = AddM2M
    south.creator.actions.DeleteM2M = DeleteM2M