"""
API changes


- Artist
    * __init__ takes a DPI instance and a Bound2D instance which is
      the bounding box of the artist in display coords
    * get_window_extent returns a Bound2D instance
    * set_size is removed; replaced by bbox and dpi
    * the clip_gc method is removed.  Artists now clip themselves with
      their box
    * added _clipOn boolean attribute.  If True, gc clip to bbox.
    
- AxisTextBase
    * Initialized with a transx, transy which are Transform instances
    * set_drawing_area removed
    * get_left_right and get_top_bottom are replaced by get_window_extent

- Line2D Patches now take transx, transy
    * Initialized with a transx, transy which are Transform instances

- Patches
   * Initialized with a transx, transy which are Transform instances

- FigureBase attributes dpi is a DPI intance rather than scalar and
  new attribute bbox is a Bound2D in display coords, and I got rid of
  the left, width, height, etc... attributes.  These are now
  accessible as, for example, bbox.x.min is left, bbox.x.interval() is
  width, bbox.y.max is top, etc...

- GcfBase attribute pagesize renamed to figsize

- Axes
    * removed figbg attribute
    * added fig instance to __init__
    * resizing is handled by figure call to resize.
     
- Subplot
    * added fig instance to __init__

- Renderer methods for patches now take gcEdge and gcFace instances.
  gcFace=None takes the place of filled=False

- True and False symbols provided by cbook in a python2.3 compatible
  way

- new module transforms supplies Bound1D, Bound2D and Transform
  instances and more

- Changes to the matlab helpers API

  * _matlab_helpers.GcfBase is renamed by Gcf.  Backends no longer
    need to derive from this class.  Instead, they provide a factory
    function new_figure_manager(num, figsize, dpi).  The destroy
    method of the GcfDerived from the backends is moved to the derived
    FigureManager.

  * FigureManagerBase moved to backend_bases

  * Gcf.get_all_figwins renamed to Gcf.get_all_fig_managers

Jeremy:

  Make sure to self._reset = False in AxisTextWX._set_font.  This was
  something missing in my backend code.
  
KNOWN BUGS

  - DONE - Circle needs to provide window extent -- see scatter demo

  - DONE - axes patch edge on gtk backend not complete

  - DONE - autoscale on y not satisfactory -- see histogram demo

  - with multiple subplots there is a small gap between the bottom of
    the xtick1 line and the axes patch edge. see subplot_demo

  - DONE 2003-11-14 -  lowest line position is off in legend_demo

  - the markers for circ and square in the gtk backend do not have the
    same center.  See the stock_demo for example.

  - DONE 2003-11-16 - GD port of new API not done

  - DONE 2003-11-16 - xticklabels in PS backend need clip

  - x and y ticklabels overlap at origin
  
OUTSTANDING ISSUES

  - rationalize DPI between backends - config file?

  - let the user pass the default font to the axes so that the text
    labels, tick labels, etc, can all be updated


"""

from __future__ import division, generators

import math, re, sys

import MLab
from Numeric import arange, array, asarray, ones, zeros, \
     nonzero, take, Float, transpose, log10

from cbook import iterable, is_string_like, flatten, enumerate, True, False
from transforms import Bound1D, Bound2D, Transform, bound2d_all,\
     identity, logwarn, pow10, transform_bound2d, inverse_transform_bound2d
from transforms import Centimeter, Dots, Inches, Millimeter, Points,\
     TransformSize, RRef

from mlab import linspace, meshgrid
from lines import Line2D, lineStyles, lineMarkers
from patches import Rectangle, Circle, bbox_artist
from artist import Artist
from colormap import ColormapJet
import backends

import mlab  #so I can override hist, psd, etc...


class Tick(Artist):
    """
    Abstract base class for the axis ticks, grid lines and labels

    Publicly accessible attributes

      tick1line  : a Line2D instance
      tick2line  : a Line2D instance
      gridline   : a Line2D instance
      label      : an AxisText instance
      gridOn     : a boolean which determines whether to draw the tickline
      tick1On    : a boolean which determines whether to draw the 1st tickline
                   (left for xtick and bottom for yticks)
      tick2On    : a boolean which determines whether to draw the 2nd tickline
                   (left for xtick and bottom for yticks)
      labelOn    : a boolean which determines whether to draw tick label                  
      
    """
    def __init__(self, axes, loc, label,
                 size=4,  # points
                 gridOn=False, tick1On=True, tick2On=True, labelOn=True):
        """
        bbox is the Bound2D bounding box in display coords of the Axes
        loc is the tick location in data coords
        size is the tick size in relative, axes coords
        """
        Artist.__init__(self, axes.dpi, axes.bbox)
        self.axes = axes

        self.tick1line = self._get_default_tick1line(loc, size)
        self.tick2line = self._get_default_tick2line(loc, size)
        self.gridline = self._get_default_gridline(loc)        
        self.label = self._get_default_text(loc, label)
        self.label.set_clip_on( False )
        self.gridOn = gridOn
        self.tick1On = tick1On
        self.tick2On = tick2On
        self.labelOn = labelOn
        self._loc = loc
        self._size = size
        
    def _get_default_text(self, loc, label):
        'Get the default AxisText instance'
        raise NotImplementedError('Derived must override')

    def _get_default_tick1line(self, loc, size):
        'Get the default line2D instance for tick1'
        raise NotImplementedError('Derived must override')

    def _get_default_tick2line(self, loc, size):
        'Get the default line2D instance for tick2'
        raise NotImplementedError('Derived must override')

    def _get_default_gridline(self, loc):
        'Get the default grid Line2d instance for this tick'
        raise NotImplementedError('Derived must override')

    def get_child_artists(self):
        return (self.label, self.tick1line, self.tick2line, self.gridline)

    def get_loc(self):
        'Return the tick location (data coords) as a scalar'
        return self._loc
    
    def _draw(self, drawable):
        midPoint = self._tick_is_midpoint(self.get_loc())
        if midPoint and self.tick1On: self.tick1line.draw(drawable)
        if midPoint and self.tick2On: self.tick2line.draw(drawable)
        if midPoint and self.gridOn:  self.gridline.draw(drawable)
        if self.labelOn: self.label.draw(drawable)

    def set_loc(self, loc):
        'Set the location of tick in data coords with scalar loc'
        raise NotImplementedError('Derived must override')

    def set_label(self, s):
        'Set the text of ticklabel in with string s'
        self.label.set_text(s)

    def _tick_is_midpoint(self, loc):
        'return true if tickloc is not on axes boundary'
        raise NotImplementedError('Derived must override')

class XTick(Tick):
    """
    Contains all the Artists needed to make an x tick - the tick line,
    the label text and the grid line
    """

    def _get_default_text(self, loc, label):
        'Get the default AxisText instance'
        # the y loc is 3 points below the min of y axis
        top = self.bbox.y.get_refmin() - self.dpi*RRef(3/72.0)
        text =  backends.AxisText(
            dpi=self.dpi, 
            bbox=self.bbox,
            x=loc,
            y=top,  
            fontsize=10,
            verticalalignment='top',
            horizontalalignment='center',
            transx = self.axes.xaxis.transData,
            # transy is default, idenity transform
            )
        return text

        
    def _get_default_tick1line(self, loc, size):
        'Get the default line2D instance'

        line = Line2D( self.dpi, self.bbox,
                       xdata=(loc, loc), ydata=(0, size),
                       color='k',
                       transx = self.axes.xaxis.transData,
                       transy = self.axes.yaxis.get_pts_transform(),
                       )
        line.set_data_clipping(False)
        return line

    def _get_default_tick2line(self, loc, size):
        'Get the default line2D instance'
        offset = self.axes.yaxis.displaylim.get_refmax()
        ytrans = self.axes.yaxis.get_pts_transform(offset)
        line = Line2D( self.dpi, self.bbox,
                       xdata=(loc, loc), ydata=(-size, 0),
                       color='k',
                       transx = self.axes.xaxis.transData,
                       transy = ytrans,
                       )
        line.set_data_clipping(False)
        return line

    def _get_default_gridline(self, loc):
        'Get the default line2D instance'
        line = Line2D( self.dpi, self.bbox,
                       xdata=(loc, loc), ydata=(0, 1),
                       color='k', linestyle=':',
                       transx = self.axes.xaxis.transData,
                       transy = self.axes.yaxis.transAxis,
                       )
        line.set_data_clipping(False)
        return line

    def set_loc(self, loc):
        'Set the location of tick in data coords with scalar loc'
        self.tick1line.set_xdata((loc, loc))
        self.tick2line.set_xdata((loc, loc))
        self.gridline.set_xdata((loc, loc))
        self.label.set_x( loc )
        self._loc = loc

    def _tick_is_midpoint(self, loc):
        'return true if tickloc is not on axes boundary'
        return self.axes.xaxis.viewlim.in_open_interval(loc)

        
class YTick(Tick):
    """
    Contains all the Artists needed to make a Y tick - the tick line,
    the label text and the grid line
    """
    # how far from the y axis line the right of the ticklabel are
    def _get_default_text(self, loc, label):
        'Get the default AxisText instance'
        right = self.bbox.x.get_refmin() - self.dpi*RRef(3/72.0)
        return backends.AxisText(
            dpi=self.dpi, 
            bbox=self.bbox,
            x=right,
            y=loc,
            fontsize=10,
            verticalalignment='center',
            horizontalalignment='right',
            #transx - default transform is identity
            transy = self.axes.yaxis.transData,

            )
        return text

    def _get_default_tick1line(self, loc, size):
        'Get the default line2D instance'
        line = Line2D( self.dpi, self.bbox,
                       (0, size), (loc, loc), color='k',
                       transx = self.axes.xaxis.get_pts_transform(),
                       transy = self.axes.yaxis.transData,
                       )
        line.set_data_clipping(False)
        return line

    def _get_default_tick2line(self, loc, size):
        'Get the default line2D instance'
        offset = self.axes.xaxis.displaylim.get_refmax()
        xtrans = self.axes.xaxis.get_pts_transform(offset)

        line = Line2D( self.dpi, self.bbox,
                       (1-size,1), (loc, loc), color='k',
                       transx = xtrans,
                       transy = self.axes.yaxis.transData,
                       )
        return line

    def _get_default_gridline(self, loc):
        'Get the default line2D instance'
        line = Line2D( self.dpi, self.bbox,
                       xdata=(0,1), ydata=(loc,loc), 
                       color='k', linestyle=':',
                       transx = self.axes.xaxis.transAxis,
                       transy = self.axes.yaxis.transData,
                       )
        line.set_data_clipping(False)
        return line


    def set_loc(self, loc):
        'Set the location of tick in data coords with scalar loc'
        self.tick1line.set_ydata((loc, loc))
        self.tick2line.set_ydata((loc, loc))
        self.gridline.set_ydata((loc, loc))
        self.label.set_y( loc )
        self._loc = loc

    def _tick_is_midpoint(self, loc):
        'return true if tickloc is not on axes boundary'
        return self.axes.yaxis.viewlim.in_open_interval(loc)

    
class Axis(Artist):

    # func / inverse func pairs for transforms
    _scalemap = {'linear' : (identity, identity),
                 'log'    : (logwarn, pow10)}

    LABELPAD = -0.01
    """
    Public attributes
      transData - transform data coords to display coords
      transAxis - transform axis coords to display coords
    """
    def __init__(self, axes):
        """
        Init the axis with the parent Axes instance
        """
        Artist.__init__(self, axes.dpi, axes.bbox)
        self.axes = axes
        self._scale = 'linear'

        self.datalim = Bound1D(None, None) # data interval instance; data coords
        self.viewlim = Bound1D(-1, 1)      # viewport interval instance; data coords

        self.displaylim = self._get_display_lim()
        
        self.transData = Transform(self.viewlim, self.displaylim)
        self.transAxis = Transform(Bound1D(0,1), self.displaylim)

        self._ticklocs = None  # None unless explicity set
        self._ticklabelStrings = None
        self._setChildAttrs = {}
        # strip trailing zeros from ticks
        self._zerorgx = re.compile('^(.*?)\.?0+(e[+-]\d+)?$')        
        
        self._gridOn = False  # off be default compat with matlab

    def autoscale_view(self):
        'Try to choose the view limits intelligently'

        vmin, vmax = self.datalim.bounds()
        if self._scale=='linear':

            if vmin==vmax:
                vmin-=1
                vmax+=1
            try:
                (exponent, remainder) = divmod(math.log10(vmax - vmin), 1)
            except OverflowError:
                print >>sys.stderr, 'Overflow error in autoscale', vmin, vmax
                return
            if remainder < 0.5:
                exponent -= 1
            scale = 10**(-exponent)
            vmin = math.floor(scale*vmin)/scale
            vmax = math.ceil(scale*vmax)/scale

            self.viewlim.set_bounds(vmin, vmax)

    def build_artists(self):
        """
        Call only after xaxis and yaxis have been initialized in axes
        since the trasnforms require both
        """
        self._label = self._get_default_label()
        self._label.set_clip_on( False )
        self._ticks = [self._get_default_tick() for i in range(20)]

    def _draw(self, drawable, *args, **kwargs):
        'Draw the axis lines, grid lines, tick lines and labels'

        # update the tick locks
        ticks = self.get_ticks()        
        locs = self.get_ticklocs()

        if self._ticklabelStrings is not None:
            labels = self._ticklabelStrings
            if len(labels)<len(locs):
                # grow the labels with empty strings as necessary
                labels.extend(['']*(len(locs)-len(labels)))
        else:
            labels = [self.format_tickval(loc) for loc in locs]

        ticklabelBoxes = []
        for tick, loc, label in zip(ticks, locs, labels):
            if not self.viewlim.in_interval(loc): continue
            tick.set_loc(loc)
            tick.set_label(label)
            tick.draw(drawable)
            ticklabelBoxes.append(tick.label.get_window_extent())


        # find the tick labels that are close to the axis labels.  I
        # scale up the axis label box to also find the neighbors, not
        # just the tick labels that actually overlap note we need a
        # *copy* of the axis label box because we don't wan't to scale
        # the actual bbox
        labelBox = self._label.get_window_extent().copy()
        labelBox.x.scale(1.25)
        labelBox.y.scale(1.25)
        overlaps = [b for b in ticklabelBoxes if b.overlap(labelBox)]

        self._update_label_postion(overlaps)
        self._label.draw(drawable)

        if 0: # draw the bounding boxes around the text for debug
            for tick in ticks:
                label = tick.label
                bbox_artist(label, drawable)
            bbox_artist(self._label, drawable)

    def format_tickval(self, x):
        'Format the number x as a string'
        d = self.viewlim.interval()
        if self._scale == 'log':
            # only label the decades
            fx = self.transData.func(x)
            isdecade = abs(fx-int(fx))<1e-10
            if self._ticklabelStrings is None and not isdecade: return ''

        #if the number is not too big and it's an int, format it as an
        #int
        if abs(x)<1e4 and x==int(x): return '%d' % x

        # if the value is just a fraction off an int, use the int
        if abs(x-int(x))<0.0001*d: return '%d' % int(x)

        # use exponential formatting for really big or small numbers,
        # else use float
        if abs(x) < 1e-4: fmt = '%1.3e'
        elif abs(x) > 1e5: fmt = '%1.3e'
        else: fmt = '%1.3f'
        s =  fmt % x

        # strip trailing zeros, remove '+', and handle exponential formatting
        m = self._zerorgx.match(s)
        if m:
            s = m.group(1)
            if m.group(2) is not None: s += m.group(2)
        s = s.replace('+', '')
        return s

    def get_child_artists(self):
        'Return a list of all Artist instances contained by Axis'
        artists = []
        # don't use get_ticks here! it conflicts with set_child_attr
        artists.extend(self._ticks) 
        artists.append(self._label)
        return artists

    def get_cm_transform(self, offset=None, transOffset=Transform()):
        """
        Transform val cm to dots

        offset is a RRef instance for TransformSize.  if None, use
        displaylim min.  transOffset is a transform to transform the
        offset to display coords
        """
        if offset is None:
            offset = self.displaylim.get_refmin()

        cm = Centimeter( self.dpi)
        dots =  Dots( self.dpi)
        return TransformSize(cm, dots, offset, transOffset)

    def _get_default_label(self):
        raise NotImplementedError('Derived must override')

    def _get_display_lim(self):
        raise NotImplementedError('Derived must override')

    def get_gridlines(self):
        'Return the grid lines as a list of Line2D instance'
        return [tick.gridline for tick in self._ticks]

    def get_inches_transform(self, offset=None, transOffset=Transform()):
        """
        Transform val inches to dots

        offset is a RRef instance for TransformSize.  if None, use
        displaylim min.  transOffset is a transform to transform the
        offset to display coords
        """
        if offset is None:
            offset = self.displaylim.get_refmin()

        inches = Inches( self.dpi)
        dots =  Dots( self.dpi)
        return TransformSize(inches, dots, offset, transOffset) 

    def get_label(self):
        'Return the axis label as an AxisText instance'
        return self._label


    def get_mm_transform(self, offset=None, transOffset=Transform()):
        """
        Transform val mm to dots

        offset is a RRef instance for TransformSize.  if None, use
        displaylim min.  transOffset is a transform to transform the
        offset to display coords

        """
        if offset is None:
            offset = self.displaylim.get_refmin()

        mm = Millimeter( self.dpi)
        dots =  Dots( self.dpi)
        return TransformSize(mm, dots, offset, transOffset)

    def get_numticks(self):
        if self._ticklocs is None:
            if self._scale == 'log':
                if self.viewlim.defined():
                    vmin, vmax = self.viewlim.bounds()
                    decMax = int(math.ceil(self.transData.func(vmax)))
                    decMin = int(math.floor(self.transData.func(vmin)))
                    numticks =  10*(decMax-decMin)
                    self._grow_ticks(numticks)
                    return numticks

            # size of axes in inches
            intv = self.displaylim.interval()/self.dpi.get()
            if intv>4: return 11
            else: return 6
        else:
            return len(self._ticklocs)

    def get_pts_transform(self, offset=None, transOffset=Transform()):
        """
        Transform val points to dots

        offset is a RRef instance for TransformSize.  if None, use
        displaylim min.  transOffset is a transform to transform the
        offset to display coords
        """

        if offset is None:
            offset = self.displaylim.get_refmin()
        
        pts = Points( self.dpi)
        dots =  Dots( self.dpi)
        return TransformSize(pts, dots, offset, transOffset)

    def get_ticklabels(self):
        'Return a list of AxisText instances for ticklabels'
        return [tick.label for tick in self.get_ticks()]

    def get_ticklines(self):
        'Return the ticklines lines as a list of Line2D instance'
        lines = []
        for tick in self._ticks:
            lines.append(tick.tick1line)
            lines.append(tick.tick2line)
        return lines

    def get_ticklocs(self):
        "Get the tick locations in data coordinates as a Numeric array"
        if self._ticklocs is not None:
            return self._ticklocs

        vmin, vmax = self.viewlim.bounds()
        vmin = self.transData.func(vmin)
        vmax = self.transData.func(vmax)
        if self._scale == 'linear':
            numticks = self.get_numticks()
            if numticks==0: return []
            ticklocs = linspace(vmin, vmax, numticks)
        elif self._scale == 'log':
            ticklocs = []
            for decadeStart in 10.0**arange(math.floor(vmin),math.ceil(vmax)):
                ticklocs.append(decadeStart)
                ticklocs.extend( arange(2,10)*decadeStart )
                ticklocs.append(10*decadeStart)
            ticklocs = array(ticklocs)
            ind = nonzero(ticklocs>=10.0**vmin * ticklocs<=10.0**vmax)
            ticklocs = take(ticklocs,ind)
        return ticklocs

    def get_ticks(self):
        numticks = self.get_numticks()
        self._grow_ticks(numticks)
        ticks = self._ticks[:numticks]

        return ticks

    def grid(self, b):
        "Set the axis grid on or off; b is a boolean"
        self._gridOn = b
        for tick in self._ticks:  # don't use get_ticks here!
            tick.gridOn = self._gridOn
                
    def _grow_ticks(self, numticks):
        """
        Dynamically grow the number of ticklines as necessary
        """
        if len(self._ticks)<numticks:
            for i in range(numticks-len(self._ticks)):
                tick = self._get_default_tick()
                if self._gridOn: tick.gridOn = True
                for attr, val in self._setChildAttrs.items():
                    tick.set_child_attr(attr, val)
                self._ticks.append(tick)
        
    def is_log(self):
        'Return true if log scaling is on'
        return self._scale=='log'
    
    def pan(self, numsteps):
        'Pan numticks (can be positive or negative)'
        step = numsteps*self.viewlim.interval()/self.get_numticks()
        self.viewlim.shift(step)
        

    def set_child_attr(self, attr, val):
        """
        Set attribute attr for self, and all child artists
        """
        Artist.set_child_attr(self, attr, val)
        self._setChildAttrs[attr] = val
        
    def set_scale(self, value):
        "Set the axis scale; either 'linear' or 'log'"
        if value not in self._scalemap.keys():
            raise ValueError('scale must be in %s'%str(self._scalemap.keys()))
        self._scale = value
        self.transData.set_funcs(self._scalemap[value])
        self.viewlim.is_positive( self.is_log() )
        self.axes.update_viewlim()

    def set_ticklabels(self, ticklabels, *args, **kwargs):
        """
        Set the text values of the tick labels.  ticklabels is a
        sequence of strings.  Return a list of AxisText instances
        """
        ticklabels = ['%s'%l for l in ticklabels]
        
        self._ticklabelStrings = ticklabels
        override = {}
        override = backends._process_text_args(override, *args, **kwargs)

        Nnew = len(self._ticklabelStrings)
        existingLabels = self.get_ticklabels()
        for i, label in enumerate(existingLabels):
            if i<Nnew: label.set_text(self._ticklabelStrings[i])
            else: label.set_text('')
            label.update_properties(override)
        return existingLabels            
            
    def set_ticks(self, ticks):
        'Set the locations of the tick marks from sequence ticks'
        self._ticklocs = asarray(ticks)
        self.viewlim.update(self._ticklocs)

    def _update_label_postion(self, overlaps):
        """
        Update the label position based on the sequence of bounding
        boxes overlaps of all the ticklabels that overlap the current
        ticklabel.  overlaps are the bounding boxes returned by the
        get_window_extent of ticklabels
        """
        raise NotImplementedError('Derived must override')
        
    def zoom(self, direction):
        "Zoom in/out on axis; if direction is >0 zoom in, else zoom out"
        vmin, vmax = self.viewlim.bounds()
        interval = self.viewlim.interval()
        step = 0.1*interval*direction
        self.viewlim.set_min(vmin + step)
        self.viewlim.set_max(vmax - step)

class XAxis(Axis):
    __name__ = 'xaxis'

    def _get_default_tick(self):
        return XTick(self.axes, 0, '')

    def _get_default_label(self):
        label = backends.AxisText(
            self.dpi, 
            self.bbox, 
            x = 0.5,       # centered
            y = self.LABELPAD,     # rel coords under axis line
            fontsize=12,
            verticalalignment='top',
            horizontalalignment='center',
            transx = self.transAxis,
            transy = self.axes.yaxis.transAxis,
            )
        return label

    def _get_display_lim(self):
        return self.bbox.x

    def _update_label_postion(self, overlaps):
        """
        Update the label position based on the sequence of bounding
        boxes overlaps of all the ticklabels that overlap the current
        ticklabel.  overlaps are the bounding boxes returned by the
        get_window_extent of ticklabels
        """

        x,y = self._label.get_position()
        if not len(overlaps): return

        bbox = bound2d_all(overlaps)
        bottomAxes = (bbox.y.min()-self.bbox.y.min())/self.bbox.y.interval()
        self._label.set_position((x,bottomAxes+self.LABELPAD))
        

class YAxis(Axis):
    __name__ = 'yaxis'

    def _get_default_tick(self):
        return YTick(self.axes, 0, '')


    def _get_default_label(self):
        label = backends.AxisText(
            dpi=self.dpi,
            bbox = self.bbox, 
            x = self.LABELPAD,    # rel coords to the left of axis line
            y = 0.5,      # centered
            fontsize=12,
            verticalalignment='center',
            horizontalalignment='right',
            rotation='vertical', 
            transx = self.axes.xaxis.transAxis,
            transy = self.transAxis,
            )
        return label

    def _get_display_lim(self):
        return self.bbox.y


    def _update_label_postion(self, overlaps):
        """
        Update the label position based on the sequence of bounding
        boxes overlaps of all the ticklabels that overlap the current
        ticklabel.  overlaps are the bounding boxes returned by the
        get_window_extent of ticklabels
        """
        x,y = self._label.get_position()
        if not len(overlaps):
            #print 'no yticklabel overlaps'
            return
        #print 'found', len(overlaps)
        bbox = bound2d_all(overlaps)
        left = (bbox.x.min()-self.bbox.x.min())/self.bbox.x.interval()
        self._label.set_position((left+self.LABELPAD,y))



def _process_plot_format(fmt):
    """
    Process a matlab style color/line style format string.  Return a
    linestyle, color tuple as a result of the processing.  Default
    values are ('-', 'b').  Example format strings include

    'ko'    : black circles
    '.b'    : blue dots
    'r--'   : red dashed lines

    See Line2D.lineStyles and GraphicsContext.colors for all possible
    styles and color format string.

    """

    colors = {
        'b' : 1,
        'g' : 1,
        'r' : 1,
        'c' : 1,
        'm' : 1,
        'y' : 1,
        'k' : 1,
        'w' : 1,
        }

    
    linestyle = None
    marker = None
    color = 'b'

    # handle the multi char special cases and strip them from the
    # string
    if fmt.find('--')>=0:
        linestyle = '--'
        fmt = fmt.replace('--', '')
    if fmt.find('-.')>=0:
        linestyle = '-.'
        fmt = fmt.replace('-.', '')
    
    chars = [c for c in fmt]

    for c in chars:        
        if lineStyles.has_key(c):
            if linestyle is not None:
                raise ValueError, 'Illegal format string "%s"; two linestyle symbols' % fmt
            
            linestyle = c
        elif lineMarkers.has_key(c):
            if marker is not None:
                raise ValueError, 'Illegal format string "%s"; two marker symbols' % fmt
            marker = c
        elif colors.has_key(c):
            color = c
        else:
            err = 'Unrecognized character %c in format string' % c
            raise ValueError, err
    if linestyle is None and marker is None:
        linestyle = '-'
    return linestyle, marker, color


class _process_plot_var_args:    
    """

    Process variable length arguments to the plot command, so that
    plot commands like the followig are supported

      plot(t, s)
      plot(t1, s1, t2, s2)
      plot(t1, s1, 'ko', t2, s2)
      plot(t1, s1, 'ko', t2, s2, 'r--', t3, e3)

    an arbitrary number of x,y,fmt are allowed
    """

    colors = ('b','g','r','c','m','y','k')
    Ncolors = len(colors)
    def __init__(self, dpi, bbox, transx, transy):
        self.dpi = dpi
        self.bbox = bbox
        self.transx = transx
        self.transy = transy
        self.count = 0
        
    def __call__(self, *args):
        ret =  self._grab_next_args(*args)
        return ret
    def _plot_1_arg(self, y):
        ret =  Line2D(self.dpi, self.bbox, arange(len(y)), y,
                      color = self.colors[self.count % self.Ncolors],
                      transx = self.transx,
                      transy = self.transy,	
                      )
        self.count += 1
        return ret

    def _plot_2_args(self, tup2):
        if is_string_like(tup2[1]):
            y, fmt = tup2
            linestyle, marker, color = _process_plot_format(fmt)
            return Line2D(self.dpi, self.bbox,
                          xdata=arange(len(y)), ydata=y,
                          color=color, linestyle=linestyle, marker=marker,
                          transx = self.transx,
                          transy = self.transy,	
                          )
        else:
            x,y = tup2
            #print self.count, self.Ncolors, self.count % self.Ncolors
            ret =  Line2D(self.dpi, self.bbox, x, y,
                          color = self.colors[self.count % self.Ncolors],
                          transx = self.transx,
                          transy = self.transy,	
                          )
            self.count += 1
            return ret


    def _plot_3_args(self, tup3):
        x, y, fmt = tup3
        linestyle, marker, color = _process_plot_format(fmt)
        return Line2D(self.dpi, self.bbox,
                      x, y, color=color, linestyle=linestyle, marker=marker,
                      transx = self.transx,
                      transy = self.transy,	
                      )

    def _grab_next_args(self, args):
        remaining = args
        while 1:
            if len(remaining)==0: return
            if len(remaining)==1:
                yield self._plot_1_arg(remaining[0])
                remaining = []
                continue
            if len(remaining)==2:
                yield self._plot_2_args(remaining)
                remaining = []
                continue
            if len(remaining)==3:
                if not is_string_like(remaining[2]):
                    raise ValueError, 'third arg must be a format string'
                yield self._plot_3_args(remaining)
                remaining=[]
                continue
            if is_string_like(remaining[2]):
                yield self._plot_3_args(remaining[:3])
                remaining=remaining[3:]
            else:
                yield self._plot_2_args(remaining[:2])
                remaining=remaining[2:]
            #yield self._plot_2_args(remaining[:2])
            #remaining=args[2:]
        


class Legend(Artist):
    """
    Place a legend on the axes at location loc.  Labels are a
    sequence of strings and loc can be a string or an integer
    specifying the legend location

    The location codes are

      'best' : 0,          (currently not supported, defaults to upper right)
      'upper right'  : 1,  (default)
      'upper left'   : 2,
      'lower left'   : 3,
      'lower right'  : 4,
      'right'        : 5,
      'center left'  : 6,
      'center right' : 7,
      'lower center' : 8,
      'upper center' : 9,
      'center'       : 10,
 
    Return value is a sequence of text, line instances that make
    up the legend
    """


    codes = {'best'        : 0,
             'upper right' : 1,  # default
             'upper left'  : 2,
             'lower left'  : 3,
             'lower right' : 4,
             'right'       : 5,
             'center left'  : 6,
             'center right' : 7,
             'lower center'  : 8,
             'upper center' : 9,
             'center' : 10,
             }


    NUMPOINTS = 4      # the number of points in the legend line
    FONTSIZE = 10
    PAD = 0.2          # the fractional whitespace inside the legend border
    # the following dimensions are in axes coords
    LABELSEP = 0.005   # the vertical space between the legend entries
    HANDLELEN = 0.05     # the length of the legend lines
    HANDLETEXTSEP = 0.02 # the space between the legend line and legend text
    AXESPAD = 0.02     # the border between the axes and legend edge


    def __init__(self, axis, handles, labels, loc):
        Artist.__init__(self, axis.dpi, axis.bbox)
        if is_string_like(loc) and not self.codes.has_key(loc):
            print >>sys.stderr, 'Unrecognized location %s. Falling back on upper right; valid locations are\n%s\t' %(loc, '\n\t'.join(self.codes.keys()))
        if is_string_like(loc): loc = self.codes.get(loc, 1)

        self._axis = axis
        self._loc = loc   

        # use axes coords
        self.transx = Transform( Bound1D(0,1), self.bbox.x )
        self.transy = Transform( Bound1D(0,1), self.bbox.y )

        # make a trial box in the middle of the axes.  relocate it
        # based on it's bbox
        left, upper = 0.5, 0.5  
        self._xdata = linspace(left, left + self.HANDLELEN, self.NUMPOINTS)        
        textleft = left+ self.HANDLELEN+self.HANDLETEXTSEP
        self._texts = self._get_texts(labels, textleft, upper)
        self._handles = self._get_handles(handles, self._texts)

        left, top = self._texts[-1].get_position()
        HEIGHT = self._approx_text_height()
        bottom = top-HEIGHT
        left -= self.HANDLELEN + self.HANDLETEXTSEP + self.PAD
        self._patch = Rectangle(
            self.dpi, self.bbox,
            xy=(left, bottom), width=0.5, height=HEIGHT*len(self._texts),
            facecolor=axis.get_axis_bgcolor(), edgecolor='k',
            transx = self.transx,
            transy = self.transy,
            )

    def _approx_text_height(self):
        return self.FONTSIZE/72.0*self.dpi.get()/self.bbox.y.interval()

            
    def _draw(self, drawable):
        self._update_positions()
        self._patch.draw(drawable)
        for h in self._handles: h.draw(drawable)
        for t in self._texts:
            if 0: bbox_artist(t, drawable)
            t.draw(drawable)

    def _get_handle_text_bbox(self):
        'Get a bbox for the text and lines in axes coords'
        boxes = [t.get_window_extent() for t in self._texts]
        boxes.extend([h.get_window_extent() for h in self._handles])
        bbox = bound2d_all(boxes)
        return inverse_transform_bound2d(bbox, self.transx, self.transy)

        
    def _get_handles(self, handles, texts):
        HEIGHT = self._approx_text_height()

        ret = []   # the returned legend lines
        for handle, label in zip(handles, texts):
            x, y = label.get_position()
            x -= self.HANDLELEN + self.HANDLETEXTSEP
            if isinstance(handle, Line2D):
                ydata = (y-HEIGHT/2)*ones(self._xdata.shape, typecode=Float)
                legline = Line2D(self.dpi, self.bbox, self._xdata, ydata,
                                 transx=self.transx, transy=self.transy)
                legline.copy_properties(handle)
                legline.set_markersize(0.6*legline.get_markersize())
                legline.set_data_clipping(False)
                ret.append(legline)
            elif isinstance(handle, Rectangle):

                p = Rectangle(self.dpi, self.bbox,
                              xy=(min(self._xdata), y-3/4*HEIGHT),
                              width = self.HANDLELEN, height=HEIGHT/2,
                              transx=self.transx, transy=self.transy)
                p.copy_properties(handle)
                ret.append(p)
                
                
                    
        return ret


    def _get_texts(self, labels, left, upper):

        # height in axes coords
        HEIGHT = self._approx_text_height()
        pos = upper
        x = left 

        ret = []  # the returned list of text instances
        for l in labels:
            text = backends.AxisText(
                self.dpi, self.bbox,
                x=x, y=pos,
                text=l, fontsize=self.FONTSIZE,
                verticalalignment='top',
                horizontalalignment='left',
                transx = self.transx,
                transy = self.transy,
                )
            ret.append(text)
            pos -= HEIGHT
            
        return ret

    def get_child_artists(self):
        l =  [self._patch]
        l.extend(self._texts)
        l.extend(self._handles)
        return l
            
    def get_window_extent(self):
        return self._patch.get_window_extent()


    def _offset(self, ox, oy):
        'Move all the artists by ox,oy (axes coords)'
        for t in self._texts:
            x,y = t.get_position()
            t.set_position( (x+ox, y+oy) )

        for h in self._handles:
            if isinstance(h, Line2D):
                x,y = h.get_xdata(), h.get_ydata()
                h.set_data( x+ox, y+oy)
            elif isinstance(h, Rectangle):
                h.xy[0] = h.xy[0] + ox
                h.xy[1] = h.xy[1] + oy
                
        x, y = self._patch.get_x(), self._patch.get_y()
        self._patch.set_x(x+ox)
        self._patch.set_y(y+oy)

    def _update_positions(self):
        # called from drawable to allow more precise estimates of
        # widths and heights with get_window_extent

        def get_tbounds(text):  #get text bounds in axes coords
            bbox = text.get_window_extent()
            bboxa = inverse_transform_bound2d(bbox, self.transx, self.transy)
            return bboxa.get_bounds()
            
        hpos = []
        for t, tabove in zip(self._texts[1:], self._texts[:-1]):
            x,y = t.get_position()
            l,b,w,h = get_tbounds(tabove)
            hpos.append( (b,h) )
            t.set_position( (x, b-0.1*h) )

        # now do the same for last line
        l,b,w,h = get_tbounds(self._texts[-1])
        hpos.append( (b,h) )
        
        for handle, tup in zip(self._handles, hpos):
            y,h = tup
            if isinstance(handle, Line2D):
                ydata = y*ones(self._xdata.shape, Float)            
                handle.set_ydata(ydata+h/2)
            elif isinstance(handle, Rectangle):
                handle.set_y(y+1/4*h)
                handle.set_height(h/2)

        # Set the data for the legend patch
        bbox = self._get_handle_text_bbox()
        bbox.x.scale(1 + self.PAD)
        bbox.y.scale(1 + self.PAD)
        l,b,w,h = bbox.get_bounds()
        self._patch.set_bounds(l,b,w,h)
        
        BEST, UR, UL, LL, LR, R, CL, CR, LC, UC, C = range(11)
        ox, oy = 0, 0                      # center
        if self._loc in (UL, LL, CL):      # left
            ox = self.AXESPAD - l
        if self._loc in (BEST, UR, LR, R, CR):  # right
            ox = 1 - (l + w + self.AXESPAD)
        if self._loc in (BEST, UR, UL, UC):     # upper
            oy = 1 - (b + h + self.AXESPAD)
        if self._loc in (LL, LR, LC):           # lower
            oy = self.AXESPAD - b
        if self._loc in (LC, UC, C):        # center x
            ox = (0.5-w/2)-l
        if self._loc in (CL, CR, C):        # center y
            oy = (0.5-h/2)-b
        self._offset(ox, oy)
        
class Axes(Artist):
    """
    Emulate matlab's axes command, creating axes with

      Axes(position=[left, bottom, width, height])

    where all the arguments are fractions in [0,1] which specify the
    fraction of the total figure window.  

    axisbg is the color of the axis background



    """

    def __init__(self, fig, position, axisbg = 'w'):
        bbox = Bound2D(0,0,1,1) # resize will update
        Artist.__init__(self, fig.dpi, bbox) 

        self._position = position
        self.figure = fig
        self.xaxis = XAxis(self)
        self.yaxis = YAxis(self)
        # these next two calls must be made immediately after both
        # axis are built since they are tightly coupled
        self.xaxis.build_artists()
        self.yaxis.build_artists()
        self.resize()  # compute bounding box


        self._axisbg = axisbg
        self._gridOn = False
        self._lines = []
        self._patches = []
        self._text = []     # text in axis coords




        self.grid(self._gridOn)


        self._get_lines = _process_plot_var_args(
            self.dpi, self.bbox, self.xaxis.transData, self.yaxis.transData)




        self._title =  backends.AxisText(
            self.dpi, 
            self.bbox, 
            x=0.5, y=1.01, text='', fontsize=14,
            verticalalignment='bottom',
            horizontalalignment='center',
            transx = self.xaxis.transAxis,
            transy = self.yaxis.transAxis,            
            )
        self._title.set_clip_on(False)

        self._axesPatch = Rectangle(
            self.dpi, self.bbox,
            xy=(0,0), width=1, height=1,
            facecolor=self._axisbg, edgecolor='k',
            transx = self.xaxis.transAxis,
            transy = self.yaxis.transAxis,            
            )
        self._axesPatch.set_clip_on(False)
        self.axison = True
        
    def _pass_func(self, *args, **kwargs):
        pass
    
    def add_line(self, line):
        "Add a line to the list of plot lines"
        self.xaxis.datalim.update(line.get_xdata())
        self.yaxis.datalim.update(line.get_ydata())
        self._lines.append(line)

    def add_patch(self, patch):
        "Add a line to the list of plot lines"
        


    def bar(self, x, y, width=0.8, color='b',
            yerr=None, xerr=None, capsize=3):
        """
        Make a bar plot with rectangles at x, x+width, 0, y
        x and y are Numeric arrays

        Return value is a list of Rectangle patch instances

        xerr and yerr, if not None, will be used to generate errorbars
        on the bar chart
        """
        patches = []
        for thisX,thisY in zip(x,y):
            r = Rectangle(
                self.dpi, self.bbox,
                xy=(thisX,0), width=width, height=thisY,
                facecolor=color,
                transx = self.xaxis.transData, 
                transy = self.yaxis.transData, 

                )
            self._patches.append(r)
            patches.append(r)

        self.xaxis.datalim.update(x)
        self.yaxis.datalim.update(y)

        if xerr is not None or yerr is not None:
            l1, l2 = self.errorbar(x+width/2, y, yerr=yerr, xerr=xerr,
                                   fmt='o', capsize=capsize)
            for line in l1:
                line.set_markerfacecolor('k')
                line.set_markeredgecolor('k')


        self.xaxis.autoscale_view()
        self.yaxis.autoscale_view()

        return patches



    def clear(self):
        # TODO: figure out what you want clear to do in relation to axes
        self._lines = []
        self._patches = []
        self.wash_brushes()

    def cohere(self, x, y, NFFT=256, Fs=2, detrend=mlab.detrend_none,
               window=mlab.window_hanning, noverlap=0):
        """
        cohere the coherence between x and y.  Coherence is the normalized
        cross spectral density

        Cxy = |Pxy|^2/(Pxx*Pyy)

        The return value is (Cxy, f), where f are the frequencies of the
        coherence vector.  See the docs for psd and csd for information
        about the function arguments NFFT, detrend, windowm noverlap, as
        well as the methods used to compute Pxy, Pxx and Pyy.

        Returns the tuple Cxy, freqs

        Refs:
          Bendat & Piersol -- Random Data: Analysis and Measurement
            Procedures, John Wiley & Sons (1986)

        """

        cxy, freqs = mlab.cohere(x, y, NFFT, Fs, detrend, window, noverlap)

        self.plot(freqs, cxy)
        self.set_xlabel('Frequency')
        self.set_ylabel('Coherence')
        self.grid(True)

        return cxy, freqs

    def csd(self, x, y, NFFT=256, Fs=2, detrend=mlab.detrend_none,
            window=mlab.window_hanning, noverlap=0):
        """
        The cross spectral density Pxy by Welches average periodogram
        method.  The vectors x and y are divided into NFFT length
        segments.  Each segment is detrended by function detrend and
        windowed by function window.  noverlap gives the length of the
        overlap between segments.  The product of the direct FFTs of x and
        y are averaged over each segment to compute Pxy, with a scaling to
        correct for power loss due to windowing.  Fs is the sampling
        frequency.

        NFFT must be a power of 2

        detrend and window are functions, unlike in matlab where they are
        vectors.  For detrending you can use detrend_none, detrend_mean,
        detrend_linear or a custom function.  For windowing, you can use
        window_none, window_hanning, or a custom function

        Returns the tuple Pxy, freqs.  Pxy is the cross spectrum (complex
        valued), and 10*log10(|Pxy|) is plotted

        Refs:
          Bendat & Piersol -- Random Data: Analysis and Measurement
            Procedures, John Wiley & Sons (1986)

        """

        pxy, freqs = mlab.csd(x, y, NFFT, Fs, detrend, window, noverlap)
        pxy.shape = len(freqs),
        # pxy is complex

        self.plot(freqs, 10*log10(absolute(pxy)))
        self.set_xlabel('Frequency')
        self.set_ylabel('Cross Spectrum Magnitude (dB)')
        self.grid(True)
        vmin, vmax = self.yaxis.viewlim.bounds()
        intv = vmax-vmin
        step = 10*int(log10(intv))
        ticks = arange(math.floor(vmin), math.ceil(vmax)+1, step)
        self.set_yticks(ticks)

        return pxy, freqs


        
    def _draw(self, drawable, *args, **kwargs):
        "Draw everything (plot lines, axes, labels)"

        if not ( self.xaxis.viewlim.defined() and
                 self.yaxis.viewlim.defined() ):
            self.update_viewlim()

        if self.axison:
            self._axesPatch.draw(drawable)
            self.xaxis.draw(drawable)
            self.yaxis.draw(drawable)

            
                      
        for p in self._patches:
            p.draw(drawable)

        for line in self._lines:
            line.draw(drawable)

        for t in self._text:
            t.draw(drawable)

        self._title.draw(drawable)

        # optional artists
        if hasattr(self, '_legend'):
            self._legend.draw(drawable)


    def errorbar(self, x, y, yerr=None, xerr=None, fmt='b-', capsize=3):
        """
        Plot x versus y with error deltas in yerr and xerr.
        Vertical errorbars are plotted if yerr is not None
        Horizontal errorbars are plotted if xerr is not None

        xerr and yerr may be any of:
            a rank-0, Nx1 Numpy array  - symmetric errorbars +/- value
            an N-element list or tuple - symmetric errorbars +/- value
            a rank-1, Nx2 Numpy array  - asymmetric errorbars -column1/+column2

        fmt is the plot format symbol for y.  if fmt is None, just
        plot the errorbars with no line symbols.  This can be useful
        for creating a bar plot with errorbars

        Return value is a length 2 tuple.  The first element is a list of
        y symbol lines.  The second element is a list of error bar lines.

        capsize is the size of the error bar caps in points
        """
        x = asarray(x)
        y = asarray(y)

        if fmt is not None: l0 = self.plot(x,y,fmt)
        else: l0 = None
        l = []

        if xerr is not None: xerr = asarray(xerr)
        if yerr is not None: yerr = asarray(yerr)

        
        def get_ybar_cap(x, y):
            xtrans = self.xaxis.get_pts_transform(
                x, self.xaxis.transData)
            line = Line2D( self.dpi, self.bbox,
                           xdata=(-capsize, capsize), ydata=(y, y),
                           color='k',
                           transx = xtrans,
                           transy = self.yaxis.transData,
                           )
            line.set_data_clipping(False)
            return line

        def get_xbar_cap(x, y):
            ytrans = self.yaxis.get_pts_transform(
                y, self.yaxis.transData)

            line = Line2D( self.dpi, self.bbox,
                           xdata=(x,x), ydata=(-capsize, capsize),
                           color='k',
                           transx = self.xaxis.transData,
                           transy = ytrans,
                           )
            line.set_data_clipping(False)
            return line

        # horizontal errorbars
        if xerr is not None:
            if len(xerr.shape) == 1:
                left  = x-xerr
                right = x+xerr
            else:
                left  = x-xerr[0]
                right = x+xerr[1]

            # horizontal errorbars
            l.extend( self.hlines(y, x, left) )
            l.extend( self.hlines(y, x, right) )
            # bar caps
            for yval, lval, rval in zip(y, left, right):
                l1 = get_xbar_cap(lval, yval) 
                l2 = get_xbar_cap(rval, yval) 
                self.add_line(l1)
                self.add_line(l2)
                l.extend( (l1,l2) )


        # vertical errorbars
        if yerr is not None:
            if len(yerr.shape) == 1:
                lower = y-yerr
                upper = y+yerr
            else:
                lower = y-yerr[0]
                upper = y+yerr[1]

            # base the bar-end length on the overall horizontal plot dimension

            # vertical errorbars
            l.extend( self.vlines(x, y, lower) )
            l.extend( self.vlines(x, y, upper) )

            # bar caps
            for xval, uval, lval in zip(x, upper, lower):
                l1 = get_ybar_cap(xval, uval) 
                l2 = get_ybar_cap(xval, lval) 
                self.add_line(l1)
                self.add_line(l2)
                l.extend( (l1,l2) )

        return (l0, l)


    def get_axis_bgcolor(self):
        'Return the axis background color'
        return self._axisbg

    def get_child_artists(self):
        artists = [self._title, self._axesPatch, self.xaxis, self.yaxis]
        artists.extend(self._lines)
        artists.extend(self._patches)
        artists.extend(self._text)
        if hasattr(self, '_legend'):
            artists.append(self._legend)
        return artists
    
        
    def get_xaxis(self):
        "Return the XAxis instance"
        return self.xaxis

    def get_xgridlines(self):
        "Get the x grid lines as a list of Line2D instances"
        return self.xaxis.get_gridlines()

    def get_xlim(self):
        "Get the x axis range [xmin, xmax]"
        return self.xaxis.viewlim.bounds()

    def get_xticklabels(self):
        "Get the xtick labels as a list of AxisText instances"
        return self.xaxis.get_ticklabels()

    def get_xticklines(self):
        "Get the xtick lines as a list of Line2D instances"
        return self.xaxis.get_ticklines()
    

    def get_xticks(self):
        "Return the x ticks as a list of locations"
        return self.xaxis.get_ticklocs()

    def get_yaxis(self):
        "Return the YAxis instance"
        return self.yaxis

    def get_ylim(self):
        "Get the y axis range [ymin, ymax]"
        return self.yaxis.viewlim.bounds()

    def get_ygridlines(self):
        "Get the y grid lines as a list of Line2D instances"
        return self.yaxis.get_gridlines()

    def get_yticklabels(self):
        "Get the ytick labels as a list of AxisText instances"
        return self.yaxis.get_ticklabels() 

    def get_yticklines(self):
        "Get the ytick lines as a list of Line2D instances"
        return self.yaxis.get_ticklines()

    def get_yticks(self):
        "Return the y ticks as a list of locations"
        return self.yaxis.get_ticklocs()  

    def grid(self,b):
        "Set the axes grids on or off; b is a boolean"
        self.xaxis.grid(b)
        self.yaxis.grid(b)

    def hist(self, x, bins=10, normed=0):
        """
        Compute the histogram of x.  bins is either an integer number of
        bins or a sequence giving the bins.  x are the data to be binned.

        if noplot is True, just compute the histogram and return the
        number of observations and the bins as an (n, bins) tuple.

        If noplot is False, compute the histogram and plot it, returning
        n, bins, patches

        If normed is true, the first element of the return tuple will be the
        counts normalized to form a probability distribtion, ie,
        n/(len(x)*dbin)

        """
        n,bins = mlab.hist(x, bins, normed)
        width = 0.9*(bins[1]-bins[0])
        patches = self.bar(bins, n, width=width)
        return n, bins, patches
        
    def in_axes(self, xwin, ywin):
        return self.bbox.x.in_interval(xwin) and \
               self.bbox.y.in_interval(ywin)

    def hlines(self, y, xmin, xmax, fmt='k-'):
        """
        plot horizontal lines at each y from xmin to xmax.  xmin or
        xmax can be scalars or len(x) numpy arrays.  If they are
        scalars, then the respective values are constant, else the
        widths of the lines are determined by xmin and xmax

        Returns a list of line instances that were added
        """
        linestyle, marker, color = _process_plot_format(fmt)
        
        # todo: fix me for y is scalar and xmin and xmax are iterable
        y = asarray(y)
        xmin = asarray(xmin)
        xmax = asarray(xmax)
        
        if len(xmin)==1:
            xmin = xmin*ones(y.shape, y.typecode())
        if len(xmax)==1:
            xmax = xmax*ones(y.shape, y.typecode())

        if len(xmin)!=len(y):
            raise ValueError, 'xmin and y are unequal sized sequences'
        if len(xmax)!=len(y):
            raise ValueError, 'xmax and y are unequal sized sequences'

        lines = []
        for (thisY, thisMin, thisMax) in zip(y,xmin,xmax):            
            line = Line2D(
                self.dpi, self.bbox,
                [thisMin, thisMax], [thisY, thisY],
                color=color, linestyle=linestyle, marker=marker,
                transx = self.xaxis.transData, 
                transy = self.yaxis.transData,
                )
            self.add_line( line )
            lines.append(line)
        return lines


    def legend(self, *args, **kwargs):
        """
        Place a legend on the current axes at location loc.  Labels are a
        sequence of strings and loc can be a string or an integer
        specifying the legend location

        USAGE: 

          Make a legend with existing lines
          legend( LABELS )
          >>> legend( ('label1', 'label2', 'label3') ) 

          Make a legend for Line2D instances lines1, line2, line3
          legend( LINES, LABELS )
          >>> legend( (line1, line2, line3), ('label1', 'label2', 'label3') )

          Make a legend at LOC
          legend( LABELS, LOC )  or
          legend( LINES, LABELS, LOC )
          >>> legend( ('label1', 'label2', 'label3'), loc='upper left')
          >>> legend( (line1, line2, line3),
                      ('label1', 'label2', 'label3'),
                      loc=2)

        The LOC location codes are

        The location codes are

          'best' : 0,          (currently not supported, defaults to upper right)
          'upper right'  : 1,  (default)
          'upper left'   : 2,
          'lower left'   : 3,
          'lower right'  : 4,
          'right'        : 5,
          'center left'  : 6,
          'center right' : 7,
          'lower center' : 8,
          'upper center' : 9,
          'center'       : 10,

        Return value is a sequence of text, line instances that make
        up the legend
        """

        loc = 1
        if len(args)==1:
            # LABELS
            labels = args[0]
            lines = [line for line, label in zip(self._lines, labels)]
        elif len(args)==2:
            if is_string_like(args[1]) or isinstance(args[1], int):
                # LABELS, LOC
                labels, loc = args
                lines = [line for line, label in zip(self._lines, labels)]
            else:
                # LINES, LABELS
                lines, labels = args
        elif len(args)==3:
            # LINES, LABELS, LOC
            lines, labels, loc = args
        else:
            raise RuntimeError('Invalid arguments to legend')

        lines = flatten(lines)
        self._legend = Legend(self, lines, labels, loc)

    def loglog(self, *args, **kwargs):
        """
        Make a loglog plot with log scaling on the a and y axis.  The args
        to semilog x are the same as the args to plot.  See help plot for
        more info
        """

        self.set_xscale('log')
        self.set_yscale('log')
        l = self.plot(*args, **kwargs)
        return l


    def panx(self, numsteps):
        "Pan the x axis numsteps (plus pan right, minus pan left)"
        self.xaxis.pan(numsteps)
        xmin, xmax = self.xaxis.viewlim.bounds()
        for line in self._lines:
            line.set_xclip(xmin, xmax)

    def pany(self, numsteps):
        "Pan the x axis numsteps (plus pan up, minus pan down)"
        self.yaxis.pan(numsteps)


    def pcolor(self, *args, **kwargs):
        """
        pcolor(C) - make a pseudocolor plot of matrix C

        pcolor(X, Y, C) - a pseudo color plot of C on the matrices X and Y  

        Shading:

          The optional keyword arg shading ('flat' or 'faceted') will
          determine whether a black grid is drawn around each pcolor
          square.  Default 'faceteted'
             e.g.,   
             pcolor(C, shading='flat')  
             pcolor(X, Y, C, shading='faceted')

        returns a list of patch objects
        """


        shading = kwargs.get('shading', 'faceted')

        if len(args)==1:
            C = args[0]
            numRows, numCols = C.shape
            X, Y = meshgrid(range(numRows), range(numCols))
        elif len(args)==3:
            X, Y, C = args
        else:
            raise RuntimeError('Illegal arguments to pcolor; see help(pcolor)')


        Nx, Ny = X.shape
        cmap = ColormapJet(256)

        cmin = MLab.min(MLab.min(C))
        cmax = MLab.max(MLab.max(C))

        patches = []
        for i in range(Nx-1):
            for j in range(Ny-1):
                c = C[i,j]
                color = cmap.get_color(c, cmin, cmax)
                left = X[i,j]
                bottom = Y[i,j]
                width = X[i,j+1]-left
                height = Y[i+1,j]-bottom
                rect = Rectangle(
                    self.dpi, self.bbox,
                    (left, bottom), width, height,
                    transx = self.xaxis.transData,
                    transy = self.yaxis.transData,
                    )
                rect.set_facecolor(color)
                if shading == 'faceted':
                    rect.set_linewidth(0.25)
                    rect.set_edgecolor('k')
                else:
                    rect.set_edgecolor(color)
                self._patches.append(rect)
                patches.append(rect)
        self.grid(0)

        minx = MLab.min(MLab.min(X))
        maxx = MLab.max(MLab.max(X))
        miny = MLab.min(MLab.min(Y))
        maxy = MLab.max(MLab.max(Y))
        self.xaxis.datalim.update((minx, maxx))
        self.yaxis.datalim.update((miny, maxy))
        self.xaxis.autoscale_view()
        self.yaxis.autoscale_view()
        return patches


    def plot(self, *args):
        """
        Emulate matlab's plot command.  *args is a variable length
        argument, allowing for multiple x,y pairs with an optional
        format string.  For example, all of the following are legal,
        assuming a is the Axis instance:
        
          a.plot(x,y)            # plot Numeric arrays y vs x
          a.plot(x,y, 'bo')      # plot Numeric arrays y vs x with blue circles
          a.plot(y)              # plot y using x = arange(len(y))
          a.plot(y, 'r+')        # ditto with red plusses

        An arbitrary number of x, y, fmt groups can be specified, as in 
          a.plot(x1, y1, 'g^', x2, y2, 'l-')  

        Returns a list of lines that were added
        """

        lines = []
        for line in self._get_lines(args):
            self.add_line(line)
            lines.append(line)
        self.xaxis.autoscale_view()
        self.yaxis.autoscale_view()
        return lines

    def psd(self, x, NFFT=256, Fs=2, detrend=mlab.detrend_none,
            window=mlab.window_hanning, noverlap=0):
        """
        The power spectral density by Welches average periodogram method.
        The vector x is divided into NFFT length segments.  Each segment
        is detrended by function detrend and windowed by function window.
        noperlap gives the length of the overlap between segments.  The
        absolute(fft(segment))**2 of each segment are averaged to compute Pxx,
        with a scaling to correct for power loss due to windowing.  Fs is
        the sampling frequency.

        -- NFFT must be a power of 2

        -- detrend and window are functions, unlike in matlab where they
           are vectors.  For detrending you can use detrend_none,
           detrend_mean, detrend_linear or a custom function.  For
           windowing, you can use window_none, window_hanning, or a custom
           function

        -- if length x < NFFT, it will be zero padded to NFFT


        Returns the tuple Pxx, freqs

        For plotting, the power is plotted as 10*log10(pxx)) for decibels,
        though pxx itself is returned

        Refs:
          Bendat & Piersol -- Random Data: Analysis and Measurement
            Procedures, John Wiley & Sons (1986)

        """
        pxx, freqs = mlab.psd(x, NFFT, Fs, detrend, window, noverlap)
        pxx.shape = len(freqs),

        self.plot(freqs, 10*log10(pxx))
        self.set_xlabel('Frequency')
        self.set_ylabel('Power Spectrum Magnitude (dB)')
        self.grid(True)
        vmin, vmax = self.yaxis.viewlim.bounds()
        intv = vmax-vmin
        step = 10*int(log10(intv))
        ticks = arange(math.floor(vmin), math.ceil(vmax)+1, step)
        self.set_yticks(ticks)

        return pxx, freqs


    def resize(self):
        l,b,w,h = self.figure.bbox.get_bounds()
        l = l + self._position[0]*w
        b = b + self._position[1]*h
        w *= self._position[2]
        h *= self._position[3]
        self.bbox.set_bounds(l,b,w,h)
        
    def set_axis_off(self):
        self.axison = False

    def set_axis_on(self):
        self.axison = True

    def scatter(self, x, y, s=None, c='b'):
        """
        Make a scatter plot of x versus y.  s is a size (in data
        coords) and can be either a scalar or an array of the same
        length as x or y.  c is a color and can be a single color
        format string or an length(x) array of intensities which will
        be mapped by the colormap jet.        

        If size is None a default size will be used
        """

        if is_string_like(c):
            c = [c]*len(x)
        elif not iterable(c):
            c = [c]*len(x)
        else:
            jet = ColormapJet(1000)
            c = jet.get_colors(c)

        if s is None:
            s = [abs(0.015*(max(y)-min(y)))]*len(x)
        elif not iterable(s):
            s = [s]*len(x)
        
        if len(c)!=len(x):
            raise ValueError, 'c and x are not equal lengths'
        if len(s)!=len(x):
            raise ValueError, 's and x are not equal lengths'

        patches = []
        for thisX, thisY, thisS, thisC in zip(x,y,s,c):
            circ = Circle( self.dpi, self.bbox, (thisX, thisY),
                           radius=thisS,
                           transx = self.xaxis.transData,
                           transy = self.yaxis.transData,
                           )
            circ.set_facecolor(thisC)
            self._patches.append(circ)
            patches.append(circ)
        self.xaxis.datalim.update(x)
        self.yaxis.datalim.update(y)
        self.xaxis.autoscale_view()
        self.yaxis.autoscale_view()
        return patches

    def semilogx(self, *args, **kwargs):
        """
        Make a semilog plot with log scaling on the x axis.  The args to
        semilog x are the same as the args to plot.  See help plot for
        more info    
        """
        self.set_xscale('log')
        l = self.plot(*args, **kwargs)
        return l


    def semilogy(self, *args, **kwargs):
        """
        Make a semilog plot with log scaling on the y axis.  The args to
        semilog x are the same as the args to plot.  See help plot for
        more info    
        """
        self.set_yscale('log')
        l = selfplot(*args, **kwargs)
        return l


    def set_axis_bgcolor(self, color):
        self._axisbg = color

                                
    def set_title(self, label, fontdict=None, **kwargs):
        """
        Set the title for the xaxis

        See the text docstring for information of how override and the
        optional args work

        """
        override = {
            'fontsize' : 12,
            'verticalalignment' : 'bottom',
            'horizontalalignment' : 'left'
            }

        self._title.set_text(label)
        override = backends._process_text_args({}, fontdict, **kwargs)
        self._title.update_properties(override)
        return self._title


    def set_xlabel(self, xlabel, fontdict=None, **kwargs):
        """
        Set the label for the xaxis

        See the text docstring for information of how override and the
        optional args work

        """

        label = self.xaxis.get_label()
        label.set_text(xlabel)
        override = backends._process_text_args({}, fontdict, **kwargs)
        label.update_properties(override)
        return label

    def set_xlim(self, v):
        "Set the limits for the xaxis; v = [xmin, xmax]"

        xmin, xmax = v
        self.xaxis.viewlim.set_bounds(xmin, xmax)
        for line in self._lines:
            line.set_xclip(xmin, xmax)
        
    def set_xscale(self, value):
        self.xaxis.set_scale(value)

    def set_xticklabels(self, labels, fontdict=None, **kwargs):
        """
        Set the xtick labels with list of strings labels
        Return a list of axis text instances
        """
        return self.xaxis.set_ticklabels(labels, fontdict, **kwargs)

    def set_xticks(self, ticks):
        "Set the x ticks with list of ticks"
        self.xaxis.set_ticks(ticks)
        

    def set_ylabel(self, ylabel, fontdict=None, **kwargs):
        """
        Set the label for the yaxis

        Defaults override is

            override = {
               'fontsize'            : 10,
               'verticalalignment'   : 'center',
               'horizontalalignment' : 'right',
               'rotation'='vertical' : }

        See the text doctstring for information of how override and
        the optional args work
        """
        label = self.yaxis.get_label()
        label.set_text(ylabel)
        override = backends._process_text_args({}, fontdict, **kwargs)
        label.update_properties(override)
        return label

    def set_ylim(self, v):
        "Set the limits for the xaxis; v = [ymin, ymax]"
        ymin, ymax = v
        self.yaxis.viewlim.set_bounds(ymin, ymax)

    def set_yscale(self, value):
        self.yaxis.set_scale(value)


    def set_yticklabels(self, labels, fontdict=None, **kwargs):
        """
        Set the ytick labels with list of strings labels.
        Return a list of AxisText instances
        """
        return self.yaxis.set_ticklabels(labels, fontdict, **kwargs)
        
    def set_yticks(self, ticks):
        "Set the y ticks with list of ticks"
        self.yaxis.set_ticks(ticks)

    
    def text(self, x, y, text, fontdict=None, **kwargs):
        """
        Add text to axis at location x,y (data coords)

        fontdict is a dictionary to override the default text properties.
        If fontdict is None, the default is

        If len(args) the override dictionary will be:

          'fontsize'            : 10,
          'verticalalignment'   : 'bottom',
          'horizontalalignment' : 'left'


        **kwargs can in turn be used to override the override, as in

          a.text(x,y,label, fontsize=12)
        
        will have verticalalignment=bottom and
        horizontalalignment=left but will have a fontsize of 12
        
        
        The AxisText defaults are
            'color'               : 'k',
            'fontname'            : 'Sans',
            'fontsize'            : 10,
            'fontweight'          : 'bold',
            'fontangle'           : 'normal',
            'horizontalalignment' : 'left'
            'rotation'            : 'horizontal',
            'verticalalignment'   : 'bottom',
            'transx'              : self.xaxis.transData,
            'transy'              : self.yaxis.transData,            

        transx and transy specify that text is in data coords,
        alternatively, you can specify text in axis coords (0,0 lower
        left and 1,1 upper right).  The example below places text in
        the center of the axes

        ax = subplot(111)
        text(0.5, 0.5,'matplotlib', 
             horizontalalignment='center',
             verticalalignment='center',
             transx = ax.xaxis.transAxis,
             transy = ax.yaxis.transAxis,
        )
                

        """
        override = {
            'fontsize' : 10,
            'verticalalignment' : 'bottom',
            'horizontalalignment' : 'left',
            'transx' : self.xaxis.transData,
            'transy' : self.yaxis.transData,            
            }

        override = backends._process_text_args(override, fontdict, **kwargs)
        t = backends.AxisText(
            dpi=self.dpi,
            bbox = self.bbox,
            x=x, y=y, text=text,
            **override)
        t.set_renderer(self._drawable)

        self._text.append(t)
        return t
    
    def update_viewlim(self):
        'Update the view limits with all the data in self'

        for line in self._lines:
            xdata = line.get_xdata()
            self.xaxis.viewlim.update(xdata)
            ydata = line.get_ydata()
            self.yaxis.viewlim.update(ydata)

        for p in self._patches:
            l,b,w,h = p.get_data_extent().get_bounds()
            self.xaxis.viewlim.update( (l, l+w) )
            self.yaxis.viewlim.update( (b, b+h) )


    def vlines(self, x, ymin, ymax, color='k'):
        """
        Plot vertical lines at each x from ymin to ymax.  ymin or ymax
        can be scalars or len(x) numpy arrays.  If they are scalars,
        then the respective values are constant, else the heights of
        the lines are determined by ymin and ymax

        Returns a list of lines that were added
        """
        

        x = asarray(x)
        ymin = asarray(ymin)
        ymax = asarray(ymax)

        if len(ymin)==1:
            ymin = ymin*ones(x.shape, x.typecode())
        if len(ymax)==1:
            ymax = ymax*ones(x.shape, x.typecode())


        if len(ymin)!=len(x):
            raise ValueError, 'ymin and x are unequal sized sequences'
        if len(ymax)!=len(x):
            raise ValueError, 'ymax and x are unequal sized sequences'

        Y = transpose(array([ymin, ymax]))
        lines = []
        for thisX, thisY in zip(x,Y):
            line = Line2D(
                self.dpi, self.bbox,
                [thisX, thisX], thisY, color=color, linestyle='-',
                transx = self.xaxis.transData,
                transy = self.yaxis.transData,
                )
            self.add_line(line)
            lines.append(line)
        return lines


    def zoomx(self, numsteps):
        """
        Zoom in on the x xaxis numsteps (plus for zoom in, minus for zoom out)
        """
        self.xaxis.zoom(numsteps)
        xmin, xmax = self.xaxis.viewlim.bounds()
        for line in self._lines:
            line.set_xclip(xmin, xmax)

    def zoomy(self, numsteps):
        """
        Zoom in on the x xaxis numsteps (plus for zoom in, minus for zoom out)
        """
        self.yaxis.zoom(numsteps)

class Subplot(Axes):
    """
    Emulate matlab's subplot command, creating axes with

      Subplot(numRows, numCols, plotNum)

    where plotNum=1 is the first plot number and increasing plotNums
    fill rows first.  max(plotNum)==numRows*numCols

    You can leave out the commas if numRows<=numCols<=plotNum<10, as
    in

      Subplot(211)    # 2 rows, 1 column, first (upper) plot
    """
    
    def __init__(self, fig, *args, **kwargs):
        # Axes __init__ below
        if len(args)==1:
            s = str(*args)
            if len(s) != 3:
                raise ValueError, 'Argument to subplot must be a 3 digits long'
            rows, cols, num = map(int, s)
        elif len(args)==3:
            rows, cols, num = args
        else:
            raise ValueError, 'Illegal argument to subplot'
        total = rows*cols
        num -= 1    # convert from matlab to python indexing ie num in range(0,total)
        if num >= total:
            raise ValueError, 'Subplot number exceeds total subplots'
        left, right = .125, .9
        bottom, top = .11, .9
        rat = 0.2             # ratio of fig to seperator for multi row/col figs
        totWidth = right-left
        totHeight = top-bottom
    
        figH = totHeight/(rows + rat*(rows-1))
        sepH = rat*figH
    
        figW = totWidth/(cols + rat*(cols-1))
        sepW = rat*figW
    
        rowNum, colNum =  divmod(num, cols)
        
        figBottom = top - (rowNum+1)*figH - rowNum*sepH
        figLeft = left + colNum*(figW + sepH)

        Axes.__init__(self, fig, [figLeft, figBottom, figW, figH], **kwargs)

        self.rowNum = rowNum
        self.colNum = colNum
        self.numRows = rows
        self.numCols = cols

    def is_first_col(self):
        return self.colNum==0

    def is_first_row(self):
        return self.rowNum==0

    def is_last_row(self):
        return self.rowNum==self.numRows-1


    def is_last_col(self):
        return self.colNum==self.numCols-1
