"""
Classes for the ticks and x and y axis
"""
from __future__ import division
import sys, math, re
from Numeric import arange, array, asarray, ones, zeros, \
     nonzero, take, Float, log10



from artist import Artist
from cbook import enumerate, True, False
from lines import Line2D
from mlab import linspace
from patches import bbox_artist
from transforms import Bound1D, Bound2D, Transform, bound2d_all,\
     identity, logwarn, pow10, Centimeter, Dots, Inches, Millimeter,\
     Points, TransformSize, RRef
from text import Text, _process_text_args


class Tick(Artist):
    """
    Abstract base class for the axis ticks, grid lines and labels

    Publicly accessible attributes

      tick1line  : a Line2D instance
      tick2line  : a Line2D instance
      gridline   : a Line2D instance
      label      : an Text instance
      gridOn     : a boolean which determines whether to draw the tickline
      tick1On    : a boolean which determines whether to draw the 1st tickline
                   (left for xtick and bottom for yticks)
      tick2On    : a boolean which determines whether to draw the 2nd tickline
                   (left for xtick and bottom for yticks)
      labelOn    : a boolean which determines whether to draw tick label                  
      
    """
    def __init__(self, axes, loc, label,
                 size=4,  # points
                 gridOn=False, tick1On=True, tick2On=True, labelOn=True):
        """
        bbox is the Bound2D bounding box in display coords of the Axes
        loc is the tick location in data coords
        size is the tick size in relative, axes coords
        """
        Artist.__init__(self, axes.dpi, axes.bbox)
        self.axes = axes

        self.tick1line = self._get_default_tick1line(loc, size)
        self.tick2line = self._get_default_tick2line(loc, size)
        self.gridline = self._get_default_gridline(loc)        
        self.label = self._get_default_text(loc, label)
        self.label.set_clip_on( False )
        self.gridOn = gridOn
        self.tick1On = tick1On
        self.tick2On = tick2On
        self.labelOn = labelOn
        self._loc = loc
        self._size = size
        
    def _get_default_text(self, loc, label):
        'Get the default Text instance'
        raise NotImplementedError('Derived must override')

    def _get_default_tick1line(self, loc, size):
        'Get the default line2D instance for tick1'
        raise NotImplementedError('Derived must override')

    def _get_default_tick2line(self, loc, size):
        'Get the default line2D instance for tick2'
        raise NotImplementedError('Derived must override')

    def _get_default_gridline(self, loc):
        'Get the default grid Line2d instance for this tick'
        raise NotImplementedError('Derived must override')

    def get_child_artists(self):
        return (self.label, self.tick1line, self.tick2line, self.gridline)

    def get_loc(self):
        'Return the tick location (data coords) as a scalar'
        return self._loc
    
    def _draw(self, renderer):
        midPoint = self._tick_is_midpoint(self.get_loc())
        if midPoint and self.tick1On: self.tick1line.draw(renderer)
        if midPoint and self.tick2On: self.tick2line.draw(renderer)
        if midPoint and self.gridOn:  self.gridline.draw(renderer)
        if self.labelOn: self.label.draw(renderer)

    def set_loc(self, loc):
        'Set the location of tick in data coords with scalar loc'
        raise NotImplementedError('Derived must override')

    def set_label(self, s):
        'Set the text of ticklabel in with string s'
        self.label.set_text(s)

    def _tick_is_midpoint(self, loc):
        'return true if tickloc is not on axes boundary'
        raise NotImplementedError('Derived must override')

class XTick(Tick):
    """
    Contains all the Artists needed to make an x tick - the tick line,
    the label text and the grid line
    """

    def _get_default_text(self, loc, label):
        'Get the default Text instance'
        # the y loc is 3 points below the min of y axis
        top = self.bbox.y.get_refmin() - self.dpi*RRef(3/72.0)
        text =  Text(
            dpi=self.dpi, 
            bbox=self.bbox,
            x=loc,
            y=top,  
            fontsize=10,
            verticalalignment='top',
            horizontalalignment='center',
            transx = self.axes.xaxis.transData,
            # transy is default, idenity transform
            )
        return text

        
    def _get_default_tick1line(self, loc, size):
        'Get the default line2D instance'

        line = Line2D( self.dpi, self.bbox,
                       xdata=(loc, loc), ydata=(0, size),
                       color='k',
                       transx = self.axes.xaxis.transData,
                       transy = self.axes.yaxis.get_pts_transform(),
                       )
        line.set_data_clipping(False)
        return line

    def _get_default_tick2line(self, loc, size):
        'Get the default line2D instance'
        offset = self.axes.yaxis.displaylim.get_refmax()
        ytrans = self.axes.yaxis.get_pts_transform(offset)
        line = Line2D( self.dpi, self.bbox,
                       xdata=(loc, loc), ydata=(-size, 0),
                       color='k',
                       transx = self.axes.xaxis.transData,
                       transy = ytrans,
                       )
        line.set_data_clipping(False)
        return line

    def _get_default_gridline(self, loc):
        'Get the default line2D instance'
        line = Line2D( self.dpi, self.bbox,
                       xdata=(loc, loc), ydata=(0, 1),
                       color='k', linestyle=':',
                       transx = self.axes.xaxis.transData,
                       transy = self.axes.yaxis.transAxis,
                       )
        line.set_data_clipping(False)
        return line

    def set_loc(self, loc):
        'Set the location of tick in data coords with scalar loc'
        self.tick1line.set_xdata((loc, loc))
        self.tick2line.set_xdata((loc, loc))
        self.gridline.set_xdata((loc, loc))
        self.label.set_x( loc )
        self._loc = loc

    def _tick_is_midpoint(self, loc):
        'return true if tickloc is not on axes boundary'
        return self.axes.xaxis.viewlim.in_open_interval(loc)

        
class YTick(Tick):
    """
    Contains all the Artists needed to make a Y tick - the tick line,
    the label text and the grid line
    """
    # how far from the y axis line the right of the ticklabel are
    def _get_default_text(self, loc, label):
        'Get the default Text instance'
        right = self.bbox.x.get_refmin() - self.dpi*RRef(3/72.0)
        return Text(
            dpi=self.dpi, 
            bbox=self.bbox,
            x=right,
            y=loc,
            fontsize=10,
            verticalalignment='center',
            horizontalalignment='right',
            #transx - default transform is identity
            transy = self.axes.yaxis.transData,

            )
        return text

    def _get_default_tick1line(self, loc, size):
        'Get the default line2D instance'
        line = Line2D( self.dpi, self.bbox,
                       (0, size), (loc, loc), color='k',
                       transx = self.axes.xaxis.get_pts_transform(),
                       transy = self.axes.yaxis.transData,
                       )
        line.set_data_clipping(False)
        return line

    def _get_default_tick2line(self, loc, size):
        'Get the default line2D instance'
        offset = self.axes.xaxis.displaylim.get_refmax()
        xtrans = self.axes.xaxis.get_pts_transform(offset)

        line = Line2D( self.dpi, self.bbox,
                       (1-size,1), (loc, loc), color='k',
                       transx = xtrans,
                       transy = self.axes.yaxis.transData,
                       )
        return line

    def _get_default_gridline(self, loc):
        'Get the default line2D instance'
        line = Line2D( self.dpi, self.bbox,
                       xdata=(0,1), ydata=(loc,loc), 
                       color='k', linestyle=':',
                       transx = self.axes.xaxis.transAxis,
                       transy = self.axes.yaxis.transData,
                       )
        line.set_data_clipping(False)
        return line


    def set_loc(self, loc):
        'Set the location of tick in data coords with scalar loc'
        self.tick1line.set_ydata((loc, loc))
        self.tick2line.set_ydata((loc, loc))
        self.gridline.set_ydata((loc, loc))
        self.label.set_y( loc )
        self._loc = loc

    def _tick_is_midpoint(self, loc):
        'return true if tickloc is not on axes boundary'
        return self.axes.yaxis.viewlim.in_open_interval(loc)

    
class Axis(Artist):

    # func / inverse func pairs for transforms
    _scalemap = {'linear' : (identity, identity),
                 'log'    : (logwarn, pow10)}

    LABELPAD = -0.01
    """
    Public attributes
      transData - transform data coords to display coords
      transAxis - transform axis coords to display coords
    """
    def __init__(self, axes):
        """
        Init the axis with the parent Axes instance
        """
        Artist.__init__(self, axes.dpi, axes.bbox)
        self.axes = axes
        self._scale = 'linear'

        self.datalim = Bound1D(None, None) # data interval instance; data coords
        self.viewlim = Bound1D(-1, 1)      # viewport interval instance; data coords

        self.displaylim = self._get_display_lim()
        
        self.transData = Transform(self.viewlim, self.displaylim)
        self.transAxis = Transform(Bound1D(0,1), self.displaylim)

        self._ticklocs = None  # None unless explicity set
        self._ticklabelStrings = None
        self._setChildAttrs = {}
        # strip trailing zeros from ticks
        self._zerorgx = re.compile('^(.*?)\.?0+(e[+-]\d+)?$')        
        
        self._gridOn = False  # off be default compat with matlab

    def autoscale_view(self):
        'Try to choose the view limits intelligently'

        vmin, vmax = self.datalim.bounds()
        if self._scale=='linear':

            if vmin==vmax:
                vmin-=1
                vmax+=1
            try:
                (exponent, remainder) = divmod(math.log10(vmax - vmin), 1)
            except OverflowError:
                print >>sys.stderr, 'Overflow error in autoscale', vmin, vmax
                return
            if remainder < 0.5:
                exponent -= 1
            scale = 10**(-exponent)
            vmin = math.floor(scale*vmin)/scale
            vmax = math.ceil(scale*vmax)/scale

            self.viewlim.set_bounds(vmin, vmax)

    def build_artists(self):
        """
        Call only after xaxis and yaxis have been initialized in axes
        since the trasnforms require both
        """
        self._label = self._get_default_label()
        self._label.set_clip_on( False )
        self._ticks = [self._get_default_tick() for i in range(20)]

    def _draw(self, renderer, *args, **kwargs):
        'Draw the axis lines, grid lines, tick lines and labels'

        # update the tick locks
        ticks = self.get_ticks()        
        locs = self.get_ticklocs()

        if self._ticklabelStrings is not None:
            labels = self._ticklabelStrings
            if len(labels)<len(locs):
                # grow the labels with empty strings as necessary
                labels.extend(['']*(len(locs)-len(labels)))
        else:
            labels = [self.format_tickval(loc) for loc in locs]

        ticklabelBoxes = []
        for tick, loc, label in zip(ticks, locs, labels):
            if not self.viewlim.in_interval(loc): continue
            tick.set_loc(loc)
            tick.set_label(label)
            tick.draw(renderer)
            ticklabelBoxes.append(tick.label.get_window_extent(renderer))


        # find the tick labels that are close to the axis labels.  I
        # scale up the axis label box to also find the neighbors, not
        # just the tick labels that actually overlap note we need a
        # *copy* of the axis label box because we don't wan't to scale
        # the actual bbox
        labelBox = self._label.get_window_extent(renderer).copy()
        labelBox.x.scale(1.25)
        labelBox.y.scale(1.25)
        overlaps = [b for b in ticklabelBoxes if b.overlap(labelBox)]

        self._update_label_postion(overlaps)
        self._label.draw(renderer)

        if 0: # draw the bounding boxes around the text for debug
            for tick in ticks:
                label = tick.label
                bbox_artist(label, renderer)
            bbox_artist(self._label, renderer)

    def format_tickval(self, x):
        'Format the number x as a string'
        d = self.viewlim.interval()
        if self._scale == 'log':
            # only label the decades
            fx = self.transData.func(x)
            isdecade = abs(fx-int(fx))<1e-10
            if self._ticklabelStrings is None and not isdecade: return ''

        #if the number is not too big and it's an int, format it as an
        #int
        if abs(x)<1e4 and x==int(x): return '%d' % x

        # if the value is just a fraction off an int, use the int
        if abs(x-int(x))<0.0001*d: return '%d' % int(x)

        # use exponential formatting for really big or small numbers,
        # else use float
        if abs(x) < 1e-4: fmt = '%1.3e'
        elif abs(x) > 1e5: fmt = '%1.3e'
        else: fmt = '%1.3f'
        s =  fmt % x

        # strip trailing zeros, remove '+', and handle exponential formatting
        m = self._zerorgx.match(s)
        if m:
            s = m.group(1)
            if m.group(2) is not None: s += m.group(2)
        s = s.replace('+', '')
        return s

    def get_child_artists(self):
        'Return a list of all Artist instances contained by Axis'
        artists = []
        # don't use get_ticks here! it conflicts with set_child_attr
        artists.extend(self._ticks) 
        artists.append(self._label)
        return artists

    def get_cm_transform(self, offset=None, transOffset=Transform()):
        """
        Transform val cm to dots

        offset is a RRef instance for TransformSize.  if None, use
        displaylim min.  transOffset is a transform to transform the
        offset to display coords
        """
        if offset is None:
            offset = self.displaylim.get_refmin()

        cm = Centimeter( self.dpi)
        dots =  Dots( self.dpi)
        return TransformSize(cm, dots, offset, transOffset)

    def _get_default_label(self):
        raise NotImplementedError('Derived must override')

    def _get_display_lim(self):
        raise NotImplementedError('Derived must override')

    def get_gridlines(self):
        'Return the grid lines as a list of Line2D instance'
        return [tick.gridline for tick in self._ticks]

    def get_inches_transform(self, offset=None, transOffset=Transform()):
        """
        Transform val inches to dots

        offset is a RRef instance for TransformSize.  if None, use
        displaylim min.  transOffset is a transform to transform the
        offset to display coords
        """
        if offset is None:
            offset = self.displaylim.get_refmin()

        inches = Inches( self.dpi)
        dots =  Dots( self.dpi)
        return TransformSize(inches, dots, offset, transOffset) 

    def get_label(self):
        'Return the axis label as an Text instance'
        return self._label


    def get_mm_transform(self, offset=None, transOffset=Transform()):
        """
        Transform val mm to dots

        offset is a RRef instance for TransformSize.  if None, use
        displaylim min.  transOffset is a transform to transform the
        offset to display coords

        """
        if offset is None:
            offset = self.displaylim.get_refmin()

        mm = Millimeter( self.dpi)
        dots =  Dots( self.dpi)
        return TransformSize(mm, dots, offset, transOffset)

    def get_numticks(self):
        if self._ticklocs is None:
            if self._scale == 'log':
                if self.viewlim.defined():
                    vmin, vmax = self.viewlim.bounds()
                    decMax = int(math.ceil(self.transData.func(vmax)))
                    decMin = int(math.floor(self.transData.func(vmin)))
                    numticks =  10*(decMax-decMin)
                    self._grow_ticks(numticks)
                    return numticks

            # size of axes in inches
            intv = self.displaylim.interval()/self.dpi.get()
            if intv>4: return 11
            else: return 6
        else:
            return len(self._ticklocs)

    def get_pts_transform(self, offset=None, transOffset=Transform()):
        """
        Transform val points to dots

        offset is a RRef instance for TransformSize.  if None, use
        displaylim min.  transOffset is a transform to transform the
        offset to display coords
        """

        if offset is None:
            offset = self.displaylim.get_refmin()
        
        pts = Points( self.dpi)
        dots =  Dots( self.dpi)
        return TransformSize(pts, dots, offset, transOffset)

    def get_ticklabels(self):
        'Return a list of Text instances for ticklabels'
        return [tick.label for tick in self.get_ticks()]

    def get_ticklines(self):
        'Return the ticklines lines as a list of Line2D instance'
        lines = []
        for tick in self._ticks:
            lines.append(tick.tick1line)
            lines.append(tick.tick2line)
        return lines

    def get_ticklocs(self):
        "Get the tick locations in data coordinates as a Numeric array"
        if self._ticklocs is not None:
            return self._ticklocs

        vmin, vmax = self.viewlim.bounds()
        vmin = self.transData.func(vmin)
        vmax = self.transData.func(vmax)
        if self._scale == 'linear':
            numticks = self.get_numticks()
            if numticks==0: return []
            ticklocs = linspace(vmin, vmax, numticks)
        elif self._scale == 'log':
            ticklocs = []
            for decadeStart in 10.0**arange(math.floor(vmin),math.ceil(vmax)):
                ticklocs.append(decadeStart)
                ticklocs.extend( arange(2,10)*decadeStart )
                ticklocs.append(10*decadeStart)
            ticklocs = array(ticklocs)
            ind = nonzero(ticklocs>=10.0**vmin * ticklocs<=10.0**vmax)
            ticklocs = take(ticklocs,ind)
        return ticklocs

    def get_ticks(self):
        numticks = self.get_numticks()
        self._grow_ticks(numticks)
        ticks = self._ticks[:numticks]

        return ticks

    def grid(self, b):
        "Set the axis grid on or off; b is a boolean"
        self._gridOn = b
        for tick in self._ticks:  # don't use get_ticks here!
            tick.gridOn = self._gridOn
                
    def _grow_ticks(self, numticks):
        """
        Dynamically grow the number of ticklines as necessary
        """
        if len(self._ticks)<numticks:
            for i in range(numticks-len(self._ticks)):
                tick = self._get_default_tick()
                if self._gridOn: tick.gridOn = True
                for attr, val in self._setChildAttrs.items():
                    tick.set_child_attr(attr, val)
                self._ticks.append(tick)
        
    def is_log(self):
        'Return true if log scaling is on'
        return self._scale=='log'
    
    def pan(self, numsteps):
        'Pan numticks (can be positive or negative)'
        step = numsteps*self.viewlim.interval()/self.get_numticks()
        self.viewlim.shift(step)
        
    def set_child_attr(self, attr, val):
        """
        Set attribute attr for self, and all child artists
        """
        Artist.set_child_attr(self, attr, val)
        self._setChildAttrs[attr] = val
        
    def set_scale(self, value):
        "Set the axis scale; either 'linear' or 'log'"
        if value not in self._scalemap.keys():
            raise ValueError('scale must be in %s'%str(self._scalemap.keys()))
        self._scale = value
        self.transData.set_funcs(self._scalemap[value])
        self.viewlim.is_positive( self.is_log() )
        self.axes.update_viewlim()

    def set_ticklabels(self, ticklabels, *args, **kwargs):
        """
        Set the text values of the tick labels.  ticklabels is a
        sequence of strings.  Return a list of Text instances
        """
        ticklabels = ['%s'%l for l in ticklabels]
        
        self._ticklabelStrings = ticklabels
        override = {}
        override = _process_text_args(override, *args, **kwargs)

        Nnew = len(self._ticklabelStrings)
        existingLabels = self.get_ticklabels()
        for i, label in enumerate(existingLabels):
            if i<Nnew: label.set_text(self._ticklabelStrings[i])
            else: label.set_text('')
            label.update_properties(override)
        return existingLabels            
            
    def set_ticks(self, ticks):
        'Set the locations of the tick marks from sequence ticks'
        self._ticklocs = asarray(ticks)
        self.viewlim.update(self._ticklocs)

    def _update_label_postion(self, overlaps):
        """
        Update the label position based on the sequence of bounding
        boxes overlaps of all the ticklabels that overlap the current
        ticklabel.  overlaps are the bounding boxes of ticklabels
        """
        raise NotImplementedError('Derived must override')
        
    def zoom(self, direction):
        "Zoom in/out on axis; if direction is >0 zoom in, else zoom out"
        vmin, vmax = self.viewlim.bounds()
        interval = self.viewlim.interval()
        step = 0.1*interval*direction
        self.viewlim.set_min(vmin + step)
        self.viewlim.set_max(vmax - step)

class XAxis(Axis):
    __name__ = 'xaxis'

    def _get_default_tick(self):
        return XTick(self.axes, 0, '')

    def _get_default_label(self):
        label = Text(
            self.dpi, 
            self.bbox, 
            x = 0.5,       # centered
            y = self.LABELPAD,     # rel coords under axis line
            fontsize=12,
            verticalalignment='top',
            horizontalalignment='center',
            transx = self.transAxis,
            transy = self.axes.yaxis.transAxis,
            )
        return label

    def _get_display_lim(self):
        return self.bbox.x

    def _update_label_postion(self, overlaps):
        """
        Update the label position based on the sequence of bounding
        boxes overlaps of all the ticklabels that overlap the current
        ticklabel.  overlaps are the bounding boxes of the ticklabels
        """

        x,y = self._label.get_position()
        if not len(overlaps): return

        bbox = bound2d_all(overlaps)
        bottomAxes = (bbox.y.min()-self.bbox.y.min())/self.bbox.y.interval()
        self._label.set_position((x,bottomAxes+self.LABELPAD))
        

class YAxis(Axis):
    __name__ = 'yaxis'

    def _get_default_tick(self):
        return YTick(self.axes, 0, '')


    def _get_default_label(self):
        label = Text(
            dpi=self.dpi,
            bbox = self.bbox, 
            x = self.LABELPAD,    # rel coords to the left of axis line
            y = 0.5,      # centered
            fontsize=12,
            verticalalignment='center',
            horizontalalignment='right',
            rotation='vertical', 
            transx = self.axes.xaxis.transAxis,
            transy = self.transAxis,
            )
        return label

    def _get_display_lim(self):
        return self.bbox.y


    def _update_label_postion(self, overlaps):
        """
        Update the label position based on the sequence of bounding
        boxes overlaps of all the ticklabels that overlap the current
        ticklabel.  overlaps are the bounding boxes of the ticklabels
        """
        x,y = self._label.get_position()
        if not len(overlaps):
            #print 'no yticklabel overlaps'
            return
        #print 'found', len(overlaps)
        bbox = bound2d_all(overlaps)
        left = (bbox.x.min()-self.bbox.x.min())/self.bbox.x.interval()
        self._label.set_position((left+self.LABELPAD,y))


