# Author: Prabhu Ramachandran <prabhu [at] aero . iitb . ac . in>
# Copyright (c) 2008, Enthought, Inc.
# License: BSD Style.

# Standard library imports.

# Enthought library imports.
from enthought.traits.api import Int, Instance, Str
from enthought.traits.ui.api import View, Group, Item
from enthought.tvtk.api import tvtk
from enthought.persistence import state_pickler

# Local imports.
from enthought.mayavi.core.common import error
from enthought.mayavi.core.pipeline_base import PipelineBase
from enthought.mayavi.core.module import Module
from enthought.mayavi.filters.optional import Optional
from enthought.mayavi.filters.mask_points import MaskPoints
from enthought.mayavi.filters.user_defined import UserDefined
from enthought.mayavi.components.actor2d import Actor2D
from enthought.mayavi.core.common import handle_children_state


################################################################################
# `Labels` class.
################################################################################ 
class Labels(Module):

    """
    Allows a user to label the current dataset or the current actor of
    the active module.
    """

    # Used for persistence.
    __version__ = 0

    # The object which we are labeling.
    object = Instance(PipelineBase)

    # The label format string.
    label_format = Str('', desc='the label format string')

    # Number of points to label.
    number_of_labels = Int(25, desc='the number of points to label')

    # The filter used for masking of the points.
    mask = Instance(MaskPoints)

    # Filter to select visible points.
    visible_points = Instance(Optional)

    # The 2D actor for the labels.
    actor = Instance(Actor2D)

    # The text property of the labels.
    property = Instance(tvtk.TextProperty)

    # The mapper for the labels.
    mapper = Instance(tvtk.LabeledDataMapper, args=())

    ########################################
    # Private traits.

    # The input used for the labeling.
    input = Instance(PipelineBase)

    # The id of the object in the modulemanager only used for
    # persistence.
    object_id = Int(-2)


    ########################################
    # View related traits.


    view = View(Group(Item(name='number_of_labels'),
                      Item(name='label_format'),
                      Item(name='mapper',
                           style='custom',
                           show_label=False,
                           resizable=True),
                      Item(name='mask',
                           style='custom',
                           resizable=True,
                           show_label=False),
                      label='Labels'
                      ),
                Group(
                      Item(name='visible_points',
                           style='custom',
                           resizable=True,
                           show_label=False),
                      label='VisiblePoints'
                      ),
                Group(Item(name='property',
                           style='custom',
                           show_label=False,
                           resizable=True),
                      label='TextProperty'
                     ),
                 resizable=True
                )

    ######################################################################
    # `object` interface.
    ###################################################################### 
    def __get_pure_state__(self):
        self._compute_object_id()
        d = super(Labels, self).__get_pure_state__()
        for name in ('object', 'mapper', 'input'):
            d.pop(name, None)
        # Must pickle the components.
        d['components'] = self.components
        return d

    def __set_pure_state__(self, state):
        handle_children_state(self.components, state.components)
        state_pickler.set_state(self, state)
        self.update_pipeline()

    ######################################################################
    # `Module` interface.
    ###################################################################### 
    def setup_pipeline(self):
        mask = MaskPoints()
        mask.filter.set(generate_vertices=True, random_mode=True)
        self.mask = mask
        v = UserDefined(filter=tvtk.SelectVisiblePoints(),
                        name='VisiblePoints')
        self.visible_points = Optional(filter=v, enabled=False)
        mapper = tvtk.LabeledDataMapper()
        self.mapper = mapper
        self.actor = Actor2D(mapper=mapper)
        self.property = mapper.label_text_property
        self.property.on_trait_change(self.render)
        self.components = [self.mask, self.visible_points, self.actor]

    def update_pipeline(self):
        mm = self.module_manager
        if mm is None:
            return

        self._find_input() # Calculates self.input
        self.mask.inputs = [self.input]
        self.visible_points.inputs = [self.mask]
        self.actor.inputs = [self.visible_points]
        self._number_of_labels_changed(self.number_of_labels)
        self._label_format_changed(self.label_format)
    
    ######################################################################
    # Non-public interface.
    ######################################################################
    def _find_input(self):
        mm = self.module_manager
        if self.object is None:
            if self.object_id == -1:
                self.input = mm.source
            elif self.object_id > -1:
                obj = mm.children[self.object_id]
                if hasattr(obj, 'actor'):
                    self.set(object=obj, trait_change_notify=False)
                    self.input = obj.actor.inputs[0]
                else:
                    self.input = mm.source
        else:
            o = self.object
            if hasattr(o, 'module_manager'):
                # A module.
                if hasattr(o, 'actor'):
                    self.input = o.actor.inputs[0]
                else:
                    self.input = o.module_manager.source

        if self.input is None:
            if self.object_id == -2:
                self.input = mm.source
            else:
                error('No object to label!')
                return

    def _number_of_labels_changed(self, value):
        if self.input is None:
            return
        f = self.mask.filter
        npts = self.input.outputs[0].number_of_points
        f.on_ratio = max(npts/value, 1)
        if self.mask.running:
            f.update()
            self.mask.data_changed = True

    def _label_format_changed(self, value):
        if len(value) > 0:
            self.mapper.label_format = value
            self.render()
        else:
            self.mapper.label_format = None
            self.render()

    def _object_changed(self, value):
        self.update_pipeline()

    def _compute_object_id(self):
        mm = self.module_manager
        input = self.input
        self.object_id = -2
        if input is mm.source:
            self.object_id = -1
            return
        for id, child in enumerate(mm.children):
            if child is self.object:
                self.object_id = id
                return

    def _scene_changed(self, old, new):
        self.visible_points.filter.filter.renderer = new.renderer
        super(Labels, self)._scene_changed(old, new)
