"""
This module contains the newfangled transform class which allows the
placement of artists (lines, patches, text) in a variety of coordinate
systems (display, arbitrary data, relative axes, physical sizes)

The default Transform() is identity.

  t = Transform()
  x == Transform.positions(x)  True
  x == Transform.scale(x)      True

A linear Transformation is specified by giving a Bound1D (min, max)
instance for the domain and range.  The transform below maps the
interval [0,1] to [-10,10]

  t = Transform( Bound1D(0,1), Bound1D(-10,10) )

Since all Transforms know their inverse function, you can compute an
inverse transformation by calling inverse_positions or inverse_scale

  t = Transform( Bound1D(0,1), Bound1D(-10,10) )
  val = t.inverse_positions(5)  # maps [-10,10] to [0,1]

The difference between 'positions' and 'scale' is that the positions
func is appropriate for locations (eg, x,y) and the scale func is
appropriate for lengths (eg, width, height).

The Bound1D methods provide a number of utility functions: the
interval max-min, determining if a point is in the open or closed
interval, constraining the bound to be positive (useful for log
transforms), and so on.  These are useful for storing view, data and
display limits of a given axis.

The Bound2D is a straight-forward generalization of Bound1D, and
stores 2 Bound1D instances 'x' and 'y' to represent a 2D Bound (useful
for Axes bounding boxes, clipping etc).  All Artists are responsible
for returning their extent in display coords as a Bound2D instance,
which is useful for determining whether 2 Artists overlap.  Some
utility functions, eg, bound2d_all, return the Bound2D instance that
bounds all the Bound2D instances passed as args.  This helps in text
layout, eg, in positioning the axis labels to not overlap the tick
labels.

The Bound1D instances store their max and min values as RWVals
(read/write references).  These are mutable scalars that can be shared
among all the figure components.  When a figure clas resizes and thus
changes the display limits of an Axes, the Axes and all its components
know about the changes because they store a reference to the
displaylim, not the scalar value of the display lim.  Likewise for
DPI.

Also, it is possible to do simple arithmetic in RRefs via the derived
BinOp class, which stores both sides of a binary arithmetic operation,
as well as the binary function to return the result of the binop
applied to the dereferenced scalars.  This allows you to place artists
with locations like '3 centimenters below the x axis'

Here are some concepts and how to apply them via the transform
architecture

  * Map view limits to display limits via a linear tranformation

    # viewlim and displaylim are Bound1D instances
    tx = Transform( axes.xaxis.viewlim, axes.xaxis.displaylim )
    ty = Transform( axes.yaxis.viewlim, axes.yaxis.displaylim )
    l = Line2D(dpi, bbox, xdata, ydata, transx=tx, transy=ty)

  * Map relative axes coords ( 0,0 is lower left and 1,1 is upper
    right ) to display coords.  This example puts text in the middle
    of the axes (0.5, 0.5)
    
    tx = Transform( Bound1D(0,1), axes.xaxis.displaylim )
    ty = Transform( Bound1D(0,1), axes.yaxis.displaylim )
    text = AxisText(dpi, bbox, 0.5, 0.5, transx=tx, transy=ty)
    
 * Map x view limits to display limits via a log transformation and y
   view limits via linear transform.  The funcs pair is the
   transform/inverse pair

    funcs = logwarn, pow10
    tx = Transform( axes.xaxis.viewlim, axes.xaxis.displaylim, funcs )
    ty = Transform( axes.yaxis.viewlim, axes.yaxis.displaylim )
    l = Line2D(dpi, bbox, xdata, ydata, transx=tx, transy=ty)

 * You can also do transformation from one physical scale (inches, cm,
   points, ...) to another.  You need to specify an offset in output
   coords.

      offset = 100  # dots
      cm = Centimeter( self.dpi)
      dots =  Dots( self.dpi)
      t =  TransformSize(cm, dots, offset)

   If you don't know the offset in output coords, you can supply an
   optional transform to transform the offset to output coords.  Eg,
   if you want to offset by x in data coords, and the output is
   display coords, you can do

      offset = 0.2  # x data coords
      cm = Centimeter( self.dpi)
      dots =  Dots( self.dpi)
      t =  TransformSize(cm, dots, offset, axes.xaxis.transData)
   
 * Combining the above, we can specify that a text instance is at an x
   location in data coords and a y location in points relative to an
   axis.  Eg. the transformation below indicates that the x value of
   an xticklabel position is in data coordinates and the y value is 3
   points below the x axis, top justified

        # the top of the x ticklabel text is the bottom of the y axis
        # minus 3 points.  Note the code below uses the overloading of
        # __sub__ and __mul__ to return a BinOp.  Changes in the
        # position of bbox.y or dpi, eg, on a resize event, are
        # automagically reflected in the tick label position.
        # dpi*3/72 converts 3 points to dots.
        
        top = self.bbox.y.get_refmin() - self.dpi*RRef(3/72.0)
        text =  backends.AxisText(dpi, bbox, x=xdata, y=top,  
            verticalalignment='top',
            horizontalalignment='center',
            transx = self.axes.xaxis.transData,
            # transy is default, identity transform)

The unittest code for the transforms module is unit/transforms_unit.py

 
"""
from __future__ import division
import sys
from numerix import array, asarray, log10, take, nonzero
from cbook import iterable



class Size:
    def __init__(self, dpi, val=1):
        self._dpi = dpi
        self._val = val

    def get(self):
        return self._val

    def to_dots(self):
        return self._dpi.get() * self.to_inches()

    def to_inches(self):
        return self.to_inches()


    def __repr__(self): return '%s %s ' % (self._val, self.units())

    def set_val(self, val):
        self._val  = val

    def to_inches(self):
        raise NotImplementedError('Derived must override')

    def units(self):
        raise NotImplementedError('Derived must override')
    
class Inches(Size):
    def to_inches(self):
        return self._val

    def units(self):
        return 'inches'

class Dots(Size):
    def to_inches(self):
        return self._val/self._dpi.get()

    def to_dots(self):
        return self._val

    def units(self):
        return 'dots'

class Points(Size):
    def to_inches(self):
        return self._val/72.0

    def units(self):
        return 'pts'

class Millimeter(Size):
    def to_inches(self):
        return self._val/25.4

    def units(self):
        return 'mm'

class Centimeter(Size):
    def to_inches(self):
        return self._val/2.54

    def units(self):
        return 'cm'


def bintimes(x,y): return x*y
def binadd(x,y): return x+y
def binsub(x,y): return x-y

    
class RRef:
    'A read only ref'
    def __init__(self, val):
        self._val = val

    def get(self):
        return self._val

    def __add__(self, other):
        return BinOp(self, other, binadd)

    def __mul__(self, other):
        return BinOp(self, other, bintimes)

    def __rmul__(self, other):
        return BinOp(self, other, bintimes)

    def __sub__(self, other):
        return BinOp(self, other, binsub)

class BinOp(RRef):
    'A read only ref that handles binary ops of refs'
    def __init__(self, ref1, ref2, func=binadd):
        assert(isinstance(ref1, RRef))
        assert(isinstance(ref2, RRef))
        self.ref1 = ref1
        self.ref2 = ref2
        
        self.func = func

    def get(self):
        return self.func(self.ref1.get(), self.ref2.get())

class RWRef(RRef):
    'A readable and writable ref'
    def set(self, val):
        self._val = val



class Bound1D:
    """
    Store and update information about a 1D bound
    """
    def __init__(self, minval=None, maxval=None, isPos=False):        
        self._min = RWRef(minval)
        self._max = RWRef(maxval)
        self._isPos = isPos
        self._defined = False
        #if self.defined(): assert(self._max.get()>=self._min.get())
        
    def bounds(self):
        'Return the min, max of the bounds'
        return self._min.get(), self._max.get()

    def defined(self):
        'return true if both endpoints defined'
        if self._defined: return True

        self._defined = (self._min.get() is not None and
                         self._max.get() is not None)
        return self._defined

    def get_refmin(self):
        return self._min

    def get_refmax(self):
        return self._max

    def in_interval(self, val):
        'Return true if val is in [min,max]'
        if not self.defined(): return False
        smin, smax = self.bounds()
        if smax<smin: smin, smax = smax, smin
        if val>=smin and val<=smax: return True
        else: return False

    def in_open_interval(self, val):
        'Return true if val is in (min,max)'
        if not self.defined(): return False
        smin, smax = self.bounds()
        if smax<smin: smin, smax = smax, smin
        if val>smin and val<smax: return True
        else: return False
        
    def interval(self):
        'return max - min if defined, else None'
        if not self.defined(): return None
        smin, smax = self.bounds()
        return smax-smin

    def max(self):
        'return the max of the bounds'
        return self._max.get()

    def min(self):
        'return the min of the bounds'
        return self._min.get()

    def overlap(self, bound):
       """
       Return true if bound overlaps with self.

       Return False if either bound undefined
       """
       if not self.defined() or not bound.defined():
          return False
       smin, smax = self.bounds()
       bmin, bmax = bound.bounds()
       return ( ( smin >= bmin and smin <= bmax)  or 
                ( bmin >= smin and bmin <= smax ) )

       
    def __repr__(self):
        smin, smax = self.bounds()
        return 'Bound1D: %s %s' % (smin, smax)


    def shift(self, val):
        'Shift min and max by val and return a reference to self'
        smin, smax = self.bounds()

        if self._min.get() is not None:
            newmin = smin + val
            self.set_min(newmin)
        if self._max.get() is not None:
            newmax = smax + val
            self.set_max(newmax)

        return self
    
    def set_bounds(self, vmin, vmax):
        'set the min and max to vmin, vmax'
        self.set_min(vmin)
        self.set_max(vmax)

    def set_min(self, vmin):
        if self._isPos and vmin<=0: return
        self._min.set(vmin)
        if vmin is None:
            self._defined = False

    def set_max(self, vmax):
        'set the max to vmax'
        if self._isPos and vmax<=0: return
        self._max.set(vmax)
        if vmax is None:
            self._defined = False


    def scale(self, s):
        'scale the min and max by ratio s and return a reference to self'
        if self._isPos and s<=0: return 
        if not self.defined(): return 
        i = self.interval()
        delta = (s*i-i)/2  # todo; correct for s<1

        vmin, vmax = self.bounds()
        self._min.set(vmin - delta)
        self._max.set(vmax + delta)
        return self

    def update(self, x):
        """
        Update the min and max with values in x.  Eg, only update min
        if min(x)<self.min().  Return a reference to self
        """
        if iterable(x) and len(x)==0: return
            
        if not iterable(x):
            if self._isPos and x<=0: return
            minx, maxx = x, x
        else:
            if self._isPos: x = take(x, nonzero(x>0))
            if not len(x): return 
            minx, maxx = min(x), max(x)
            

        if self._max.get() is None: self._max.set(maxx)
        else: self._max.set(max(self._max.get(), maxx))

        if self._min.get() is None: self._min.set(minx)
        else: self._min.set( min(self._min.get(), minx))
        return self

    def is_positive(self, b):
        """
        If true, bound will only return positive endpoints.
        """
        self._isPos = b
        if b:
            if self._min.get()<=0: self._min.set(None)
            if self._max.get()<=0: self._max.set( None)
            
        
def bound1d_all(bounds):
    """
    Return a Bound1D instance that bounds all the Bound1D instances in
    sequence bounds.

    If the min or max val for any of the bounds is None, the
    respective value for the returned bbox will also be None
    """
    
    if not len(bounds): return Bound1D(None, None)
    if len(bounds)==1: return bounds[0]

    # min with a sequence with None is None
    minval = min([b.min() for b in bounds])
    
    # max with a sequence with None is not None, so we use a different
    # approach
    maxvals = [b.max() for b in bounds]
    if None in maxvals: maxval = None
    else: maxval = max(maxvals)
    return Bound1D(minval, maxval)

class Bound2D:
    """
    Store and update 2D bounding box information

    Publicly accessible attributes

     x the x Bound1D instance
     y the y Bound2D instance

    """
    def __init__(self, left, bottom, width, height):
        self.x = Bound1D()
        self.y = Bound1D()
        self.set_bounds(left, bottom, width, height)
        
    def __repr__(self):
        return 'Bound2D: %s\n\tx: %s\n\ty: %s' % \
               (list(self.get_bounds()), self.x, self.y) 

    def copy(self):
       'Return a deep copy of self'
       return Bound2D(*self.get_bounds())

    def defined(self):
        return self.x.defined() and self.y.defined()

    def set_bounds(self, left, bottom, width, height):
        'Reset the bounds'
        assert(left is not None)
        assert(bottom is not None)
        assert(width is not None)
        assert(height is not None)
        assert(width>=0)
        assert(height>=0)

        minx = left
        maxx = left + width
        miny = bottom
        maxy = bottom + height
        self.x.set_bounds(minx, maxx)
        self.y.set_bounds(miny, maxy)
        
    def get_bounds(self):
        left = self.x.min()
        bottom = self.y.min()
        right = self.x.max()
        top = self.y.max()
        width = right-left
        height = top-bottom
        return left, bottom, width, height

    def overlap(self, bound):
       """
       Return true if bound overlaps with self.

       Return False if either bound undefined
       """
       if not self.defined() or not bound.defined():
          return False
       return self.x.overlap(bound.x) and self.y.overlap(bound.y)


   
def bound2d_all(bounds):
    """
    Return a Bound2D instance that bounds all the Bound2D instances in
    sequence bounds.

    If the min or max val for any of the bounds is None, the
    respective value for the returned bbox will also be None
    """

    bx = bound1d_all([b.x for b in bounds])
    by = bound1d_all([b.y for b in bounds])
    
    left = bx.min()
    bottom = by.min()
    width = bx.interval()
    height = by.interval()
    return Bound2D(left, bottom, width, height)

def iterable_to_array(x):
    if iterable(x): return asarray(x)
    else: return x

def identity(x):
    'The identity function'
    try: return x.get()
    except AttributeError: return x


def logwarn(x):
    'Return log10 for positive x'
    # x is a scalar
    if not iterable(x):
        if x<=0:
            raise ValueError('Cannot take log of non-positive data')
    elif min(x)<=0:
        raise ValueError('Cannot take log of non-positive data')
    return log10(x)

def pow10(x):
    'the inverse of log10; 10**x'
    return 10**asarray(x)

class Transform:
    """
    Abstract base class for transforming data

    Publicly accessible attributes are
      func            : the transform func
      ifunc           : the inverse tranform func

    A tranform from in->out is defined by

    scale = (maxout-maxin)/( func(maxin)-func(minin) )
    out =  scale * ( func(in)-func(minin) ) + minout

    funcs are paired with inverses, allowing Transforms to return
    their inverse
    
    """

    def __init__(self, boundin=Bound1D(0,1), boundout=Bound1D(0,1), 
                 funcs = (identity, identity),
                 ):
        """
        The default transform is identity.

        To do a linear transform, replace the bounds with the
        coodinate bounds of the input and output spaces

        To do a log transform, use funcs=(log10, pow10)
        """
        
        self._boundin = boundin
        self._boundout = boundout
        self.set_funcs(funcs)
        

    def inverse_positions(self, x):
        'Return the inverse transform of x'
        x = iterable_to_array(x)
        minin, maxin = self._boundin.bounds()
        if self.func != identity:
            minin, maxin = self.func(minin), self.func(maxin)
        scale = (maxin-minin)/self._boundout.interval()
        return  self.ifunc(scale*(x-self._boundout.min()) + minin)

    def inverse_scale(self, x):
        'Return the inverse transform of scale x'
        x = iterable_to_array(x)
        minin, maxin = self._boundin.bounds()
        if self.func != identity:
            minin, maxin = self.func(minin), self.func(maxin)
        scale = (maxin-minin)/self._boundout.interval()
        return  self.ifunc(scale*x)
        

    def position_scale(self, pos, scale, isScalarOrArray=False):
        """
        Transform scalar position and scale

        Doing both together reduces duplication of function calls, eg,
        for a x position and width if both use same transform

        set isScalarOrArray to True if you know x is a scalar or array
        for performace

        """

        if not isScalarOrArray:
            pos = iterable_to_array(pos)
            scale = iterable_to_array(scale)
        minin, maxin = self._boundin.bounds()
        minout, maxout = self._boundout.bounds()
        if self.func != identity:
            minin = self.func(minin)
            maxin = self.func(maxin)
            pos = self.func(pos)
            scale = self.func(scale)
        m = (maxout-minout)/(maxin-minin)
        pos =  m*(pos-minin) + minout
        scale = m*scale
        return pos, scale

    def positions(self, x, isScalarOrArray=False):
        """
        Transform the positions in x.

        set isScalarOrArray to True if you know x is a scalar or array
        for performace
        """
        if not isScalarOrArray: x = iterable_to_array(x)
        minin, maxin = self._boundin.bounds()
        if self.func != identity:
            minin, maxin = self.func(minin), self.func(maxin)
        scale = self._boundout.interval()/(maxin-minin)
        #print self._boundout.interval(), (maxin-minin), scale, self.func(x), minin, self._boundout.min()
        return  scale*(self.func(x)-minin) + self._boundout.min()

        
    def scale(self, s, isScalarOrArray=False):
        """
        Transform the scale in s

        set isScalarOrArray to True if you know x is a scalar or array
        for performace

        """
        if not isScalarOrArray: s = iterable_to_array(s)
        minin, maxin = self._boundin.bounds()
        if self.func != identity:
            minin, maxin = self.func(minin), self.func(maxin)
        scale = self._boundout.interval()/(maxin-minin)
        return  scale*self.func(s)

    def __repr__(self):
        return  'Transform: %s to %s' %(self._boundin, self._boundout)
    
    def set_funcs(self, funcs):
        'Set the func, ifunc to funcs'
        self.func, self.ifunc = funcs


class TransformSize:
    def __init__(self, sin, sout, offset, transOffset=Transform()):
        """
        transform size in Size instance sin to Size instance sout,
        offsetting by Bound1D instance RRef instance offset.
        transOffset is used to transform the offset if not None
        """
        self._sin = sin
        self._sout = sout
        self._offset = offset
        self._transOffset = transOffset

    def positions(self, x):
        offset = self._get_offset() 
        return offset + self._sin.to_inches()/self._sout.to_inches()*x

    def inverse_positions(self, x):
        offset = self._get_offset() 
        return (x - offset)*self._sout.to_inches()/self._sin.to_inches()

    def inverse_scale(self, x):
        return self._sout.to_inches()/self._sin.to_inches()*x

    def scale(self, x):
        return self._sin.to_inches()/self._sout.to_inches()*x

    def _get_offset(self):
        try: o =  self._offset.get()
        except AttributeError: o =  self._offset
        return self._transOffset.positions(o)



def transform_bound1d(bound, trans):
    """
    Transform a Bound1D instance using transforms trans
    """
    
    tmin, tmax = trans.positions(bound.bounds())
    return Bound1D(tmin, tmax)

def inverse_transform_bound1d(bound, trans):
    """
    Inverse transform a Bound1D instance using trans.inverse()
    """
    
    tmin, tmax = trans.inverse_positions(bound.bounds())
    return Bound1D(tmin, tmax)


def transform_bound2d(bbox, transx, transy):
    """
    Transform a Bound2D instance using transforms transx, transy
    """
    b1 = transform_bound1d(bbox.x, transx)
    b2 = transform_bound1d(bbox.y, transy)
    l = b1.min()
    w = b1.interval()
    b = b2.min()
    h = b2.interval()
    return Bound2D(l,b,w,h)

def inverse_transform_bound2d(bbox, transx, transy):
    """
    Inverse transform a Bound2D instance using transforms transx, transy
    """
    b1 = inverse_transform_bound1d(bbox.x, transx)
    b2 = inverse_transform_bound1d(bbox.y, transy)
    l = b1.min()
    w = b1.interval()
    b = b2.min()
    h = b2.interval()
    return Bound2D(l,b,w,h)
    


    
