import numpy
import numpy.random as ran
import math
import os
# import PySimbad
# import astropysics.obstools
import scipy.stats as stats

def avg(x0, err=False, nb=100):
    x = numpy.array(x0)[~numpy.isnan(numpy.array(x0))]
    nx = len(x)
    m = (1.0*sum(x))/nx

    if (err == True):
        mr = []
        for n in range(nb):
            r = ran.randint(0,nx,nx)
            xxr = x[r]
            mxr = 1.0*sum(xxr)/nx
            mr.append(mxr)

        sigm = stddev(mr)
        return (m, sigm)
    else:
        return m


def stddev(x0,sigclip=None, nit=2,rng=None,extras=False):
    if (rng is None):
        x = x0
    else:
        w = numpy.where((x0 > rng[0]) & (x0 < rng[1]))
        x = x0[w]

    if (sigclip is None):
        sigx = (sum((numpy.array(x) - avg(x))**2.)/(len(x)-1.))**0.5
        xx = x
    else:
        sigx = 1e10
        xx = numpy.copy(x)
        xavg = avg(xx)
        for it in range(nit): 
            w = numpy.where(abs(xx - xavg) < sigclip*sigx)
            xx = numpy.copy(xx[w])
            xavg = avg(xx)
            sigx = (sum((numpy.array(xx) - xavg)**2.)/(len(xx)-1.))**0.5

    if (extras==True):
        return sigx, xx
    else:
        return sigx

def ravg(x0, sigclip=2, nit=2, rng=None):
    x = numpy.array(x0)[~numpy.isnan(numpy.array(x0))] 
    xx = stddev(x, sigclip=sigclip, nit=nit, rng=rng, extras=True)[1]
    return avg(xx)

def rms(x):
    return (sum((numpy.array(x) - avg(x))**2.)/(len(x)))**0.5

def med(x,err=False,nb=100):
    m = numpy.median(x)
    xx = numpy.array(x)
    nx = len(xx)
    mr = []

    if (err == True):
        for n in range(nb):
            r = ran.randint(0,nx,nx)
            xxr = xx[r]
            mxr = numpy.median(xxr)
            mr.append(mxr)

        sigm = stddev(mr)
        return (m, sigm)
    else:
        return m

def sigbs(x,nb=1000,sigclip=None,nit=2):
    m = stddev(x,sigclip=sigclip,nit=nit)
    xx = numpy.array(x)
    nx = len(xx)
    mr = []

    for n in range(nb):
        r = ran.randint(0,nx,nx)
        xxr = xx[r]
        mxr = stddev(xxr,sigclip=sigclip,nit=nit)
        mr.append(mxr)

    sigm = stddev(mr)
    return (m, sigm)

def ncol(fname, nskip=0):
    l = 0
    ncmin = 1000000

    with open(fname,'r') as f:
        for s in f:
            if (l >= nskip):
                ss = s.split()
                if (s[0] != '#'): 
                    nc = len(ss)	    
                    if (nc < ncmin): ncmin = nc
            l+=1

    return ncmin

def getcol(fname, n, nskip=0, nrd=float("inf"), skipstr=None):
    if isinstance(n,tuple):
        ctmp = []
        for nn in n:
            l = 0
            cc = []

            with open(fname) as f:
                for s in f:
                    if ((l >= nskip) and (len(cc) < nrd)):
                        ss = s.split()
                        if (s[0] != '#'): 
                            if (skipstr==None) or (not skipstr in s):
                                cc.append(float(ss[nn]))
                    l+=1

            ctmp.append(numpy.array(cc))

        c = tuple(ctmp)

    else:
        l = 0
        c = []

        with open(fname,'r') as f:
            for s in f:
                if ((l >= nskip) and (len(c) < nrd)):
                    ss = s.split()
                    if (s[0] != '#'): 
                        if (skipstr==None) or (not skipstr in s):
                            c.append(float(ss[n]))
                l+=1

        c = numpy.array(c)

    return c


def vartst(sigref, sigdata, N):
    T = (N-1) * (sigdata/sigref)**2.
    print ('T-value = %0.2e' % T)
    print ('Upper 99%% critical value = %0.2e' % stats.chi2.isf(0.01/2,N-1))
    print ('Lower 99%% critical value = %0.2e' % stats.chi2.isf(1-0.01/2,N-1))
    ss = stats.chi2.sf(T, N-1)
    return ss


def wavg(x00, err00, rng=None):

    x0 = numpy.array(x00)[~numpy.isnan(numpy.array(x00))]
    err0 = numpy.array(err00)[~numpy.isnan(numpy.array(x00))]

    if (rng is None):
        x = x0
        err = err0
    else:
        w = numpy.where((x0 > rng[0]) & (x0 < rng[1]))
        x = x0[w]
        err = err0[w]

    ax = numpy.array(x)
    aerr = numpy.array(err)
    a = sum(ax/aerr**2.) / sum(1/aerr**2.)
    eavg = ((1/sum(1/aerr**2.))**2. * sum(1/aerr**2.))**0.5
    return a, eavg, len(x)

def wrms(x0,err0):
    x = numpy.array(x0)[~numpy.isnan(numpy.array(x0))]
    err = numpy.array(err0)[~numpy.isnan(numpy.array(x0))]

    wa = (wavg(x,err))[0]
    return (sum((numpy.array(x) - wa)**2.)/(len(x)))**0.5

def wwrms(x00,err00, rng=None):
    x0 = numpy.array(x00)[~numpy.isnan(numpy.array(x00))]
    err0 = numpy.array(err00)[~numpy.isnan(numpy.array(x00))]

    if (rng is None):
        x = x0
        err = err0
    else:
        w = numpy.where((x0 > rng[0]) & (x0 < rng[1]))
        x = x0[w]
        err = err0[w]

    wa = (wavg(x,err))[0]
    ne = numpy.array(err)
    nx = numpy.array(x)
    return (sum(((nx - wa)/ne)**2.)/(sum(1/ne**2.)))**0.5

def wwstddev(x,err, sigclip=None, nit=2,extras=False):
    if (sigclip is None):
        wa = (wavg(x,err))[0]
        ne = numpy.array(err)/max(err)
        nx = numpy.array(x)
        sigx = (sum(((nx - wa)/ne)**2.)/(sum(1/ne**2.)-1))**0.5
        xx = x
    else:
        sigx = 1e10
        xx = numpy.copy(x)
        errxx = numpy.copy(err)
        wa = (wavg(xx,errxx))[0]
        for it in range(nit): 
            w = numpy.where(abs(xx - wa) < sigclip*sigx)
            xx = numpy.copy(xx[w])
            errxx = numpy.copy(errxx[w])
            wa = (wavg(xx,errxx))[0]
            ne = numpy.array(errxx)/max(errxx)
            nx = numpy.array(xx)
            sigx = (sum(((nx - wa)/ne)**2.)/(sum(1/ne**2.)-1))**0.5


    if (extras==True):
        return sigx, xx
    else:
        return sigx


def dm2d(dm):
    d = 10*10**(0.2*dm)
    return d

def d2dm(d):
    dm = -5 + 5 * numpy.log10(d)
    return dm

def pint(Mtot, mmin, m1, m2, alpha=-2.35, mmax=100.):
    x1 = (alpha+2) * (m2**(alpha+1) - m1**(alpha+1))
    x2 = (alpha+1) * (mmax**(alpha+2) - mmin**(alpha+2))
    nm1 = Mtot * x1 / x2
    return nm1

def bbody_nu(T, nu, z=0.):
    h = 6.62606876e-34
    c = 299792458.
    k = 1.380658e-23
    nu0 = nu*(1+z)
    bnu = 2*h*nu0**3 / (c**2 * (numpy.exp(h*nu0/(k*T)) - 1)) * (1+z)
    return bnu

def getgdat(gname,col):
    pdir = os.getenv('HOME')+'/python/'
    val = 0.
#    for l in file(pdir+'gal.dat'):
    with open(pdir+'gal.dat','r') as f: 
        for l in f:
            ll = l.split()
            if (ll[0] == gname): val = float(ll[col])

    return val

def dm(gname):
    return getgdat(gname,1)

def ab(gname):
    return getgdat(gname,2)

def air2vac(air):
    '''air2vac(air): convert wavelength(s) from air to vacuum. Units are Angstrom'''
# Uses the formula from Morton 1991, ApJS 77, 119
    s2 = (1e4/air)**2
    n  = 1.0 + 6.4328e-5 + 2.94981e-2/(146.0 - s2) + 2.5540e-4/(41.0 - s2)
    return air * n
#    return air * (1.0 + 2.735182e-4 + 131.4182 / air**2 + 2.76249e8 / air**4)

def vac2air(vac):
    '''vac2air(vac): convert wavelength(s) from vacuum to air. Units are Angstrom'''
    
    s2 = (1e4/vac)**2
    n  = 1.0 + 6.4328e-5 + 2.94981e-2/(146.0 - s2) + 2.5540e-4/(41.0 - s2)
    return vac / n
#    return vac / (1.0 + 2.735182e-4 + 131.4182 / vac**2 + 2.76249e8 / vac**4)

def rext(filter):

   ai = {'U': 1.195, 'B': 1.0, 'V' : 0.756, 'R': 0.598, 'I': 0.415, 
         'J': 0.196, 'H': 0.124, 'K': 0.083}

   return ai[filter]


# def getebv(objname):
#     coo = PySimbad.SimbadGalCoord(objname)
#     glen = float(coo.split()[0])
#     glat = float(coo.split()[1])
# 
#     cwd = os.getcwd()
#     os.chdir(os.getenv('HOME')+'/cats/Schlegel')
#     ebv = astropysics.obstools.get_SFD_dust(glen, glat)
#     os.chdir(cwd)
# 
#     return ebv
# 
# def getevi(objname):
#     return getebv(objname) * 1.4
# 
# def getav(objname):
#     return getebv(objname) * 3.1
# 
# def getab(objname):
#     return getebv(objname) * 4.1

def extc89(lam):
# Cardelli et al. reddening law
# lam in AA
    Rv = 3.1
    x = 1/(lam * 1E-4)

    y = x-1.82
    a = 1 + 0.17699 * y - 0.50447 * y**2. - 0.02427 * y**3. + 0.72085 * y**4. + 0.01979 * y**5. - 0.77530 * y**6. + 0.32999 * y**7.
    b = 1.41338 * y + 2.28305 * y**2. + 1.07233 * y**3. - 5.38434 * y**4. - 0.62251 * y**5. + 5.30260 * y**6. - 2.09002 * y**7.

    return a + b/Rv



def st2ab(lam,tran):
# Calculate offset from STMAG to ABMAG magnitudes, given
# the transmission curve of a filter (lam, tran)

    c = 299792458.
    wlam = 0.
    wnu = 0.
    for i in range(len(lam)-1):
       dlam = lam[i+1] - lam[i]
       dnu  = dlam*1e10*c/(lam[i]**2.)
       t    = tran[i]
       wlam = wlam + t*dlam
       wnu  = wnu +t*dnu

    dm =  -2.5*math.log10(wlam/wnu) + 21.1 - 48.6
    return dm


def eff(r, a, gamma):
    mu = (1 + r*r/(a*a))**(-gamma/2)
    return mu

def hlreff(a, gamma, rmax=None):
    rc = a
    eta = gamma/2

    if (rmax is None):
        Reff = rc * numpy.sqrt(0.5**(1./(1-eta)) - 1)
    else:
        a1 = (1 + (rmax/rc)**2.)**(1. - eta) + 1
        a2 = (0.5 * a1)**(1./(1-eta)) - 1
        Reff = rc * numpy.sqrt(a2)

    return Reff

def fwhm2reff_eff(fwhm, eta):
    reff = fwhm * math.sqrt(0.5**(1/(1-eta))-1)/(2*math.sqrt(2**(1/eta)-1))
    return reff

def fwhm2reff_king(fwhm, c):
    rc = fwhm/(2*((math.sqrt(0.5) + (1 - math.sqrt(0.5))/math.sqrt(1+c*c))**(-2.) - 1)**(0.5))
    reff = rc * 0.547*c**0.486
    return reff

def is_number(s):
    try:
        complex(s) # for int, long, float and complex
    except ValueError:
        return False

    return True

def getmcm(cols=('ID','RA','DEC','R_Sun','R_Gc')):
    fn1 = '/Users/soeren/cats/mcmaster/mwgc-I.dat' 
    fn2 = '/Users/soeren/cats/mcmaster/mwgc-II.dat' 
    fn3 = '/Users/soeren/cats/mcmaster/mwgc-III.dat' 
   
    xID, xRA, xDEC = [], [], []
    xR_Sun, xR_gc = [], []
    xX, xY, xZ = [], [], []
   
#    for l in file(fn1):
    with open(fn1,'r') as f:
        for l in f:
            if (l[0] != '#'): 
                ID = l[0:10]
                RA = l[25:36]
                DEC = l[38:49]
                R_Sun = l[68:73]
                R_gc  = l[74:79]
                X     = l[80:85]
                Y     = l[86:91]
                Z     = l[92:97]

                xID.append(ID)
                xRA.append(RA)
                xDEC.append(DEC)
                xR_Sun.append(R_Sun)
                xR_gc.append(R_gc)
                xX.append(X)
                xY.append(Y)
                xZ.append(Z)


    xFeH, xEBV, xmM, xV_t, xM_Vt = [], [], [], [], []
    xUB, xBV, xVR, xVI = [], [], [], []

#    for l in file(fn2):
    with open(fn2,'r') as f:
        for l in f:
            if (l[0] != '#'): 
                FeH = l[13:18]
                EBV = l[23:28]
                mM  = l[35:40]
                V_t = l[41:46]
                M_Vt = l[47:53]
                UB  = l[55:60]
                BV  = l[61:66]
                VR  = l[67:72]
                VI  = l[73:78]

                xFeH.append(FeH)
                xEBV.append(EBV)
                xmM.append(mM)
                xV_t.append(V_t)
                xM_Vt.append(M_Vt)
                xUB.append(UB)
                xBV.append(BV)
                xVR.append(VR)
                xVI.append(VI)

    xr_c, xr_h = [], []

#    for l in file(fn3):
    with open(fn3,'r') as f:
        for l in f:
            if (l[0] != '#'): 
                r_c = l[59:64]
                r_h = l[65:70]

                xr_c.append(r_c)
                xr_h.append(r_h)

    ret = []

    for colid in cols:
        if (colid == 'ID'): ret.append(xID)
        if (colid == 'RA'): ret.append(xRA)
        if (colid == 'DEC'): ret.append(xDEC)
        if (colid == 'R_Sun'): ret.append(xR_Sun)
        if (colid == 'R_gc'): ret.append(xR_gc)
        if (colid == 'EBV'): ret.append(xEBV)
        if (colid == 'mM'): ret.append(xmM)
        if (colid == 'V_t'): ret.append(xV_t)
        if (colid == 'M_Vt'): ret.append(xM_Vt)
        if (colid == 'UB'): ret.append(xUB)
        if (colid == 'BV'): ret.append(xBV)
        if (colid == 'VR'): ret.append(xVR)
        if (colid == 'VI'): ret.append(xVI)
        if (colid == 'FeH'): ret.append(xFeH)
        if (colid == 'r_c'): ret.append(xr_c)
        if (colid == 'r_h'): ret.append(xr_h)
        if (colid == 'X'): ret.append(xX)
        if (colid == 'Y'): ret.append(xY)
        if (colid == 'Z'): ret.append(xZ)
	    

    return tuple(ret)


def kmm(xx, c0, v0, mix=None):

    if (mix is None): mix = numpy.ones(len(c0))/len(c0)

    fkmm = open('kmm.in','w')
    fkmm.write('%d %d\n' % (len(xx), 1))
    for x in xx:
        fkmm.write('%11.5f\n' % x)
    fkmm.write('%d\n' % (len(c0)))
    fkmm.write('%d\n' % (len(v0)))     # Number of covariance groups
    fkmm.write('%d\n' % (len(c0)))     
    for c in c0: fkmm.write('%0.3f\n' % c)
    for v in v0: fkmm.write('%0.5f\n' % v)
    for m in mix: fkmm.write('%0.3f ' % m)
    fkmm.write('\n')
    fkmm.close()

    frun = open('kmm.com','w')
    frun.write('#!/bin/csh\n')
    frun.write('\\rm -f kmm.out\n')
    frun.write('kmm<<EOF\n')
    frun.write('kmm.in\n')
    frun.write('kmm.out\n')
    frun.write('EOF\n')
    frun.close()
    os.system("bash -c 'source kmm.com'")

    estimated_mean = []
    estimated_covariance = []
    fout = open('kmm.out','r')
    l = fout.readline()
    while (not 'Entity: Final estimates' in l): l = fout.readline()

    while (not 'Resulting partition of the entities' in l): l = fout.readline()
    partition = []
    l = fout.readline()
    while (l != '\n'):
        ll = l.split()
        for lll in ll: partition.append(int(lll))
        l = fout.readline()

    while len(l) > 0:
        if ('Number assigned to each group' in l):
            l = fout.readline()
            ll = l.split()
            number_assigned = [float(lll) for lll in ll]

        if ('Estimated mean (as a row vector) for each group' in l):
            for j in c0:
                l = fout.readline()
                estimated_mean.append(float(l))

        for j in range(len(v0)):
            if ('Estimated covariance matrix for group%3d' %(j+1) in l):
                l = fout.readline()
                estimated_covariance.append(float(l))

        if ('and the p value for this statistic is' in l):
            ll = l.split()
            pvalue = float(ll[-1])

        l = fout.readline()

    fout.close()

    ret = {'n': number_assigned, 'p': pvalue, 
           'mean': estimated_mean, 'var': estimated_covariance,
           'partition': partition}

    return ret


