from __future__ import print_function

from . import absetup
from . import abutils
from . import amoeba
from . import kurucz
import numpy
import tempfile
import scipy
import scipy.ndimage as ndimage
import scipy.interpolate as interpolate
import scipy.integrate as integrate
import matplotlib.pyplot as pyplot
import datetime
import getpass
import os
import logging
import string
import shutil

# Some "User-defined" variables. We need to find a better
# mechanism for this (e.g. a parameter file)

SPECRNG = [5000., 5050., 0.01]
EXPAND = 1.0
# MAXPROC = 12
CFITORD = 3
CFITFNC = 'spline'   # 'spline' or 'poly'
CFRAC = 0.0          # Threshold for flagging as "continuum" level (0 -> all pixels are used)
CDLAM = 5.0          # Width of wavelength bins for continuum search
CITER = 1            # Number of iterations for continuum fitting
CSIGP = 2.0          # Upper clipping threshold for continuum fitting
CSIGM = 1.5          # Lower clippling threshold for continuum fitting
CSEL = 'MODEL'       # Select continuum points based on 'MODEL' or 'DATA'
NKNOTS  = 5          # Number of interior knots for spline fitting
NPINT = 1            # Number of interpolation points per output pixel
SMOOTHFNC = 'gaussian' # 'gaussian' or 'uniform'

# These global variables should normally be left alone.

SIGTOL  = 1e-4
LOGVTDEF = 0.30103   # Default microturbulent velocity = 2 km/s
LOGZDEF  = 0.0       # Default global scaling = Solar
R = 0.61803399
C = 1 - R
ENABLE_NLTE = False

INIT_SPEC = True
INITDIR = '.'

def markcont(lam, flx, frac=0.95, dlam=5.0):
    '''Mark continuum regions of the flx array. These are defined as regions
  where flx/max(flx) > fmin, where max(flx) is evaluated over a wavelength
  range of width dlam'''
  
    logging.debug('MARKCONT: frac=%0.3f  dlam=%0.2f' % (frac, dlam))
  
    if (frac == 0):
        c = numpy.ones(len(lam))
        nm = len(lam)
    else:
        nlam = numpy.array(lam)
        nflx = numpy.array(flx)
        c    = []
        nm   = 0
 
        for i in range(len(lam)):
            ll = lam[i]
            ff = fmax = flx[i]
            l1,l2 = ll - dlam/2, ll + dlam/2
            i1 = i2 = i
            while (i1>=10 and lam[i1] > l1): i1 -= 10
            while (lam[i1] < l1): i1 += 1
            while (i2 < len(lam)-10 and lam[i2] < l2): i2 += 10
            while (lam[i2] > l2): i2 -= 1
            fmax = max(flx[i1:i2])

            if (ff > frac * fmax): 
                c.append(1.0)
                nm += 1
            else:
                c.append(0.0)
     
    logging.debug('MARKCONT: Marked %d / %d points with F(lam)>%0.3f Fmax' % (nm, len(lam), frac))
            
    return numpy.array(c)
    
    
def markusercont(lam, cont):
    '''Map the continuum specifications in the user-supplied cont array onto 
    the wavelength scale specified in lam.'''
    
# For now, just a simple interpolation is done.
# We may think of something more sophisticated in the future    
    
    fc = interpolate.interp1d(cont[0], cont[1], bounds_error=False, kind='linear')
    c = fc(lam)
    
    return numpy.array(c)
    

def getchisq(lobs, fobs, eobs, lsyn, fsyn, sigsm, wobs=None, cont=None):

    dlam = (SPECRNG[1] - SPECRNG[0])/(len(lsyn)-1.)
    midpt = (SPECRNG[0] + SPECRNG[1])/2.

    if (SMOOTHFNC == 'gaussian'):
        logging.debug('GETCHISQ: Using Gaussian smoothing')
        fsynsm = ndimage.gaussian_filter1d(fsyn, sigsm/dlam)
    elif (SMOOTHFNC == 'uniform'):
        logging.debug('GETCHISQ: Using Uniform smoothing')
        fsynsm = ndimage.uniform_filter1d(fsyn, int(2*sigsm/dlam))
    else:
        raise Exception('Error: unknown SMOOTHFNC (%s)' % SMOOTHFNC)

# Interpolate in the smoothed spectrum, or integrate over each pixel,
# to give values at observed wavelength sampling points

    fncsyn = interpolate.interp1d(lsyn, fsynsm, bounds_error=False, kind='linear')
    
    if NPINT == 1:
        fsynsmi = fncsyn(lobs)
    else:
#        logging.debug('GETCHISQ: Rebinning')
        fsynsmi = numpy.zeros(len(lobs))
        ipix = numpy.array(range(len(lobs)))
        fnclam = interpolate.interp1d(ipix, lobs, bounds_error=False, kind='linear')
        
        l1 = 2*lobs[0] - fnclam(0.5)
        for i in ipix:
            if (i > 0): l1 = l2
            if (i < ipix[-1]): 
                l2 = fnclam(i+0.5)
            else:
                l2 = 2*lobs[i]-l1
                
            dlsub = (l2 - l1)/NPINT
            lsub = l1 + dlsub/2 + numpy.array(range(NPINT))*dlsub
            flxarr = fncsyn(lsub)
            fsynsmi[i] = sum(flxarr)/NPINT
#            flx = integrate.quad(fncsyn, l1, l2)/ (l2 - l1)
#            fsynsmi[i] = flx[0]
#            logging.debug('Lambda(1,c,2), Flux = %0.3f, %0.3f, %0.3f, %0.3e' % (l1, lobs[i], l2, fsynsmi[i]))
#            ls = '  '
#            lf = '  '
#            for ll,ff in zip(lsub,flxarr):
#                ls = ls + str('%0.3f ' % (ll))
#                lf = lf + str('%0.3e ' % (ff))
#            logging.debug(ls)
#            logging.debug(lf)

# Fit a polynomial or spline to the ratio F(synt)/F(obs) and scale F(obs).

    if wobs is None: wobs = numpy.ones(len(lobs))
   
    if (cont is not None):
        cflag = markusercont(lobs, cont)
    elif (CSEL.upper() == 'MODEL'):
        cflag = markcont(lobs, fsynsmi, frac=CFRAC, dlam=CDLAM)
    else:
        cflag = markcont(lobs, fobs, frac=CFRAC, dlam=CDLAM)
        
    cw = numpy.where(cflag > 0.5)

    mobs = sum(fobs)/len(fobs)
    msyn = sum(fsynsmi)/len(fsynsmi)
       
    f2n = fsynsmi/msyn 
    w = wobs * f2n/eobs

    for i in range(CITER):
        if CFITFNC.upper() == 'POLY':
            coeff = scipy.polyfit((lobs-midpt)[cw], (fobs/f2n)[cw], CFITORD, w=w[cw])  
            pscl = scipy.polyval(coeff, lobs-midpt)
        else:
            t = SPECRNG[0] + numpy.array(range(1,NKNOTS))/float(NKNOTS) * (SPECRNG[1]-SPECRNG[0])
            spline = interpolate.LSQUnivariateSpline((lobs-midpt)[cw], (fobs/f2n)[cw], t-midpt, w=w[cw], k=CFITORD) 
            pscl = spline(lobs-midpt)
            
# Avoid extrapolation; replace spline/poly fits beyond edges of continuum range with values at the edges
        lcwmin = min(lobs[cw])    
        lcwmax = max(lobs[cw])
        w1 = numpy.where(lobs < lcwmin)
        if (len(w1[0]) > 0): pscl[w1] = pscl[cw[0][0]]
        w2 = numpy.where(lobs > lcwmax)
        if (len(w2[0]) > 0): pscl[w2] = pscl[cw[0][-1]]

        
        f1n = fobs / pscl 
        e1n = eobs / pscl
    
        if (i < CITER):
            cdiff = f1n - f2n
            cdmean = sum(((f1n-f2n)*wobs)[cw])/sum(wobs[cw])
            cstd = numpy.std((wobs*cdiff)[cw])
            ww = numpy.where((cdiff - cdmean < -2*CSIGM*cstd) | (cdiff - cdmean > 2*CSIGP*cstd))
            cflag[ww] = 0.
            cw = numpy.where(cflag > 0.5)
#            logging.debug('cdmean = %0.3e, cstd = %0.3e' % (cdmean, cstd))
            logging.debug('Continuum scaling, ITER=%d: Sigma=%0.3e, %d pixels left' % (i+1, cstd, len(cw[0])))
    
# Calculate chi-square
    
    chisq = sum(wobs*((f1n-f2n)/e1n)**2.)
    chisqn = chisq / (len(f1n) - CFITORD - 1)
    
    fdat = {'obs': f1n, 'syn': f2n, 'scl': pscl, 'cflag': cflag}

    return (chisq, chisqn, fdat)


def fitfunc(specobs, stelpar, logz, atoms=[], abun=[], output=None, cont=None,
            sigsm=[0.0, 1.0], logvt=LOGVTDEF, initialize=True, initdir='.'):

# If sigsm is a list with two elements these indicate the range over
#  which to fit for the best fitting smoothing (in wavelength units).
# If sigsm is a single number then this value will be used for the
# smoothing
    
# First reorganize the observed spectrum    

    print('FITFUNC:')
    print('    LogZ  = %+0.3f' % logz)
    if (logvt == 'USER'):
        print('    LogVT = USER')
    else:
        print('    LogVT = %0.3f' % logvt)
    print('    Atoms = ',atoms)
    print('    Abun  = ',abun)
    
        
    lobs, fobs, eobs, wobs  = [], [], [], []
    for pixdat in specobs:
        lam = float(pixdat[0])
        if (lam >= SPECRNG[0] and lam <= SPECRNG[1]):
            lobs.append(lam)
            fobs.append(float(pixdat[1]))
            eobs.append(float(pixdat[2]))
            if (len(pixdat) >3): 
                wobs.append(float(pixdat[3]))
            else:
                wobs.append(1.)
            
    lobs = numpy.array(lobs)
    fobs = numpy.array(fobs)    
    eobs = numpy.array(eobs)
    wobs = numpy.array(wobs)
    
# Then compute synthetic spectra. Add a bit of "padding" to make sure we avoid
# extrapolating
        
    dp1 = lobs[1] - lobs[0]
    dp2 = lobs[-1] - lobs[-2]
    dps = SPECRNG[2] + EXPAND
    specr = [SPECRNG[0]-dp1-dps, SPECRNG[1]+dp2+dps, SPECRNG[2]]
    abutils.hrd2spec(stelpar,'specTMP.asc', specr, logz, atoms=atoms,
                     logvt=logvt, 
                     abun=abun, 
                     tmproot=absetup.TMPROOT, 
                     initialize=initialize, initdir=initdir)
                     
    with open('specTMP.asc','r') as f:                     
        synt = [s.split()[0:2] for s in f if s[0] != '#']

    lsyn, fsyn  = [], []
    for i in range(len(synt)):
        lsyn.append(float(synt[i][0]))
        fsyn.append(float(synt[i][1]))
    lsyn = numpy.array(lsyn)
    fsyn = numpy.array(fsyn)

# Apply smoothing to the synthetic spectrum
    
#    dlam = (SPECRNG[1] - SPECRNG[0])/(len(lsyn)-1)
#    midpt = (SPECRNG[0] + SPECRNG[1])/2.
    
    logging.info('FITFUNC:')
    logging.info('    LogZ= %+0.3f' % logz)
    if (logvt == 'USER'):
        logging.info('    LogVT = USER')
    else:    
        logging.info('    LogVT = %0.3f (%0.2f km/s)' % (logvt, 10**logvt))
    logging.info('    Atoms, Abun:')
    for r,s in zip(atoms,abun): logging.info('    %-2s %+0.3f' % (r, s))
       
    fitsig = isinstance(sigsm, list)
        
    if (fitsig):
 
# Bracket the minimum for sigma
        ax = sigsm[0]
        cx = sigsm[1]       
        fa, fna, fdata = getchisq(lobs, fobs, eobs, lsyn, fsyn, ax, wobs=wobs, cont=cont)
        fc, fnc, fdatc = getchisq(lobs, fobs, eobs, lsyn, fsyn, cx, wobs=wobs, cont=cont)
    
        tmpb = (cx - ax)/2.
        bx = ax + tmpb
        fb, fnb, fdatb = getchisq(lobs, fobs, eobs, lsyn, fsyn, bx, wobs=wobs, cont=cont)
        
        while (((fb > fa) or (fb > fc)) and (tmpb > SIGTOL)):
            tmpb /= 2.
            if (fa < fc):
                bx = ax + tmpb
            else:
                bx = cx - tmpb
            fb, fnb, fdatb = getchisq(lobs, fobs, eobs, lsyn, fsyn, bx, wobs=wobs, cont=cont)
            
            print('sigsm(a,b,c) = %0.3f, %0.3f, %0.3f. Chisqr(a,b,c) = %0.1f, %0.1f, %0.1f' % (ax,bx,cx,fa,fb,fc))
            logging.debug('    sigsm(a,b,c) = %0.3f, %0.3f, %0.3f. Chisqr(a,b,c) = %0.1f, %0.1f, %0.1f' % (ax,bx,cx,fa,fb,fc))

        
        if (tmpb <= SIGTOL):
            print('**  Warning: No minimum found for SIGSM between %0.3f and %0.3f' % (ax, cx))
            logging.info('    Warning: No minimum found for SIGSM between %0.3f and %0.3f' % (ax, cx))
            sigbest = bx
            chisqmin  = fb
            chisqnmin = fnb
            f1best = fdatb['obs']
            f2best = fdatb['syn']
            psclbest = fdatb['scl']
            cflagbest = fdatb['cflag']
        else:    
 
 # Golden Section search
           
            x0 = ax
            x3 = cx
    
            if abs(cx-bx) > abs(bx-ax):
                x1 = bx
                f1, fn1 = fb, fnb
                x2 = bx + C * (cx-bx)
                f2, fn2, fdat2 = getchisq(lobs, fobs, eobs, lsyn, fsyn, x2, wobs=wobs, cont=cont)
            else:
                x2 = bx
                f2, fn2 = fb, fnb
                x1 = bx - C * (bx-ax)
                f1, fn1, fdat1 = getchisq(lobs, fobs, eobs, lsyn, fsyn, x1, wobs=wobs, cont=cont)

            while abs(x2 - x1) > SIGTOL:
                if (f2 < f1):
                    x0 = x1
                    x1 = x2
                    x2 = R * x1 + C * x3
                    f1, fn1 = f2, fn2
                    f2, fn2, fdat2 = getchisq(lobs, fobs, eobs, lsyn, fsyn, x2, wobs=wobs, cont=cont)
                else:
                    x3 = x2
                    x2 = x1
                    x1 = R * x2 + C * x0
                    f2, fn2 = f1, fn1
                    f1, fn1, fdat1 = getchisq(lobs, fobs, eobs, lsyn, fsyn, x1, wobs=wobs, cont=cont)
                    
                print('sigsm(1,2) = %0.3f, %0.3f. Chisqr(1,2) = %0.1f, %0.1f' % (x1,x2,f1,f2))
                logging.debug('    sigsm(1,2) = %0.3f, %0.3f. Chisqr(1,2) = %0.1f, %0.1f' % (x1,x2,f1,f2))
            
            if (f1 < f2):
                sigbest = x1
                chisqmin  = f1
                chisqnmin = fn1
                f1best = fdat1['obs']
                f2best = fdat1['syn']
                psclbest = fdat1['scl']
                cflagbest = fdat1['cflag']
            else:
                sigbest = x2
                chisqmin  = f2
                chisqnmin = fn2
                f1best = fdat2['obs']
                f2best = fdat2['syn']
                psclbest = fdat2['scl']
                cflagbest = fdat2['cflag']
    else:
        chisqmin, chisqnmin, fdat = getchisq(lobs, fobs, eobs, lsyn, fsyn, sigsm, wobs=wobs, cont=cont)
        sigbest = sigsm
        f1best = fdat['obs']
        f2best = fdat['syn']
        psclbest = fdat['scl']
        cflagbest = fdat['cflag']

            
    print('    Best fit: smooth = %0.3f, chi-square = %0.1f' % (sigbest, chisqmin))
    print('              Reduced chi-square = %0.3f' % chisqnmin)
   
    logging.info('    Best fit: smooth = %0.3f, chi-square = %0.1f' % (sigbest, chisqmin))
    logging.info('              Reduced chi-square = %0.3f' % chisqnmin)
    
# Write the best fit to output file (optional)    
    
    if output is not None:
        fo = open(output,'w')
        fo.write('# LogZ= %+0.3f\n' % logz)
        if (logvt == 'USER'):
            fo.write('# LogVT = USER\n')
        else:
            fo.write('# LogVT = %0.3f (%0.2f km/s)\n' % (logvt, 10**logvt))
        fo.write('# Atoms, Abun:\n')
        for i in range(len(atoms)): fo.write('#  %-2s %+0.3f\n' % (atoms[i], abun[i]))
        fo.write('# Fit range = %0.3f - %0.3f\n' % (SPECRNG[0], SPECRNG[1]))
        fo.write('# Smoothing = %0.3f\n' % sigbest)
        fo.write('# Continuum fitting function = %s\n' % CFITFNC)
        fo.write('# Order of continuum fit = %d\n' % CFITORD)
        if CFITFNC.upper()=='SPLINE': fo.write('# Number of knots = %d\n' % NKNOTS)
        fo.write('# Reduced chi-square = %0.3f\n' % chisqnmin)
        fo.write('#  Lambda  F(obs) F(mod) Pscl       Wgt  Cont\n')

        for r,s,t,u,v,w in zip(lobs,f1best,f2best,psclbest,wobs,cflagbest): 
            fo.write('%10.3f %6.4f %6.4f %0.4e %0.2f %0d\n' % (r, s, t, u, v, w))
        fo.close()
 
    return (chisqmin, chisqnmin, sigbest)
    
 
def abunvar(abfix, atfit, v, fitz, met, fixz):
    abun  = []  + abfix + list(numpy.zeros(len(atfit)) + v)     
    if (fixz): logz = met
    if (fitz): logz = v
    return (logz, abun)

    

def fit1par(specobs, stelpar, pfit, prange, pfix=[], vfix=[], tol=0.001, calcerr=True, sigsm=[0,1.0], cont=None):
    '''Fit elements or metallicity, applying a common scaling. 
   fit1par(specobs, stelpar, pfit, prange, pfix, vfix):
   specobs = observed spectrum: [[lam1,flux1,err1,<w1>], [lam2,flux2,err2,<w2>], ...].
             where the user-specified weights (w1, w2, ..) are optional
   stelpar = stellar parameters, a list of [teff, logg, weight] data.
   pfit    = ['element1','element2', ...]
             where <elementi> are the names of the elements to fit,
             'O', 'Mg', 'Ba', etc.
             The overall (logarithmic) metallicity can be specified 
             as the 'special' element 'LogZ'.
   prange  = [pmin, pmax] where pmin and pmax
             are the minimum and maximum of the range to consider.
   pfix    = ['element1','element2', ...]
             Elements to keep fixed during the fitting process
   vfix    = [a1, a2, ...]
             Abundances for elements to keep fixed.
   cont    = [[lam1, lam2, ...], [c1, c2, ...]]
             User-specified continuum sampling points (c1=1)
   Examples:    
       fit1par(gcspec, gcstelpar, ['O','Mg','Ca','Ti'], [0.0, 1.0)
       fit1par(gcspec, gcstelpar, ['LogZ'], [-4.0, -1.0])'''
          
    atfix, abfix = [], []
    atfix = atfix + pfix
    abfix  = abfix + vfix
    
    if 'LogVT' in pfit:
        logging.error('Error: Cannot fit for LogVT in fit1par')
        raise Exception('Error: Cannot fit for LogVT in fit1par')
    
# If nothing is specified for global scaling, we assume Log Z/Z0 = LOGZDE    
       
    fixz = False
    met = LOGZDEF
    logvt = LOGVTDEF

    if 'LogZ' in atfix:
        i     = atfix.index('LogZ')
        dummy = atfix.pop(i)
        met   = abfix.pop(i)
        fixz  = True

    if 'LogVT' in atfix:
        i     = atfix.index('LogVT')
        dummy = atfix.pop(i)
        logvt = abfix.pop(i)
                      
    if len(atfix) > 0:
        print('Modifying the following FIXED abundances:')
        for i in range(len(atfix)):   
            print('  Delta '+atfix[i]+' = %0.3f' % abfix[i])
            
    atfit = []
    atfit = atfit + pfit

    fitz = False
    if 'LogZ' in atfit:
        i     = atfit.index('LogZ')
        dummy = atfit.pop(i)
        fitz  = True
            
    if not (fitz or fixz):
        print('Warning: overall metallicity neither fitted nor fixed')
        print('         Using default LogZ/Z0 = %0.3f' % LOGZDEF)
        fixz = True
            
# Iterate:
    print('Now fitting ',pfit)
    
    initdir = tempfile.mkdtemp(dir=absetup.TMPROOT, prefix='tmpinit')

        
# First bracket the minimum

    logging.info('# %s@%s %s' % (getpass.getuser(), os.uname()[1], os.getcwd()))
    logging.info('# %s %s %s' % (os.uname()[0], os.uname()[2], str(datetime.datetime.now())))

    logging.info('FIT1PAR (%0.3f - %0.3f):' % (SPECRNG[0], SPECRNG[1]))

    logging.info('    Bracketing minimum')
  
    ax = prange[0]
    cx = prange[1]            
    atoms = [] + atfix + atfit
 
    logz, abun = abunvar(abfix, atfit, ax, fitz, met, fixz)
    fa, fna, siga = fitfunc(specobs, stelpar, logz, atoms, abun, logvt=logvt,
                            sigsm=sigsm, cont=cont, initialize=True, initdir=initdir)
 
    logz, abun = abunvar(abfix, atfit, cx, fitz, met, fixz)
    fc, fnc, sigc = fitfunc(specobs, stelpar, logz, atoms, abun, logvt=logvt,
                            sigsm=sigsm, cont=cont, initialize=False, initdir=initdir)
    
    tmpb = (cx - ax)/2.
    bx = ax + tmpb
    logz, abun = abunvar(abfix, atfit, bx, fitz, met, fixz)
    fb, fnb, sigb = fitfunc(specobs, stelpar, logz, atoms, abun, logvt=logvt,
                            sigsm=sigsm, cont=cont, initialize=False, initdir=initdir)
        
    while ((( fb > fa) or (fb > fc)) and (tmpb > tol)):
        tmpb /= 2.
        if (fa < fc):
            bx = ax + tmpb
        else:
            bx = cx - tmpb
        logz, abun = abunvar(abfix, atfit, bx, fitz, met, fixz)
        fb, fnb, sigb = fitfunc(specobs, stelpar, logz, atoms, abun, logvt=logvt,
                                sigsm=sigsm, cont=cont, initialize=False, initdir=initdir)
        
    if (tmpb < tol):
        logging.info('ERROR: No minimum found between %0.3f and %0.3f' % (ax, cx))
        print('** ERROR: No minimum found between %0.3f and %0.3f' % (ax, cx))
        return numpy.nan, numpy.nan, numpy.nan, numpy.nan, numpy.nan
            
    print('Minimum bracketed at %0.3f %0.3f %0.3f' % (ax,bx,cx))
    print('Function values are %0.1f %0.1f %0.1f' % (fa,fb,fc))

    print('Now proceeding with Golden Section Search')    
    
    logging.info('FIT1PAR:')
    logging.info('    Minimum bracketed at %0.3f %0.3f %0.3f' % (ax,bx,cx))
    logging.info('    Function values are %0.1f %0.1f %0.1f' % (fa,fb,fc))
            
# Then apply Golden Section Search            

    logging.info('FIT1PAR:')
    logging.info('    Golden section search for minimum:')

    x0 = ax
    x3 = cx
    
    if abs(cx-bx) > abs(bx-ax):
        x1 = bx
        f1, fn1, sig1 = fb, fnb, sigb
        x2 = bx + C * (cx-bx)
        logz, abun = abunvar(abfix, atfit, x2, fitz, met, fixz)
        f2, fn2, sig2 = fitfunc(specobs, stelpar, logz, atoms, abun, logvt=logvt,
                                sigsm=sigsm, cont=cont, initialize=False, initdir=initdir)    
    else:
        x2 = bx
        f2, fn2, sig2 = fb, fnb, sigb
        x1 = bx - C * (bx-ax)
        logz, abun = abunvar(abfix, atfit, x1, fitz, met, fixz)
        f1, fn1, sig1 = fitfunc(specobs, stelpar, logz, atoms, abun, logvt=logvt,
                                sigsm=sigsm, cont=cont, initialize=False, initdir=initdir)

    while abs(x2 - x1) > tol:
        if (f2 < f1):
            x0 = x1
            x1 = x2
            x2 = R * x1 + C * x3
            f1, fn1, sig1 = f2, fn2, sig2
            logz, abun = abunvar(abfix, atfit, x2, fitz, met, fixz)
            f2, fn2, sig2 = fitfunc(specobs, stelpar, logz, atoms, abun, logvt=logvt,
                                    sigsm=sigsm, cont=cont, initialize=False, initdir=initdir)    
        else:
            x3 = x2
            x2 = x1
            x1 = R * x2 + C * x0
            f2, fn2, sig2 = f1, fn1, sig1
            logz, abun = abunvar(abfix, atfit, x1, fitz, met, fixz)
            f1, fn1, sig1 = fitfunc(specobs, stelpar, logz, atoms, abun, logvt=logvt,
                                    sigsm=sigsm, cont=cont, initialize=False, initdir=initdir)    
            
    if (f1 < f2):
        vfit = x1
        chsq, chsqn  = f1, fn1
        sigbest = sig1
    else:
        vfit = x2
        chsq, chsqn  = f2, fn2
        sigbest = sig2
        
    logging.info('FIT1PAR:')
    logging.info('    Minimum found at %0.3f  (chi-square=%0.1f, sigsm=%0.3f)' % (vfit, chsq, sigbest))
    
    logz, abun = abunvar(abfix, atfit, vfit, fitz, met, fixz)
    f, fn, sig = fitfunc(specobs, stelpar, logz, atoms, abun, logvt=logvt,
                         output='fitabun.txt', 
                         sigsm=sigsm, cont=cont, initialize=False, initdir=initdir)    
    
# Determine error on fit by searching for parameter values where
# chi-square has increased by one. 

    if (calcerr):

# Find upper error bound

        print('Searching for upper error bound')
        logging.info('FIT1PAR:')
        logging.info('    Searching for upper error bound')

        v1 = vfit
        v2 = vfit + 1.
        logz, abun = abunvar(abfix, atfit, v2, fitz, met, fixz)
        f2, fn2, sig2 = fitfunc(specobs, stelpar, logz, atoms, abun, logvt=logvt,
                                sigsm=sigsm, cont=cont, initialize=False, initdir=initdir)    

        while (f2 < chsq+1 and v2 < v1+10):
            v2 += 1
            logz, abun = abunvar(abfix, atfit, v2, fitz, met, fixz)
            f2, fn2, sig2 = fitfunc(specobs, stelpar, logz, atoms, abun, logvt=logvt,
                                    sigsm=sigsm, cont=cont, initialize=False, initdir=initdir)    

        if (v2 >= v1+10):
            print('** Warning: upper bound on error is unconstrained')
            logging.info('FIT1PAR warning: upper bound on error is unconstrained')
        else:        
            while (v2 > v1 + tol):
                vm = (v1 + v2)/2.
                logz, abun = abunvar(abfix, atfit, vm, fitz, met, fixz)
                fm, fnm, sigm = fitfunc(specobs, stelpar, logz, atoms, abun, logvt=logvt,
                                        sigsm=sigsm, cont=cont, initialize=False, initdir=initdir)    
                if (fm < chsq+1):
                    v1 = vm
                else:
                    v2 = vm   
        vmax = v2             
        print('Upper 1-sigma limit is %0.3f ' % vmax)
        logging.info('FIT1PAR:')
        logging.info('    Upper 1-sigma limit is %0.3f' % vmax)
 

# Find lower error bound

        print('Searching for lower error bound')
        logging.info('FIT1PAR:')
        logging.info('    Searching for lower error bound')

        v1 = vfit - 1.
        v2 = vfit
        logz, abun = abunvar(abfix, atfit, v1, fitz, met, fixz)
        f1, fn1, sig1 = fitfunc(specobs, stelpar, logz, atoms, abun, logvt=logvt,
                                sigsm=sigsm, cont=cont, initialize=False, initdir=initdir)    

        while (f1 < chsq+1 and v1 > v2-10):
            v1 -= 1
            logz, abun = abunvar(abfix, atfit, v1, fitz, met, fixz)
            f1, fn1, sig1 = fitfunc(specobs, stelpar, logz, atoms, abun, logvt=logvt,
                                    sigsm=sigsm, cont=cont, initialize=False, initdir=initdir)    
        
        if (v1 <= v2-10):
            print('** Warning: lower bound on error is unconstrained')
            logging.info('FIT1PAR warning: lower bound on error is unconstrained')
        else:
            while (v2 > v1 + tol):
                vm = (v1 + v2)/2.
                logz, abun = abunvar(abfix, atfit, vm, fitz, met, fixz)
                fm, fnm, sigm = fitfunc(specobs, stelpar, logz, atoms, abun, logvt=logvt,
                                        sigsm=sigsm, cont=cont, initialize=False, initdir=initdir)    
                if (fm < chsq+1):
                    v2 = vm   
                else:
                    v1 = vm    
        vmin = v1   
        print('Lower 1-sigma limit is %0.3f ' % vmin)
        logging.info('FIT1PAR:')
        logging.info('    Lower 1-sigma limit is %0.3f' % vmin)
    else:
        vmax = vfit
        vmin = vfit

    logging.info('FIT1PAR results (%0.3f - %0.3f):' % (SPECRNG[0], SPECRNG[1]))
    logging.info('  Best fit = %0.3f +%0.3f -%0.3f' % (vfit, vmax-vfit, vfit-vmin))
    logging.info('  Broadening = %0.3f' % sigbest)
    logging.info('  Chi-square (reduced) = %0.1f (%0.3f)' % (chsq, chsqn))
    
    shutil.rmtree(initdir)
        
    return vfit, sigbest, vmax-vfit, vmin-vfit, chsqn

#   while abs(v1 - v2) > tol:           
#       v = ....
#       atoms, abun = [], []
#       atoms = atoms + atfix + atfit
#       abun  = abun  + abfix + list(numpy.zeros(len(atfit)) + v)     
#       if (fitz): m = 10**v
#       rms = fitfunc(specobs, stelpar, m, atoms, abun)
      

def deltarv(specobs, stelpar, logz, atoms, abun, sigsm=0.10, rvtol=0.01, drvmax=5.0, cont=None):

# Observed spectrum

    lobs, fobs, eobs, wobs  = [], [], [], []
    for pixdat in specobs:
        lam = float(pixdat[0])
        if (lam >= SPECRNG[0] and lam <= SPECRNG[1]):
            lobs.append(lam)
            fobs.append(float(pixdat[1]))
            eobs.append(float(pixdat[2]))
            if (len(pixdat) >3): 
                wobs.append(float(pixdat[3]))
            else:
                wobs.append(1.)
           
    lobs = numpy.array(lobs)
    fobs = numpy.array(fobs)    
    eobs = numpy.array(eobs)
    wobs = numpy.array(wobs)
  
# Synthetic spectrum  
  
    specr = [SPECRNG[0]*(1-drvmax/299792.458)-1, SPECRNG[1]*(1+drvmax/299792.458)+1., SPECRNG[2]]
    abutils.hrd2spec(stelpar,'specTMP.asc', specr, logz, atoms=atoms, \
                     abun=abun, 
                     tmproot=absetup.TMPROOT)
                     
    with open('specTMP.asc','r') as f:                 
        synt = [s.split()[0:2] for s in f if s[0] != '#']

    lsyn, fsyn  = [], []
    for i in range(len(synt)):
        lsyn.append(float(synt[i][0]))
        fsyn.append(float(synt[i][1]))
    lsyn = numpy.array(lsyn)
    fsyn = numpy.array(fsyn)
  
# Find best fitting Radial Velocity offset
# Assume that the initial guess is good  
  
    ax = -drvmax
    cx = drvmax
    bx = 0.
    fb, fnb, fdatb = getchisq(lobs, fobs, eobs, lsyn, fsyn, sigsm, wobs=wobs, cont=cont)

    x0 = ax
    x3 = cx
    
    x1 = bx
    f1 = fb
    x2 = bx + C * (cx-bx)
    lamc = lobs * (1 - x2/299792.458)
    f2, fn2, fdat2 = getchisq(lamc, fobs, eobs, lsyn, fsyn, sigsm, wobs=wobs, cont=cont)


    while abs(x2 - x1) > rvtol:
        if (f2 < f1):
            x0 = x1
            x1 = x2
            x2 = R * x1 + C * x3
            f1 = f2
            lamc = lobs * (1 - x2/299792.458)
            f2, fn2, fdat2 = getchisq(lamc, fobs, eobs, lsyn, fsyn, sigsm, wobs=wobs, cont=cont)
        else:
            x3 = x2
            x2 = x1
            x1 = R * x2 + C * x0
            f2 = f1
            lamc = lobs * (1 - x1/299792.458)
            f1, fn1, fdat1 = getchisq(lamc, fobs, eobs, lsyn, fsyn, sigsm, wobs=wobs, cont=cont)
   
        print('DRV1,2 = %0.3f, %0.3f  CHISQ1,2 = %0.1f, %0.1f' % (x1, x2, f1, f2))
        logging.info('DRV1,2 = %0.3f, %0.3f  CHISQ1,2 = %0.1f, %0.1f' % (x1, x2, f1, f2))
    
    if (f2 < f1): 
        drvmin = x2
    else:
        drvmin = x1
        
    logging.info('DELTARV: Best fit dRV = %0.3f' % drvmin)    
        
    return drvmin    
    

def abparse(id_fit, val_fit, id_fix, val_fix):

    nfit = len(val_fit)
    nfix = len(val_fix)
    logz = LOGZDEF
    logvt = LOGVTDEF
    
# A couple of sanity checks
    
    if (nfit != len(id_fit)):
        print('** Error: number of abundances to fit does not match number of element groups')
        return

    if (nfix != len(id_fix)):
        print('** Error: number of fixed abundances does not match number of element groups')
        return
        

# Now do the real work..
#
    atoms, abun = [], []
    
    for idf,valf in zip(id_fit+id_fix, val_fit+val_fix):
        if isinstance(idf,list):
            for s in idf:
                if (s == 'LogZ'):
                    logz = valf
                elif (idf == 'LogVT'):
                    logvt = valf    
                else:
                    atoms.append(s)
                    abun.append(valf)    
        else:
            if (idf == 'LogZ'):
                logz = valf
            elif (idf == 'LogVT'):
                logvt = valf    
            else:    
                atoms.append(idf)
                abun.append(valf)
                   
    return (atoms, abun, logz, logvt)
    
    
def maxfunc(var, data):
    
    val_fit = var
    id_fit  = data['pfit']
    val_fix = data['vfix']
    id_fix  = data['pfix']

    atoms, abun, logz, logvt = abparse(id_fit, val_fit, id_fix, val_fix)

    specobs = data['specobs']
    stelpar = data['stelpar']
    sigsm   = data['sigsm']
    cont    = data['cont']
    
    chisqmin, chisqnmin, sigbest = fitfunc(specobs, stelpar, logz, atoms, abun, 
              logvt=logvt, cont=cont, 
              sigsm=sigsm, initialize=fitabun.INIT_SPEC, initdir=fitabun.INITDIR)  
    fitabun.INIT_SPEC = False

    return -chisqnmin
    
    
def fitnpar(specobs, stelpar, pfit, pinit, pfix, vfix, tol=0.001, sigsm=[0,1.0], cont=None):
    '''Fit elements or metallicity, applying individual scalings. 
   fitnpar(specobs, stelpar, pfit, pinit, pfix, vfix):
   specobs = observed spectrum: [[lam1,flux1,err1,<w1>], [lam2,flux2,err2,<w2>], ...].
             where the user-specified weights (w1, w2, ..) are optional
   stelpar = stellar parameters, a list of [teff, logg, weight] data.
   pfit    = ['element1','element2', ...]
             where <elementi> are the names of the elements to fit,
             'O', 'Mg', 'Ba', etc.
             The overall (logarithmic) metallicity can be specified 
             as the 'special' element 'LogZ'.
             Another 'special' element is the Log of the microturbulent
             velocity, 'LogVT'
             Elements can be grouped together:
             ['element1', ['element2','element3'], 'element4'..]
   pinit   = [p1, p2, ...] where pi are the initial guesses for
             the individual parameters to fit.
   pfix    = ['element1','element2', ...]
             Elements to keep fixed during the fitting process
   vfix    = [a1, a2, ...]
             Abundances for elements to keep fixed.
   cont    = [[lam1, lam2, ...], [c1, c2, ...]]
             User-specified continuum sampling points (c1=1)
   Examples:    
       fitnpar(gcspec, gcstelpar, ['Ba', 'O','Mg','Ca','Ti'], \
                                  [0.0, 0.4, 0.4, 0.4, 0.4)
       fitnpar(gcspec, gcstelpar, ['LogZ', 'Mg'], [-2.0, 0.4])
       fitnpar(gcspec, gcstelpar, ['LogZ', ['O','Mg','Ti'], 'Ba'], \
                                  [-2.0, 0.4, 0.0]'''

    logging.info('# %s@%s %s' % (getpass.getuser(), os.uname()[1], os.getcwd()))
    logging.info('# %s %s %s' % (os.uname()[0], os.uname()[2], str(datetime.datetime.now())))
    logging.info('FITNPAR (%0.3f - %0.3f):' % (SPECRNG[0], SPECRNG[1]))
    logging.info('    Searching for minimum')

 # As in fit1par, no specification of overall scaling (either fixed value
 # or to fit) will be taken to mean Log Z/Z0 = 0.
         
    simscl = 0.2
    scale = [simscl for dummy in pinit]
    xtolerance = simscl * tol
    data = {'specobs': specobs, 'stelpar': stelpar, 'pfit': pfit, 
            'pfix': pfix, 'vfix': vfix, 'sigsm': sigsm, 'cont': cont}
    fitabun.INIT_SPEC = True
    fitabun.INITDIR = tempfile.mkdtemp(dir=absetup.TMPROOT, prefix='tmpinit')
    vfit, chisqnmin, niter = amoeba.amoeba(pinit, scale, maxfunc, data=data, ftolerance=-1., xtolerance=xtolerance)
#    vfit, chisqnmin, niter = pinit, 0.99, 100

    atoms, abun, logz, logvt = abparse(pfit, vfit, pfix, vfix)
    f, fn, sigbest = fitfunc(specobs, stelpar, logz, atoms, abun, 
                             output='fitabun.txt', sigsm=sigsm, cont=cont, logvt=logvt, 
                             initialize=False, initdir=fitabun.INITDIR)    
      
    logging.info('FITNPAR (%0.3f - %0.3f):' % (SPECRNG[0], SPECRNG[1]))

    logging.info('    Final Fit:')

    for pp,a in zip(pfit, vfit):
        logs = '     '
        if isinstance(pp, list):
            for s in pp:
                logs = logs + s + ' '
        else:
            logs = logs + pp + ' '
    
        logging.info('%s: %+0.3f' % (logs, a))

    logging.info('    Broadening = %0.3f' % sigbest)
    logging.info('    Reduced chi-square = %0.3f' % -chisqnmin)
    logging.info('    Number of iterations = %i' % niter)
    
    shutil.rmtree(fitabun.INITDIR)
        
    return vfit
 
 
def fitabun(specobs, stelpar, pfit):
    print('Hello')
    
    
