#!/usr/bin/env python
#
# Copyright (C) 2017 Michael Janssen
#
# This library is free software; you can redistribute it and/or modify it
# under the terms of the GNU Library General Public License as published by
# the Free Software Foundation; either version 2 of the License, or (at your
# option) any later version.
#
# This library is distributed in the hope that it will be useful, but WITHOUT
# ANY WARRANTY; without even the implied warranty of MERCHANTABILITY or
# FITNESS FOR A PARTICULAR PURPOSE.  See the GNU Library General Public
# License for more details.
#
# You should have received a copy of the GNU Library General Public License
# along with this library; if not, write to the Free Software Foundation,
# Inc., 675 Massachusetts Ave, Cambridge, MA 02139, USA.
#
"""
Additional CASA utils for interactive use around the rPICARD pipeline.
Add pipeline to PYTHONPATH and import within the CASA interactive shell.
"""
import os
import sys
import shutil
from glob import glob
from collections import namedtuple
import numpy as np
import pipe_modules.auxiliary   as auxiliary
import pipe_modules.calibration as calibration
from pipe_modules.default_casa_imports import *



####################################################### flagging functions ######################################################

def flag_sbtable(caltb='calibration_tables/ff_sb.t'):
    """
    Interactive flagging based on delay outliers of single-band fringefit cal table.
    Opens windows for RPC and LCP for each antenna to cycle through delay vs time for each spw.
    """
    ants   = get_caltb_antennas(caltb)
    fields = get_caltb_fields(caltb)
    for ant in ants:
        for field in fields:
            inp = input23('\nPress Enter to plot or write anything and press Enter to exit here\n>')
            if inp:
                extend_flags_phasedelayrate(caltb)
                return
            plotcal_interactive(caltb, 'time', 'delay', antenna=ant, field=field, iteration='spw')
            print('\n -- Showing ' + str(ant) + ', ' + str(field) + ' -- \n')
    extend_flags_phasedelayrate(caltb)


def flag_mbtable(caltb='calibration_tables/ff_mb_cal.t', pols=['R','L']):
    """
    Interactive flagging based on delay outliers of multi-band fringe-fit table.
    Opens windows for RPC and LCP to cycle through delay vs time for each antenna.
    """
    fields = get_caltb_fields(caltb)
    for pol in pols:
        for field in fields:
            inp = input23('\nPress Enter to plot or write anything and press Enter to exit here\n>')
            if inp:
                extend_flags_phasedelayrate(caltb)
                return
            plotcal_interactive(caltb, 'time', 'delay', pol, field=field, iteration='antenna')
            print('\n -- Showing ' + str(pol) + ', ' + str(field) + ' -- \n')
    extend_flags_phasedelayrate(caltb)


def flag_bptable(caltb='calibration_tables/bpass_scalar.t', pols=['R','L']):
    """
    Interactive flagging based on amplitude outliers of bandpass table.
    Opens windows for RPC and LCP to cycle through amplitude vs frequency for each antenna (stacking solutions for all scans).
    """
    for pol in pols:
        inp = input23('\nPress Enter to plot or write anything and press Enter to exit here\n>')
        if inp:
            return
        plotcal_interactive(caltb, 'freq', 'amp', pol, iteration='antenna')
        print('\n -- Showing ' + str(pol) + ' -- \n')



####################################################### imaging functions ######################################################

def imager(field,
           uvdata                               = 'VLBI.ms.avg',
           startsolint                          = 10.,
           smallest_sc_solint                   = 0,
           solint_denominator                   = 2.,
           startmod_sc                          = 0,
           timeavg                              = 0,
           phase_only_selfcal                   = [300,30,0],
           N_sciter                             = 3,
           stop_at_DRdrop                       = False,
           niter0                               = 1000,
           cleaniterations                      = 'constant',
           cellsize                             = '',
           imsize                               = 0,
           robust                               = 0.5,
           threshold                            = 'auto',
           imagefiles                           = 'imfiles',
           calibrationfiles                     = 'selfcaltables',
           diagnostic_plots                     = 'image_diagplots',
           output_ID                            = '',
           mask                                 = '',
           flagtable                            = '',
           feeds_polarization                   = 'circular',
           minsnr_sc_phase                      = 1.0,
           minsnr_sc_amp                        = 1.5,
           combine_sc_data                      = 'spw',
           flag_last_sc                         = False,
           uvrange_sc                           = '',
           goautomated                          = False,
           parang                               = False,
           solnorm                              = True,
           amp_selfcal_ants                     = '',
           station_constraints                  = {},
           station_weights                      = {},
           uvzero_mod                           = '',
           switching_pol_stations               = 'JC',
           save_fits                            = 'image',
           gain                                 = 0.05,
           nterms                               = 1,
           multiscale                           = [0, 2, 6],
           stokes                               = 'pseudoI',
           interactive                          = True,
           nsigma                               = 0.0,
           cycleniter                           = -1,
           cyclefactor                          = 1.0,
           minpsffraction                       = 0.05,
           maxpsffraction                       = 0.8,
           smallscalebias                       = 0.4,
           sidelobethreshold                    = 1.35,
           noisethreshold                       = 8.8,
           lownoisethreshold                    = 6.9,
           negativethreshold                    = 0.0,
           smoothfactor                         = 1.0,
           minbeamfrac                          = 0.008,
           cutthreshold                         = 0.01,
           growiterations                       = 75,
           dogrowprune                          = True,
           minpercentchange                     = -1.0,
           usemask                              = 'auto-multithresh'
          ):
    """
    Use CASA tclean to image data. Must start CASA in mpi mode: $mpicasa -n <num cores> path_to_casa/bin/casa <casa_options>.
    Documentation: https://casa.nrao.edu/casadocs-devel/stable/global-task-list/task_tclean/about.
    Uses cycles of CLEAN + phase+amplitude phase calibrations.
    Splits off field from ms and creates a new MS <ms.field> to work with (unless it already exists).
    Input:
      - field: Name of the source to be imaged. Example: field='3C279'.

      - uvdata: Data from field that is to be imaged.
                If a measurement set called <uvdata>.<field> already exists, it will be used for imaging. Else it will be
                created from the specified uvdata, which can either be a MS or a uvfits file.

      - startsolint: The starting solution interval for the amplitude calibration in hours. After each self-calibration iteration
                     this solint will be reduced by a factor given by the solint_denominator parameter.
                     Phase self-calibration is determined by the smallest_sc_solint and phase_only_selfcal parameters.

      - smallest_sc_solint: The smallest allowed self-calibration solution interval in seconds.
                            The amplitude self-calibration will stop when this timescale is reached unless specified otherwise
                            (see N_sciter parameter).
                            The phase self-calibration is controlled by the phase_only_selfcal parameter.
                            > If =0, the smallest self-calibration timescale is set to the data integration time (correlator
                            accumulation period unless the data has been averaged in time).
                            > If =10, the smallest self-calibration timescale will be 10 seconds.

      - solint_denominator: The factor by which the amplitude self-calibration timescale is lowered for each iteration.
                            Can also affect the increase of the number of clean iterations (niter) which set how deep to clean,
                            see cleaniterations parameter.
                            > If =2, the self-calibration timescale is lowered by a factor of 2 in each iteration.
                              E.g., 10h, 5h, 2.5h, ...
                            > If =1.5, the self-calibration timescale is lowered by a factor of 1.5 in each iteration.
                              E.g., 10h, 6.67h, 4.44h, ...

      - startmod_sc: Similar to the Difmap <startmod> command, this imager can first align the phases by self-calibrating to a
                     point source. Can set the timescale [in seconds] for this phase self-calibration as input parameter.
                     > If =10, phase solutions will be obtained every 10 seconds.
                     > If =9999999, a single phase solution will be obtained for every scan.
                     > If =0, a phase solution will be obtained for the smallest_sc_solint.
                     > If =False, no phase self-calibration to a point source is done.
                       Setting it to False should be done when you want to continue imaging a dataset, because it will
                       automatically disable time-averaging (see timeavg parameter) as well.

      - timeavg: If >0, average the data in time over the specified interval [in seconds] after startmod_sc.
                 This is not done if startmod_sc=False.

      - phase_only_selfcal: List of solution intervals in seconds for first phase-only self-calibration steps.
                            This is done after startmod_sc and timeavg.
                            The smallest number in this list is used for the phase self-calibration solution interval of all
                            subsequent combined phase + amplitude self-calibration rounds. Numbers smaller than smallest_sc_solint
                            integration time are set to smallest_sc_solint.
                            > If =[120,60,30], three rounds of phase-only self-calibration rounds are done first, with solution
                              intervals of 120s, 60s, and then 30s. Starting at the fourth iteration, the usual
                              phase + amplitude self-calibration rounds are done with a 30s phase self-calibration solution
                              interval).
                            > If =[45,0], phase-only self-calibration is first done for 45s and then for the data integration
                              time. All phase + amplitude self-calibration rounds are then done with the phase self-calibration
                              solution intervals set to the smallest_sc_solint.
                            > If =[60], a first phase-only self calibration is done for 60s and all subsequent phase + amplitude
                              self-calibration rounds will use 60s for the phase self-calibration solution interval.
                            > If =[], phase + amplitude self-calibration is done for every step. The phase self-calibration
                              solution interval is set to the smallest_sc_solint.

      - N_sciter: Number of times amplitude+phase calibration is performed. After each iteration, the model will be build up from
                  scratch again.
                    > If =0, N_sciter is determined such that the last self-cal solution interval will be equal to the
                      smallest_sc_solint.
                    > If =-1, only the phase_only_selfcal steps will be performed.
                    > If =-2, the self-calibration step will be skipped entirely:
                      Set N_sciter=-2 to only image without self-calibration.
                    > If =-3, the self-calibration will be skipped as when N_sciter=-2.
                      Additionally, clearcal, startmod_sc, and timeavg will be skipped:
                      Set N_sciter=-3 to make an image from pre-self-calibrated data.

      - stop_at_DRdrop: Exit the code when the dynamic range (DR) of the image from the current iteration drops below the maximum
                        DR of all images. If this happens, the idea is to re-run the imaging script with N_sciter set to the
                        last good iteration. Most likely, the data is not good enough for very deep self-calibration.
                        > If =False, disable this feature.
                        > If =0.8, stop imaging when the DR of the current image is less tahn 80% of the image with the highest
                          DR (out of all images from all iterations).

      - niter0: Maximum number of CLEAN iterations if threshold is not reached.
                Starting at niter0, the niter (see <help(tclean)> in CASA) parameter will be increased after each self-cal
                iteration depending on the cleaniterations parameter.

      - cleaniterations: Setting to control the increase of niter with self-cal iterations. At the first iteration, niter is set
                         to the niter0 parameter.
                         > If ='constant', niter will stay at niter0 for each iteration.
                         > If =x, where x is some number, niter0 will be used for the phase self-cal steps and niter0*x will be
                           used after the first amplitude self-cal.
                         > If ='shallow', niter will be increased by niter0 for each iteration.
                         > If ='deep', niter will be increased by the solint_denominator factor for each iteration. This is not
                           recommended because cleaning will likely go too deep.

      - cellsize: Size of a single cell. Example: cellsize='1.0e-05arcsec'.
                  > If ='', CASA will try to find a suitable cellsize for you.

      - imsize: Number of pixels.
                > If =0, CASA will try to find a suitable imsize for you.

      - robust: Briggs robust weighting parameter. Can be set to a single number to be used for each imaging cycle.
                Or a list [a,b] can be given to smoothly go from robust=a to robust=b in the imaging cycles.
                For N cycles, robust = a, a+(b-a)/N, a+2*(b-a)/N, a+3*(b-a)/N, ..., b will be used.
                So the final image will be made with robust=b, while self-calibration solutions can be obtained from other
                weighting schemes.
                > If =-2, uniform weighting will be used.
                > If =+2, natural weighting will be used.
                > If =[0.5,1.5], the weighting will change from 0.5 initially to 1.5 in the end.
                > If =[1.0,-2], the weighting will change from 1.0 initially to -2.0 in the end.

      - threshold: Stopping threshold [Jy], see  <help(tclean)> in CASA. Example: threshold='1mJy'.
                   > If ='auto', CASA will try to estimate threshold as the expected point source sensitivity of the data.

      - imagefiles: Folder where image files will be created.

      - calibrationfiles: Folder where calibraton tables for self-cal will be written to.

      - diagnostic_plots: Folder where diagnostics plots (images, gain solutions, and data+model phases+amplitudes) will be
                          written to.
                          Set to False to not generate any plots.

      - output_ID: Can add a string to imagefiles, calibrationfiles, diagnostic_plots, the name of the MS created, and the name
                   of the final fits file written. Useful for batch processing of different datasets.

      - mask: Imaging mask to specify regions where emission is to be expected (see  <help(tclean)> in CASA).
              > If ='', can draw masks in interactive mode, which will be kept through self-cal cycles.
              > Else, can give a filename here. A wildcard (*) can be used to have different masks at different iterations.
                If mask='ANG.mask_*' and the files 'ANG.mask_0', 'ANG.mask_1', 'ANG.mask_5', and 'ANG.mask_8' are available,
                'ANG.mask_0' is used for iteration 0, 'ANG.mask_1' is used for iterations 1,2,3, and 5, 'ANG.mask_5' is used for
                iteration 5,6, and 7, and 'ANG.mask_8' is used for iteration 8 and all other remaining iterations.
              > Note that if usemask='auto-multithresh', CASA will try to automatically determine a suitable image mask and any
                file given here will be ignored.

      - flagtable: File with CASA flagging instructions that are applied before starting to image.

      - feeds_polarization: Polarization of antenna feeds.
                            > If ='circular', will use R and L.
                            > If ='linear', will use X and Y.

      - minsnr_sc_phase: Minimum signal to noise ratio for phase self-calibration solutions.

      - minsnr_sc_amp: Minimum signal to noise ratio for amplitude self-calibration solutions.

      - combine_sc_data: Which data to combine for self-calibration solutions.
                         > If ='spw', will use the full bandwidth for self-calibration.
                         > If ='', will not combine any data axis (need high SNR in each spw for this to work).

      - uvrange_sc: uvrange parameter for amplitude self-calibration. Selects range for which data is used to obtain gain
                    solutions. Default is to use all data. Examples taken from <help(gaincal)> in CASA:
                    > If ='0~1000klambda', select data from 0-1000 kilo-lambda
                    > If ='>4klambda', select data with uvrange greater than 4 kilo lambda.

      - flag_last_sc: By default, flagged solutions are removed and the data will be calibrated by interpolating only over good
                      solutions (with SNR>minsnr_sc). This parameter can be set to flag data for which no good self-calibration
                      solutions are obtained in the last round of self-calibration.
                      >If =True, do not remove flagged solutions from last self-cal tables.
                      >If =False, remove flagged solutions from the last self-cal table.

      - goautomated: If True, the user can overwrite <interactive> to False at some point during the run.
                     Once the first mask is available, the user will be asked for terminal input after each self-cal iteration
                     cycle, until interactive=False has been specified.
                     The idea is that masks can be drawn interactively for the first few imaging rounds until no new source
                     features are appearing. Then the user can continue with interactive=False.

      - parang: If True, perform a parallactic angle correction when self-calibrating. Set to False for data from rPICARD, where
                this correction has already been applied to the data for the created .avg measurement set.

      - solnorm: If True, normalize the average amplitude self-calibration solutions to unity.

      - amp_selfcal_ants: String containing a list of comma-separated antenna names or indices for which amplitude
                          self-calibration solutions are to be obtained and applied.
                          Note that solutions will only be obtained when at least five baselines can be formed from the list of
                          stations provided.
                          > If ='BR,LA,MK,FD', only the amplitudes of the BR, LA, MK, and FD stations will be adjusted.
                          > If ='', the amplitudes of all stations are allowed to be adjusted with self-calibration.

      - station_constraints: Dictionary with constraints for the station-based amplitude self-calibration gains.
                             If any constraints are given, amplitude self-calibration solutions are always obtained for each scan
                             separately (no combination along scan axis anymore), i.e. solutions on very long timescales,
                             spanning multiple scans are no longer possible.
                             The syntax is {'station_name1':{scanID1: x1, scanID2: x2, ...}, 'station_name2':..., ...}
                             with x1, x2, ... percentage values over which gains are allowed to vary per station and scan.
                             The scanIDs must be integers, They are optional and can provide different constraints for different
                             scans.
                             An incomplete dictionary can be given; the constraints are only enforced for the stations mentioned.
                             Note that the self-calibrated amplitudes are likely to differ from the model visiilties if
                             station_constraints are enforced.
                             > If ={}, no constraints are in place and solutions can span multiple scans.
                             > If ={'BR': 20, 'LA': 10, 'FD': 150}, the BR gains are allowed to vary within 20%, LA gains within
                               10%, and the gains for the FD station are allowed to vary by 150%.

      - station_weights: Dictionary with antenna weights.
                         This parameter should be set with caution as it will mess with the statistics of the data.
                         > If ={}, the antenna weights remain unmodified.
                         > If ={'BR': 10.0, 'LA': 0.3, 'FD':0.5}, the visibility weights for baselines including the BR, LA, FD
                           stations will be multiplied by 10, 0.3, 0.5 respectively.

      - uvzero_mod: Can give a string of 'uvzero-distance,uvzero-flux' to keep the CLEAN model fixed at a set zero-spacing flux.
                    All RR and LL model visibilities at a u-v distance smaller than the set uvzero-distance [in meters] will be
                    set to have a uvzero-flux amplitude [in Jy] and zero phase.
                    > If ='', this feature is disabled.
                    > If ='1000,1.5', the flux will be fixed at 1.5Jy for all baselines less than 1000m apart.

      - switching_pol_stations: List of stations that sometimes observe in RCP and sometimes in LCP.

      - save_fits: Which file to export as fits file.
                   > If ='image', will export '.image' file.
                   > If ='model', will export '.model' file.

      * All other parameters (including multi-term and multi-scale): See <help(tclean)> in CASA.
      * usemask: 'user' or 'auto-multithresh' (may need to tweak parameters to get satisfactory results).

    Output:
      - Calibrated <ms.field> and <field.fits> files.

      - CASA calibration tables from self-calibration steps in the <calibrationfiles> folder.

      - tclean imaging files in the the <imagefiles> folder (.image, .mask, .psf, .residual,... files).

    Example usage (see also the rPICARD documentation):
      - Make a script <goimage.py> with these contents:
          import interactive_utils
          interactive_utils.imager('3C279', timeavg=60, goautomated=True)
        and call that script with:
          $mpicasa -n 3 <path/to/your/casa/installation/>/bin/casa -c goimage.py 2>/dev/null
    """
    mytb       = casac.table()
    myim       = casac.imager()
    myms       = casac.ms()
    output_ID  = output_ID.replace('/', '')
    _ms        = uvdata + '.' + field + '.ms' + output_ID
    if not os.path.isdir(_ms):
        if not os.path.isdir(uvdata):
            ms = uvdata + '.ms.tmp'
            if os.path.isdir(ms):
                shutil.rmtree(ms, ignore_errors=True)
            tasks.importuvfits(fitsfile=uvdata, vis=ms)
            remove_tmp_ms = True
        else:
            ms            = uvdata
            remove_tmp_ms = False
        mytb.open(ms)
        all_cols = mytb.colnames()
        if 'CORRECTED_DATA' in all_cols:
            col = 'corrected'
        elif 'DATA' in all_cols:
            col = 'data'
        else:
            raise ValueError('No data found in the measurement set {0}.\n'.format(ms))
        mytb.close()
        if os.path.isdir(_ms+'.flagversions'):
            shutil.rmtree(_ms+'.flagversions', ignore_errors=True)
        tasks.split(vis=ms,outputvis=_ms,keepmms=False,field=field,spw="",scan="",antenna="",correlation="",timerange="",
                    intent="",array="",uvrange="",observation="",feed="",datacolumn=col,keepflags=False,width=1,timebin="0s",
                    combine=""
                   )
        if remove_tmp_ms:
            shutil.rmtree(ms, ignore_errors=True)
    myms.open(_ms)
    if not smallest_sc_solint:
        scansum   = myms.getscansummary()
        ssum_fkey = list(scansum.keys())[0]
        ac_period = round(float(scansum[ssum_fkey]['0']['IntegrationTime']), 6)
    else:
        ac_period = smallest_sc_solint
    myms.close()
    cumulative_sc = []
    cumulative_ap = []
    imfiles       = imagefiles + output_ID + '/'
    calfiles      = calibrationfiles + output_ID + '/'
    if diagnostic_plots:
        diagplots = diagnostic_plots + output_ID + '/'
        os.system('rm -r {0}*'.format(diagplots))
        os.system('mkdir {0}'.format(diagplots))
    else:
        diagplots = ''
    _inp_p, _ms_m = mock_rpicard_inps(switching_pol_stations, _ms)
    os.system('rm -r {0}*'.format(imfiles))
    os.system('rm -r {0}*'.format(calfiles))
    os.system('mkdir {0}'.format(imfiles))
    os.system('mkdir {0}'.format(calfiles))
    if flagtable:
        tasks.flagcmd(vis=_ms, inpmode='list', inpfile=flagtable)
    if N_sciter!=-3:
        tasks.clearcal(vis=_ms,field="",spw="",intent="",addmodel=False)
        if not isinstance(startmod_sc, bool):
            startmod_caltb = calfiles+'startmod'
            gaincal_calib(_ms, _ms_m, caltable=startmod_caltb, field=field, solint=str(max(startmod_sc,ac_period)),
                          combine='', minblperant=3, minsnr=minsnr_sc_phase, gaintype="G", calmode="p",
                          parang=parang, smodel=[1,0,0,0]
                         )
            calibration.flag_below_SNR(_inp_p, _ms_m, startmod_caltb, minsnr_sc_phase)
            calibration.interpolate_over_flags(startmod_caltb, param='CPARAM')
            calibration.remove_flagged_solutions(_inp_p, _ms_m, startmod_caltb)
            applycal_calib(_ms, _ms_m, gaintable=startmod_caltb, calwt=False, field=field, parang=parang)
            if timeavg:
                tmpms = _ms+'._tmp_'
                if os.path.isdir(tmpms):
                    shutil.rmtree(tmpms, ignore_errors=True)
                shutil.move(_ms, tmpms)
                if os.path.isdir(_ms+'.flagversions'):
                    shutil.rmtree(_ms+'.flagversions', ignore_errors=True)
                tasks.mstransform(vis=tmpms, outputvis=_ms, datacolumn='corrected', timeaverage=True, timebin=str(timeavg)+'s')
                shutil.rmtree(tmpms, ignore_errors=True)
            else:
                cumulative_sc.append(startmod_caltb)
    nsigma    = float(nsigma)
    ac_period = max(ac_period, timeavg)
    multimask = False
    if mask:
        if '*' in mask:
            masklist  = []
            multimask = glob(mask)
            auxiliary.natural_sort_Ned_Batchelder(multimask)
            if not multimask:
                raise ValueError('No mask files found for ' + str(mask))
        else:
            _mask     = mask
    if not cellsize or not imsize:
        myim.open(_ms)
        myim.selectvis(field=field)
        advice = myim.advise()
        myim.close()
    if not cellsize:
        _cellsize = str( float(advice[2]['value']) / 2. ) + advice[2]['unit']
    else:
        _cellsize = cellsize
    if not imsize:
        _imsize = int(advice[1]) * 3
    else:
        _imsize = imsize
    if threshold=='auto':
        try:
            _threshold = tasks.apparentsens(vis=_ms, field=field, imsize=_imsize, cell=_cellsize, weighting='briggs')['effSens']
        except AttributeError:
            myim.open(_ms)
            myim.selectvis(field=field)
            _threshold = myim.apparentsens()[1]
            myim.close()
    else:
        _threshold = threshold
    if nterms > 1:
        _deconvolver = 'mtmfs'
    elif multiscale:
        _deconvolver = 'multiscale'
    else:
        _deconvolver = 'hogbom'
    if station_constraints:
        ampsc_combine = ''
    else:
        ampsc_combine = 'scan'
    if 'scan' in combine_sc_data:
        ampsc_combine = combine_sc_data
    else:
        ampsc_combine = combine_sc_data + ',scan'
    if station_weights:
        calibration.mod_MS_antwt(_ms, station_weights)
    if uvzero_mod:
        uvzero_mod   = uvzero_mod.split(',')
        uvzero_maxuv = float(uvzero_mod[0])
        uvzero_flux  = float(uvzero_mod[1])
    else:
        uvzero_maxuv = None
        uvzero_flux  = None
    ssolint            = startsolint * 3600.
    _interactive       = interactive
    doselfcal          = True
    ph_only_steps      = len(phase_only_selfcal)
    solint_denominator = float(solint_denominator)
    if not N_sciter:
        _N_sciter = int(np.log(ssolint/ac_period)/np.log(solint_denominator)) + 5
    elif N_sciter==-1:
        _N_sciter = 0
    elif N_sciter<-1:
        _N_sciter = 1
        doselfcal = False
    else:
        _N_sciter = N_sciter
    if not ph_only_steps:
        doamp_selfcal = True
        smallest_phs  = str(ac_period)
    else:
        doamp_selfcal      = False
        phase_only_selfcal = [max(ph_os, ac_period) for ph_os in phase_only_selfcal]
        smallest_phs       = str(min(phase_only_selfcal))
    if doselfcal:
        _N_sciter += ph_only_steps
    amp_gaintype  = 'G'
    ph_gaintype   = 'G'
    solints_array = []
    niter_array   = []
    float_citer   = False
    for i in np.arange(_N_sciter):
        if i>=ph_only_steps:
            thisiter      = i - ph_only_steps
            adjust_factor = solint_denominator**int(thisiter)
            if cleaniterations == 'shallow':
                deepness_incr = thisiter + 1
            elif cleaniterations == 'deep':
                deepness_incr = adjust_factor
            elif cleaniterations == 'constant':
                deepness_incr = 1
            else:
                if thisiter > 0:
                    deepness_incr = float(cleaniterations)
                else:
                    deepness_incr = 1
                float_citer   = True
            niter_array.append(min(int(niter0 * deepness_incr), 500000000))
            _solint_float = ssolint / adjust_factor
        else:
            niter_array.append(int(min(niter0, 500000000)))
            _solint_float = ssolint
        if multimask:
            thisiter_mask = mask.replace('*', str(i))
            if thisiter_mask in multimask:
                masklist.append(thisiter_mask)
            else:
                try:
                    masklist.append(masklist[-1])
                except IndexError:
                    masklist.append(multimask[0])
        if _solint_float <= ac_period:
            _solint_float = str(ac_period)
            solints_array.append(_solint_float)
            break
        else:
            solints_array.append(_solint_float)
    N_cycles = len(niter_array)
    if isinstance(robust, list):
        robust_array = np.linspace(robust[0], robust[-1], N_cycles)
        robust_array = [round(briggs_param,2) for briggs_param in robust_array]
    else:
        robust_array = [robust] * N_cycles
    last_DR = -1
    for i, _niter in enumerate(niter_array):
        _imagename     = imfiles+str(i)+'.im'
        _caltabnam_ph  = calfiles+str(i)+'.ph'
        _caltabnam_amp = calfiles+str(i)+'.amp'
        _solint        = str(solints_array[i])+'s'
        _robust        = robust_array[i]
        if not mask:
            if i > 0:
                _mask = 'imfiles/'+str(i-1)+'.im.mask'
            else:
                _mask = []
        elif multimask:
            _mask = masklist[i]
        try:
            if goautomated and os.path.isdir(_mask) and _interactive:
                gauto = input23('\n\nEnter anything to set interactive=False. Enter nothing to continue interactively.\n')
                if gauto:
                    _interactive = False
        except TypeError:
            pass
        tasks.delmod(vis=_ms,otf=True,field='',scr=True)
        casa_tclean(_ms, field, _imagename, _imsize, _cellsize, stokes, _deconvolver, multiscale, nterms, smallscalebias,
                    _robust, _niter, gain, _threshold, nsigma, cycleniter, cyclefactor, minpsffraction, maxpsffraction,
                    _interactive, usemask, _mask, sidelobethreshold, noisethreshold, lownoisethreshold, negativethreshold,
                    smoothfactor, minbeamfrac, cutthreshold, growiterations, dogrowprune, minpercentchange
                   )
        if uvzero_flux:
            calibration.set_uvzero_modelflux(_ms, uvzero_maxuv, uvzero_flux)
        this_DR = plot_images_nterms(_imagename, diagplots+str(i), diagnostic_plots, nterms)
        if stop_at_DRdrop:
            if last_DR>0 and this_DR/last_DR < stop_at_DRdrop:
                print ('\nStopping at iteration {0} as the DR dropped by too much.\n'.format(str(i)))
                return
            if this_DR > last_DR:
                last_DR = this_DR
        if doselfcal:
            if not doamp_selfcal:
                try:
                    _solint  = str(phase_only_selfcal[i])
                    phsolint = _solint + 's'
                    progress = '\n -- sc phase-only iteration {0}/{1} with a {2} solint -- \n'.format(str(i+1), str(N_cycles),
                                                                                                      str(phsolint)
                                                                                                     )
                except IndexError:
                    doamp_selfcal = True
            if doamp_selfcal:
                progress = '\n -- sc iteration {0}/{1} with a {2} solint -- \n'.format(str(i+1), str(N_cycles), str(_solint))
                phsolint = smallest_phs
            print (progress)
            gaincal_calib(_ms, _ms_m, caltable=_caltabnam_ph, field=field, solint=phsolint, combine=combine_sc_data,
                          minblperant=3, minsnr=minsnr_sc_phase, gaintype=ph_gaintype, calmode="p", gaintable=cumulative_sc,
                          parang=parang
                         )
            ph_gaintype = 'T'
            if not flag_last_sc or i!=N_cycles-1:
                calibration.flag_below_SNR(_inp_p, _ms_m, _caltabnam_ph, minsnr_sc_phase)
                calibration.interpolate_over_flags(_caltabnam_ph, param='CPARAM')
                calibration.remove_flagged_solutions(_inp_p, _ms_m, _caltabnam_ph)
            cumulative_sc.append(_caltabnam_ph)
            if doamp_selfcal:
                gaincal_calib(_ms, _ms_m, caltable=_caltabnam_amp, field=field, solint=_solint, minblperant=5,
                              minsnr=minsnr_sc_amp, gaintype=amp_gaintype, calmode="a", gaintable=cumulative_sc,
                              parang=parang, combine=ampsc_combine, uvrange=uvrange_sc, antenna=amp_selfcal_ants,
                              solnorm=solnorm
                             )
                amp_gaintype = 'T'
                if not flag_last_sc or i!=N_cycles-1:
                    calibration.flag_below_SNR(_inp_p, _ms_m, _caltabnam_amp, minsnr_sc_amp)
                    calibration.interpolate_over_flags(_caltabnam_amp, param='CPARAM')
                    calibration.remove_flagged_solutions(_inp_p, _ms_m, _caltabnam_amp)
                _, cal_ants = get_caltb_antennas(_caltabnam_amp, True)
                if amp_selfcal_ants:
                    amp_selfcal_IDs = []
                    for ant in amp_selfcal_ants.split(','):
                        if ant in cal_ants:
                            amp_selfcal_IDs.append(np.where(cal_ants==ant)[0][0])
                        else:
                            print('Warning: '+ant+' in amp_selfcal_ants input parameter is not in the data.')
                    calibration.unity_amplitude_based_on_ant(_caltabnam_amp, amp_selfcal_IDs, True)
                if station_constraints:
                    calibration.constrain_gains(_caltabnam_amp, station_constraints, cal_ants, ac_period, cumulative_ap)
                cumulative_sc.append(_caltabnam_amp)
                cumulative_ap.append(_caltabnam_amp)
            applycal_calib(_ms, _ms_m, gaintable=cumulative_sc, calwt=[True]*len(cumulative_sc), field=field, parang=parang)
            if station_weights:
                calibration.mod_MS_antwt(_ms, station_weights)
            if diagnostic_plots:
                plotall_diagnostics(i, diagplots, _ms, _caltabnam_amp, _caltabnam_ph, feeds_polarization)
            if i==N_cycles-1:
                _imagename = imfiles+str(i+1)+'.im'
                if cleaniterations == 'shallow':
                    _niter += niter0
                elif cleaniterations == 'deep':
                    _niter *= solint_denominator
                elif float_citer and _niter == niter0:
                    _niter *= float(cleaniterations)
                _niter = int(_niter)
                tasks.delmod(vis=_ms,otf=True,field='',scr=True)
                casa_tclean(_ms, field, _imagename, _imsize, _cellsize, stokes, _deconvolver, multiscale, nterms, smallscalebias,
                            _robust, _niter, gain, _threshold, nsigma, cycleniter, cyclefactor, minpsffraction, maxpsffraction,
                            _interactive, usemask, _mask, sidelobethreshold, noisethreshold, lownoisethreshold,
                            negativethreshold, smoothfactor, minbeamfrac, cutthreshold, growiterations, dogrowprune,
                            minpercentchange
                           )
                plot_images_nterms(_imagename, diagplots+str(i+1), diagnostic_plots, nterms)
                print('\n -- stopping here as the smallest self-cal solint has been reached -- \n')
                break
    if nterms>1:
        for tt in np.arange(nterms):
            tasks.exportfits(imagename=_imagename+'.'+save_fits+'.tt'+str(tt),
                             fitsimage=field+'.tt' + str(tt) + '.fits' + output_ID, overwrite=True
                            )
    else:
        tasks.exportfits(imagename=_imagename+'.'+save_fits, fitsimage=field+'.fits'+output_ID, overwrite=True)


def plotall_diagnostics(iteration, plotfolder, vis, amp_caltable, phase_caltable, feeds_polarization):
    """
    Can be invoked after each self-calibration iteration of the imager.
    Makes plots in plotfolder/iteration/:
      - amplitude vs uv-distance after latest self-calibration.
      - phase vs uv-distance after latest self-calibration.
      - RCP and LCP ampltude gains of each station.
      - RCP and LCP phase gains of each station.
    """
    if feeds_polarization=='circular':
        pol_basis = ['R', 'L']
    elif feeds_polarization=='linear':
        pol_basis = ['X', 'Y']
    else:
        raise ValueError('feeds_polarization must be circular or linear.')
    if not os.path.isdir(plotfolder):
        os.makedirs(plotfolder)
    this_plotfolder = plotfolder+'/'+str(iteration)+'/'
    if not os.path.isdir(this_plotfolder):
        os.makedirs(this_plotfolder)
    if os.path.isdir(amp_caltable):
        _, cal_ants = get_caltb_antennas(amp_caltable, True)
    elif os.path.isdir(phase_caltable):
        _, cal_ants = get_caltb_antennas(phase_caltable, True)
    else:
        return False
    for ant in cal_ants:
        if os.path.isdir(amp_caltable):
            plotcal_savefig(amp_caltable, this_plotfolder+'amp_'+str(ant)+'_'+pol_basis[0], 'time', 'amp', pol_basis[0],
                            str(ant)
                           )
            plotcal_savefig(amp_caltable, this_plotfolder+'amp_'+str(ant)+'_'+pol_basis[1], 'time', 'amp', pol_basis[1],
                            str(ant)
                           )
        if os.path.isdir(phase_caltable):
            plotcal_savefig(phase_caltable, this_plotfolder+'phase_'+str(ant)+'_'+pol_basis[0], 'time', 'phase', pol_basis[0],
                            str(ant)
                           )
            plotcal_savefig(phase_caltable, this_plotfolder+'phase_'+str(ant)+'_'+pol_basis[1], 'time', 'phase', pol_basis[1],
                            str(ant)
                           )
    plotms_savefig(vis, 'uvdist', 'amp', xlabel='uvdistance', ylabel='amplitude',
                   plotfile=this_plotfolder+'radplot_{0}_amp_data.png'.format(pol_basis[0]),
                   correlation="{0}{1}".format(pol_basis[0],pol_basis[0])
                  )
    plotms_savefig(vis, 'uvdist', 'phase', xlabel='uvdistance', ylabel='phase',
                   plotfile=this_plotfolder+'radplot_{0}_phase_data.png'.format(pol_basis[0]),
                   correlation="{0}{1}".format(pol_basis[0],pol_basis[0])
                  )
    plotms_savefig(vis, 'uvdist', 'amp', xlabel='uvdistance', ylabel='amplitude', ydatacolumn='model',
                   plotfile=this_plotfolder+'radplot_{0}_amp_model.png'.format(pol_basis[0]),
                   correlation="{0}{1}".format(pol_basis[0],pol_basis[0])
                  )
    plotms_savefig(vis, 'uvdist', 'phase', xlabel='uvdistance', ylabel='phase', ydatacolumn='model',
                   plotfile=this_plotfolder+'radplot_{0}_phase_model.png'.format(pol_basis[0]),
                   correlation="{0}{1}".format(pol_basis[0],pol_basis[0])
                  )
    plotms_savefig(vis, 'uvdist', 'amp', xlabel='uvdistance', ylabel='amplitude',
                   plotfile=this_plotfolder+'radplot_{0}_amp_data.png'.format(pol_basis[1]),
                   correlation="{0}{1}".format(pol_basis[1],pol_basis[1])
                  )
    plotms_savefig(vis, 'uvdist', 'phase', xlabel='uvdistance', ylabel='phase',
                   plotfile=this_plotfolder+'radplot_{0}_phase_data.png'.format(pol_basis[1]),
                   correlation="{0}{1}".format(pol_basis[1],pol_basis[1])
                  )
    plotms_savefig(vis, 'uvdist', 'amp', xlabel='uvdistance', ylabel='amplitude', ydatacolumn='model',
                   plotfile=this_plotfolder+'radplot_{0}_amp_model.png'.format(pol_basis[1]),
                   correlation="{0}{1}".format(pol_basis[1],pol_basis[1])
                  )
    plotms_savefig(vis, 'uvdist', 'phase', xlabel='uvdistance', ylabel='phase', ydatacolumn='model',
                   plotfile=this_plotfolder+'radplot_{0}_phase_model.png'.format(pol_basis[1]),
                   correlation="{0}{1}".format(pol_basis[1],pol_basis[1])
                  )



####################################################### analysis functions ######################################################

def concatuvf(extension='*.uvfits'):
    """
    Glue uvfits files together via a measurement set.
    Useful for EHT data from HOPS.
    """
    uvfs = glob(extension)
    mss  = []
    for uvf in uvfs:
        thisms = uvf+'.ms'
        tasks.importuvfits(fitsfile=uvf, vis=thisms)
        mss.append(thisms)
    tasks.concat(vis=mss, concatvis='tmp.ms')
    tasks.exportuvfits(vis='tmp.ms', fitsfile='concat.uvfits', multisource=False)
    for ms in mss:
        shutil.rmtree(ms, ignore_errors=True)
    shutil.rmtree('tmp.ms', ignore_errors=True)



######################################################### help functions #########################################################

def get_caltb_antennas(caltb, also_names=False):
    """
    Returns array of all antennas present in caltb.
    """
    mytb = casac.table()
    mytb.open(caltb+'/ANTENNA')
    fields = mytb.getcol('NAME')
    if also_names:
        names = mytb.getcol('STATION')
    mytb.close()
    if also_names:
        return fields, names
    else:
        return fields


def get_caltb_fields(caltb):
    """
    Returns array of all antennas present in caltb.
    """
    mytb = casac.table()
    mytb.open(caltb+'/FIELD')
    ants = mytb.getcol('NAME')
    mytb.close()
    return ants


def plot_images_nterms(infile, plotfile, doplot=True, nterms=1, plotthese=['image', 'model', 'residual', 'mask']):
    """
    Wrapper around plot_image that operates on multiple images created
    if nterms>1.
    """
    if nterms>1:
        DRsum = 0
        for tt in np.arange(nterms):
            DRsum += plot_image(infile, plotfile, doplot, plotthese, '.tt'+str(tt))
        return DRsum/nterms
    else:
        return plot_image(infile, plotfile, doplot, plotthese)


def plot_image(infile, plotfile, doplot=True, plotthese=['image', 'model', 'residual', 'mask'], extension=''):
    """
    Save plot of CASA image and other files specified with the plotthese input parameter.
    plotthese[0] must be 'image'.
    A .png file extensions will be added to all plots.
    """
    plotthese_ext = [pt + extension for pt in plotthese]
    infls         = [infile + '.' + pt for pt in plotthese_ext]
    pltfls        = [plotfile + '.' + pt for pt in plotthese_ext]
    myia          = casac.image()
    myia.open(infls[0])
    stat = myia.statistics()
    keys = sorted(stat.keys())
    print('{0} statistics:'.format(infls[0]))
    for key in keys:
        print('{0}: {1}'.format(key, str(stat[key])))
    DR = round(float(stat['max'])/float(stat['sigma']),1)
    myia.close()
    if doplot:
        tasks.imview(raster  = {'file': infls[0],'colorwedge':True},
                     out     = '{0}_DR-{1}.png'.format(pltfls[0], str(DR))
                    )
        for infl, pltfl in zip(infls[1:], pltfls[1:]):
            tasks.imview(raster  = {'file': infl,'colorwedge':True},
                         out     = '{0}.png'.format(pltfl)
                        )
    return DR


def plotms_savefig(vis,
                   xaxis       = '',
                   yaxis       = '',
                   ydatacolumn = 'corrected',
                   field       = '',
                   spw         = '',
                   antenna     = '',
                   scan        = '',
                   correlation = '',
                   avgchannel  = '999999',
                   avgspw      = True,
                   avgtime     = '',
                   coloraxis   = 'baseline',
                   xlabel      = '',
                   ylabel      = '',
                   plotfile    = ''
                  ):
    """
    Save CASA plotms figure.
    """
    tasks.plotms(vis                  =  vis,
                 gridrows             =  1,
                 gridcols             =  1,
                 rowindex             =  0,
                 colindex             =  0,
                 plotindex            =  0,
                 xaxis                =  xaxis,
                 xdatacolumn          =  '',
                 yaxis                =  yaxis,
                 ydatacolumn          =  ydatacolumn,
                 yaxislocation        =  None,
                 selectdata           =  True,
                 field                =  field,
                 spw                  =  spw,
                 timerange            =  '',
                 uvrange              =  '',
                 antenna              =  antenna,
                 scan                 =  scan,
                 correlation          =  correlation,
                 array                =  '',
                 observation          =  '',
                 intent               =  '',
                 feed                 =  '',
                 msselect             =  '',
                 averagedata          =  True,
                 avgchannel           =  avgchannel,
                 avgtime              =  avgtime,
                 avgscan              =  False,
                 avgfield             =  False,
                 avgbaseline          =  False,
                 avgantenna           =  False,
                 avgspw               =  avgspw,
                 scalar               =  False,
                 transform            =  False,
                 freqframe            =  '',
                 restfreq             =  '',
                 veldef               =  'RADIO',
                 shift                =  [0.0, 0.0],
                 extendflag           =  False,
                 extcorr              =  False,
                 extchannel           =  False,
                 iteraxis             =  '',
                 xselfscale           =  False,
                 yselfscale           =  False,
                 xsharedaxis          =  False,
                 ysharedaxis          =  False,
                 customsymbol         =  True,
                 symbolshape          =  'circle',
                 symbolsize           =  12,
                 symbolcolor          =  '0000ff',
                 symbolfill           =  'fill',
                 symboloutline        =  False,
                 coloraxis            =  coloraxis,
                 customflaggedsymbol  =  False,
                 flaggedsymbolshape   =  'circle',
                 flaggedsymbolsize    =  12,
                 flaggedsymbolcolor   =  'ff0000',
                 flaggedsymbolfill    =  'fill',
                 flaggedsymboloutline =  False,
                 plotrange            =  [],
                 title                =  '',
                 titlefont            =  0,
                 xlabel               =  xlabel,
                 xaxisfont            =  0,
                 ylabel               =  ylabel,
                 yaxisfont            =  0,
                 showmajorgrid        =  False,
                 majorwidth           =  1,
                 majorstyle           =  '',
                 majorcolor           =  'B0B0B0',
                 showminorgrid        =  False,
                 minorwidth           =  1,
                 minorstyle           =  '',
                 minorcolor           =  'D0D0D0',
                 showlegend           =  False,
                 legendposition       =  None,
                 plotfile             =  plotfile,
                 expformat            =  '',
                 exprange             =  '',
                 highres              =  False,
                 dpi                  =  -1,
                 width                =  4000,
                 height               =  3200,
                 overwrite            =  True,
                 showgui              =  False,
                 clearplots           =  True,
                 callib               =  ['']
                )


def plotcal_savefig(caltable,
                    figfile,
                    xaxis     = '',
                    yaxis     = '',
                    poln      = 'R,L',
                    antenna   = '',
                    field     = '',
                    spw       = '',
                    timerange = '',
                    overplot  = False,
                    iteration = 'antenna',
                    plotrange = [],
                    showflags = False,
                   ):
    """
    Save CASA plotcal figure.
    """
    tasks.plotcal(caltable   = caltable,
                  xaxis      = xaxis,
                  yaxis      = yaxis,
                  poln       = str(poln),
                  field      = str(field),
                  antenna    = str(antenna),
                  spw        = spw,
                  timerange  = timerange,
                  subplot    = 111,
                  overplot   = overplot,
                  clearpanel = 'Auto',
                  iteration  = iteration,
                  plotrange  = plotrange,
                  showflags  =  showflags,
                  plotsymbol = 'o',
                  plotcolor  = 'blue',
                  markersize =  10,
                  fontsize   =  15,
                  showgui    =  False,
                  figfile    =  figfile
                 )


def plotcal_interactive(caltable,
                        xaxis     = '',
                        yaxis     = '',
                        poln      = '',
                        antenna   = '',
                        field     = '',
                        spw       = '',
                        timerange = '',
                        overplot  = False,
                        iteration = 'antenna',
                        plotrange = [],
                        showflags = False,
                       ):
    """
    Start CASA plotcal interactively (needs X-server).
    """
    tasks.plotcal(caltable   = caltable,
                  xaxis      = xaxis,
                  yaxis      = yaxis,
                  poln       = poln,
                  field      = str(field),
                  antenna    = str(antenna),
                  spw        = spw,
                  timerange  = timerange,
                  subplot    = 111,
                  overplot   = overplot,
                  clearpanel = 'Auto',
                  iteration  = iteration,
                  plotrange  = plotrange,
                  showflags  =  showflags,
                  plotsymbol = 'o',
                  plotcolor  = 'blue',
                  markersize =  10,
                  fontsize   =  15,
                  showgui    =  True,
                  figfile    =  ''
                 )


def extend_flags_phasedelayrate(caltb):
    """If any of phase, delay, rate are flagged, then extend flags to all of them."""
    mytb = casac.table()
    mytb.open(caltb, nomodify=False)
    nrow = mytb.nrows()
    for row in range(nrow):
        extflags = False
        thisflag = mytb.getcell('FLAG', row)
        rcpflag  = thisflag[:3]
        lcpflag  = thisflag[3:]
        if any(rcpflag) and not all(rcpflag):
            rcpflag  = [[True],[True],[True]]
            extflags = True
        if any(lcpflag) and not all(lcpflag):
            lcpflag  = [[True],[True],[True]]
            extflags = True
        if extflags:
            flagthis = np.concatenate((rcpflag,lcpflag), axis=0)
            mytb.putcell('FLAG', row, flagthis)
    mytb.flush()
    mytb.done()
    mytb.clearlocks()


def casa_tclean(_ms, field, _imagename, _imsize, _cellsize, stokes, _deconvolver, multiscale, nterms, smallscalebias, _robust,
                _niter, gain, _threshold, nsigma, cycleniter, cyclefactor, minpsffraction, maxpsffraction, _interactive, usemask,
                _mask, sidelobethreshold, noisethreshold, lownoisethreshold, negativethreshold, smoothfactor, minbeamfrac,
                cutthreshold, growiterations, dogrowprune, minpercentchange):
    """
    Wrapper around CASA tclean.
    """
    tasks.tclean(vis=_ms,selectdata=True,field=field,spw="",timerange="",uvrange="",antenna="",scan="",observation="",
                 intent="",datacolumn="corrected",imagename=_imagename,imsize=_imsize,cell=_cellsize,phasecenter="",
                 stokes=stokes,projection="SIN",startmodel="",specmode="mfs",reffreq="",nchan=-1,start="",width="",
                 outframe="LSRK",veltype="radio",restfreq=[],interpolation="linear",gridder="standard",facets=1,
                 chanchunks=-1,wprojplanes=1,vptable="",aterm=True,psterm=False,wbawp=True,conjbeams=True,
                 cfcache="",computepastep=360.0,rotatepastep=360.0,pblimit=0.2,normtype="flatnoise",deconvolver=_deconvolver,
                 scales=multiscale,nterms=nterms,smallscalebias=smallscalebias,restoration=True,restoringbeam=[],pbcor=False,
                 outlierfile="",weighting="briggs",robust=_robust,npixels=0,uvtaper=[],niter=int(_niter),gain=gain,
                 threshold=_threshold,nsigma=nsigma,cycleniter=cycleniter,cyclefactor=cyclefactor,
                 minpsffraction=minpsffraction,maxpsffraction=maxpsffraction,interactive=_interactive,usemask=usemask,
                 mask=_mask,pbmask=0.0,sidelobethreshold=sidelobethreshold,noisethreshold=noisethreshold,
                 lownoisethreshold=lownoisethreshold,negativethreshold=negativethreshold,smoothfactor=smoothfactor,
                 minbeamfrac=minbeamfrac,cutthreshold=cutthreshold,growiterations=growiterations,dogrowprune=dogrowprune,
                 minpercentchange=minpercentchange,restart=True,savemodel="modelcolumn",calcres=True,calcpsf=True,parallel=False,
                 fastnoise=False)
    #make sure the model is written correctly in the MODEL_DATA column:
    tasks.tclean(vis=_ms,selectdata=True,field=field,spw="",timerange="",uvrange="",antenna="",scan="",observation="",
                 intent="",datacolumn="corrected",imagename=_imagename,imsize=_imsize,cell=_cellsize,phasecenter="",
                 stokes=stokes,projection="SIN",startmodel="",specmode="mfs",reffreq="",nchan=-1,start="",width="",
                 outframe="LSRK",veltype="radio",restfreq=[],interpolation="linear",gridder="standard",facets=1,
                 chanchunks=-1,wprojplanes=1,vptable="",aterm=True,psterm=False,wbawp=True,conjbeams=True,
                 cfcache="",computepastep=360.0,rotatepastep=360.0,pblimit=0.2,normtype="flatnoise",deconvolver=_deconvolver,
                 scales=multiscale,nterms=nterms,smallscalebias=smallscalebias,restoration=True,restoringbeam=[],pbcor=False,
                 outlierfile="",weighting="briggs",robust=_robust,npixels=0,uvtaper=[],niter=0,gain=gain,
                 threshold=_threshold,nsigma=nsigma,cycleniter=cycleniter,cyclefactor=cyclefactor,
                 minpsffraction=minpsffraction,maxpsffraction=maxpsffraction,interactive=False,usemask=usemask,
                 mask='',pbmask=0.0,sidelobethreshold=sidelobethreshold,noisethreshold=noisethreshold,
                 lownoisethreshold=lownoisethreshold,negativethreshold=negativethreshold,smoothfactor=smoothfactor,
                 minbeamfrac=minbeamfrac,cutthreshold=cutthreshold,growiterations=growiterations,dogrowprune=dogrowprune,
                 minpercentchange=minpercentchange,restart=True,savemodel="modelcolumn",calcres=False,calcpsf=False,
                 parallel=False, fastnoise=False)


def gaincal_calib(ms, _ms_metadata, caltable, gaintype, field='', uvrange='', antenna='', scan='', solint='inf', combine='',
                  refant='', minblperant=4, minsnr=3.0, solnorm=False, smodel=[], calmode='ap',
                  gaintable=[], gainfield=[], interp=[], parang=False):
    """
    Generic gaincal task.
    """
    this_spwmap = auxiliary.get_spwmap(None, _ms_metadata, gaintable)
    tasks.gaincal(vis                =  ms,
                  caltable           =  caltable,
                  field              =  field,
                  spw                =  "",
                  intent             =  "",
                  selectdata         =  True,
                  timerange          =  "",
                  uvrange            =  uvrange,
                  antenna            =  antenna,
                  scan               =  scan,
                  observation        =  "",
                  msselect           =  "",
                  solint             =  solint,
                  combine            =  combine,
                  preavg             =  -1.0,
                  refant             =  str(refant),
                  refantmode         =  "flex",
                  minblperant        =  minblperant,
                  minsnr             =  minsnr,
                  solnorm            =  solnorm,
                  gaintype           =  gaintype,
                  smodel             =  smodel,
                  solmode            =  'L1',
                  rmsthresh          =  [],
                  calmode            =  calmode,
                  append             =  False,
                  splinetime         =  3600.0,
                  npointaver         =  3,
                  phasewrap          =  180.0,
                  docallib           =  False,
                  callib             =  "",
                  gaintable          =  gaintable,
                  gainfield          =  gainfield,
                  interp             =  interp,
                  spwmap             =  this_spwmap,
                  parang             =  parang
                 )


def applycal_calib(ms, _ms_metadata, gaintable=[], gainfield=[], interp=[], calwt=True,
                   field='', scan='', antenna='', parang=False, applymode='calflag'):
    """
    Genereic applycal task.
    """
    this_spwmap = auxiliary.get_spwmap(None, _ms_metadata, gaintable)
    tasks.applycal(vis         = ms,
                   field       = field,
                   spw         = '',
                   intent      = '',
                   selectdata  = True,
                   timerange   = '',
                   uvrange     = '',
                   antenna     = antenna,
                   scan        = scan,
                   observation = '',
                   msselect    = '',
                   docallib    = False,
                   callib      = '',
                   gaintable   = gaintable,
                   gainfield   = gainfield,
                   interp      = interp,
                   spwmap      = this_spwmap,
                   calwt       = calwt,
                   parang      = parang,
                   applymode   = applymode,
                   flagbackup  = True
                  )


def apply_and_split(ms, gaintable=[], gainfield=[], interp=[], calwt=True,
                    field='', scan='', antenna='', parang=False, applymode='calflag'):
    """
    Applies gaintables to DATA column (via split).
    *unused for now* (using defaul incremental self-cal).
    """
    applycal_calib(ms, gaintable, gainfield, interp, calwt, field, scan, antenna, parang, applymode)
    tmpms = auxiliary.unique_filename(ms+'.tmp')
    shutil.move(ms, tmpms)
    auxiliary.rm_dir_if_present(ms+'.flagversions')
    tasks.split(vis=tmpms, outputvis=ms, datacolumn="corrected", keepflags=False)
    auxiliary.rm_dir_if_present(tmpms)
    return []


def input23(_msg):
    """Keyboard input for both python3 and 2.7."""
    if sys.version_info >= (3, 0):
        _inp = input(_msg)
    else:
        _inp = raw_input(_msg)
    return _inp


def mock_rpicard_inps(switching_pol_stations, ms_name):
    class _ms_metadata(object):
        def __init__(self, ms_name):
            self.all_spwds = auxiliary.read_CASA_table(None, 'unique DATA_DESC_ID', _tablename=ms_name)
        def yield_antname(self, _id):
            raise ValueError('this is just a dummy')
    kwargs            = ['verbose', 'ms_name', 'switching_pol_stations',
                         'F_0PHAS', 'F_0DELA', 'F_0RATE', 'F_0DISP',
                         'F_1PHAS', 'F_1DELA', 'F_1RATE', 'F_1DISP',
                        ]
    args              = [False, False, switching_pol_stations, 0, 1, 2, 3, 4, 5, 6, 7]
    inputs_obj        = namedtuple('_inp_params', kwargs)
    dummy_inp_params  = inputs_obj(*args)
    dummy_ms_metadata = _ms_metadata(ms_name)
    return dummy_inp_params, dummy_ms_metadata
