import sys
from artist import Artist
from axes import Axes, Subplot, PolarSubplot, PolarAxes
from cbook import flatten, allequal, popd, Stack, iterable
import _image
from colors import normalize
from image import FigureImage
from matplotlib import rcParams
from patches import Rectangle
from text import Text, _process_text_args

from legend import Legend
from transforms import Bbox, Value, Point, get_bbox_transform, unit_bbox
from numerix import array, clip, transpose
from mlab import linspace
from ticker import FormatStrFormatter

class Figure(Artist):
    
    def __init__(self,
                 figsize   = None,  # defaults to rc figure.figsize
                 dpi       = None,  # defaults to rc figure.dpi
                 facecolor = None,  # defaults to rc figure.facecolor
                 edgecolor = None,  # defaults to rc figure.edgecolor
                 linewidth = 1.0,   # the default linewidth of the frame
                 frameon = True,
                 ):
        """
        paper size is a w,h tuple in inches
        DPI is dots per inch 
        """
        Artist.__init__(self)
        #self.set_figure(self)
        self._axstack = Stack()  # maintain the current axes
        self._axobservers = []
        self._seen = {}          # axes args we've seen        

        if figsize is None  : figsize   = rcParams['figure.figsize']
        if dpi is None      : dpi       = rcParams['figure.dpi']
        if facecolor is None: facecolor = rcParams['figure.facecolor']
        if edgecolor is None: edgecolor = rcParams['figure.edgecolor']
        
        self.dpi = Value(dpi)
        self.figwidth = Value(figsize[0])
        self.figheight = Value(figsize[1])
        self.ll = Point( Value(0), Value(0) )
        self.ur = Point( self.figwidth*self.dpi,
                         self.figheight*self.dpi )
        self.bbox = Bbox(self.ll, self.ur)
        self.frameon = frameon
        
        self.transFigure = get_bbox_transform( unit_bbox(), self.bbox) 


        
        self.figurePatch = Rectangle(
            xy=(0,0), width=1, height=1,
            facecolor=facecolor, edgecolor=edgecolor,
            linewidth=linewidth,
            )
        self._set_artist_props(self.figurePatch)

        self._hold = rcParams['axes.hold']
        self.canvas = None
        self.clf()

    def set_canvas(self, canvas):
        """\
Set the canvas the contains the figure

ACCEPTS: a FigureCanvas instance"""
        self.canvas = canvas
        
    def hold(self, b=None):
        """
        Set the hold state.  If hold is None (default), toggle the
        hold state.  Else set the hold state to boolean value b.

        Eg
        hold()      # toggle hold
        hold(True)  # hold is on
        hold(False) # hold is off
        """
        if b is None: self._hold = not self._hold
        else: self._hold = b

    def figimage(self, X,
                 xo=0,
                 yo=0,
                 alpha=1.0,
                 norm=None,
                 cmap=None, 
                 vmin=None,
                 vmax=None,
                 origin=None):
        """\
FIGIMAGE(X) # add non-resampled array to figure

FIGIMAGE(X, xo, yo) # with pixel offsets

FIGIMAGE(X, **kwargs) # control interpolation ,scaling, etc

Add a nonresampled figure to the figure from array X.  xo and yo are
offsets in pixels

X must be a float array

    If X is MxN, assume luminance (grayscale)
    If X is MxNx3, assume RGB
    If X is MxNx4, assume RGBA

The following kwargs are allowed: 

  * cmap is a cm colormap instance, eg cm.jet.  If None, default to
    the rc image.cmap valuex

  * norm is a matplotlib.colors.normalize instance; default is
    normalization().  This scales luminance -> 0-1

  * vmin and vmax are used to scale a luminance image to 0-1.  If
    either is None, the min and max of the luminance values will be
    used.  Note if you pass a norm instance, the settings for vmin and
    vmax will be ignored.

  * alpha = 1.0 : the alpha blending value

  * origin is either 'upper' or 'lower', which indicates where the [0,0]
    index of the array is in the upper left or lower left corner of
    the axes.  Defaults to the rc image.origin value

This complements the axes image (Axes.imshow) which will be resampled
to fit the current axes.  If you want a resampled image to fill the
entire figure, you can define an Axes with size [0,1,0,1].

A image.FigureImage instance is returned.
"""        

        if not self._hold: self.clf()

        im = FigureImage(self, cmap, norm, xo, yo, origin)
        im.set_array(X)
        im.set_alpha(alpha)
        if norm is None:
            im.set_clim(vmin, vmax)
        self.images.append(im)
        return im

        
    def set_figsize_inches(self, *args):
        """
Set the figure size in inches

Usage: set_figsize_inches(self, w,h)  OR
       set_figsize_inches(self, (w,h) )

ACCEPTS: a w,h tuple with w,h in inches
"""
        if len(args)==1:
            w,h = args[0]
        else:
            w,h = args
        self.figwidth.set(w)
        self.figheight.set(h)

    def get_size_inches(self):
        return self.figwidth.get(), self.figheight.get()

    def get_edgecolor(self):
        'Get the edge color of the Figure rectangle' 
        return self.figurePatch.get_edgecolor()

    def get_facecolor(self):
        'Get the face color of the Figure rectangle'
        return self.figurePatch.get_facecolor()

    def get_figwidth(self):
        'Return the figwidth as a float'
        return self.figwidth.get()

    def get_figheight(self):
        'Return the figheight as a float'
        return self.figheight.get()

    def get_dpi(self):
        'Return the dpi as a float'
        return self.dpi.get()

    def get_frameon(self):
        'get the boolean indicating frameon'
        return self.frameon

    def set_edgecolor(self, color):
        """
Set the edge color of the Figure rectangle

ACCEPTS: any matplotlib color - see help(colors)"""
        self.figurePatch.set_edgecolor(color)

    def set_facecolor(self, color):
        """
Set the face color of the Figure rectangle

ACCEPTS: any matplotlib color - see help(colors)"""
        self.figurePatch.set_facecolor(color)

    def set_dpi(self, val):
        """
Set the dots-per-inch of the figure

ACCEPTS: float"""
        self.dpi.set(val)

    def set_figwidth(self, val):
        """
Set the width of the figure in inches

ACCEPTS: float"""
        self.figwidth.set(val)

    def set_figheight(self, val):
        """
Set the height of the figure in inches

ACCEPTS: float"""
        self.figheight.set(val)

    def set_frameon(self, b):
        """
Set whether the figure frame (background) is displayed or invisible

ACCEPTS: boolean"""
        self.frameon = b

    def delaxes(self, a):
        'remove a from the figure and update the current axes'
        self.axes.remove(a)
        self._axstack.remove(a)
        keys = []
        for key, thisax in self._seen.items():
            if a==thisax: del self._seen[key]
        for func in self._axobservers: func(self)        
            

    def add_axes(self, *args, **kwargs):
        """
Add an a axes with axes rect [left, bottom, width, height] where all
quantities are in fractions of figure width and height.  kwargs are
legal Axes kwargs plus"polar" which sets whether to create a polar axes

    add_axes((l,b,w,h))
    add_axes((l,b,w,h), frameon=False, axisbg='g')
    add_axes((l,b,w,h), polar=True)
    add_axes(ax)   # add an Axes instance


If the figure already has an axes with key *args, *kwargs then it will
simply make that axes current and return it.  If you do not want this
behavior, eg you want to force the creation of a new axes, you must
use a unique set of args and kwargs.  The artist "label" attribute has
been exposed for this purpose.  Eg, if you want two axes that are
otherwise identical to be added to the axes, make sure you give them
unique labels:

    add_axes((l,b,w,h), label='1')
    add_axes((l,b,w,h), label='2')

The Axes instance will be returned
        """

        if iterable(args[0]):
            key = tuple(args[0]), tuple(kwargs.items())
        else:
            key = args[0], tuple(kwargs.items())            

        if self._seen.has_key(key):
            ax = self._seen[key]
            self.sca(ax)
            return ax

        if not len(args): return        
        if isinstance(args[0], Axes):
            a = args[0]
            a.set_figure(self)
        else:
            rect = args[0]
            ispolar = popd(kwargs, 'polar', False)

            if ispolar:
                a = PolarAxes(self, rect, **kwargs)
            else:
                a = Axes(self, rect, **kwargs)            
                

        self.axes.append(a)
        self._axstack.push(a)
        self.sca(a)
        self._seen[key] = a
        return a

    def add_subplot(self, *args, **kwargs):
        """
Add an a subplot.  Examples

    add_subplot(111)
    add_subplot(212, axisbg='r')  # add subplot with red background
    add_subplot(111, polar=True)  # add a polar subplot
    add_subplot(sub)              # add Subplot instance sub
        
kwargs are legal Axes kwargs plus"polar" which sets whether to create a
polar axes.  The Axes instance will be returned.

If the figure already has a subplot with key *args, *kwargs then it will
simply make that subplot current and return it
        """
        
        key = args, tuple(kwargs.items())
        if self._seen.has_key(key):
            ax = self._seen[key]
            self.sca(ax)
            return ax
        
                
        if not len(args): return        
        
        if isinstance(args[0], Subplot) or isinstance(args, PolarSubplot):
            a = args[0]
            a.set_figure(self)
        else:
            ispolar = popd(kwargs, 'polar', False)
            if ispolar:
                a = PolarSubplot(self, *args, **kwargs)
            else:
                a = Subplot(self, *args, **kwargs)

        
        self.axes.append(a)
        self._axstack.push(a)
        self.sca(a)
        self._seen[key] = a
        return a
    
    def clf(self):
        """
        Clear the figure
        """
        self.axes = []
        self._axstack.clear()
        self._seen = {}
        self.lines = []
        self.patches = []
        self.texts=[]
        self.images = []
        self.legends = []

    def clear(self):
        """
        Clear the figure
        """
        self.clf()
        
    def draw(self, renderer):
        """
        Render the figure using RendererGD instance renderer
        """
        # draw the figure bounding box, perhaps none for white figure
        #print 'figure draw'
        if not self.get_visible(): return 
        renderer.open_group('figure')
        self.transFigure.freeze()  # eval the lazy objects
        if self.frameon: self.figurePatch.draw(renderer)

        for p in self.patches: p.draw(renderer)
        for l in self.lines: l.draw(renderer)

        if len(self.images)==1:
            im = self.images[0]
            im.draw(renderer)
        elif len(self.images)>1:
            # make a composite image blending alpha
            # list of (_image.Image, ox, oy)
            if not allequal([im.origin for im in self.images]):
                raise ValueError('Composite images with different origins not supported')
            else:
                origin = self.images[0].origin

            ims = [(im.make_image(), im.ox, im.oy) for im in self.images]
            im = _image.from_images(self.bbox.height(), self.bbox.width(), ims)
            im.is_grayscale = False
            l, b, w, h = self.bbox.get_bounds()
            renderer.draw_image(0, 0, im, origin, self.bbox)



        # render the axes
        for a in self.axes: a.draw(renderer)

        # render the figure text
        for t in self.texts: t.draw(renderer)

        for legend in self.legends:
            legend.draw(renderer)

        self.transFigure.thaw()  # release the lazy objects
        renderer.close_group('figure')

    def get_axes(self):
        return self.axes

    def legend(self, handles, labels, loc, **kwargs):
        """
        Place a legend in the figure.  Labels are a sequence of
        strings, handles is a sequence of line or patch instances, and
        loc can be a string or an integer specifying the legend
        location

        USAGE: 
          legend( (line1, line2, line3),
                  ('label1', 'label2', 'label3'),
                  'upper right')

        The LOC location codes are

          'best' : 0,          (currently not supported, defaults to upper right)
          'upper right'  : 1,  (default)
          'upper left'   : 2,
          'lower left'   : 3,
          'lower right'  : 4,
          'right'        : 5,
          'center left'  : 6,
          'center right' : 7,
          'lower center' : 8,
          'upper center' : 9,
          'center'       : 10,

        loc can also be an (x,y) tuple in figure coords, which
        specifies the lower left of the legend box.  figure coords are
        (0,0) is the left, bottom of the figure and 1,1 is the right,
        top.

        The legend instance is returned
        """


        handles = flatten(handles)
        l = Legend(self, handles, labels, loc, isaxes=False, **kwargs)
        self._set_artist_props(l)
        self.legends.append(l)
        return l
    
    def text(self, x, y, s, *args, **kwargs):
        """
        Add text to figure at location x,y (relative 0-1 coords) See
        the help for Axis text for the meaning of the other arguments
        """

        override = _process_text_args({}, *args, **kwargs)
        t = Text(
            x=x, y=y, text=s,
            )

        t.update(override)
        self._set_artist_props(t)
        self.texts.append(t)
        return t

    def _set_artist_props(self, a):
        if a!= self:
            a.set_figure(self)
        a.set_transform(self.transFigure)

    def get_width_height(self):
        'return the figure width and height in pixels'
        w = self.bbox.width()
        h = self.bbox.height()
        return w, h


    def gca(self, **kwargs):
        """
Return the current axes, creating one if necessary
        """
        ax = self._axstack()
        if ax is not None: return ax
        return self.add_subplot(111, **kwargs)
        
    def sca(self, a):
        'Set the current axes to be a and return a'
        self._axstack.bubble(a)
        for func in self._axobservers: func(self)
        return a

    def add_axobserver(self, func):
        'whenever the axes state change, func(self) will be called'
        self._axobservers.append(func)
        

    def savefig(self, *args, **kwargs):
        """
SAVEFIG(fname, dpi=150, facecolor='w', edgecolor='w',
orientation='portrait'):

Save the current figure to filename fname.  dpi is the resolution
in dots per inch.

Output file types currently supported are jpeg and png and will be
deduced by the extension to fname

facecolor and edgecolor are the colors os the figure rectangle

orientation is either 'landscape' or 'portrait' - not supported on
all backends; currently only on postscript output."""
    
        for key in ('dpi', 'facecolor', 'edgecolor'):
            if not kwargs.has_key(key):
                kwargs[key] = rcParams['savefig.%s'%key]

        self.canvas.print_figure(*args, **kwargs)
    

    def colorbar(self, mappable, tickfmt='%1.1f', cax=None, orientation='vertical'):
        """
        Create a colorbar for mappable image

        tickfmt is a format string to format the colorbar ticks

        cax is a colorbar axes instance in which the colorbar will be
        placed.  If None, as default axesd will be created resizing the
        current aqxes to make room for it.  If not None, the supplied axes
        will be used and the other axes positions will be unchanged.

        orientation is the colorbar orientation: one of 'vertical' | 'horizontal'
        return value is the colorbar axes instance
        """

        if orientation not in ('horizontal', 'vertical'):
            raise ValueError('Orientation must be horizontal or vertical')

        if isinstance(mappable, FigureImage) and cax is None:
            raise TypeError('Colorbars for figure images currently not supported unless you provide a colorbar axes in cax')


        ax = self.gca()

        cmap = mappable.cmap
        norm = mappable.norm

        if norm.vmin is None or norm.vmax is None:
            mappable.autoscale()
        cmin = norm.vmin
        cmax = norm.vmax

        if cax is None:
            l,b,w,h = ax.get_position()
            if orientation=='vertical':
                neww = 0.8*w
                ax.set_position((l,b,neww,h))
                cax = self.add_axes([l + 0.9*w, b, 0.1*w, h])
            else:
                newh = 0.8*h
                ax.set_position((l,b+0.2*h,w,newh))
                cax = self.add_axes([l, b, w, 0.1*h])

        else:
            if not isinstance(cax, Axes):
                raise TypeError('Expected an Axes instance for cax')

        N = cmap.N

        c = linspace(cmin, cmax, N)
        C = array([c,c])

        if orientation=='vertical':
            C = transpose(C)

        if orientation=='vertical':
            extent=(0, 1, cmin, cmax)
        else:
            extent=(cmin, cmax, 0, 1)
        coll = cax.imshow(C,
                          interpolation='nearest',
                          #interpolation='bilinear', 
                          origin='lower',
                          cmap=cmap, norm=norm,
                          extent=extent)
        mappable.add_observer(coll)
        mappable.set_colorbar(coll, cax)

        if orientation=='vertical':
            cax.set_xticks([])
            cax.yaxis.tick_right()
            cax.yaxis.set_major_formatter(FormatStrFormatter(tickfmt))
        else:
            cax.set_yticks([])
            cax.xaxis.set_major_formatter(FormatStrFormatter(tickfmt))

        self.sca(ax)
        return cax


def figaspect(arr):
    """
    Determine the width and height for a figure that would fit array
    preserving aspcect ratio.  The figure width, height in inches are
    returned.  Be sure to create an axes with equal with and height, eg

    w, h = figaspect(A)
    fig = Figure(figsize=(w,h))
    ax = fig.add_axes([0.1, 0.1, 0.8, 0.8])
    ax.imshow(A, **kwargs)

    Thanks to Fernando Perez for this function
    """

    # min/max sizes to respect when autoscaling.  If John likes the idea, they
    # could become rc parameters, for now they're hardwired.
    figsize_min = array((4.0,2.0)) # min length for width/height
    figsize_max = array((16.0,16.0)) # max length for width/height
    #figsize_min = rcParams['figure.figsize_min']
    #figsize_max = rcParams['figure.figsize_max']

    # Extract the aspect ratio of the array
    nr,nc = arr.shape[:2]
    arr_ratio = float(nr)/nc

    # Height of user figure defaults
    fig_height = rcParams['figure.figsize'][1]

    # New size for the figure, keeping the aspect ratio of the caller
    newsize = array((fig_height/arr_ratio,fig_height))

    # Sanity checks, don't drop either dimension below figsize_min
    newsize /= min(1.0,*(newsize/figsize_min))

    # Avoid humongous windows as well
    newsize /= max(1.0,*(newsize/figsize_max))

    # Finally, if we have a really funky aspect ratio, break it but respect
    # the min/max dimensions (we don't want figures 10 feet tall!)
    newsize = clip(newsize,figsize_min,figsize_max)
    return newsize
