#TODO: fix yticks and ylabels on y zoom
from __future__ import generators
from __future__ import division

import distutils.sysconfig
import re, math, os, sys, time

import pygtk
pygtk.require('2.0')
import gtk
from gtk import gdk
from Numeric import arange, array, ones, zeros, logical_and, \
     nonzero, take, Float

from utils import CmpTrue, Range, indices, iterable, \
     is_string_like, flatten
from lines import Line2D_Dispatcher, SolidLine2D, \
     Line2D, DottedLine2D
from colors import ColorDispatcher

#PIXMAP_PATH = '/home/jdhunter/python/matplotlib/xpm/'
PIXMAP_PATH = os.path.join(distutils.sysconfig.PREFIX, 'share', 'pixmaps')

def _add_button_icon_pixmap(gtk_button, pixmap, orientation='left'):
    """
    Add pixmap to button functions borrowed from a post by Valery
    Febvre at
    http://www.mail-archive.com/pygtk@daa.com.au/msg03639.html
    """
    
    gtk_button.realize()
    label = gtk.Label(gtk_button.get_children()[0].get())
    gtk_button.remove(gtk_button.get_children()[0])

    if orientation is None:
        box = gtk.HBox(spacing=0)
        box.pack_start(pixmap, gtk.FALSE, gtk.FALSE, 0)        

    if orientation in ('left', 'right'):
        box = gtk.HBox(spacing=5)
    elif orientation in ('top', 'bottom'):
        box = gtk.VBox(spacing=5)
    if orientation in ('left', 'top'):
        box.pack_start(pixmap, gtk.FALSE, gtk.FALSE, 0)
        box.pack_start(label, gtk.FALSE, gtk.FALSE, 0)
    elif orientation in ('right', 'bottom'):
        box.pack_start(label, gtk.FALSE, gtk.FALSE, 0)
        box.pack_start(pixmap, gtk.FALSE, gtk.FALSE, 0)

    hbox = gtk.HBox()
    if box is not None:
        hbox.pack_start(box, gtk.TRUE, gtk.FALSE, 0)
    hbox.show_all()
    gtk_button.add(hbox)

def _add_button_icon(gtk_button, file, orientation='left'):
    gtk_button.realize()
    window = gtk_button.window
    xpm, mask = gtk.create_pixmap_from_xpm(window, None, file)
    image = gtk.Image()
    image.set_from_pixmap(xpm, mask)

    _add_button_icon_pixmap(gtk_button, image, orientation)



class AxisRange(Range):
    def __init__(self, valMin=CmpTrue(), valMax=CmpTrue()):        

        """
        A Range object containing the axis limits with support for
        labels, ticks and autoscaling
        """
        Range.__init__(self, valMin, valMax)
        self._ticks = []
        self._zeroRgx = re.compile('^(.*?)\.?0+(e[+-]\d+)?$')        
        self._tickFmtFunc = self._tickformat
        self._autoscale = 1

        # memoizing is not used for efficientcy here, but to insure
        # that repeated calls to autoscaling funcs with the same range
        # will return the same answer.
        self._memoize = {}
        
    def _get_decade_ticks(self, numTicks):

        exponent = self.set_decade_range()
        scale = 10**(-exponent)
        d = self.get_distance()
        step = ((scale*d)//numTicks+1)/scale

        (valMin, valMax) = self.get_range()

        # I add this (1e-3*d) to valMax to put a tick on valMax if
        # it's marginally above valmax with respect to the width d of
        # the plot
        ticks = arange(valMin, valMax+(1e-3*d), step)
        return ticks

    def get_labels(self):
        try: return self._labelsSet
        except AttributeError:
            labels =  map(self._tickFmtFunc, self.get_ticks())
            #turn off every other label
            for i in range(1,len(labels),2):
                labels[i] = ''
            return labels


    def get_ticks(self, numTicks=10):
        #if it's defined from the set, return it
        try: return self._ticksSet
        except AttributeError: pass

        # if the data limits haven't been set, can't compute ticks
        try: d = self.get_distance()
        except RuntimeError: return []

        # OK, all systems go for automated tick computation.
        if self._autoscale:
            return self._get_decade_ticks(numTicks)
        else:
            xmin, xmax = self.get_range()
            step = (xmax-xmin)/numTicks
            return arange(xmin, xmax+step, step)
                    
    def set_autoscale(self, b=1):
        if b:
            try: del self._ticksSet
            except AttributeError: pass
        self._autoscale = b

    def set_decade_range(self):
        if not self._autoscale: return

        # handle the special case of flat data by making the axis
        # range large comapred to the data range so it looks flat on a
        # plot.
        if self.get_distance()==0.0:
            self._xmin -= 1
            self._xmax += 1
        
        (exponent, remainder) = divmod(math.log10(self.get_distance()),1)
        if remainder < 0.84:
            exponent -= 1
        scale = 10**(-exponent)
        xmin, xmax = self.get_min(), self.get_max()

        try:
            (valMin, valMax) =  self._memoize[(xmin, xmax)]
        except KeyError:
            valMin = math.floor(scale*xmin)/scale
            valMax = math.ceil(scale*xmax)/scale
            self._memoize[ (xmin, xmax) ] = (valMin, valMax)

        self._xmin, self._xmax = valMin, valMax
        return exponent

    def set_labels(self, labels):
        self._labelsSet = labels

    def set_tick_format_func(self, func):
        """
        Set the function used to do tick formatting.  func has
        signature string = func(val)
        """
        self._tickFmtFunc = func

    def set_ticks(self, ticks):
        self.set_autoscale(0)
        self._ticksSet = ticks

    def set_tick_format(self, s):
        """
        Set the format string to do tick formatting.  s is a python
        format string
        """
        self._tickFmtFunc = lambda x: s % x

    def _tickformat(self, val):

        #if the number is not too big and it's an int, format it as an
        #int
        if abs(val)<1e4 and val==int(val): return '%d' % val

        # use exponential formatting for really big or small numbers,
        # else use float
        if abs(val) < 1e-3: fmt = '%1.3e'
        elif abs(val) > 1e4: fmt = '%1.3e'
        else: fmt = '%1.3f'
        s =  fmt % val

        # strip trailing zeros, remove '+', and handle exponential formatting
        m = self._zeroRgx.match(s)
        if m:
            s = m.group(1)
            if m.group(2) is not None: s += m.group(2)
        s = s.replace('+', '')
        return s

    def set_range(self, range):
        # this function turns autoscaling off as a side effect
        self.set_autoscale(0)
        self._xmin, self._xmax = range[0], range[1]


def _process_plot_format(fmt):
    """
    Process a matlab style color/line style format string.  Return a
    (lineStyle, color) tuple as a result of the processing.  Default
    values are (solidLine, Blue).  Example format strings include

    'ko'    : black circles
    '.b'    : blue dots
    'r--'   : red dashed lines

    See Line2D_Dispatcher and ColorDispatcher for more info.

    """

    #print 'FMT is: %s' % fmt
    styles = Line2D_Dispatcher()
    colors = ColorDispatcher()
    
    LineClass = styles['-']
    color = colors('b')

    # handle the multi char special cases and strip them from the
    # string
    if fmt.find('--')>=0:
        LineClass = styles['--']
        fmt = fmt.replace('--', '')
    if fmt.find('-.')>=0:
        LineClass = styles['-.']
        fmt = fmt.replace('-.', '')
    
    chars = [c for c in fmt]

    for c in chars:        
        if styles.has_key(c):
            LineClass = styles[c]
        elif ColorDispatcher().has_key(c):
            color = ColorDispatcher().get(c)
        else:
            err = 'Unrecognized character %c in format string' % c
            raise ValueError, err
    return LineClass, color


class _process_plot_var_args:    
    """

    Process variable length arguments to the plot command, so that
    plot commands like the followig are supported

      plot(t, s)
      plot(t1, s1, t2, s2)
      plot(t1, s1, 'ko', t2, s2)
      plot(t1, s1, 'ko', t2, s2, 'r--', t3, e3)

    an arbitrary number of x,y,fmt are allowed
    """
    def __call__(self, *args):
        return self._grab_next_args(*args)
            

    def _plot_1_arg(self, y):
        return SolidLine2D(arange(len(y)), y)

    def _plot_2_args(self, tup2):
        if is_string_like(tup2[1]):
            y, fmt = tup2
            (LineStyleClass, color) = _process_plot_format(fmt)
            return LineStyleClass(x=arange(len(y)),
                                  y=y,
                                  color=color)
        else:
            x,y = tup2
            return SolidLine2D(x, y)

    def _plot_3_args(self, tup3):
        x, y, fmt = tup3
        (LineStyleClass, color) = _process_plot_format(fmt)
        return LineStyleClass(x, y, color=color)



    def _grab_next_args(self, args):
        remaining = args
        while 1:
            if len(remaining)==0: return
            if len(remaining)==1:
                yield self._plot_1_arg(remaining[0])
                remaining = []
                continue
            if len(remaining)==2:
                yield self._plot_2_args(remaining)
                remaining = []
                continue
            if len(remaining)==3:
                if not is_string_like(remaining[2]):
                    raise ValueError, 'third arg must be a format string'
                yield self._plot_3_args(remaining)
                remaining=[]
                continue
            if is_string_like(args[2]):
                yield self._plot_3_args(remaining[:3])
                remaining=remaining[3:]
                continue
            yield self._plot_2_args(remaining[:2])
            remaining=args[2:]
        
    
class Axes:
    """
    Emulate matlab's axes command, creating axes with

      Axes(position=[left, bottom, width, height])

    where all the arguments are fractions in [0,1] which specify the
    fraction of the total figure window.  
    
    """

    _colors = ColorDispatcher()
    def __init__(self, position,
                 bg=ColorDispatcher().get('g')):
        self._position = position
        self._bg = bg
        self._gridState = 0
        self._lines = []
        self._axText = []
        self._get_lines = _process_plot_var_args()

        self._xdataRange = Range()
        self._ydataRange = Range()
        self._xaxisRange = AxisRange()
        self._yaxisRange = AxisRange()
        self._winXlim = Range()
        self._winYlim = Range()

        map(self._generate_axis_boilerplate, ('x', 'y'))

    def _pass_func(self, *args, **kwargs):
        pass
    
    def add_line(self, line):
        "Add a line to the list of plot lines"
        self._xaxisRange.update(line.get_x())
        self._yaxisRange.update(line.get_y())
        self._xdataRange.update(line.get_x())
        self._ydataRange.update(line.get_y())
        self._lines.append(line)

    def autoscale(self):
        map(lambda l: l.flush_clip(), self._lines)
        self._xaxisRange.set_range(self._xdataRange.get_range())
        self._yaxisRange.set_range(self._ydataRange.get_range())
        self._xaxisRange.set_autoscale(1)
        self._yaxisRange.set_autoscale(1)
        self.draw()

    def set_autoscale(self,b):
        self._xaxisRange.set_autoscale(b)
        self._yaxisRange.set_autoscale(b)


    def configure_event(self, widget, event):
        "Reset the window params after a configure event"
        if widget is not None: self._widget = widget
        if event is not None: self._event = event

        win = self._widget.window

        width, height = self._widget.window.get_size()
        self._left = self._position[0] * width
        self._bottom = (1-self._position[1]) * height
        self._width = self._position[2] * width
        self._height = self._position[3] * height
        self._right = self._left + self._width
        self._top = self._bottom - self._height

        # set the new axis information
        self._winXlim = Range(self._left, self._right)
        self._winYlim = Range(self._bottom, self._top)        

        self.draw(self._widget)
        return gtk.TRUE

    def in_axes(self, xwin, ywin):
        if xwin<self._left or xwin > self._right:
            return 0
        if ywin>self._bottom or ywin<self._top:
            return 0
        return 1



    def draw(self, widget=None):
        "Draw everything (plot lines, axes, labels)"

        if widget is None:
            try: widget = self._widget
            except AttributeError: return gtk.TRUE

        self._widget = widget                
        self._set_axis_lines()
        widget.window.draw_rectangle(widget.get_style().white_gc, gtk.TRUE,
                              self._left, self._top,
                              self._width, self._height)
        
        self._draw_lines(widget)
        self._draw_axes(widget)
        self._draw_labels(widget)

    def clear(self, widget=None, event=None):
        if widget is not None: self._widget = widget
        if event is not None: self._event = event

        self._xaxis = []
        self._yaxis = []
        self._xticks = []
        self._yticks = []
        self._xgrid = []
        self._ygrid = []
        self._lines = []
        self.configure_event(self._widget, self._event)

    def _draw_axes(self, widget):
        "Draw the axis lines and ticks"
        lines = flatten([self._xaxis, self._yaxis,
                         self._xticks, self._yticks,
                         self._xgrid, self._ygrid,
                          ])
        for line in lines:
            gc = self._widget.window.new_gc()
            line.draw(widget, gc, self.transform)

    def _draw_labels(self, widget):
        "Draw the axis and tick labels"
        gc = widget.window.new_gc()
        xlocs = self._transformx(self.get_xticks())
        xlabels = self.get_xticklabels()

        offset = 2
        map(lambda a: a.erase(), self._axText)
        self._axText = []

        xlabelMax = CmpTrue()

        for (loc, label) in zip(xlocs, xlabels):
            axt = AxisText(x=loc, y=self._bottom+5,
                           text=label, bg=self._bg,
                           valign='top', halign='center')
            self._axText.append(axt)
            b,t = axt.get_bottom_top()
            if b>xlabelMax: xlabelMax=b

        ylocs = self._transformy(self.get_yticks())
        ylabels = self.get_yticklabels()

        ylabelMin = CmpTrue()

        for (loc, label) in zip(ylocs, ylabels):
            axt = AxisText(x=self._left-2, y=loc,
                           text=label, bg=self._bg,
                           valign='center', halign='right')
            self._axText.append(axt)
            l,r = axt.get_left_right()
            if l<ylabelMin: ylabelMin=l
            

        try: label = self._xlabel
        except AttributeError: pass
        else:
            self._axText.append(AxisText(x=(self._left+self._right)/2,
                                         y=xlabelMax+2,
                                         text=label, bg=self._bg,
                                         valign='top', halign='center'))


        try: label = self._title
        except AttributeError: pass
        else:
            self._axText.append(AxisText(x=(self._left+self._right)/2,
                                         y=self._top-2,
                                         text=label, bg=self._bg,
                                         valign='bottom', halign='center'))

        map(lambda a: a.draw(widget), self._axText)
        
    def _draw_lines(self, widget):
        "Draw the plot lines"
        for line in self._lines:
            gc = self._widget.window.new_gc()
            gc.set_clip_rectangle( (self._left, self._top,
                                    self._width, self._height) )
            line.draw(widget, gc, self.transform)
    
    def _generate_axis_boilerplate(self, c):

        """ Generate the getters and setters which emulate matlab
        style axis properties (eg, xticks, xticklabels, yticks,
        yticklabels) to interact with the underlying AxisRange
        objects.  

        The underlying range objects are self._xaxisRange, self._xaxisRange,
        self._zrange.

        c is the char specifying the axis to gen the boilerplate for,
        eg, 'x', 'y' 'z'
        """        


        for s in ('get', 'set'):
            exec('self.%s_%cticklabels = self._%caxisRange.%s_labels' % (s,c,c,s))
            exec('self.%s_%cticks = self._%saxisRange.%s_ticks' % (s,c,c,s))
        exec('self.get_%clim = self._%caxisRange.get_range' % (c,c))

    def get_lines(self, type=Line2D):        
        """
        Get all lines of type type, where type is Line2D (all lines)
        or a derived class, eg, CircleLine2D

        You can use this function to set properties of several plot
        lines at once, as in the following

            a1.plot(t1, s1, 'gs', t1, e1, 'bo', t1, p1)
            def fmt_line(l):
               l.set_line_width(2)
               l.set_size(10)
               l.set_fill(1)
            map(fmt_line, a1.get_lines(SymbolLine2D))


        """
        return [line for line in self._lines if isinstance(line, type)]

    def get_xtick_size(self):
        "Get the size of the xticks in pixels"
        try: return self._xtickSizeSet
        except AttributeError: return 5

    def _get_tick_sizes_in_data_coords(self):
        """
        I want the tick sizes to be independent of the size of the
        axis.  Since the axis lines and ticks are in data
        coordinates, I compute the data coordinates of a fixed
        tick size (admittedly, a bit backwards)
        """
        
        ytickSize = self.get_ytick_size()
        dWin = self._winXlim.get_distance()
        dData = self._xaxisRange.get_distance()
        ytickLen = abs(dData/dWin*ytickSize)

        xtickSize = self.get_xtick_size()
        dWin = self._winYlim.get_distance()
        dData = self._yaxisRange.get_distance()
        xtickLen = abs(dData/dWin*xtickSize)
        return (xtickLen, ytickLen)

    def get_ytick_size(self):
        "Get the size of the yticks in pixels"
        try: return self._ytickSizeSet
        except AttributeError: return 5


    def hlines(self, y, xmin, xmax, fmt='k-'):
        """
        plot horizontal lines at each y from xmin to xmax.  xmin or
        xmax can be scalars or len(x) numpy arrays.  If they are
        scalars, then the respective values are constant, else the
        widths of the lines are determined by xmin and xmax
        """
        (LineClass, color) = _process_plot_format(fmt)

        if not iterable(y): y = array([y], Float)
        else:
            try: y.shape
            except AttributeError: y = array([y], Float)
        
        o = ones(y.shape)
        if not iterable(xmin): xmin = xmin*o
        if not iterable(xmax): xmax = xmax*o
        if len(xmin)!=len(y):
            raise ValueError, 'xmin and y are unequal sized sequences'
        if len(xmax)!=len(y):
            raise ValueError, 'xmax and y are unequal sized sequences'
        for (thisY, thisMin, thisMax) in zip(y,xmin,xmax):
            line = LineClass( [thisMin, thisMax], [thisY, thisY],
                              color=color)
            self.add_line( line )



    def plot(self, *args):
        """
        Emulate matlab's plot command.  *args is a variable length
        argument, allowing for multiple x,y pairs with an optional
        format string.  For example, all of the following are legal,
        assuming a is the Axis instance:
        
          a.plot(x,y)            # plot Numeric arrays y vs x
          a.plot(x,y, 'bo')      # plot Numeric arrays y vs x with blue circles
          a.plot(y)              # plot y using x = arange(len(y))
          a.plot(y, 'r+')        # ditto with red plusses

        An arbitrary number of x, y, fmt groups can be specified, as in 
          a.plot(x1, y1, 'g^', x2, y2, 'l-')  
        """

        for line in self._get_lines(args):
            self.add_line(line)

    def _set_axis_lines(self):
        """
        Set the location and extent of the axis lines and ticks.  The
        resultant lines are assigned to the class attributes _xaxis,
        _yaxis, _xticks, _yticks, each of which are lists of Line2D
        instances.
        """
        axColor = self._colors.get('k')
        try:
            self._xaxisRange.set_decade_range()
        except RuntimeError: return

        try:
            self._yaxisRange.set_decade_range()
        except RuntimeError: return

        xmin, xmax = self.get_xlim()
        ymin, ymax = self.get_ylim()
        
        self._xaxis = [
            SolidLine2D([xmin, xmax] , [ymin, ymin],
                        color=axColor),
            SolidLine2D([xmin, xmax] , [ymax, ymax],
                        color=axColor), 
            ]

        self._yaxis = [
            SolidLine2D([xmin, xmin] , [ymin, ymax],
                        color=axColor), 
            SolidLine2D([xmax, xmax] , [ymin, ymax],
                        color=axColor), 
            ]



        xtickLen, ytickLen = self._get_tick_sizes_in_data_coords()
        ytick1 = array([xmin, xmin+ytickLen ], Float)
        ytick2 = array([xmax-ytickLen, xmax], Float)
        ygrid = array([xmin, xmax], Float)


        xtick1 = array([ymin, ymin+xtickLen], Float)
        xtick2 = array([ymax-xtickLen, ymax], Float)
        xgrid = array([ymin, ymax], Float)

        xticks = self.get_xticks()
        yticks = self.get_yticks()
        Nx = len(xticks)
        Ny = len(yticks)

        self._xticks = [None]*(2*Nx)
        self._xgrid = [None]*Nx
        

        for tick,i in zip(xticks, range(0, 2*Nx, 2)):
            x = array([tick, tick], Float)
            self._xticks[i] = SolidLine2D(x, xtick1, color=axColor)
            self._xticks[i+1] = SolidLine2D(x, xtick2, color=axColor)
            l = DottedLine2D(x, xgrid, color=axColor)
            l.set_spacing(5)
            self._xgrid[i//2]= l

        self._yticks = [None]*(2*Ny)
        self._ygrid = [None]*Ny
        for tick,i in zip(yticks, range(0, 2*Ny, 2)):
            y = array([tick, tick], Float)
            self._yticks[i] = SolidLine2D(ytick1, y, color=axColor)
            self._yticks[i+1] = SolidLine2D(ytick2, y, color=axColor)
            l = DottedLine2D(ygrid, y, color=axColor)
            l.set_spacing(5)
            self._ygrid[i//2] = l

    def set_background_color(self, color):
        self._bg = color
        #self._bg = ColorDispatcher().get('r')
        
    def set_title(self, label):
        "Set the figure title"
        self._title = label

    def set_xlabel(self, label):
        "Set the label for the xaxis"
        self._xlabel = label

    def set_xlim(self, v):
        "Set the limits for the xaxis; v = [xmin, xmax]"
        xmin, xmax = v
        self._xaxisRange.set_range([xmin, xmax])

        map(lambda l: l.set_xclip(xmin, xmax), self._lines)

    def set_ylim(self, v):
        "Set the limits for the xaxis; v = [ymin, ymax]"
        ymin, ymax = v
        self._yaxisRange.set_range([ymin, ymax])

        # I set the gc clip to be just outside the actual range so
        # that the flat, artifactual lines caused by the fact that the
        # x data clip is done first will be drawn outside the gc clip
        # rectangle .  5% is an arbitrary factor chosen so that only a
        # fraction of unnessecary data is plotted, since the data
        # clipping is done for plot efficiency.  See _set_clip in
        # lines.py for more info.  [ Note: now that I have disabled y
        # clipping for connected lines in lines.py, this hack is no
        # longer needed, but I'm going to preserve it since I may want
        # to re-enable y clipping for conencted lines and I can afford
        # the small performance hit. ]
        offset = 0.05*(ymax-ymin)
        map(lambda l: l.set_yclip(ymin-offset, ymax+offset), self._lines)

    def set_xtick_size(self, size):
        "Set the size of the x ticks in pixels"
        self._xtickSizeSet = size

    def set_ylabel(self, label):
        "Set the label for the yaxis"
        self._ylabel = label

    def set_ytick_size(self, size):
        "Set the size of the y ticks in pixels"
        self._ytickSizeSet = size

    def transform(self, x, y):
        """
        A wrapper func to pass to the Line2D draw instance.  The
        return tuple xt, yt are the x, y values tranformed into window
        coordinates
        """
        return (self._transformx(x), self._transformy(y))

    def transform_win_to_data(self, xwin, ywin):
        return (self._transform_win_to_x(xwin),
                self._transform_win_to_y(ywin))
    
    def _transform_win_to_x(self, win):
        """
        win is a scalar or numpy array in window coords; transform it
        into a same size array in x coords
        """
        valMin, valMax = self.get_xlim()
        winMin, winMax = self._winXlim.get_range()
        return ((valMax-valMin)/(winMax - winMin))*(win-winMin) + valMin

    def _transform_win_to_y(self, win):
        """
        win is a scalar or numpy array in window coords; transform it
        into a same size array in y coords
        """

        valMin, valMax = self.get_ylim()
        winMin, winMax = self._winYlim.get_range()
        return ((valMax-valMin)/(winMax - winMin))*(win-winMin) + valMin

    def _transformx(self, val):
        "Transform x data (val can be a scalar or numpy array) into window coords"
        valMin, valMax = self.get_xlim()
        winMin, winMax = self._winXlim.get_range()
        return ((winMax - winMin)/(valMax-valMin))*(val-valMin) + winMin

    def _transformy(self, val):
        "Transform y data (val can be a scalar or numpy array) into window coords"
        valMin, valMax = self.get_ylim()
        winMin, winMax = self._winYlim.get_range()
        return ((winMax - winMin)/(valMax-valMin))*(val-valMin) + winMin



    def vlines(self, x, ymin, ymax, fmt='k-'):
        """
        Plot vertical lines at each x from ymin to ymax.  ymin or ymax
        can be scalars or len(x) numpy arrays.  If they are scalars,
        then the respective values are constant, else the heights of
        the lines are determined by ymin and ymax
        """
        (LineClass, color) = _process_plot_format(fmt)
        try: x.shape
        except AttributeError: x = array(x, Float)

        o = ones(x.shape)
        if not iterable(ymin): ymin = ymin*o
        if not iterable(ymax): ymax = ymax*o
        if len(ymin)!=len(x):
            raise ValueError, 'ymin and x are unequal sized sequences'
        if len(ymax)!=len(x):
            raise ValueError, 'ymax and x are unequal sized sequences'
        for (thisX, thisMin, thisMax) in zip(x,ymin,ymax):
            line = LineClass( [thisX, thisX], [thisMin, thisMax],
                              color=color)
            self.add_line( line)



class Subplot(Axes):
    """
    Emulate matlab's subplot command, creating axes with

      Subplot(numRows, numCols, plotNum)

    where plotNum=1 is the first plot number and increasing plotNums
    fill rows first.  max(plotNum)==numRows*numCols

    You can leave out the commans if numRows<=numCols<=plotNum<10, as in

      Subplot(211)    # 2 rows, 1 column, first (upper) plot
    """
    
    def __init__(self, *args):
        if len(args)==1:
            s = str(*args)
            if len(s) != 3:
                raise ValueError, 'Argument to subplot must be a 3 digits long'
            rows, cols, num = map(int, s)
        elif len(args)==3:
            rows, cols, num = args
        else:
            raise ValueError, 'Illegal argument to subplot'
        total = rows*cols
        num -= 1    # convert from matlab to python indexing ie num in range(0,total)
        if num >= total:
            raise ValueError, 'Subplot number exceeds total subplots'
        left, right = .11, .9
        bottom, top = .11, .9
        rat = 0.2             # ratio of fig to seperator for multi row/col figs
        totWidth = right-left
        totHeight = top-bottom
    
        figH = totHeight/(rows + rat*(rows-1))
        sepH = rat*figH
    
        figW = totWidth/(cols + rat*(cols-1))
        sepW = rat*figW
    
        rowNum, colNum =  divmod(num, cols)
        
        figBottom = top - (rowNum+1)*figH - rowNum*sepH
        figLeft = left + colNum*(figW + sepH)
        Axes.__init__(self, [figLeft, figBottom, figW, figH])

        self.rowNum = rowNum
        self.colNum = colNum
        self.numRows = rows
        self.numCols = cols

    def is_first_col(self):
        return self.colNum==0

    def is_first_row(self):
        return self.rowNum==0

    def is_last_row(self):
        return self.rowNum==self.numRows-1


    def is_last_col(self):
        return self.colNum==self.numCols-1


class Button_Navigation(gtk.Button):
    def __init__(self, axes, label):
        gtk.Button.__init__(self, label)
        self._axes = axes
        self._ind = range(len(self._axes))
        self._active = [ self._axes[i] for i in self._ind ]

    def set_active(self, ind):
        self._ind = ind
        self._active = [ self._axes[i] for i in self._ind ]

class Button_XAxisLim(Button_Navigation):
    def __init__(self, axes, label):
        Button_Navigation.__init__(self, axes, label)

    def scroll(self, widget, event):

        for axis in self._active:
            xlim = array(axis.get_xlim(), Float)
            step = (xlim[1]-xlim[0])/10
            if event.direction == gdk.SCROLL_UP:
                xlim += step
            elif event.direction == gdk.SCROLL_DOWN:
                xlim -= step

            axis.set_xlim(xlim)
            axis.draw()
        return gtk.TRUE

class Button_YAxisLim(Button_Navigation):
    def __init__(self, axes, label):
        Button_Navigation.__init__(self, axes, label)

    def scroll(self, widget, event):
        for axis in self._active:
            ylim = array(axis.get_ylim(), Float)
            step = (ylim[1]-ylim[0])/10
            if event.direction == gdk.SCROLL_UP:
                ylim += step
            elif event.direction == gdk.SCROLL_DOWN:
                ylim -= step
            axis.set_ylim(ylim)
            axis.draw()
        return gtk.TRUE

class Button_XAxisScale(Button_Navigation):
    def __init__(self, axes, label):
        Button_Navigation.__init__(self, axes, label)
        
    def zoom(self, widget, event):
        for axis in self._active:
            xlim = array(axis.get_xlim(), Float)
            d = (xlim[1]-xlim[0])/10
          
            if event.direction == gdk.SCROLL_UP:
                sign = -1
            elif event.direction == gdk.SCROLL_DOWN:
                sign = 1
            axis.set_xlim([xlim[0]-sign*d, xlim[1]+sign*d])
            axis.draw()
        return gtk.TRUE

class Button_AxisAutoScale(Button_Navigation):
    def __init__(self, axes, label):
        Button_Navigation.__init__(self, axes, label)

    def autoscale(self, event):
        for axis in self._active:
            axis.autoscale()
        return gtk.TRUE

class Button_YAxisScale(Button_Navigation):
    def __init__(self, axes, label):
        Button_Navigation.__init__(self, axes, label)

    def zoom(self, widget, event):
        for axis in self._active:
            ylim = array(axis.get_ylim(), Float)
            d = (ylim[1]-ylim[0])/10
            if event.direction == gdk.SCROLL_UP:
                sign = -1
            elif event.direction == gdk.SCROLL_DOWN:
                sign = 1            
            axis.set_ylim([ylim[0]-sign*d, ylim[1]+sign*d])
            axis.draw()
        return gtk.TRUE


class Dialog_MeasureTool(gtk.Dialog):
    def __init__(self):
        gtk.Dialog.__init__(self)
        self.set_title("Axis measurement tool")
        self.vbox.set_spacing(1)
        tooltips = gtk.Tooltips()

        self.posFmt =   'Position: x=%1.4f y=%1.4f'
        self.deltaFmt = 'Delta   : x=%1.4f y=%1.4f'

        self.positionLabel = gtk.Label(self.posFmt % (0,0))
        self.vbox.pack_start(self.positionLabel)
        self.positionLabel.show()
        tooltips.set_tip(self.positionLabel,
                         "Move the mouse to data point over axis")

        self.deltaLabel = gtk.Label(self.deltaFmt % (0,0))
        self.vbox.pack_start(self.deltaLabel)
        self.deltaLabel.show()

        tip = "Left click and hold while dragging mouse to measure " + \
              "delta x and delta y"
        tooltips.set_tip(self.deltaLabel, tip)
                         
        self.show()

    def update_position(self, x, y):
        self.positionLabel.set_text(self.posFmt % (x,y))

    def update_delta(self, dx, dy):
        self.deltaLabel.set_text(self.deltaFmt % (dx,dy))


class Dialog_NavigationTool(gtk.Dialog):
    def __init__(self, axes):
        gtk.Dialog.__init__(self)
        self.axes = axes
        self.set_title("Navigation tool")

        self.vbox.set_spacing(1)
        tooltips = gtk.Tooltips()
        self._save_buttons = []

        self._checked = {}
        self._buttons = []
        
        hbox = gtk.HBox(spacing=3)
        self.vbox.pack_start(hbox)
        hbox.show()

        label = gtk.Label('Axis #: ')
        hbox.pack_start(label)
        label.show()
        
        for (a,i) in zip(self.axes, indices):
            label = '%d' % (i+1)
            button = gtk.CheckButton(label)
            button.connect('toggled', self.checked_event, label)
            button.set_active(gtk.TRUE)
            hbox.pack_start(button, gtk.TRUE, gtk.TRUE, 0)
            button.show()
            self._checked[i] = 1



        button = Button_XAxisLim(self.axes, "x range")
        button.connect("scroll_event", button.scroll)
        hbox.pack_start(button, expand=gtk.TRUE, fill=gtk.TRUE)
        button.show()
        tooltips.set_tip(button, "Use the wheel mouse to scroll the x axis")
        _add_button_icon(button, os.path.join(PIXMAP_PATH, 'lrarrow.xpm'), None)
        self._buttons.append(button)

        # the xscale
        button = Button_XAxisScale(self.axes, "x zoom")
        button.connect("scroll_event", button.zoom)
        hbox.pack_start(button, expand=gtk.TRUE, fill=gtk.TRUE)
        tooltips.set_tip(button, "Use the wheel mouse to zoom the x axis")
        button.show()
        _add_button_icon(button, os.path.join(PIXMAP_PATH, 'magnify_x.xpm'), None)
        self._buttons.append(button)

        # the ylim
        button = Button_YAxisLim(self.axes, "y range")
        button.connect("scroll_event", button.scroll)
        hbox.pack_start(button, expand=gtk.TRUE, fill=gtk.TRUE)
        tooltips.set_tip(button, "Use the wheel mouse to scroll the y axis")
        button.show()
        _add_button_icon(button, os.path.join(PIXMAP_PATH, 'udarrow.xpm'), None)
        self._buttons.append(button)

        # the yscale
        button = Button_YAxisScale(self.axes, "y zoom")
        button.connect("scroll_event", button.zoom)
        hbox.pack_start(button, expand=gtk.TRUE, fill=gtk.TRUE)
        tooltips.set_tip(button, "Use the wheel mouse to zoom the y axis")
        button.show()
        _add_button_icon(button, os.path.join(PIXMAP_PATH, 'magnify_y.xpm'), None)
        self._buttons.append(button)

        # autoscale
        button = Button_AxisAutoScale(self.axes, "refresh")
        button.connect("clicked", button.autoscale)
        hbox.pack_start(button, expand=gtk.TRUE, fill=gtk.TRUE)
        tooltips.set_tip(button, "Click to restore the default axis scaling")
        button.show()
        _add_button_icon(button, os.path.join(PIXMAP_PATH, 'recycle.xpm'), None)
        self._buttons.append(button)
        self.show()

    def checked_event(self, widget, data=None):
        i = int(data)-1
        val = widget.get_active()        
        self._checked[i] = val

        ind = [key for (key, val) in self._checked.items() if val]

        for button in self._buttons:
            button.set_active(ind)


    
    
class Figure(gtk.DrawingArea):
    def __init__(self, size=(600, 500)):
        gtk.DrawingArea.__init__(self)
        self.axes = []
        self._lastDir = os.getcwd()
        self._create_context_menu()
        self.set_size_request(size[0], size[1])

        #self.connect('focus_in_event', self.focus_in_event)
        self.connect('expose_event', self.expose_event)
        self.connect('configure_event', self.configure_event)
        self.connect('motion_notify_event', self.motion_notify_event)
        self.connect('button_press_event', self.button_press_event)
        self.connect('button_release_event', self.button_release_event)

        self.set_events(
            #gdk.FOCUS_CHANGE_MASK|
                        gdk.EXPOSURE_MASK |
                        gdk.LEAVE_NOTIFY_MASK |
                        gdk.BUTTON_PRESS_MASK |
                        gdk.BUTTON_RELEASE_MASK |
                        gdk.POINTER_MOTION_MASK )
        

        self.isConfigured=0
        
    def add_axis(self, a):
        self.axes.append(a)
         
    def button_press_event(self, widget, event):
        win = widget.window

        if event.button==1:
            for a in self.axes:
                if not a.in_axes(event.x, event.y): continue
                self._in_x, self._in_y = \
                            a.transform_win_to_data(event.x, event.y)
                break
        
        elif event.button==3:
            # right click brings up the context menu
            self._context_menu.popup(None, None, None, 0, 0)
                
        return gtk.TRUE

    def button_release_event(self, widget, event):

        if event.button==1:
            try: del self._in_x, self._in_y
            except AttributeError: pass
        
        return gtk.TRUE

    def motion_notify_event(self, widget, event):

        try: self.measureDialog
        except AttributeError: return gtk.TRUE

        for a in self.axes:
            if not a.in_axes(event.x, event.y): continue
            x, y = a.transform_win_to_data(event.x, event.y)
            self.measureDialog.update_position(x,y)


            try: self.measureDialog.update_delta( x-self._in_x,
                                                  y-self._in_y)
            except AttributeError: pass
            break

        return gtk.TRUE

    def clear(self, widget=None, event=None):
        if widget is not None: self._widget = widget
        if event is not None: self._event = event
        self.axes = []
        self.configure_event(self._widget, self._event)
        
            
    def configure_event(self, widget, event):
        if widget is not None: self._widget = widget
        if event is not None: self._event = event
        cmap = self._widget.get_colormap()
        self.width, self.height = self._widget.window.get_size()
        self.grey = cmap.alloc_color(197*255,202*255,197*255)
        self.black = cmap.alloc_color(0,0,0)
        for axis in self.axes:
            axis.set_background_color(self.grey)
            axis.configure_event(self._widget, self._event)
        self.draw(self._widget)
        self.isConfigured=1
        return gtk.TRUE

    def _create_measure_dialog(self, *args):
        self.measureDialog = Dialog_MeasureTool()
        #(l,t)  = self._win.get_position()
        #(w,h) = self._win.get_size()
        #(dw, dh) = self.measureDialog.get_size()
        #self.measureDialog.move(l+w+10, t)
        

    def _create_navigation_dialog(self, *args):
        "Build the axis navigation dialog box"
        self.navigationDialog = Dialog_NavigationTool(self.axes)
        #(l,t)  = self._win.get_position()
        #(w,h) = self._win.get_size()
        #(dw, dh) = self.navigationDialog.get_size()
        #print (l,t,w,h,dw,dh)
        #self.navigationDialog.move(l, t+h+dh)
        #print 'NPos: ', self.navigationDialog.get_position()

    def _create_print_dialog(self, *args):
                
        def print_ok(event):
            fname = self._filew.get_filename()
            self._filew.destroy()
            self.print_figure(fname)

        def print_cancel(event):
            self._filew.destroy()
            
        self._filew = gtk.FileSelection(title='Save the figure')
        self._filew.set_filename(self._lastDir + os.sep)
        self._filew.connect("destroy", lambda w: self._filew.destroy())

        self._filew.ok_button.connect("clicked", print_ok)
        self._filew.cancel_button.connect("clicked", print_cancel)
        self._filew.show()

    def draw(self, widget):
        if widget is not None: self._widget = widget

        win = self._widget.window
        gc = win.new_gc()
        gc.foreground = self.grey
        
        widget.window.draw_rectangle(gc, gtk.TRUE,
                                     0, 0, self.width, self.height)
        for axis in self.axes:
             axis.draw(self._widget)
        

    def expose_event(self, widget=None, event=None):
        #print dir(event), dir(widget)
        if widget is not None: self._widget = widget
        if event is not None: self._event = event

        if not self.isConfigured:
            self.configure_event(self._widget, self._event)

        self.draw(self._widget)

        win = self._widget.window
        width, height = win.get_size()
        self.pixbuf = gtk.gdk.Pixbuf(gtk.gdk.COLORSPACE_RGB, 0, 8,
                                     width, height)
        self.pixbuf.get_from_drawable(win, win.get_colormap (),
                                      0, 0, 0, 0, width, height)

        return gtk.TRUE

    def focus_in_event(self, widget, event):
        pass
        win = widget.window
        width, height = win.get_size()
        self.pixbuf = gtk.gdk.Pixbuf(gtk.gdk.COLORSPACE_RGB, 0, 8,
                                     width, height)
        self.pixbuf.get_from_drawable(win, win.get_colormap (),
                                      0, 0, 0, 0, width, height)

    def get_axes(self):
        return self.axes

    def _create_context_menu(self):
        self._context_menu = gtk.Menu()

        label = "Show navigation tool"
        menu_items = gtk.MenuItem(label)
        self._context_menu.append(menu_items)
        menu_items.connect("activate", self. _create_navigation_dialog, label)
        menu_items.show()

        label = "Show measure tool"
        menu_items = gtk.MenuItem(label)
        self._context_menu.append(menu_items)
        menu_items.connect("activate", self._create_measure_dialog, label)
        menu_items.show()

        label = "Save Figure"
        menu_items = gtk.MenuItem(label)
        self._context_menu.append(menu_items)
        menu_items.connect("activate", self._create_print_dialog, label)
        menu_items.show()


    def print_figure(self, filename):
        "Print figure to filename; png only"
        self.pixbuf.save(filename, "png")




class AxisText:
    "Handles storing and drawing of text in windows coordinates"
    def __init__(self, x, y, text,
                 fg=ColorDispatcher().get('k'),
                 bg=ColorDispatcher().get('w'),
                 valign='bottom',
                 halign='left',
                 font=gtk.load_font(
#        "-b&h-lucidatypewriter-bold-r-normal-sans-14-100-100-100-m-80-iso8859-1"
#        "-*-helvetica-medium-r-normal--*-*-*-*-*-*-*-*"
#        "-*-times-medium-r-normal--*-*-*-*-*-*-*-*"
#        "-*-times-bold-r-normal--*-*-*-*-*-*-*-*"
         "-*-times-bold-r-normal--14-100-100-100-*-*-*-*"
        )
                 ):
#        "-b&h-lucidatypewriter-bold-r-normal-sans-14-100-100-100-m-80-iso8859-1")
        
        self._x, self._y = x, y
        self._fg, self._bg = fg, bg
        self._text = text
        self._valign, self._halign = valign, halign
        self._font = font

        #Pango:  (0, 2048, 15360, 8192) (0, 0, 15360, 13312)


    def draw(self, widget=None):

        # store the widget if it's there, otherwise use stored widget if possible or just ignore
        if widget is None:
            try: widget = self._widget
            except AttributeError: return gtk.TRUE
        self._widget = widget                
        gc = widget.window.new_gc()
        gc.foreground = self._fg

        if 0:
            layout = widget.create_pango_layout(self._text)
            #layout.set_font_description(self._font)
            ink_rect, logical_rect = layout.get_pixel_extents()
            self._width = ink_rect[2]/1000
            self._height = ink_rect[3]/1000

            print 'GC.Font: ', gc.font
            print 'Pango: "%s"' % self._text, self._width, self._height
            print 'Font : "%s"' % self._text, \
                  gc.font.string_width(self._text), gc.font.string_height(self._text)        
            print dir(gc)
        ox, oy = self._compute_offsets()
        widget.window.draw_text(self._font, gc,
                                self._x+ox, self._y+oy, self._text)
        self._widget = widget

    def erase(self):
        try: self._widget
        except AttributeError: return
        
        gc = self._widget.window.new_gc()
        gc.foreground = self._bg
        ox, oy = self._compute_offsets()
        self._widget.window.draw_text(self._font, gc, self._x+ox,
                                      self._y+oy, self._text)


    def set_text(self, text):
        self.erase()
        self._text = text
        self._draw()


    def _compute_offsets(self):
        'Return the (x,y) offsets to comensate for the alignment specifications'

        if self._halign=='center':
            offsetx = -self._font.string_width(self._text)/2
        elif self._halign=='right':
            offsetx = -self._font.string_width(self._text)
        else:
            offsetx = 0

        if self._valign=='center':
            offsety = self._font.string_height(self._text)/2
        elif self._valign=='top':
            offsety = self._font.string_height(self._text)
        else:
            offsety = 0

        return (offsetx, offsety)

    def get_left_right(self):
        "get the left, right boundaries of the text in in win coords"
        ox,oy = self._compute_offsets()
        return self._x + ox, self._x + ox + \
               self._font.string_width(self._text)

    def get_bottom_top(self):
        "get the  bottom, top boundaries of the text in win coords"
        ox,oy = self._compute_offsets()
        return self._y + oy, self._y + oy - \
               self._font.string_height(self._text)

if __name__=='__main__':
    import _simple_demo
    _simple_demo.subplot_demo()
    
