# TODO: let the user pass the default font to the axes so that the
# text labels, tick labels, etc, can all be updated

# BUG -- the markers for circ and square in the gtk backend do not
# have the same center.  See the stock_demo for example.

from __future__ import generators
from __future__ import division

import math, re, sys

from Numeric import arange, array, ones, zeros, logical_and, \
     nonzero, take, Float, transpose, log

from cbook import iterable, is_string_like, flatten, enumerate
from mlab import linspace
from lines import Line2D, lineStyles, lineMarkers
from patches import Rectangle, Circle
from artist import Artist
from colormap import ColormapJet
import backends
#import AxisText, _process_text_args

True = 1
False = 0

def to_arrays(typecode, *args):
    ret = []
    for val in args:
        try: val.shape
        except AttributeError:
            if iterable(val):
                val = array(val, typecode=typecode)
            else:
                val = array([val], typecode=typecode)
        ret.append(val)


    if len(ret)==1:
        return ret[0]
    else:
        return ret
        
class Axis(Artist):

    def __init__(self):
        Artist.__init__(self)

        self._left = None
        self._right = None
        self._bottom = None
        self._top = None
        self._width = None
        self._height = None
        self._dataLim = None  #min, max of data coords
        self._viewLim = None  #min, max of view in data coords
        self._axisLim = None  #min, max of win coords

        self._updateAxisLines = 1
        self._updateLabel = 1 
        self._ticksize = 0.01  # fraction of win lim
        self._ticklocs = None  # None unless explicity set
        self._ticklabelStrings = None
        
        # strip trailing zeros from ticks
        self._zerorgx = re.compile('^(.*?)\.?0+(e[+-]\d+)?$')        
        
        self._ticklines1 = []
        self._ticklines2 = []
        self._gridlines = []
        self._ticklabels = []

        def identity(x): return x
        def logwarn(x):
            ind = nonzero(x>0)
            if len(ind)<len(x):
                print >>sys.stderr, 'log scale warning, negative data ignored'
                x = take(x, ind)
            return log(x)

            
        self._scale = 'linear'
        self._scalemap = {'linear' : identity,
                          'log' : logwarn}
        self._scalefunc = self._scalemap[self._scale]
        self._gridOn = 1
        # I want to create persistent refs to the ticklines,
        # ticklabels and grid lines.  So I am creating a large number
        # of the (100) and just returning the first numTicks number of
        # them when they are requested
        maxTicks = 100
        self._ticklines1 = []
        # 1 and 2 are the left/right axes (or top/bottom)
        for i in range(maxTicks):            
            self._ticklines1.append(Line2D( [0,0], [0,0], color='k'))
            self._ticklines2.append(Line2D( [0,0], [0,0], color='k'))
            self._gridlines.append(Line2D([0,0], [0,0], color='k',
                                          linestyle=':'))
            self._ticklabels.append(self.default_ticklabel())


    def autoscale_view(self):
        """
        Choose the view limits and number of ticks to make nice tick labels
        """
        vmin, vmax = self.get_view_lim()
        if vmin==vmax:
            vmin-=1
            vmax+=1

        (exponent, remainder) = divmod(math.log10(vmax - vmin), 1)
        if remainder < 0.84:
            exponent -= 1
        scale = 10**(-exponent)

        vmin = math.floor(scale*vmin)/scale
        vmax = math.ceil(scale*vmax)/scale
        self.set_view_lim(vmin, vmax)
        
    def default_ticklabel(self):
        """
        Create an axis text instance with the proper attributes (but
        no x,y,label) info
        """
        raise NotImplementedError, 'Derived must override'

    def _draw(self, drawable, *args, **kwargs):
        'Draw the axis lines, grid lines, tick lines and labels'

        lines = []
        self.update_axis_lines()


        lines.extend(self.get_ticklines())
        lines.extend(self.get_gridlines())
        
        for line in lines:
            line.draw(drawable)

        for t in self.get_ticklabels():
            t.erase()
            t.draw(drawable)

        self.update_label_position()
        self._label.draw(drawable)

    def get_ticklines(self):
        'Return a list of tick Line2D tick instances'
        numticks = self.get_numticks()
        # don't plot the first or last ticklines
        lines = self._ticklines1[1:numticks-1]
        lines.extend(self._ticklines2[1:numticks-1])
        return lines

    def get_gridlines(self):
        'Return a list of grid Line2D instances'
        if not self._gridOn: return []
        numticks = self.get_numticks()
        lines = self._gridlines[1:numticks-1]
        return lines

    def get_numticks(self):
        'Return the number of ticks'
        raise NotImplementedError, 'Derived must override'

    def get_data_distance(self):
        'Return  the distance max(datalim) - min(datalim)'
        if self._dataLim is None:
            raise RuntimeError, 'No data in range'
        return self._dataLim[1] - self._dataLim[0]

    def get_data_lim(self):
        'Return the tuple min(datalim), max(datalim)'
        if self._dataLim is None:
            raise RuntimeError, 'No data in range'
        #print 'datalim', self._dataLim
        return self._dataLim

    def get_label(self):
        'Return the axis label (AxisText instance)'
        return self._label
    
    def get_data_extent(self):
        "Data extent == window extent for Axis because tranfunc is identity"
        if self._left is None:
            raise RuntimeError, 'Extent is not set'
        return self._left, self._right, self._width, self._height

    def get_view_distance(self):
        'Return the distance max(viewlim) - min(viewlim)'
        if self._dataLim is None:
            vmin, vmax = self._scalefunc(array(self.get_view_lim()))
            return vmax - vmin
        if self._viewLim is None:
            return self.get_data_distance()
        else:
            return self._viewLim[1] - self._viewLim[0]

    def get_view_lim(self):
        'Return the view limits as tuple min(viewlim), max(viewlim)'
        if self._dataLim is None:
            lim = -1,1
            return lim
        if self._viewLim is None:
            lim = self.get_data_lim()
        else:
            lim =  self._viewLim
        assert(lim is not None)
        return lim

    def get_window_distance(self):
        'Return the distance max(windolim) - min(windowlim)'
        if self._left is None:
            raise RuntimeError, 'Window range not set range'
        wmin, wmax = self.get_window_lim()
        return int(wmax - wmin)
    
    def get_window_lim(self):
        'Return the window limits as tuple min(winlim), max(winlim)'
        raise NotImplementedError, 'Derived must override'

    def get_label(self):
        'Return the axis label as an AxisText instance'
        return self._label

    def get_ticklocs(self):
        "Get the tick locations in data coordinates as a Numeric array"
        if self._ticklocs is not None:
            return self._ticklocs
        
        numticks = self.get_numticks()

        if numticks==0: return []
        vmin, vmax = self.get_view_lim()
        d = self.get_view_distance()
        if numticks==1: 0.5*d
        step = d/(numticks-1)
        # add a small offset to include endpoint
        ticklocs = arange(vmin, vmax+0.1*d, step)
        return ticklocs

    def get_ticklocs_win(self):
        "Get the tick locations in window coordinates as a Numeric array"
        if self._ticklocs is None:
            wmin, wmax = self.get_window_lim()
            return linspace(wmin, wmax, self.get_numticks())
        else:
            return self.transform_to_display(self._ticklocs)

    def get_ticklabels(self):
        'Return a list of tick labels as AxisText instances'
        return self._ticklabels[:self.get_numticks()]

    def get_ticklabel_extent(self):
        """
        Get the extent of all the tick labels as tuple bottom, top,
        width, height
        """
        # this is reveresed not because of gtk but to insure a false
        # comparison
        bottom, top = self._top, self._bottom
        left, right = self._right, self._left
        for label in self.get_ticklabels():
            l,b,w,h = label.get_window_extent()
            r, t = l+w, b+h                        
            if b<bottom: bottom=b
            if t>top: top=t
            if l<left: left = l
            if r>right: right=r

        return left, bottom, right-left, top-bottom

    def grid(self, b):
        "Set the axis grid on or off; b is a boolean"
        self._gridOn = b
        
    def pan(self, numsteps):
        'Pan numticks (can be positive or negative)'
        vmin, vmax = self.get_view_lim()
        ticks =  self.get_ticklocs()
        if len(ticks)>2:
            step = (ticks[1]-ticks[0])*numsteps
        else:
            step = 0.1*(vmax-vmin)*numsteps
        vmin += step
        vmax += step
        self.set_view_lim(vmin, vmax)
        
    def set_window_extent(self, l, b, w, h):
        'Set the window extent as left, bottom, width, height'
        self._left, self._right = l, l+w
        self._bottom, self._top = b, b+h
        self._width, self._height = w, h
        self._updateAxisLines = 1
        self._updateLabel = 1

    def get_child_artists(self):
        'Return a list of all Artist instances contained by Axis'
        artists = []
        artists.extend(self._ticklabels)
        artists.extend(self._ticklines1)
        artists.extend(self._ticklines2)
        artists.extend(self._gridlines)
        artists.append(self._label)
        return artists


    def set_data_lim(self, dmin, dmax):
        'Set the data limits to dmin, dmax'
        self._dataLim = [dmin, dmax]

    def set_ticks(self, ticks):
        'Set the locations of the tick marks from sequence ticks'
        try: ticks.shape
        except AttributeError: ticks = array(ticks)
        self._ticklocs = ticks
        if self._viewLim is None and len(self._ticklocs)>1:
            self.set_view_lim(min(self._ticklocs), max(self._ticklocs))
        self.update_axis_lines(force=1)
        
    def set_ticklabels(self, ticklabels, *args, **kwargs):
        """
        Set the text values of the tick labels.  ticklabels is a
        sequence of strings
        """
        ticklabels = ['%s'%l for l in ticklabels]
        
        self._ticklabelStrings = ticklabels
        # init all the tick labels with ''
        #for i in range(self.get_numticks()):
        #    self._ticklabels[i].set_text('')
        # fill with the custom tick labels
        override = {}
        override = backends._process_text_args(override, *args, **kwargs)
        for s, label in zip(self._ticklabelStrings, self._ticklabels):
            label.set_text(s)
            label.update_properties(override)
            
    def set_view_lim(self, vmin, vmax):
        'Set the view limits (data coords) to vmin, vmax'

        self._viewLim = vmin, vmax
        locs = self.get_ticklocs()
        if self._ticklabelStrings is None:
            for label, loc in zip(self._ticklabels, locs):
                label.set_text(self.format_tickval(loc))

        
    def update_data(self, d):
        """
        Update the min, max of the data lim with values in min(d), max(d)
        if min(d) or max(d) exceed the existing limits
        """
        if len(d)==0: return
        mind = min(d)
        maxd = max(d)
        if self._dataLim is None:
            self._dataLim = [mind, maxd]
            return
        if mind < self._dataLim[0]: self._dataLim[0] = mind
        if maxd > self._dataLim[1]: self._dataLim[1] = maxd


    def transform_to_display(self, v):
        """
        Transform v data (v can be a scalar or Numeric array) into
        window coords
        """
        if iterable(v) and len(v)==0: return v
        vmin, vmax = self.get_view_lim()
        wmin, wmax = self.get_window_lim()
        v = self._scalefunc(v)
        return (wmax-wmin)/(vmax-vmin)*(v-vmin)+wmin




    def transform_scale_to_display(self, v):
        """
        Transform v scale (v can be a scalar or numpy array) into
        window coords
        """
        if iterable(v) and len(v)==0: return v
        wd = self.get_window_distance()        
        vd = self.get_view_distance()
        #print wd, vd
        try: v.shape
        except AttributeError: v = array(v)
        return abs(wd/vd)*v

    def transform_to_user(self, w):
        """
        Transform w data (w can be a scalar or Numeric array) into
        user coords
        """
        if iterable(w) and len(w)==0: return v
        vmin, vmax = self.get_view_lim()
        wmin, wmax = self.get_window_lim()
        wd = self.get_window_distance()        
        vd = self.get_view_distance()
        return vd/wd*(w-wmin)+vmin

        
    def transform_to_user_scale(self, w):
        """
        Transform w scale (w can be a scalar or numpy array) into
        window coords
        """
        if iterable(w) and len(w)==0: return w
        wd = self.get_window_distance()        
        vd = self.get_view_distance()
        #print wd, vd
        try: w.shape
        except AttributeError: w = array(w)
        return abs(vd/wd)*w

    def format_tickval(self, x):
        'Format the number x as a string'
        d = self.get_view_distance()
        #if the number is not too big and it's an int, format it as an
        #int
        if abs(x)<1e4 and x==int(x): return '%d' % x

        # if the value is just a fraction off an int, use the int
        if abs(x-int(x))<0.0001*d: return '%d' % int(x)

        # use exponential formatting for really big or small numbers,
        # else use float
        if abs(x) < 1e-4: fmt = '%1.3e'
        elif abs(x) > 1e5: fmt = '%1.3e'
        else: fmt = '%1.3f'
        s =  fmt % x

        # strip trailing zeros, remove '+', and handle exponential formatting
        m = self._zerorgx.match(s)
        if m:
            s = m.group(1)
            if m.group(2) is not None: s += m.group(2)
        s = s.replace('+', '')
        return s

    def update_label_position(self):
        """
        Update the position of the axis label so it doesn't conflict with
        the tick labels
        """
        raise NotImplementedError, 'Derived must override'

    def update_axis_lines(self):
        """
        Update the axis, tick and grid lines
        """
        raise NotImplementedError, 'Derived must override'

    def zoom(self, direction):
        "Zoom in/out on axis"
        vmin, vmax = self.get_view_lim()
        d = self.get_view_distance()
        vmin += 0.1*d*direction
        vmax -= 0.1*d*direction        
        self.set_view_lim(vmin, vmax)
        #self.autoscale_view()
    
class XAxis(Axis):

    def __init__(self, *args, **kwargs):
        Axis.__init__(self, *args, **kwargs)
        self._label = backends.AxisText(
            fontsize=12,
            verticalalignment='top',
            horizontalalignment='center')

    def default_ticklabel(self):
        "Create a default ticklabel"
        return  backends.AxisText(
            fontsize=10,
            verticalalignment='top',
            horizontalalignment='center')

    def get_numticks(self):
        if self._ticklocs is None:
            if self._width is None: return 0
            win = self._width/self._dpi # width of axes in inches
            if win>4: return 11
            else: return 6
        else: return len(self._ticklocs)

    def get_window_lim(self):
        if self._left is None:
            raise RuntimeError, 'set_window_extent must be called first'
        return self._left, self._right


    def set_window_extent(self, l, b, w, h):

        Axis.set_window_extent(self, l, b, w, h)
        self.winLim = self._left, self._right


    def update_axis_lines(self, force=0):
        if not force and not self._updateAxisLines: return 
        numticks = self.get_numticks()
        ticklocsData = self.get_ticklocs()
        ticklocsWin = self.get_ticklocs_win()
        # todo: ticksize in points?
        ticksize = self._ticksize*self._height
        tickLabels = map(self.format_tickval, ticklocsData)
        for i in range(numticks):
            self._ticklines1[i].set_data(
                [ticklocsWin[i], ticklocsWin[i]],
                [self._bottom, self._bottom+ticksize])
            self._ticklines1[i].clip_gc = self.clip_gc

            self._ticklines2[i].set_data(
                [ticklocsWin[i], ticklocsWin[i]],
                [self._top, self._top-ticksize])
            self._ticklines2[i].clip_gc = self.clip_gc

            self._gridlines[i].set_data(
                [ticklocsWin[i], ticklocsWin[i]],
                [self._bottom, self._top])
            self._gridlines[i].clip_gc = self.clip_gc

            self._ticklabels[i].set_position(
                (ticklocsWin[i], self._bottom-0.02*self._height))
            if self._ticklabelStrings is None:
                self._ticklabels[i].set_text(tickLabels[i])

        self._updateLabel = 1
        self._updateAxisLines = 0

    def update_label_position(self):
        "Update the position of the axis label"
        # this cannot be done in set_window_extent because we can't assume
        # that children know their extent during a set extent call
        if self._left is None:
            raise RuntimeError, 'You must first call set_window_extent on the xaxis'
        if not self._updateLabel: return 
        tickBottom = 0
        for i in range(self.get_numticks()):
            l,b,w,h = self._ticklabels[i].get_window_extent()
            if b>tickBottom: tickBottom = b

        # todo: fix me
        self._label.set_position( ( 0.5*(self._left+self._right), tickBottom) )
        self._updateLabel = 0
        

        
            
    
class YAxis(Axis):

    def __init__(self, *args, **kwargs):
        Axis.__init__(self, *args, **kwargs)
        self._label = backends.AxisText(
            fontsize=12,
            verticalalignment='center',
            horizontalalignment='right',
            rotation='vertical')        

    def default_ticklabel(self):
        "Create a default ticklabel"
        return  backends.AxisText(
            fontsize=10,
            verticalalignment='center',
            horizontalalignment='right')

    def get_numticks(self):
        if self._ticklocs is None:
            if self._height is None: return 0
            hin = self._height/self._dpi  # height of axes in inches
            if hin>4: return 11
            else: return 6
        else:
            return len(self._ticklocs)

    def get_window_lim(self):
        if self._top is None:
            raise RuntimeError, 'set_window_extent must be called first'
        return self._bottom, self._top

    def set_window_extent(self, l, b, w, h):
        Axis.set_window_extent(self, l, b, w, h)
        
        self.winLim = b, b+h

    def update_axis_lines(self, force=0):
        if not force and not self._updateAxisLines: return 
        numticks = self.get_numticks()
        ticklocsData = self.get_ticklocs()
        ticklocsWin = self.get_ticklocs_win()
        ticksize = self._ticksize*self._width
        tickLabels = map(self.format_tickval, ticklocsData)

        for i in range(numticks):
            self._ticklines1[i].set_data(
                [self._left, self._left+ticksize],
                [ticklocsWin[i], ticklocsWin[i]])
            self._ticklines2[i].set_data(
                [self._right, self._right-ticksize],
                [ticklocsWin[i], ticklocsWin[i]])
            self._gridlines[i].set_data(
                [self._left, self._right],
                [ticklocsWin[i], ticklocsWin[i]])
            self._ticklabels[i].set_position(
                (self._left-0.02*self._width, ticklocsWin[i]) )
            if self._ticklabelStrings is None:
                self._ticklabels[i].set_text(tickLabels[i])

        self._updateAxisLines = 0
        self._updateLabel = 1

    def update_label_position(self):
        "Update the position of the axis label"
        # this cannot be done in set_window_extent because we can't assume
        # that children know their extent during a set extent call
        if self._left is None:
            raise RuntimeError, 'You must first call set_window_extent on the yaxis'
        if not self._updateLabel: return 
        tickLeft = self._right  # compare false on first comparison
        for i in range(self.get_numticks()):
            l,b,w,h = self._ticklabels[i].get_window_extent()
            if l < tickLeft: tickLeft = l
        self._label.set_position( (tickLeft-3, (self._top+self._bottom)/2) )
        self._updateLabel = 0


def _process_plot_format(fmt):
    """
    Process a matlab style color/line style format string.  Return a
    linestyle, color tuple as a result of the processing.  Default
    values are ('-', 'b').  Example format strings include

    'ko'    : black circles
    '.b'    : blue dots
    'r--'   : red dashed lines

    See Line2D.lineStyles and GraphicsContext.colors for all possible
    styles and color format string.

    """

    colors = {
        'b' : 1,
        'g' : 1,
        'r' : 1,
        'c' : 1,
        'm' : 1,
        'y' : 1,
        'k' : 1,
        'w' : 1,
        }

    
    linestyle = None
    marker = None
    color = 'b'

    # handle the multi char special cases and strip them from the
    # string
    if fmt.find('--')>=0:
        linestyle = '--'
        fmt = fmt.replace('--', '')
    if fmt.find('-.')>=0:
        linestyle = '-.'
        fmt = fmt.replace('-.', '')
    
    chars = [c for c in fmt]

    for c in chars:        
        if lineStyles.has_key(c):
            if linestyle is not None:
                raise ValueError, 'Illegal format string "%s"; two linestyle symbols' % fmt
            
            linestyle = c
        elif lineMarkers.has_key(c):
            if marker is not None:
                raise ValueError, 'Illegal format string "%s"; two marker symbols' % fmt
            marker = c
        elif colors.has_key(c):
            color = c
        else:
            err = 'Unrecognized character %c in format string' % c
            raise ValueError, err
    if linestyle is None and marker is None:
        linestyle = '-'
    return linestyle, marker, color


class _process_plot_var_args:    
    """

    Process variable length arguments to the plot command, so that
    plot commands like the followig are supported

      plot(t, s)
      plot(t1, s1, t2, s2)
      plot(t1, s1, 'ko', t2, s2)
      plot(t1, s1, 'ko', t2, s2, 'r--', t3, e3)

    an arbitrary number of x,y,fmt are allowed
    """
    def __call__(self, *args):
        return self._grab_next_args(*args)
            
    def _plot_1_arg(self, y):
        return Line2D(arange(len(y)), y)

    def _plot_2_args(self, tup2):
        if is_string_like(tup2[1]):
            y, fmt = tup2
            linestyle, marker, color = _process_plot_format(fmt)
            return Line2D(xdata=arange(len(y)), ydata=y,
                          color=color, linestyle=linestyle, marker=marker)
        else:
            x,y = tup2
            return Line2D(x, y)

    def _plot_3_args(self, tup3):
        x, y, fmt = tup3
        linestyle, marker, color = _process_plot_format(fmt)
        return Line2D(x, y, color=color, linestyle=linestyle, marker=marker)



    def _grab_next_args(self, args):
        remaining = args
        while 1:
            if len(remaining)==0: return
            if len(remaining)==1:
                yield self._plot_1_arg(remaining[0])
                remaining = []
                continue
            if len(remaining)==2:
                yield self._plot_2_args(remaining)
                remaining = []
                continue
            if len(remaining)==3:
                if not is_string_like(remaining[2]):
                    raise ValueError, 'third arg must be a format string'
                yield self._plot_3_args(remaining)
                remaining=[]
                continue
            if is_string_like(remaining[2]):
                #print 'is', remaining[2]
                yield self._plot_3_args(remaining[:3])
                remaining=remaining[3:]
            else:
                #print 'not', remaining[2]
                yield self._plot_2_args(remaining[:2])
                remaining=remaining[2:]
            #yield self._plot_2_args(remaining[:2])
            #remaining=args[2:]
        
    
class Axes(Artist):
    """
    Emulate matlab's axes command, creating axes with

      Axes(position=[left, bottom, width, height])

    where all the arguments are fractions in [0,1] which specify the
    fraction of the total figure window.  

    figbg is the color background of the figure
    axisbg is the color of the axis background
    """

    def __init__(self, position, figbg='w', axisbg = 'w'):
        Artist.__init__(self)
        self._position = position
        self._figbg = figbg
        self._axisbg = axisbg
        self._gridState = 0
        self._lines = []
        self._patches = []
        self._text = []     # text in axis coords
        self._get_lines = _process_plot_var_args()

        self._xaxis = XAxis()
        self._yaxis = YAxis()


        self._title =  backends.AxisText(
            x=0, y=0, text='', fontsize=14,
            verticalalignment='bottom',
            horizontalalignment='center')

        self._axesPatch = Rectangle(
            (0,0),0,0, facecolor=self._axisbg, edgecolor='k')


        MAXLEGEND = 10
        self._legendLoc = 1  # defaul upper right
        self._legendLabels = []
        self._legendTexts =  [
            backends.AxisText(x=0, y=0, text='', fontsize=10,
            verticalalignment='center',
            horizontalalignment='left') for i in range(MAXLEGEND)]
        self._legendPatch = Rectangle(
            (0,0),0,0, facecolor=self._axisbg, edgecolor='k')
        self._legendLines =  [Line2D(xdata=[],ydata=[])
                              for i in range(MAXLEGEND)]
        

        self.axison = True
        
    def _pass_func(self, *args, **kwargs):
        pass
    
    def add_line(self, line):
        "Add a line to the list of plot lines"
        self._xaxis.update_data(line.get_xdata())
        self._yaxis.update_data(line.get_ydata())
        line.transform_xy_to_display = self.transform_xy_to_display
        line.transform_xyscale_to_display = self.transform_xyscale_to_display
        line.clip_gc = self.clip_gc
        self._lines.append(line)

    def add_patch(self, patch):
        "Add a line to the list of plot lines"
        patch.transform_xy_to_display = self.transform_xy_to_display
        patch.transform_xyscale_to_display = self.transform_xyscale_to_display
        patch.clip_gc = self.clip_gc
        l, b, w, h = patch.get_data_extent()
        self._xaxis.update_data((l, l+w))
        self._yaxis.update_data((b, b+h))

        #patch.clip_gc = self.clip_gc

        self._patches.append(patch)


    def bar(self, x, y, width=0.8):
        """
        Make a bar plot with rectangles at x, x+width, 0, y
        x and y are Numeric arrays

        Return value is a list of Rectangle patch instances
        """
        patches = []
        for thisX,thisY in zip(x,y):
            r = Rectangle( (thisX,0), width=width, height=thisY)
            self.add_patch(r)
            patches.append(r)
        return patches


        return True

    def clip_gc(self, gc):
        l,b,w,h = self._axesPatch.get_data_extent()
        lw = self._axesPatch.get_linewidth()
        # TODO: note this is currently device dependent depending on
        # whether the patch stroke is inside, outside, or split on the
        # egdge of the patch path
        gc.set_clip_rectangle( (l+0.5*lw, b+0.5*lw, w-lw, h-lw))


    def clear(self):
        # TODO: figure out what you want clear to do in relation to axes
        self._lines = []
        self._patches = []
        self.wash_brushes()
        
    def _draw(self, drawable, *args, **kwargs):
        "Draw everything (plot lines, axes, labels)"

        if self.axison:
            self._axesPatch.draw(drawable)
            self._xaxis.draw(drawable)
            self._yaxis.draw(drawable)
        self._draw_lines(drawable)
        self._draw_patches(drawable)
        self._draw_text(drawable)
        self._draw_legend(drawable)

        self._title.set_position(
            ( 0.5*(self._left+self._right),self._top+5) )
        self._title.draw(drawable)

    def _draw_legend(self, drawable):
        #I'm using the length of self._legendLabels to control the
        #number of texts drawn
        if not len(self._legendLabels): return
        
        self._legendPatch.draw(drawable)

        for line, label in zip(self._legendLines, self._legendLabels):
            line.draw(drawable)
        
        for t,label in zip(self._legendTexts, self._legendLabels):
            t.draw(drawable)

    def _draw_lines(self, drawable):
        "Draw the plot lines"
        for line in self._lines:
            line.draw(drawable)
                      
    def _draw_patches(self, drawable):
        "Draw the plot lines"
        for p in self._patches:
            p.draw(drawable)

    def _draw_text(self, drawable):
        text = self._text

        for t in self._text:
            t.draw(drawable)


    def get_child_artists(self):
        artists = []
        artists.append(self._title)
        artists.append(self._xaxis)
        artists.append(self._yaxis)
        artists.extend(self._lines)
        artists.extend(self._patches)
        artists.extend(self._text)
        artists.extend(self._legendTexts)
        artists.extend(self._legendLines)
        return artists
    
        
    def get_xaxis(self):
        "Return the XAxis instance"
        return self._xaxis

    def get_xlim(self):
        "Get the x axis range [xmin, xmax]"
        return self._xaxis.get_view_lim()

    def get_xticklabels(self):
        "Get the xtick labels as a list of strings"
        return self._xaxis.get_ticklabels()

    def get_xticks(self):
        "Return the y ticks as a list of locations"
        return self._xaxis.get_ticklocs()


    def get_yaxis(self):
        "Return the YAxis instance"
        return self._yaxis

    def get_ylim(self):
        "Get the y axis range [ymin, ymax]"
        return self._yaxis.get_view_lim()


    def get_yticklabels(self):
        "Get the ytick labels as a list of strings"
        return self._yaxis.get_ticklabels()

    def get_yticks(self):
        "Return the y ticks as a list of locations"
        return self._yaxis.get_ticklocs()

    def grid(self,b):
        "Set the axes grids on or off; b is a boolean"
        self._xaxis.grid(b)
        self._yaxis.grid(b)
        
    def in_axes(self, xwin, ywin):
        if xwin<self._left or xwin > self._right:
            return 0
        if ywin<self._bottom or ywin>self._top:
            return 0
        return 1


    def hlines(self, y, xmin, xmax, fmt='k-'):
        """
        plot horizontal lines at each y from xmin to xmax.  xmin or
        xmax can be scalars or len(x) numpy arrays.  If they are
        scalars, then the respective values are constant, else the
        widths of the lines are determined by xmin and xmax

        Returns a list of line instances that were added
        """
        linestyle, marker, color = _process_plot_format(fmt)
        
        # todo: fix me for y is scalar and xmin and xmax are iterable
        y = to_arrays(Float, y)
        
        if not iterable(xmin):
            xmin = xmin*ones(y.shape, y.typecode())
        if not iterable(xmax):
            xmax = xmax*ones(y.shape, y.typecode())

        xmin, xmax = to_arrays(Float, xmin, xmax)
        if len(xmin)!=len(y):
            raise ValueError, 'xmin and y are unequal sized sequences'
        if len(xmax)!=len(y):
            raise ValueError, 'xmax and y are unequal sized sequences'

        lines = []
        for (thisY, thisMin, thisMax) in zip(y,xmin,xmax):            
            line = Line2D( [thisMin, thisMax], [thisY, thisY],
                           color=color, linestyle=linestyle, marker=marker)
            self.add_line( line )
            lines.append(line)
        return lines


    def legend(self, labels, loc=1):
        """
        Place a legend on the current axes at location loc.  Labels are a
        sequence of strings and loc can be a string or an integer
        specifying the legend location

        The location codes are

          'best' : 0,
          'upper right' : 1,  (default)
          'upper left'  : 2,
          'lower left'  : 3,
          'lower right' : 4,
          'right'       : 5,

        Return value is a sequence of text, line instances that make
        up the legend
        """


        m = {'best' : 0,
             'upper right' : 1,
             'upper left'  : 2,
             'lower left'  : 3,
             'lower right' : 4,
             'right'       : 5,
             }
        if m.has_key(loc): loc = m[loc]
        if loc not in range(6):
            print >>sys.stderr, 'Unrecognized legend loc %s' % loc

        self._legendLabels = labels
        self._legendLoc = loc
        for label, line, legtext, legline in zip(
            self._legendLabels, self._lines,
            self._legendTexts, self._legendLines):
            legtext.set_text(label)
            legline.copy_properties(line)

        return [(t,line) for t,line, label in zip(
            self._legendTexts, self._legendLines, self._legendLabels)]

    def _update_legend_positions(self):

        loc = self._legendLoc
        # first pass: get width and height of strings
        totwid, totheight = 0,0
        maxh = 0
        # use labels to constrain number of legend texts

        for label, legtext in zip(self._legendLabels, self._legendTexts):
            l,b,w,h = legtext.get_window_extent()
            if w > totwid: totwid = w
            if h > maxh: maxh = h
            totheight += h

        textsep = 0.05*self._dpi
        linewid = 0.3*self._dpi
        border = 0.2*self._dpi
        totwid += textsep + linewid


        if loc==0: loc = 1  # todo: compute best
        if loc==1:  # upper right
            left = self._right - border - totwid
            upper = self._top - border
        elif loc==2:  # upper left
            left = self._left + border
            upper = self._top - border
        elif loc==3:  # lower left
            left = self._left + border
            upper = self._bottom + border + totheight
        elif loc==4:  # lower right
            left = self._right - border - totwid
            upper = self._bottom + border + totheight
        elif loc==5:  # right
            left = self._right + border
            upper = self._top - border

        self._legendPatch.set_y(upper - totheight + 0.1*border)
        self._legendPatch.set_x(left-0.1*border)
        self._legendPatch.set_width(totwid + 0.2*border)
        self._legendPatch.set_height(totheight+ 0.2*border)
        
        
        # now reset the positions with the info
        x = arange(left+0.2*linewid,left+linewid, 0.2*linewid)        
        for label, legline, legtext in zip(
            self._legendLabels, self._legendLines, self._legendTexts):
            y = upper * ones(x.shape, Float)
            legline.set_data( x, y)
            legtext.set_position( (left+linewid+textsep, upper) )
            l,b,w,h = legtext.get_window_extent()
            upper -= h

    def panx(self, numsteps):
        "Pan the x axis numsteps (plus pan right, minus pan left)"
        self._xaxis.pan(numsteps)
        xmin, xmax = self._xaxis.get_view_lim()
        map(lambda l: l.set_xclip(xmin, xmax), self._lines)

    def pany(self, numsteps):
        "Pan the x axis numsteps (plus pan up, minus pan down)"
        self._yaxis.pan(numsteps)
        ymin, ymax = self._yaxis.get_view_lim()
        map(lambda l: l.set_yclip(ymin, ymax), self._lines)



    def plot(self, *args):
        """
        Emulate matlab's plot command.  *args is a variable length
        argument, allowing for multiple x,y pairs with an optional
        format string.  For example, all of the following are legal,
        assuming a is the Axis instance:
        
          a.plot(x,y)            # plot Numeric arrays y vs x
          a.plot(x,y, 'bo')      # plot Numeric arrays y vs x with blue circles
          a.plot(y)              # plot y using x = arange(len(y))
          a.plot(y, 'r+')        # ditto with red plusses

        An arbitrary number of x, y, fmt groups can be specified, as in 
          a.plot(x1, y1, 'g^', x2, y2, 'l-')  

        Returns a list of lines that were added
        """

        lines = []
        for line in self._get_lines(args):
            self.add_line(line)
            lines.append(line)
        self._xaxis.autoscale_view()
        self._yaxis.autoscale_view()
        return lines

    def set_axis_off(self):
        self.axison = False

    def set_axis_on(self):
        self.axison = True

    def scatter(self, x, y, s=None, c='b'):
        """
        Make a scatter plot of x versus y.  s is a size (in data
        coords) and can be either a scalar or an array of the same
        length as x or y.  c is a color and can be a single color
        format string or an length(x) array of intensities which will
        be mapped by the colormap jet.        

        If size is None a default size will be used
        """
        # TODO

        if is_string_like(c):
            c = [c]*len(x)
        elif not iterable(c):
            c = [c]*len(x)
        else:
            jet = ColormapJet(1000)
            c = jet.get_colors(c)

        if s is None:
            s = [abs(0.015*(max(y)-min(y)))]*len(x)
        elif not iterable(s):
            s = [s]*len(x)
        
        if len(c)!=len(x):
            raise ValueError, 'c and x are not equal lengths'
        if len(s)!=len(x):
            raise ValueError, 's and x are not equal lengths'

        patches = []
        for thisX, thisY, thisS, thisC in zip(x,y,s,c):
            circ = Circle( (thisX, thisY), radius=thisS)
            #print thisC
            circ.set_facecolor(thisC)
            self.add_patch(circ)
            patches.append(circ)
        return patches


    def set_axis_bgcolor(self, color):
        self._axisbg = color

    def set_fig_bgcolor(self, color):
        self._figbg = color
                        
    def set_size(self, figsize, dpi):
        Artist.set_size(self, figsize, dpi)
        width  = figsize[0]*dpi
        height = figsize[1]*dpi
        
        "Reset the window params"
        self._left = self._position[0] * width
        self._bottom = self._position[1] * height
        self._width = self._position[2] * width
        self._height = self._position[3] * height
        self._right = self._left + self._width
        self._top = self._bottom + self._height
        
        # set the new axis information
        self._xaxis.set_window_extent(self._left, self._bottom,
                                      self._width, self._height)
        self._yaxis.set_window_extent(self._left, self._bottom,
                                      self._width, self._height)

        # todo: how best to set the border for the axes patch
        self._axesPatch.set_x(self._left)
        self._axesPatch.set_y(self._bottom)
        self._axesPatch.set_width(self._width)
        self._axesPatch.set_height(self._height)

        self._update_legend_positions()
        
    def set_title(self, label, *args, **kwargs):
        """
        Set the title for the xaxis

        See the text docstring for information of how override and the
        optional args work

        """
        override = {
            'fontsize' : 9,
            'verticalalignment' : 'bottom',
            'horizontalalignment' : 'left'
            }

        self._title.set_text(label)
        override = backends._process_text_args({}, *args, **kwargs)
        self._title.update_properties(override)
        return self._title


    def set_xlabel(self, xlabel, *args, **kwargs):
        """
        Set the label for the xaxis

        See the text docstring for information of how override and the
        optional args work

        """

        label = self._xaxis.get_label()
        label.set_text(xlabel)
        override = backends._process_text_args({}, *args, **kwargs)
        label.update_properties(override)
        return label

    def set_xlim(self, v):
        "Set the limits for the xaxis; v = [xmin, xmax]"
        xmin, xmax = v
        self._xaxis.set_view_lim(xmin, xmax)
        map(lambda l: l.set_xclip(xmin, xmax), self._lines)
        
    def set_xticklabels(self, labels, *args, **kwargs):
        "Set the xtick labels with list of strings labels"
        self._xaxis.set_ticklabels(labels, *args, **kwargs)

    def set_xticks(self, ticks):
        "Set the x ticks with list of ticks"
        self._xaxis.set_ticks(ticks)
        

    def set_ylabel(self, ylabel, *args, **kwargs):
        """
        Set the label for the yaxis

        Defaults override is

            override = {
               'fontsize'            : 10,
               'verticalalignment'   : 'center',
               'horizontalalignment' : 'right',
               'rotation'='vertical' : }

        See the text doctstring for information of how override and
        the optional args work
        """
        label = self._yaxis.get_label()
        label.set_text(ylabel)
        override = backends._process_text_args({}, *args, **kwargs)
        label.update_properties(override)
        return label

    def set_ylim(self, v):
        "Set the limits for the xaxis; v = [ymin, ymax]"
        ymin, ymax = v
        self._yaxis.set_view_lim(ymin, ymax)

        # I set the gc clip to be just outside the actual range so
        # that the flat, artifactual lines caused by the fact that the
        # x data clip is done first will be drawn outside the gc clip
        # rectangle .  5% is an arbitrary factor chosen so that only a
        # fraction of unnessecary data is plotted, since the data
        # clipping is done for plot efficiency.  See _set_clip in
        # lines.py for more info.  [ Note: now that I have disabled y
        # clipping for connected lines in lines.py, this hack is no
        # longer needed, but I'm going to preserve it since I may want
        # to re-enable y clipping for conencted lines and I can afford
        # the small performance hit. ]
        offset = 0.05*(ymax-ymin)
        map(lambda l: l.set_yclip(ymin-offset, ymax+offset), self._lines)


    def set_yticklabels(self, labels, *args, **kwargs):
        "Set the ytick labels with list of strings labels"
        self._yaxis.set_ticklabels(labels, *args, **kwargs)
        
    def set_yticks(self, ticks):
        "Set the y ticks with list of ticks"
        self._yaxis.set_ticks(ticks)

    
    def text(self, x, y, text, *args, **kwargs):
        """
        Add text to axis at location x,y (data coords)
        
        args, if present, must be a single argument which is a
        dictionary to override the default text properties

        If len(args) the override dictionary will be:

          'fontsize'            : 9,
          'verticalalignment'   : 'bottom',
          'horizontalalignment' : 'left'


        **kwargs can in turn be used to override the override, as in

          a.text(x,y,label, fontsize=12)
        
        will have verticalalignment=bottom and
        horizontalalignment=left but will have a fontsize of 12
        
        
        The AxisText defaults are
            'color'               : 'k',
            'fontname'            : 'Sans',
            'fontsize'            : 10,
            'fontweight'          : 'bold',
            'fontangle'           : 'normal',
            'horizontalalignment' : 'left'
            'rotation'            : 'horizontal',
            'verticalalignment'   : 'bottom',

        """
        override = {
            'fontsize' : 9,
            'verticalalignment' : 'bottom',
            'horizontalalignment' : 'left'
            }

        override = backends._process_text_args(override, *args, **kwargs)
        t = backends.AxisText(
            x=x, y=y, text=text,
            **override)
        t.set_renderer(self._drawable)
        t.transform_xy_to_display = self.transform_xy_to_display
        t.transform_xyscale_to_display = self.transform_xyscale_to_display
        t.transform_xy_to_user = self.transform_xy_to_user
        t.transform_xyscale_to_user = self.transform_xyscale_to_user
        t.clip_gc = self.clip_gc

        self._text.append(t)
        return t
    
    def transform_xy_to_display(self, x, y):
        return (self._xaxis.transform_to_display(x),
                self._yaxis.transform_to_display(y))
    
    def transform_xyscale_to_display(self, x, y):
        return (self._xaxis.transform_scale_to_display(x),
                self._yaxis.transform_scale_to_display(y))


    def transform_xy_to_user(self, x, y):
        return (self._xaxis.transform_to_user(x),
                self._yaxis.transform_to_user(y))
    
    def transform_xyscale_to_user(self, x, y):
        return (self._xaxis.transform_to_user_scale(x),
                self._yaxis.transform_to_user_scale(y))
            
    def vlines(self, x, ymin, ymax, color='k'):
        """
        Plot vertical lines at each x from ymin to ymax.  ymin or ymax
        can be scalars or len(x) numpy arrays.  If they are scalars,
        then the respective values are constant, else the heights of
        the lines are determined by ymin and ymax

        Returns a list of lines that were added
        """
        

        x = to_arrays(Float, x)
        if not iterable(ymin):
            ymin = ymin*ones(x.shape, x.typecode())
        if not iterable(ymax):
            ymax = ymax*ones(x.shape, x.typecode())

        ymin, ymax = to_arrays(Float, ymin, ymax)

        if len(ymin)!=len(x):
            raise ValueError, 'ymin and x are unequal sized sequences'
        if len(ymax)!=len(x):
            raise ValueError, 'ymax and x are unequal sized sequences'

        Y = transpose(array([ymin, ymax]))
        lines = []
        for thisX, thisY in zip(x,Y):
            line = Line2D([thisX, thisX], thisY, color=color, linestyle='-')
            self.add_line(line)
            lines.append(line)
        return lines


    def zoomx(self, numsteps):
        """
        Zoom in on the x xaxis numsteps (plus for zoom in, minus for zoom out)
        """
        self._xaxis.zoom(numsteps)
        xmin, xmax = self._xaxis.get_view_lim()
        map(lambda l: l.set_xclip(xmin, xmax), self._lines)

    def zoomy(self, numsteps):
        """
        Zoom in on the x xaxis numsteps (plus for zoom in, minus for zoom out)
        """
        self._yaxis.zoom(numsteps)
        ymin, ymax = self._yaxis.get_view_lim()
        map(lambda l: l.set_yclip(ymin, ymax), self._lines)

class Subplot(Axes):
    """
    Emulate matlab's subplot command, creating axes with

      Subplot(numRows, numCols, plotNum)

    where plotNum=1 is the first plot number and increasing plotNums
    fill rows first.  max(plotNum)==numRows*numCols

    You can leave out the commas if numRows<=numCols<=plotNum<10, as
    in

      Subplot(211)    # 2 rows, 1 column, first (upper) plot
    """
    
    def __init__(self, *args, **kwargs):
        # Axes __init__ below
        if len(args)==1:
            s = str(*args)
            if len(s) != 3:
                raise ValueError, 'Argument to subplot must be a 3 digits long'
            rows, cols, num = map(int, s)
        elif len(args)==3:
            rows, cols, num = args
        else:
            raise ValueError, 'Illegal argument to subplot'
        total = rows*cols
        num -= 1    # convert from matlab to python indexing ie num in range(0,total)
        if num >= total:
            raise ValueError, 'Subplot number exceeds total subplots'
        left, right = .125, .9
        bottom, top = .11, .9
        rat = 0.2             # ratio of fig to seperator for multi row/col figs
        totWidth = right-left
        totHeight = top-bottom
    
        figH = totHeight/(rows + rat*(rows-1))
        sepH = rat*figH
    
        figW = totWidth/(cols + rat*(cols-1))
        sepW = rat*figW
    
        rowNum, colNum =  divmod(num, cols)
        
        figBottom = top - (rowNum+1)*figH - rowNum*sepH
        figLeft = left + colNum*(figW + sepH)

        Axes.__init__(self, [figLeft, figBottom, figW, figH], **kwargs)

        self.rowNum = rowNum
        self.colNum = colNum
        self.numRows = rows
        self.numCols = cols

    def is_first_col(self):
        return self.colNum==0

    def is_first_row(self):
        return self.rowNum==0

    def is_last_row(self):
        return self.rowNum==self.numRows-1


    def is_last_col(self):
        return self.colNum==self.numCols-1
