""" Tornado plot example from Brennan Williams """

# Major library imports
from numpy import arange, cos, linspace, pi, sin, ones

from enthought.chaco2.example_support import COLOR_PALETTE
from enthought.enable2.example_support import DemoFrame, demo_main

# Enthought library imports
from enthought.enable2.wx_backend.api import Window

# Chaco imports
from enthought.chaco2.api import ArrayDataSource, BarPlot, DataRange1D, LabelAxis, \
                                LinearMapper, OverlayPlotContainer, PlotAxis, PlotGrid, \
                                DataLabel

def make_curves(spec):
   (index_points, value_points) = spec.get_points()
   size = len(index_points)

   middle_value=2500000.0
   mid_values=middle_value*ones(size)
   low_values=mid_values-10000.0*value_points
   high_values=mid_values+20000.0*value_points
   range_values=high_values-low_values

   spec.index_source = idx = ArrayDataSource(index_points)
   spec.value_source = vals = ArrayDataSource(low_values, sort_order="none")

   idx3 = ArrayDataSource(index_points)
   vals3 = ArrayDataSource(high_values, sort_order="none")

   starting_vals = ArrayDataSource(mid_values, sort_order="none")

   # Create the index range
   index_range = DataRange1D(idx, low=0.5, high=9.5)
   index_mapper = LinearMapper(range=index_range)

   # Create the value range
   lower_value=low_values.min()
   higher_value=high_values.max()
   value_range = DataRange1D(vals, vals3, low_setting='auto',
                             high_setting='auto', tight_bounds=False)
   value_mapper = LinearMapper(range=value_range,tight_bounds=False)

   # Create the plot
   plot1 = BarPlot(index=idx, value=vals,
                   value_mapper=value_mapper,
                   index_mapper=index_mapper,
                   starting_value=starting_vals,
                   line_color='black',
                   orientation='v',
                   fill_color=tuple(COLOR_PALETTE[6]),
                   bar_width=0.8, antialias=False)

   plot3 = BarPlot(index=idx3, value=vals3,
                   value_mapper=value_mapper,
                   index_mapper=index_mapper,
                   starting_value=starting_vals,
                   line_color='black',
                   orientation='v',
                   fill_color=tuple(COLOR_PALETTE[1]),
                   bar_width=0.8, antialias=False)

   return [plot1, plot3]

class PlotFrame(DemoFrame):

   def get_points(self):
       index = linspace(pi/4, 3*pi/2, 9)
       data = sin(index) + 2
       return (range(1, 10), data)

   def _create_window(self):
       container = OverlayPlotContainer(bgcolor = "white")

       self.container = container

       plots = make_curves(self)
       for plot in plots:
           plot.padding = 60 
           container.add(plot)

       bottom_axis = PlotAxis(plot, orientation='bottom')

       label_list=['var a', 'var b', 'var c', 'var d', 'var e', 'var f', 'var g', 'var h', 'var i']
       vertical_axis = LabelAxis(plot, orientation='left',
                              title='Categories',
                              positions = range(1, 10),
                              labels=label_list)
       vertical2_axis = LabelAxis(plot, orientation='right',
                              positions = range(1, 10),
                              labels=label_list)

       plot.underlays.append(vertical_axis)
       plot.underlays.append(vertical2_axis)
       plot.underlays.append(bottom_axis)

       return Window(self, -1, component=container)

if __name__ == "__main__":
   demo_main(PlotFrame, size=(800,600), title="Tornado plot")
