from __future__ import division, generators

import math, sys

from numerix import MLab, absolute, arange, array, asarray, ones, transpose, \
     log10, Float

import mlab
from artist import Artist
from axis import XTick, YTick, XAxis, YAxis
from cbook import iterable, is_string_like, flatten, enumerate, True, False
from colors import ColormapJet, Grayscale
from dates import SEC_PER_HOUR, SEC_PER_DAY, SEC_PER_WEEK
from ticker import YearLocator, MonthLocator, WeekdayLocator, \
     DayLocator, HourLocator, MinuteLocator, DateFormatter

from image import Image
from legend import Legend
from lines import Line2D, lineStyles, lineMarkers

from mlab import meshgrid
from matplotlib import rcParams
from patches import Rectangle, Circle, Polygon, bbox_artist
from table import Table
from text import Text, _process_text_args
from transforms import Bound2D, Transform
from font_manager import FontProperties

# build a string to constants dict for images

def _process_plot_format(fmt):
    """
    Process a matlab style color/line style format string.  Return a
    linestyle, color tuple as a result of the processing.  Default
    values are ('-', 'b').  Example format strings include

    'ko'    : black circles
    '.b'    : blue dots
    'r--'   : red dashed lines

    See Line2D.lineStyles and GraphicsContext.colors for all possible
    styles and color format string.

    """

    colors = {
        'b' : 1,
        'g' : 1,
        'r' : 1,
        'c' : 1,
        'm' : 1,
        'y' : 1,
        'k' : 1,
        'w' : 1,
        }

    
    linestyle = None
    marker = None
    color = rcParams['lines.color']

    # handle the multi char special cases and strip them from the
    # string
    if fmt.find('--')>=0:
        linestyle = '--'
        fmt = fmt.replace('--', '')
    if fmt.find('-.')>=0:
        linestyle = '-.'
        fmt = fmt.replace('-.', '')
    
    chars = [c for c in fmt]

    for c in chars:        
        if lineStyles.has_key(c):
            if linestyle is not None:
                raise ValueError, 'Illegal format string "%s"; two linestyle symbols' % fmt
            
            linestyle = c
        elif lineMarkers.has_key(c):
            if marker is not None:
                raise ValueError, 'Illegal format string "%s"; two marker symbols' % fmt
            marker = c
        elif colors.has_key(c):
            color = c
        else:
            err = 'Unrecognized character %c in format string' % c
            raise ValueError, err
    if linestyle is None and marker is None:
        linestyle = rcParams['lines.linestyle']
    return linestyle, marker, color

class _process_plot_var_args:    
    """

    Process variable length arguments to the plot command, so that
    plot commands like the following are supported

      plot(t, s)
      plot(t1, s1, t2, s2)
      plot(t1, s1, 'ko', t2, s2)
      plot(t1, s1, 'ko', t2, s2, 'r--', t3, e3)

    an arbitrary number of x, y, fmt are allowed
    """

    colors = ['b','g','r','c','m','y','k']

    # if the default line color is a color format string, move it up
    # in the que
    try: ind = colors.index(rcParams['lines.color'])
    except ValueError:
        firstColor = rcParams['lines.color']
    else:
        colors[0], colors[ind] = colors[ind], colors[0]
        firstColor = colors[0]

    Ncolors = len(colors)
    def __init__(self, dpi, bbox, transx, transy, command='plot'):
        self.dpi = dpi
        self.bbox = bbox
        self.transx = transx
        self.transy = transy
        self.count = 0
        self.command = command
        
    def __call__(self, *args, **kwargs):
        ret =  self._grab_next_args(*args, **kwargs)
        return ret

    def set_lineprops(self, line, **kwargs):
        assert self.command == 'plot', 'set_lineprops only works with "plot"'
        for key, val in kwargs.items():
            funcName = "set_%s"%key
            func = getattr(line,funcName)
            func(val)
        
    def set_patchprops(self, fill_poly, **kwargs):
        assert self.command == 'fill', 'set_patchprops only works with "fill"'
        for key, val in kwargs.items():
            funcName = "set_%s"%key
            func = getattr(fill_poly,funcName)
            func(val)
        
    def _plot_1_arg(self, y, **kwargs):
        assert self.command == 'plot', 'fill needs at least 2 arguments'
        if self.count==0:
            color = self.firstColor
        else:
            color = self.colors[int(self.count % self.Ncolors)]

        ret =  Line2D(self.dpi, self.bbox, arange(len(y)), y,
                      color = color,
                      transx = self.transx,
                      transy = self.transy,	
                      )
        self.set_lineprops(ret, **kwargs)
        self.count += 1
        return ret

    def _plot_2_args(self, tup2, **kwargs):
        if is_string_like(tup2[1]):
            assert self.command == 'plot', 'fill needs at least 2 non-string arguments'
            y, fmt = tup2
            linestyle, marker, color = _process_plot_format(fmt)
            ret =  Line2D(self.dpi, self.bbox,
                          xdata=arange(len(y)), ydata=y,
                          color=color, linestyle=linestyle, marker=marker,
                          transx = self.transx,
                          transy = self.transy,	
                          )
            self.set_lineprops(ret, **kwargs)
            return ret
        else:
            x,y = tup2
            #print self.count, self.Ncolors, self.count % self.Ncolors
            if self.command == 'plot':
                ret =  Line2D(self.dpi, self.bbox, x, y,
                              color = self.colors[self.count % self.Ncolors],
                              transx = self.transx,
                              transy = self.transy,	
                              )
                self.set_lineprops(ret, **kwargs)
                self.count += 1
            elif self.command == 'fill':
                ret = Polygon(
                    self.dpi, self.bbox,
                    zip(x,y),
                    fill=True,
                    transx = self.transx, 
                    transy = self.transy, 
                    )
                self.set_patchprops(ret, **kwargs)
            return ret

    def _plot_3_args(self, tup3, **kwargs):
        if self.command == 'plot':
            x, y, fmt = tup3
            linestyle, marker, color = _process_plot_format(fmt)
            ret = Line2D(self.dpi, self.bbox,
                          x, y, color=color, linestyle=linestyle, marker=marker,
                          transx = self.transx,
                          transy = self.transy,	
                          )
            self.set_lineprops(ret, **kwargs)
        if self.command == 'fill':
            x, y, facecolor = tup3
            ret = Polygon(
                self.dpi, self.bbox,
                zip(x,y),
                facecolor = facecolor,
                fill=True, 
                transx = self.transx, 
                transy = self.transy, 
                )
            self.set_patchprops(ret, **kwargs)
        return ret

    def _grab_next_args(self, *args, **kwargs):
        remaining = args
        while 1:
            if len(remaining)==0: return
            if len(remaining)==1:
                yield self._plot_1_arg(remaining[0], **kwargs)
                remaining = []
                continue
            if len(remaining)==2:
                yield self._plot_2_args(remaining, **kwargs)
                remaining = []
                continue
            if len(remaining)==3:
                if not is_string_like(remaining[2]):
                    raise ValueError, 'third arg must be a format string'
                yield self._plot_3_args(remaining, **kwargs)
                remaining=[]
                continue
            if is_string_like(remaining[2]):
                yield self._plot_3_args(remaining[:3], **kwargs)
                remaining=remaining[3:]
            else:
                yield self._plot_2_args(remaining[:2], **kwargs)
                remaining=remaining[2:]
            #yield self._plot_2_args(remaining[:2])
            #remaining=args[2:]
        



        
class Axes(Artist):
    """
    Emulate matlab's axes command, creating axes with

       Axes(position=[left, bottom, width, height])

    where all the arguments are fractions in [0,1] which specify the
    fraction of the total figure window.  

    axisbg is the color of the axis background

    """

    def __init__(self, fig, position,
                 axisbg = rcParams['axes.facecolor'],
                 frameon = True):
        bbox = Bound2D(0,0,1,1) # resize will update
        Artist.__init__(self, fig.dpi, bbox) 

        self._position = position
        self.figure = fig
        # these next two calls must be made immediately after both
        # axis are built since they are tightly coupled
        self._axisbg = axisbg
        self._frameon = frameon
        

        self.cla()


    def cla(self):
        """
        Clear the current axes        
        """
        self.xaxis = XAxis(self)
        self.yaxis = YAxis(self)
        self.xaxis.build_artists()
        self.yaxis.build_artists()
        self.resize()  # compute bounding box

        self._get_lines = _process_plot_var_args(
            self.dpi, self.bbox, self.xaxis.transData, self.yaxis.transData)

        self._get_patches_for_fill = _process_plot_var_args(
            self.dpi, self.bbox, self.xaxis.transData, self.yaxis.transData,
            command='fill')

        self._gridOn = rcParams['axes.grid']
        self._lines = []
        self._patches = []
        self._text = []     # text in axis coords
        self._tables = []
        self._artists = []
        self._legend = None
        self._image = None
        self._imExtent = None

        self.grid(self._gridOn)
        self._title =  Text(
            self.dpi, 
            self.bbox, 
            x=0.5, y=1.02, text='',
            fontproperties=FontProperties(size=rcParams['axes.titlesize']),
            verticalalignment='bottom',
            horizontalalignment='center',
            transx = self.xaxis.transAxis,
            transy = self.yaxis.transAxis,            
            )
        self._title.set_clip_on(False)

        self._axesPatch = Rectangle(
            self.dpi, self.bbox,
            xy=(0,0), width=1, height=1,
            facecolor=self._axisbg,
            edgecolor=rcParams['axes.edgecolor'],
            transx = self.xaxis.transAxis,
            transy = self.yaxis.transAxis,            
            )
        self._axesPatch.set_linewidth(rcParams['axes.linewidth'])
        self._axesPatch.set_clip_on(False)
        self.axison = True


    def _pass_func(self, *args, **kwargs):
        pass

    def add_artist(self, a):
        "Add any artist to the axes"
        self._artists.append(a)

    def add_line(self, line):
        "Add a line to the list of plot lines"

        xdata = line.get_xdata()
        if line.transx != self.xaxis.transData:
            # data is not in axis data units.  We must transform it to
            # display and then back to data to get it in data units
            displayx = line.transx.positions(xdata)
            xdata = self.xaxis.transData.inverse_positions(displayx)

        ydata = line.get_ydata()
        if line.transy != self.yaxis.transData:
            # data is not in axis data units.  We must transform it to
            # display and then back to data to get it in data units
            displayy = line.transy.positions(ydata)
            ydata = self.yaxis.transData.inverse_positions(displayy)

            
        self.xaxis.datalim.update(xdata)
        self.yaxis.datalim.update(ydata)
        self._lines.append(line)

    def add_patch(self, patch):
        "Add a line to the list of plot lines"
        self._patches.append(patch)

    def add_table(self, tab):
        "Add a table instance to the list of axes tables"
        self._tables.append(tab)

    


    def bar(self, left, height, width=0.8, bottom=0,
            color='b', yerr=None, xerr=None, capsize=3
            ):
        """
        BAR(left, height)
        
        Make a bar plot with rectangles at
          left, left+width, 0, height
        left and height are Numeric arrays

        Return value is a list of Rectangle patch instances

        BAR(left, height, width, bottom,
            color, yerr, xerr, capsize, yoff)

        xerr and yerr, if not None, will be used to generate errorbars
        on the bar chart

        color specifies the color of the bar

        capsize determines the length in points of the error bar caps

        
        The optional arguments color, width and bottom can be either
        scalars or len(x) sequences

        This enables you to use bar as the basis for stacked bar
        charts, or candlestick plots
        """

        
        left = asarray(left)
        height = asarray(height)

        patches = []


        # if color looks like a color string, and RGB tuple or a
        # scalar, then repeat it by len(x)
        if (is_string_like(color) or
            (iterable(color) and len(color)==3 and len(left)!=3) or
            not iterable(color)):
            color = [color]*len(left)


        if not iterable(bottom):
            bottom = array([bottom]*len(left), Float)
        else:
            bottom = asarray(bottom)
        if not iterable(width):
            width = array([width]*len(left), Float)
        else:
            width = asarray(width)

        N = len(left)
        assert len(bottom)==N, 'bar arg bottom must be len(left)'
        assert len(width)==N, 'bar arg width must be len(left) or scalar'
        assert len(height)==N, 'bar arg height must be len(left) or scalar'
        assert len(color)==N, 'bar arg color must be len(left) or scalar'

        self.xaxis.datalim.update(left)
        self.xaxis.datalim.update(left+width)
        self.yaxis.datalim.update(bottom)
        self.yaxis.datalim.update(bottom+height)

        self.xaxis.autoscale_view()
        self.yaxis.autoscale_view()


        args = zip(left, bottom, width, height, color)
        for l, b, w, h, c in args:            
            if h<0:
                b += h
                h = abs(h)
            r = Rectangle(
                self.dpi, self.bbox,
                xy=(l, b), width=w, height=h,
                facecolor=c,
                transx = self.xaxis.transData, 
                transy = self.yaxis.transData, 
                )

            self._patches.append(r)
            patches.append(r)
 

        if xerr is not None or yerr is not None:
            l1, l2 = self.errorbar(
                left+0.5*width, bottom+height,
                yerr=yerr, xerr=xerr,
                fmt='o', capsize=capsize)
            for line in l1:
                line.set_markerfacecolor('k')
                line.set_markeredgecolor('k')
        return patches



    def clear(self):
        self.cla()
        
    def cohere(self, x, y, NFFT=256, Fs=2, detrend=mlab.detrend_none,
               window=mlab.window_hanning, noverlap=0):
        """
        cohere the coherence between x and y.  Coherence is the normalized
        cross spectral density

        Cxy = |Pxy|^2/(Pxx*Pyy)

        The return value is (Cxy, f), where f are the frequencies of the
        coherence vector.  See the docs for psd and csd for information
        about the function arguments NFFT, detrend, windowm noverlap, as
        well as the methods used to compute Pxy, Pxx and Pyy.

        Returns the tuple Cxy, freqs

        Refs:
          Bendat & Piersol -- Random Data: Analysis and Measurement
            Procedures, John Wiley & Sons (1986)

        """

        cxy, freqs = mlab.cohere(x, y, NFFT, Fs, detrend, window, noverlap)

        self.plot(freqs, cxy)
        self.set_xlabel('Frequency')
        self.set_ylabel('Coherence')
        self.grid(True)

        return cxy, freqs

    def csd(self, x, y, NFFT=256, Fs=2, detrend=mlab.detrend_none,
            window=mlab.window_hanning, noverlap=0):
        """
        The cross spectral density Pxy by Welches average periodogram
        method.  The vectors x and y are divided into NFFT length
        segments.  Each segment is detrended by function detrend and
        windowed by function window.  noverlap gives the length of the
        overlap between segments.  The product of the direct FFTs of x and
        y are averaged over each segment to compute Pxy, with a scaling to
        correct for power loss due to windowing.  Fs is the sampling
        frequency.

        NFFT must be a power of 2

        detrend and window are functions, unlike in matlab where they are
        vectors.  For detrending you can use detrend_none, detrend_mean,
        detrend_linear or a custom function.  For windowing, you can use
        window_none, window_hanning, or a custom function

        Returns the tuple Pxy, freqs.  Pxy is the cross spectrum (complex
        valued), and 10*log10(|Pxy|) is plotted

        Refs:
          Bendat & Piersol -- Random Data: Analysis and Measurement
            Procedures, John Wiley & Sons (1986)

        """

        pxy, freqs = mlab.csd(x, y, NFFT, Fs, detrend, window, noverlap)
        pxy.shape = len(freqs),
        # pxy is complex

        self.plot(freqs, 10*log10(absolute(pxy)))
        self.set_xlabel('Frequency')
        self.set_ylabel('Cross Spectrum Magnitude (dB)')
        self.grid(True)
        vmin, vmax = self.yaxis.viewlim.bounds()

        intv = vmax-vmin
        step = 10*int(log10(intv))
        
        ticks = arange(math.floor(vmin), math.ceil(vmax)+1, step)
        self.set_yticks(ticks)

        return pxy, freqs

        
    def _draw(self, renderer, *args, **kwargs):
        "Draw everything (plot lines, axes, labels)"

        if not ( self.xaxis.viewlim.defined() and
                 self.yaxis.viewlim.defined() ):
            self.update_viewlim()

        if self.axison:
            if self._frameon: self._axesPatch.draw(renderer)

        if self.axison:
            self.xaxis.draw(renderer)
            self.yaxis.draw(renderer)


        if self._image is not None:
            self._image.draw(renderer)

        for p in self._patches:
            p.draw(renderer)

        for line in self._lines:
            line.draw(renderer)

        for t in self._text:
            t.draw(renderer)

        self._title.draw(renderer)
        if 0: bbox_artist(self._title, renderer)
        # optional artists
        for a in self._artists:
            a.draw(renderer)


        if self._legend is not None:
            self._legend.draw(renderer)

        for table in self._tables:
            table.draw(renderer)


    def errorbar(self, x, y, yerr=None, xerr=None, fmt='b-', capsize=3):
        """
        Plot x versus y with error deltas in yerr and xerr.
        Vertical errorbars are plotted if yerr is not None
        Horizontal errorbars are plotted if xerr is not None

        xerr and yerr may be any of:
            a rank-0, Nx1 Numpy array  - symmetric errorbars +/- value
            an N-element list or tuple - symmetric errorbars +/- value
            a rank-1, Nx2 Numpy array  - asymmetric errorbars -column1/+column2

        Alternatively, x, y, xerr, and yerr can all be scalars, which
        plots a single error bar at x, y.
        
        fmt is the plot format symbol for y.  if fmt is None, just
        plot the errorbars with no line symbols.  This can be useful
        for creating a bar plot with errorbars

        Return value is a length 2 tuple.  The first element is a list of
        y symbol lines.  The second element is a list of error bar lines.

        capsize is the size of the error bar caps in points
        """

        # make sure all the args are iterable arrays
        if not iterable(x): x = asarray([x])
        else: x = asarray(x)

        if not iterable(y): y = asarray([y])
        else: y = asarray(y)

        if xerr is not None:
            if not iterable(xerr): xerr = asarray([xerr])
            else: xerr = asarray(xerr)

        if yerr is not None:
            if not iterable(yerr): yerr = asarray([yerr])
            else: yerr = asarray(yerr)


        if fmt is not None:
            l0 = self.plot(x,y,fmt)
        else: l0 = None
        l = []


        
        def get_ybar_cap(x, y):
            xtrans = self.xaxis.get_pts_transform(
                x, self.xaxis.transData)
            line = Line2D( self.dpi, self.bbox,
                           xdata=(-capsize, capsize), ydata=(y, y),
                           color='k',
                           transx = xtrans,
                           transy = self.yaxis.transData,
                           )
            line.set_data_clipping(False)
            return line

        def get_xbar_cap(x, y):
            ytrans = self.yaxis.get_pts_transform(
                y, self.yaxis.transData)

            line = Line2D( self.dpi, self.bbox,
                           xdata=(x,x), ydata=(-capsize, capsize),
                           color='k',
                           transx = self.xaxis.transData,
                           transy = ytrans,
                           )
            line.set_data_clipping(False)
            return line

        # horizontal errorbars
        if xerr is not None:
            if len(xerr.shape) == 1:
                left  = x-xerr
                right = x+xerr
            else:
                left  = x-xerr[0]
                right = x+xerr[1]

            # horizontal errorbars
            l.extend( self.hlines(y, x, left) )
            l.extend( self.hlines(y, x, right) )
            # bar caps
            for yval, lval, rval in zip(y, left, right):
                l1 = get_xbar_cap(lval, yval) 
                l2 = get_xbar_cap(rval, yval) 
                self.add_line(l1)
                self.add_line(l2)
                l.extend( (l1,l2) )


        # vertical errorbars
        if yerr is not None:
            if len(yerr.shape) == 1:
                lower = y-yerr
                upper = y+yerr
            else:
                lower = y-yerr[0]
                upper = y+yerr[1]

            # base the bar-end length on the overall horizontal plot dimension

            # vertical errorbars
            l.extend( self.vlines(x, y, lower) )
            l.extend( self.vlines(x, y, upper) )

            # bar caps
            for xval, uval, lval in zip(x, upper, lower):
                l1 = get_ybar_cap(xval, uval) 
                l2 = get_ybar_cap(xval, lval) 
                self.add_line(l1)
                self.add_line(l2)
                l.extend( (l1,l2) )
        #print 'before xaxis', self.xaxis.datalim
        self.xaxis.autoscale_view()
        self.yaxis.autoscale_view()        
        #print 'after xaxis', self.xaxis.datalim
        return (l0, l)


    def fill(self, *args, **kwargs):
        """
        Emulate matlab's fill command.  *args is a variable length
        argument, allowing for multiple x,y pairs with an optional
        color format string.  For example, all of the following are
        legal, assuming a is the Axis instance:
        
          a.fill(x,y)            # plot polygon with vertices at x,y
          a.fill(x,y, 'b' )      # plot polygon with vertices at x,y in blue

        An arbitrary number of x, y, color groups can be specified, as in 
          a.fill(x1, y1, 'g', x2, y2, 'r')  

        Returns a list of patches that were added.


        """
        patches = []
        for poly in self._get_patches_for_fill(*args, **kwargs):
            self.add_patch( poly )
            self.xaxis.datalim.update(poly.get_xlim())
            self.yaxis.datalim.update(poly.get_ylim())
            patches.append( poly )
        self.xaxis.autoscale_view()
        self.yaxis.autoscale_view()
        return patches
    
    def get_axis_bgcolor(self):
        'Return the axis background color'
        return self._axisbg

    def get_child_artists(self):
        artists = [self._title, self._axesPatch, self.xaxis, self.yaxis]
        artists.extend(self._lines)
        artists.extend(self._patches)
        artists.extend(self._text)
        if self._legend is not None:
            artists.append(self._legend)
        return artists
    
    def get_frame(self):
        "Return the axes Rectangle frame"
        return self._axesPatch

    def get_legend(self):
        'Return the Legend instance, or None if no legend is defined'
        return self._legend


    def get_lines(self):
        return self._lines
    
    def get_xaxis(self):
        "Return the XAxis instance"
        return self.xaxis

    def get_xgridlines(self):
        "Get the x grid lines as a list of Line2D instances"
        return self.xaxis.get_gridlines()

    def get_xlim(self):
        "Get the x axis range [xmin, xmax]"
        return self.xaxis.viewlim.bounds()


    def get_xticklabels(self):
        "Get the xtick labels as a list of Text instances"
        return self.xaxis.get_ticklabels()

    def get_xticklines(self):
        "Get the xtick lines as a list of Line2D instances"
        return self.xaxis.get_ticklines()
    

    def get_xticks(self):
        "Return the x ticks as a list of locations"
        return self.xaxis.get_ticklocs()

    def get_yaxis(self):
        "Return the YAxis instance"
        return self.yaxis

    def get_ylim(self):
        "Get the y axis range [ymin, ymax]"
        return self.yaxis.viewlim.bounds()

    def get_ygridlines(self):
        "Get the y grid lines as a list of Line2D instances"
        return self.yaxis.get_gridlines()

    def get_yticklabels(self):
        "Get the ytick labels as a list of Text instances"
        return self.yaxis.get_ticklabels() 

    def get_yticklines(self):
        "Get the ytick lines as a list of Line2D instances"
        return self.yaxis.get_ticklines()

    def get_yticks(self):
        "Return the y ticks as a list of locations"
        return self.yaxis.get_ticklocs()  

    def grid(self, b):
        "Set the axes grids on or off; b is a boolean"
        self.xaxis.grid(b)
        self.yaxis.grid(b)

    def hist(self, x, bins=10, normed=0):
        """
        Compute the histogram of x.  bins is either an integer number of
        bins or a sequence giving the bins.  x are the data to be binned.

        if noplot is True, just compute the histogram and return the
        number of observations and the bins as an (n, bins) tuple.

        If noplot is False, compute the histogram and plot it, returning
        n, bins, patches

        If normed is true, the first element of the return tuple will be the
        counts normalized to form a probability distribtion, ie,
        n/(len(x)*dbin)

        """
        n,bins = mlab.hist(x, bins, normed)
        width = 0.9*(bins[1]-bins[0])
        patches = self.bar(bins, n, width=width)
        return n, bins, patches


    def set_frame_on(self, b):
        """
        Set whether the axes rectangle patch is drawn with boolean b
        """
        self._frameon = b
    def set_image_extent(self, xmin, xmax, ymin, ymax):
        """
        Set the dat units of the image.  This is useful if you want to
        plot other things over the image, eg, lines or scatter
        """
        self._imExtent = xmin, xmax, ymin, ymax
        if self._image is not None:
            self._image.set_data_extent(*self._imExtent)


        self.xaxis.datalim.set_bounds(xmin, xmax)
        self.yaxis.datalim.set_bounds(ymin, ymax)
        self.xaxis.viewlim.set_bounds(xmin, xmax)
        self.yaxis.viewlim.set_bounds(ymin, ymax)
        #self.set_axis_off()

        
    def imshow(self, X, cmap = Grayscale(256)):
        """
        Display the image in array X to current axes.  X must be a
        float array

        If X is MxN, assume luminance (grayscale)
        If X is MxNx3, assume RGB
        If X is MxNx4, assume RGBA

        cmap is a colors.Colormap instance
        
        An Image instance is returned

        the 0,0 index is the upper left of the image and the -1,-1
        index is the lower right
        """

        if (cmap is None or cmap.color == "gray") and len(X.shape) == 2:
            is_grayscale = 1
        else:
            is_grayscale = 0
            
        if cmap is not None:
            X = cmap.array_as_rgb(X)

        im = Image(self.dpi, self.bbox, self.xaxis.viewlim, self.yaxis.viewlim)
        im.fromarray(X)
        im._im.is_grayscale = is_grayscale
        self._image = im
        
        if self._imExtent is not None:
            # call this again so it can update image and viewlim accordingly
            self.set_image_extent(*self._imExtent) 
        else:
            numrows, numcols = im.get_size()
            # the number of rows is along the vertical axis
            self.set_image_extent(0, numcols, 0, numrows) 
        return im
        
    def in_axes(self, xwin, ywin):
        return self.bbox.x.in_interval(xwin) and \
               self.bbox.y.in_interval(ywin)

    def hlines(self, y, xmin, xmax, fmt='k-'):
        """
        plot horizontal lines at each y from xmin to xmax.  xmin or
        xmax can be scalars or len(x) numpy arrays.  If they are
        scalars, then the respective values are constant, else the
        widths of the lines are determined by xmin and xmax

        Returns a list of line instances that were added
        """
        linestyle, marker, color = _process_plot_format(fmt)
        
        # todo: fix me for y is scalar and xmin and xmax are iterable
        y = asarray(y)
        xmin = asarray(xmin)
        xmax = asarray(xmax)
        
        if len(xmin)==1:
            xmin = xmin*ones(y.shape, y.typecode())
        if len(xmax)==1:
            xmax = xmax*ones(y.shape, y.typecode())

        if len(xmin)!=len(y):
            raise ValueError, 'xmin and y are unequal sized sequences'
        if len(xmax)!=len(y):
            raise ValueError, 'xmax and y are unequal sized sequences'

        lines = []
        for (thisY, thisMin, thisMax) in zip(y,xmin,xmax):            
            line = Line2D(
                self.dpi, self.bbox,
                [thisMin, thisMax], [thisY, thisY],
                color=color, linestyle=linestyle, marker=marker,
                transx = self.xaxis.transData, 
                transy = self.yaxis.transData,
                )
            self.add_line( line )
            lines.append(line)
        return lines


    def legend(self, *args, **kwargs):
        """
        Place a legend on the current axes at location loc.  Labels are a
        sequence of strings and loc can be a string or an integer
        specifying the legend location

        USAGE: 

          Make a legend with existing lines

          >>> legend()

          legend by itself will try and build a legend using the label
          property of the lines.  You can set the label of a line by
          doing plot(x, y, label='my data') or
          line.set_label('my data')
          
          legend( LABELS )
          >>> legend( ('label1', 'label2', 'label3') ) 

          Make a legend for Line2D instances lines1, line2, line3
          legend( LINES, LABELS )
          >>> legend( (line1, line2, line3), ('label1', 'label2', 'label3') )

          Make a legend at LOC
          legend( LABELS, LOC )  or
          legend( LINES, LABELS, LOC )
          >>> legend( ('label1', 'label2', 'label3'), loc='upper left')
          >>> legend( (line1, line2, line3),
                      ('label1', 'label2', 'label3'),
                      loc=2)

        The LOC location codes are

        The location codes are

          'best' : 0,          (currently not supported, defaults to upper right)
          'upper right'  : 1,  (default)
          'upper left'   : 2,
          'lower left'   : 3,
          'lower right'  : 4,
          'right'        : 5,
          'center left'  : 6,
          'center right' : 7,
          'lower center' : 8,
          'upper center' : 9,
          'center'       : 10,

        If none of these are suitable, loc can be a 2-tuple giving x,y
        in axes coords, ie,

          loc = 0, 1 is left top
          loc = 0.5, 0.5 is center, center

          and so on

        """

        loc = 1
        if len(args)==0:
            labels = [line.get_label() for line in self._lines]
            lines = self._lines
        elif len(args)==1:
            # LABELS
            labels = args[0]
            lines = [line for line, label in zip(self._lines, labels)]
        elif len(args)==2:
            if is_string_like(args[1]) or isinstance(args[1], int):
                # LABELS, LOC
                labels, loc = args
                lines = [line for line, label in zip(self._lines, labels)]
            else:
                # LINES, LABELS
                lines, labels = args
        elif len(args)==3:
            # LINES, LABELS, LOC
            lines, labels, loc = args
        else:
            raise RuntimeError('Invalid arguments to legend')

        lines = flatten(lines)
        self._legend = Legend(self.dpi, self.bbox, lines, labels, loc)
        return self._legend

    def loglog(self, *args, **kwargs):
        """
        Make a loglog plot with log scaling on the a and y axis.  The args
        to semilog x are the same as the args to plot.  See help plot for
        more info
        """

        self.set_xscale('log')
        self.set_yscale('log')
        l = self.plot(*args, **kwargs)
        return l


    def panx(self, numsteps):
        "Pan the x axis numsteps (plus pan right, minus pan left)"
        self.xaxis.pan(numsteps)
        xmin, xmax = self.xaxis.viewlim.bounds()
        for line in self._lines:
            line.set_xclip(xmin, xmax)

    def pany(self, numsteps):
        "Pan the x axis numsteps (plus pan up, minus pan down)"
        self.yaxis.pan(numsteps)


    def pcolor(self, *args, **kwargs):
        """
        pcolor(C) - make a pseudocolor plot of matrix C

        pcolor(X, Y, C) - a pseudo color plot of C on the matrices X and Y  

        pcolor(C, cmap=colormapInstance) - make a pseudocolor plot of
        matrix C using rectangle patches using a custom colormap derived
        from matplotlib.colors.Colormap.  You must pass this as a kwarg.

        Shading:

          The optional keyword arg shading ('flat' or 'faceted') will
          determine whether a black grid is drawn around each pcolor
          square.  Default 'faceteted'
             e.g.,   
             pcolor(C, shading='flat')  
             pcolor(X, Y, C, shading='faceted')

        returns a list of patch objects

        Note, the behavior of meshgrid in matlab is a bit
        counterintuitive for x and y arrays.  For example,

          x = arange(7)
          y = arange(5)
          X, Y = meshgrid(x,y)

          Z = rand( len(x), len(y))
          pcolor(X, Y, Z)

        will fail in matlab and matplotlib.  You will probably be
        happy with

         pcolor(X, Y, transpose(Z))

        Likewise, for nonsquare Z,

         pcolor(transpose(Z))

        will make the x and y axes in the plot agree with the numrows
        and numcols of Z
        """


        shading = kwargs.get('shading', 'faceted')

        if len(args)==1:
            C = args[0]
            numRows, numCols = C.shape
            #X, Y = meshgrid(range(numRows), range(numCols) )
            X, Y = meshgrid(arange(numCols), arange(numRows) )
        elif len(args)==3:
            X, Y, C = args
        else:
            raise RuntimeError('Illegal arguments to pcolor; see help(pcolor)')
        

        Nx, Ny = X.shape
        

        cmap = kwargs.get('cmap', ColormapJet(256))
        cmin = MLab.min(MLab.min(C))
        cmax = MLab.max(MLab.max(C))

        patches = []
        
        for i in range(Nx-1):
            for j in range(Ny-1):
                c = C[i,j]
                color = cmap.get_color(c, cmin, cmax)
                left = X[i,j]
                bottom = Y[i,j]
                width = X[i,j+1]-left
                height = Y[i+1,j]-bottom
                rect = Rectangle(
                    self.dpi, self.bbox,
                    (left, bottom), width, height,
                    transx = self.xaxis.transData,
                    transy = self.yaxis.transData,
                    )
                rect.set_facecolor(color)
                if shading == 'faceted':
                    rect.set_linewidth(0.25)
                    rect.set_edgecolor('k')
                else:
                    rect.set_edgecolor(color)
                self._patches.append(rect)
                patches.append(rect)
        self.grid(0)

        minx = MLab.min(MLab.min(X))
        maxx = MLab.max(MLab.max(X))
        miny = MLab.min(MLab.min(Y))
        maxy = MLab.max(MLab.max(Y))
        self.xaxis.datalim.update((minx, maxx))
        self.yaxis.datalim.update((miny, maxy))
        self.xaxis.autoscale_view()
        self.yaxis.autoscale_view()
        return patches


    def plot(self, *args, **kwargs):
        """
        Emulate matlab's plot command.  *args is a variable length
        argument, allowing for multiple x,y pairs with an optional
        format string.  For example, all of the following are legal,
        assuming a is the Axis instance:
        
          a.plot(x,y)            # plot Numeric arrays y vs x
          a.plot(x,y, 'bo')      # plot Numeric arrays y vs x with blue circles
          a.plot(y)              # plot y using x = arange(len(y))
          a.plot(y, 'r+')        # ditto with red plusses

        An arbitrary number of x, y, fmt groups can be specified, as in 
          a.plot(x1, y1, 'g^', x2, y2, 'l-')  

        Returns a list of lines that were added
        """

        lines = []
        for line in self._get_lines(*args, **kwargs):
            self.add_line(line)
            lines.append(line)
        self.xaxis.autoscale_view()
        self.yaxis.autoscale_view()
        return lines

    def plot_date(self, d, y, converter, fmt='bo', **kwargs):
        """
        plot_date(d, y, converter, fmt='bo', **kwargs)

        d is a sequence of dates; converter is a dates.DateConverter
        instance that converts your dates to seconds since the epoch for
        plotting.  y are the y values at those dates.  fmt is a plot
        format string.  kwargs are passed on to plot.  See plot for more
        information.
        """
        e = array([converter.epoch(thisd) for thisd in d])
        assert(len(e))
        self.plot(e, y, fmt, **kwargs)

        span  = max(e)-min(e)
        hours  = span/SEC_PER_HOUR
        days   = span/SEC_PER_DAY
        weeks  = span/SEC_PER_WEEK
        months = span/(SEC_PER_WEEK*12)  # approx        
        years  = span/(SEC_PER_WEEK*52)   # approx

        numticks = 5
        if years>numticks:
            locator = YearLocator(int(years/numticks))  # define
            fmt = '%Y'
        elif months>numticks:
            locator = MonthLocator(int(months/numticks))  # define
            fmt = '%b %Y'
        elif weeks>numticks:
            locator = WeekdayLocator(0) 
            fmt = '%b %d'
        elif days>numticks:
            locator = DayLocator(0) 
            fmt = '%b %d'
        elif hours>numticks:
            locator = HourLocator(1) 
            fmt = '%H'
        else:
            locator = MinuteLocator(1) 
            fmt = '%H:%M'

        formatter = DateFormatter(fmt)
        self.xaxis.set_major_locator(locator)
        self.xaxis.set_minor_locator(formatter)        
            
        
    def psd(self, x, NFFT=256, Fs=2, detrend=mlab.detrend_none,
            window=mlab.window_hanning, noverlap=0):
        """
        The power spectral density by Welches average periodogram method.
        The vector x is divided into NFFT length segments.  Each segment
        is detrended by function detrend and windowed by function window.
        noperlap gives the length of the overlap between segments.  The
        absolute(fft(segment))**2 of each segment are averaged to compute Pxx,
        with a scaling to correct for power loss due to windowing.  Fs is
        the sampling frequency.

        -- NFFT must be a power of 2

        -- detrend and window are functions, unlike in matlab where they
           are vectors.  For detrending you can use detrend_none,
           detrend_mean, detrend_linear or a custom function.  For
           windowing, you can use window_none, window_hanning, or a custom
           function

        -- if length x < NFFT, it will be zero padded to NFFT


        Returns the tuple Pxx, freqs

        For plotting, the power is plotted as 10*log10(pxx)) for decibels,
        though pxx itself is returned

        Refs:
          Bendat & Piersol -- Random Data: Analysis and Measurement
            Procedures, John Wiley & Sons (1986)

        """
        pxx, freqs = mlab.psd(x, NFFT, Fs, detrend, window, noverlap)
        pxx.shape = len(freqs),

        self.plot(freqs, 10*log10(pxx))
        self.set_xlabel('Frequency')
        self.set_ylabel('Power Spectrum Magnitude (dB)')
        self.grid(True)
        vmin, vmax = self.yaxis.viewlim.bounds()
        intv = vmax-vmin
        step = 10*int(log10(intv))
        ticks = arange(math.floor(vmin), math.ceil(vmax)+1, step)
        self.set_yticks(ticks)

        return pxx, freqs

    def get_position(self):
        """
        Return the axes position 
        """
        return self._position

    def set_position(self, pos):
        """
        Set the axes position with pos = left, bottom, width, height
        in relative 0,1 coords
        """
        self._position = pos
        self.resize()
        
    def resize(self):
        l,b,w,h = self.figure.bbox.get_bounds()
        l = l + self._position[0]*w
        b = b + self._position[1]*h
        w *= self._position[2]
        h *= self._position[3]
        self.bbox.set_bounds(l,b,w,h)
        
    def set_axis_off(self):
        self.axison = False

    def set_axis_on(self):
        self.axison = True

    def scatter(self, x, y, s=None, c='b'):
        """
        Make a scatter plot of x versus y.  s is a size (in data
        coords) and can be either a scalar or an array of the same
        length as x or y.  c is a color and can be a single color
        format string or an length(x) array of intensities which will
        be mapped by the colormap jet.        

        If size is None a default size will be used
        """

        if is_string_like(c):
            c = [c]*len(x)
        elif not iterable(c):
            c = [c]*len(x)
        else:
            jet = ColormapJet(1000)
            c = jet.get_colors(c)

        if s is None:
            s = [abs(0.015*(max(y)-min(y)))]*len(x)
        elif not iterable(s):
            s = [s]*len(x)
        
        if len(c)!=len(x):
            raise ValueError, 'c and x are not equal lengths'
        if len(s)!=len(x):
            raise ValueError, 's and x are not equal lengths'

        patches = []
        for thisX, thisY, thisS, thisC in zip(x,y,s,c):
            circ = Circle( self.dpi, self.bbox, (thisX, thisY),
                           radius=thisS,
                           transx = self.xaxis.transData,
                           transy = self.yaxis.transData,
                           )
            circ.set_facecolor(thisC)
            self._patches.append(circ)
            patches.append(circ)
        self.xaxis.datalim.update(x)
        self.yaxis.datalim.update(y)
        self.xaxis.autoscale_view()
        self.yaxis.autoscale_view()
        return patches

    def semilogx(self, *args, **kwargs):
        """
        Make a semilog plot with log scaling on the x axis.  The args to
        semilog x are the same as the args to plot.  See help plot for
        more info    
        """
        self.set_xscale('log')
        l = self.plot(*args, **kwargs)
        return l


    def semilogy(self, *args, **kwargs):
        """
        Make a semilog plot with log scaling on the y axis.  The args to
        semilog x are the same as the args to plot.  See help plot for
        more info    
        """
        self.set_yscale('log')
        l = self.plot(*args, **kwargs)
        return l


    def set_axis_bgcolor(self, color):
        self._axisbg = color

                                
    def set_title(self, label, fontdict=None, **kwargs):
        """
        Set the title for the xaxis

        See the text docstring for information of how override and the
        optional args work

        """
        override = {
            'fontproperties': FontProperties(size=rcParams['axes.titlesize']),
            'verticalalignment' : 'bottom',
            'horizontalalignment' : 'left'
            }

        self._title.set_text(label)
        override = _process_text_args({}, fontdict, **kwargs)
        self._title.update_properties(override)
        return self._title


    def set_xlabel(self, xlabel, fontdict=None, **kwargs):
        """
        Set the label for the xaxis

        See the text docstring for information of how override and the
        optional args work

        """

        label = self.xaxis.get_label()
        label.set_text(xlabel)
        override = _process_text_args({}, fontdict, **kwargs)
        label.update_properties(override)
        return label

    def set_xlim(self, v):
        "Set the limits for the xaxis; v = [xmin, xmax]"

        xmin, xmax = v

        self.xaxis.viewlim.set_bounds(xmin, xmax)
        self.xaxis.viewlimset = True
        for line in self._lines:
            line.set_xclip(xmin, xmax)
        
    def set_xscale(self, value):
        self.xaxis.set_scale(value)

    def set_xticklabels(self, labels, fontdict=None, **kwargs):
        """
        Set the xtick labels with list of strings labels
        Return a list of axis text instances
        """
        return self.xaxis.set_ticklabels(labels, fontdict, **kwargs)

    def set_xticks(self, ticks):
        "Set the x ticks with list of ticks"
        return self.xaxis.set_ticks(ticks)
        

    def set_ylabel(self, ylabel, fontdict=None, **kwargs):
        """
        Set the label for the yaxis

        Defaults override is

            override = {
               'fontproperties'      : see FontProperties()
               'verticalalignment'   : 'center',
               'horizontalalignment' : 'right',
               'rotation'='vertical' : }

        See the text doctstring for information of how override and
        the optional args work
        """
        label = self.yaxis.get_label()
        label.set_text(ylabel)
        override = _process_text_args({}, fontdict, **kwargs)
        label.update_properties(override)
        return label

    def set_ylim(self, v):
        "Set the limits for the xaxis; v = [ymin, ymax]"
        ymin, ymax = v
        self.yaxis.viewlim.set_bounds(ymin, ymax)
        self.yaxis.viewlimset = True
        
    def set_yscale(self, value):
        self.yaxis.set_scale(value)


    def set_yticklabels(self, labels, fontdict=None, **kwargs):
        """
        Set the ytick labels with list of strings labels.
        Return a list of Text instances
        """
        return self.yaxis.set_ticklabels(labels, fontdict, **kwargs)
        
    def set_yticks(self, ticks):
        "Set the y ticks with list of ticks"
        return self.yaxis.set_ticks(ticks)

    def specgram(self, x, NFFT=256, Fs=2, detrend=mlab.detrend_none,
                 window=mlab.window_hanning, noverlap=128,
                 cmap = ColormapJet(256)):
        """
        Compute a spectrogram of data in x.  Data are split into NFFT
        length segements and the PSD of each section is computed.  The
        windowing function window is applied to each segment, and the
        amount of overlap of each segment is specified with noverlap

        See help(psd) for information on the other arguments

        cmap is a colormap
        return value is Pxx, freqs, bins, im

        bins are the time points the spectrogram is calculated over
        freqs is an array of frequencies
        Pxx is a len(times) x len(freqs) array of power
        im is a matplotlib image
        """
        Pxx, freqs, bins = mlab.specgram(x, NFFT, Fs, detrend,
             window, noverlap)


        Z = 10*log10(Pxx)
        Z =  mlab.flipud(Z)
        im = self.imshow(Z, cmap)
        self.set_image_extent(0, max(bins), 0, max(freqs))
        return Pxx, freqs, bins, im

    def table(self,              
        cellText=None, cellColours=None,
        cellLoc='right', colWidths=None,
        rowLabels=None, rowColours=None, rowLoc='left',
        colLabels=None, colColours=None, colLoc='center',
        loc='bottom', bbox=None):
        """
        Create a table and add it to the axes.  Returns a table
        instance.  For finer grained control over tables, use the
        Table class and add it to the axes with add_table.

        Thanks to John Gill for providing the class and table.
        """

        # Check we have some cellText
        if cellText is None:
            # assume just colours are needed
            rows = len(cellColours)
            cols = len(cellColours[0])
            cellText = [[''] * rows] * cols

        rows = len(cellText)
        cols = len(cellText[0])
        for row in cellText:
            assert len(row) == cols

        if cellColours is not None:
            assert len(cellColours) == rows
            for row in cellColours:
                assert len(row) == cols
        else:
            cellColours = ['w' * cols] * rows

        # Set colwidths if not given
        if colWidths is None:
            colWidths = [1.0/cols] * cols

        # Check row and column labels
        rowLabelWidth = 0
        if rowLabels is None:
            if rowColours is not None:
                rowLabels = [''] * cols
                rowLabelWidth = colWidths[0]
        elif rowColours is None:
            rowColours = 'w' * rows

        if rowLabels is not None:
            assert len(rowLabels) == rows

        offset = 0
        if colLabels is None:
            if colColours is not None:
                colLabels = [''] * rows
                offset = 1
        elif colColours is None:
            colColours = 'w' * cols
            offset = 1

        if rowLabels is not None:
            assert len(rowLabels) == rows

        # Set up cell colours if not given
        if cellColours is None:
            cellColours = ['w' * cols] * rows

        # Now create the table
        table = Table(self, loc, bbox)
        height = table._approx_text_height()

        # Add the cells
        for row in xrange(rows):
            for col in xrange(cols):
                table.add_cell(row+offset, col,
                               width=colWidths[col], height=height,
                               text=cellText[row][col],
                               facecolor=cellColours[row][col],
                               loc=cellLoc)
        # Do column labels
        if colLabels is not None:
            for col in xrange(cols):
                table.add_cell(0, col,
                               width=colWidths[col], height=height,
                               text=colLabels[col], facecolor=colColours[col],
                               loc=colLoc)

        # Do row labels
        if rowLabels is not None:
            for row in xrange(rows):
                table.add_cell(row+offset, -1,
                               width=rowLabelWidth, height=height,
                               text=rowLabels[row], facecolor=rowColours[row],
                               loc=rowLoc)
            if rowLabelWidth == 0:
                table.auto_set_column_width(-1)

        self.add_table(table)
        return table

    
    def text(self, x, y, text, fontdict=None, **kwargs):
        """
        Add text to axis at location x,y (data coords)

        fontdict is a dictionary to override the default text properties.
        If fontdict is None, the default is

        If len(args) the override dictionary will be:

          'fontproperties'      : see FontProperties
          'verticalalignment'   : 'bottom',
          'horizontalalignment' : 'left'


        **kwargs can in turn be used to override the override, as in

          a.text(x,y,label, fontpropeties=FontProperties(size=12))
        
        will have verticalalignment=bottom and
        horizontalalignment=left but will have a fontsize of 12
        
        
        The Text defaults are
            'color'               : 'k',
            'fontproperties'      : see FontProperties
            'horizontalalignment' : 'left'
            'rotation'            : 'horizontal',
            'verticalalignment'   : 'bottom',
            'transx'              : self.xaxis.transData,
            'transy'              : self.yaxis.transData,            

        transx and transy specify that text is in data coords,
        alternatively, you can specify text in axis coords (0,0 lower
        left and 1,1 upper right).  The example below places text in
        the center of the axes

        ax = subplot(111)
        text(0.5, 0.5,'matplotlib', 
             horizontalalignment='center',
             verticalalignment='center',
             transx = ax.xaxis.transAxis,
             transy = ax.yaxis.transAxis,
        )
                

        """
        override = {
            'fontproperties': FontProperties(),
            'verticalalignment' : 'bottom',
            'horizontalalignment' : 'left',
            'transx' : self.xaxis.transData,
            'transy' : self.yaxis.transData,            
            }

        override = _process_text_args(override, fontdict, **kwargs)
        t = Text(
            dpi=self.dpi,
            bbox = self.bbox,
            x=x, y=y, text=text,
            )
        t.update_properties(override)

        self._text.append(t)
        return t
    
    def update_viewlim(self):
        'Update the view limits with all the data in self'

        for line in self._lines:
            xdata = line.get_xdata()
            self.xaxis.viewlim.update(xdata)
            ydata = line.get_ydata()
            self.yaxis.viewlim.update(ydata)

        for p in self._patches:
            l,b,w,h = p.get_data_extent().get_bounds()
            self.xaxis.viewlim.update( (l, l+w) )
            self.yaxis.viewlim.update( (b, b+h) )


    def vlines(self, x, ymin, ymax, color='k'):
        """
        Plot vertical lines at each x from ymin to ymax.  ymin or ymax
        can be scalars or len(x) numpy arrays.  If they are scalars,
        then the respective values are constant, else the heights of
        the lines are determined by ymin and ymax

        Returns a list of lines that were added
        """
        

        x = asarray(x)
        ymin = asarray(ymin)
        ymax = asarray(ymax)

        if len(ymin)==1:
            ymin = ymin*ones(x.shape, x.typecode())
        if len(ymax)==1:
            ymax = ymax*ones(x.shape, x.typecode())


        if len(ymin)!=len(x):
            raise ValueError, 'ymin and x are unequal sized sequences'
        if len(ymax)!=len(x):
            raise ValueError, 'ymax and x are unequal sized sequences'

        Y = transpose(array([ymin, ymax]))
        lines = []
        for thisX, thisY in zip(x,Y):
            line = Line2D(
                self.dpi, self.bbox,
                [thisX, thisX], thisY, color=color, linestyle='-',
                transx = self.xaxis.transData,
                transy = self.yaxis.transData,
                )
            self.add_line(line)
            lines.append(line)
        return lines


    def zoomx(self, numsteps):
        """
        Zoom in on the x xaxis numsteps (plus for zoom in, minus for zoom out)
        """
        self.xaxis.zoom(numsteps)
        xmin, xmax = self.xaxis.viewlim.bounds()
        for line in self._lines:
            line.set_xclip(xmin, xmax)

    def zoomy(self, numsteps):
        """
        Zoom in on the x xaxis numsteps (plus for zoom in, minus for zoom out)
        """
        self.yaxis.zoom(numsteps)

class Subplot(Axes):
    """
    Emulate matlab's subplot command, creating axes with

      Subplot(numRows, numCols, plotNum)

    where plotNum=1 is the first plot number and increasing plotNums
    fill rows first.  max(plotNum)==numRows*numCols

    You can leave out the commas if numRows<=numCols<=plotNum<10, as
    in

      Subplot(211)    # 2 rows, 1 column, first (upper) plot
    """
    
    def __init__(self, fig, *args, **kwargs):
        # Axes __init__ below
        if len(args)==1:
            s = str(*args)
            if len(s) != 3:
                raise ValueError, 'Argument to subplot must be a 3 digits long'
            rows, cols, num = map(int, s)
        elif len(args)==3:
            rows, cols, num = args
        else:
            raise ValueError, 'Illegal argument to subplot'
        total = rows*cols
        num -= 1    # convert from matlab to python indexing ie num in range(0,total)
        if num >= total:
            raise ValueError, 'Subplot number exceeds total subplots'
        left, right = .125, .9
        bottom, top = .11, .9
        rat = 0.2             # ratio of fig to seperator for multi row/col figs
        totWidth = right-left
        totHeight = top-bottom
    
        figH = totHeight/(rows + rat*(rows-1))
        sepH = rat*figH
    
        figW = totWidth/(cols + rat*(cols-1))
        sepW = rat*figW
    
        rowNum, colNum =  divmod(num, cols)
        
        figBottom = top - (rowNum+1)*figH - rowNum*sepH
        figLeft = left + colNum*(figW + sepW)

        Axes.__init__(self, fig, [figLeft, figBottom, figW, figH], **kwargs)

        self.rowNum = rowNum
        self.colNum = colNum
        self.numRows = rows
        self.numCols = cols

    def is_first_col(self):
        return self.colNum==0

    def is_first_row(self):
        return self.rowNum==0

    def is_last_row(self):
        return self.rowNum==self.numRows-1


    def is_last_col(self):
        return self.colNum==self.numCols-1
