#!/usr/bin/env python 
from __future__ import print_function

from extasycoco._version import __version__

import logging as log
import sys
import os
import os.path as op
import numpy as np
import argparse
import glob
import time

from MDPlus.analysis import pca
from MDPlus.core import Fasu, Cofasu
from MDPlus.fastfitting import rmsd, fitted
from MDPlus.analysis import mapping 
from extasycoco import new_points

def coco_ui(args):
    '''
    The command line implementation of the CoCo procedure. Should be invoked
    as:
    pyCoCo -i mdfiles -t topfile -o outname [-d ndims -n npoints -g gridsize -l
    logfile -s selection --nompi --regularize --fmt format --skip nskip]    
    where:
        mdfiles  is a list of one or more trajectory files
        topfile  is a compatible topology file
        outname  two options here. If a single name is given it is assumed it
                 is the basename for the structure files generated by CoCo.
                 The default format is pdb but this can be overriden by
                 the --fmt argument. There will be npoints of these; if 
                 outname='out' then they will be called
                 'out0.pdb', 'out1.pdb'... etc up to 'out(npoints-1).pdb'.
                 Option two is that multiple file names are given here. In 
                 that case the number of them defines the number of new
                 structures created (any npoints argument is ignored) and
                 their extensions define the format.
        format   is the output file format. Accepted options are 'pdb'
                 (default), 'rst', and 'gro'.
        ndims    specifies the number of dimensions (PCs) in the CoCo mapping
                 (default=3).
        npoints  specifies the number of frontier points to return structures
                 from (default=1)
        gridsize specifies the number of grid points per dimension in the CoCo
                 histogram (default=10)
        logfile  is an optional file with detailed analysis data.
        selection is an optional MDTraj style selection string. Only 
                 selected atoms will be used in the CoCo procedure, however
                 ALL atoms will be included in the output files (all unselected
                 ones having coordinates drawn from the first frame analyzed).
                 Such structures are, obviously, only useful as targets for
                 restrained MD or EM procedures.
        cache    is an optional directory of trajectory files, created/updated 
                 in a previous CoCo run, which must only feature the selected
                 atoms. In an MPI context, one file per process should work
                 best.

        nompi    specifies that CoCo should not be run in parallel
        regularize specifies that generated structures should (if possible)
                 have their bond lengths and angles regularised.
        skip     specifies the number of top eigenvectors to skip over in the
                 CoCo process, e.g. if nskip is 1, and ndims is 3, the process
                 will use the distributions in PCs 2-4.
        
    '''

    if args.grid < 1:
        print('Error: gridsize must be > 0')
        exit(1)
    if args.frontpoints < 1:
        print('Error - frontpoints must be > 0')
        exit(1)
    if args.dims < 1:
        print('Error: dims must be > 0')
        exit(1)

    if args.verbosity == 2:
        log.basicConfig(format="%(asctime)s: %(levelname)s: %(message)s", 
                        level=log.DEBUG)
        log.debug("Debug output.")
    elif args.verbosity == 1:
        log.basicConfig(format="%(levelname)s: %(message)s", 
                        level=log.INFO)
        log.info("Verbose output.")

    dict = {}
    dict['topfile'] = args.topfile
    if ((len(args.mdfile)==1) and (("*" in args.mdfile[0]) 
                                    or ("?" in args.mdfile[0]))):
        dict['mdfiles'] = glob.glob('%s' % args.mdfile[0])
    else:
        dict['mdfiles'] = args.mdfile
    dict['logfile'] = args.logfile
    dict['ndims'] = args.dims
    dict['npoints'] = args.frontpoints
    dict['gridsize'] = args.grid
    dict['selection'] = args.selection
    dict['cache'] = args.cache

    if len(args.output) > 1:
        dict['npoints'] = len(args.output)
        dict['outnames'] = args.output
    else:
        root, ext = op.splitext(args.output[0])
        dict['outnames'] = ['{}{}{}'.format(root,rep,ext) for rep in range(dict['npoints'])]
            
    if args.nompi:
        comm = None
        rank = 0
        size = 1
    else:
        try:
            from mpi4py import MPI
            comm = MPI.COMM_WORLD
            rank = comm.Get_rank()
            size = comm.Get_size()
        except ImportError:
            comm = None
            rank = 0
            size = 1
        
    if dict['logfile'] is not None and rank == 0:
        try:
            logfile = open(dict['logfile'],'w')
        except IOError as e:
            print(e)
            exit(-1)
        logfile.write("*** pyCoCo ***\n\n")

    if rank == 0:
        log.info('creating cofasu...')
    cofasustart = time.time()
    f = []
    i = 0
    if args.cache:
        # First load trajectory files (if any) from the cache:
        if rank == 0:
            if not op.isdir(args.cache):
                os.mkdir(args.cache)
        ctemp = Cofasu(Fasu(dict['topfile'], dict['topfile'], selection=dict['selection']))
        cachepdb = args.cache + '/cache.pdb'
        ctemp.write(cachepdb)
        cachelist =  glob.glob(dict['cache'] + "/cache*.dcd")
        for atrj in cachelist:
            f.append(Fasu(cachepdb, atrj, selection='all'))

    # Now new ones:
    for trj in dict['mdfiles']:
        try:
            f.append(Fasu(trj, top=dict['topfile'], selection=dict['selection']))
            i += 1
        except (IOError, TypeError)  as e:
            print(e)
            exit(-1)
        except:
            raise
    
    cf = Cofasu(f, comm=comm)
    cofasutime = time.time() - cofasustart
    natoms = cf.shape[1]
    if natoms == 0:
          print('Error: the selection matches no atoms.')
          exit(-1)
    
    # create a cofasu corresponding to the full system, and also an index
    # file for the subset.
    cref = Cofasu(Fasu(trj, top=dict['topfile'], frames=slice(1)))
    xref = cref[0]
    if rank == 0:
        selndx = cf.fasulist[0].sel

    if dict['logfile'] is not None and rank == 0:
        logfile.write("Trajectory files to be analysed:\n")
        for i in range(len(cf.fasulist)):
            logfile.write("{0}: frames: {1} \n".format(dict['mdfiles'][i], cf.fasulist[i].shape[0]))

        logfile.write('\n')
    if rank == 0:
        log.info('cofasu contains {0} atoms and {1} frames'.format(natoms, len(cf)))
        log.info('time to load trajectory data: {:.2f} s.'.format(cofasutime))

    # Some sanity checking for situations where few input structures have
    # been given. If there is just one, just return copies of it. If there
    # are < 5, ensure ndims is reasonable, and that the total number of 
    # grid points (at which new structures might be generated) is OK too.
    # Adust both ndims and gridsize if required, giving warning messages.
    if len(cf) == 1:
        if dict['logfile'] is not None and rank == 0:
            logfile.write("WARNING: Only one input structure given, CoCo\n")
            logfile.write("procedure not possible, new structures will be\n")
            logfile.write("copies of the input structure.\n")

        if rank == 0:
            log.info('Warning: only one input structure!')
        for rep in range(dict['npoints']):
            dict['rep'] = rep
            if rank == 0:
                opt = cf[0]
                cf.write(dict['outputnames'][0], opt)
    else:
        if rank == 0:
            log.info('running pcazip...')
        p = pca.fromtrajectory(cf)
        if rank == 0:
            log.info('Total variance: {0:.2f}'.format(p.totvar))
            
        if len(cf) <= dict['ndims'] or p.n_vecs < dict['ndims']: 
            dict['ndims'] = min(len(cf) - 1, p.n_vecs)
            if rank == 0:
                log.info("Warning - resetting ndims to {ndims}".format(**dict))
                if dict['logfile'] is not None:
                    logfile.write('Warning - ndims must be smaller than the\n')
                    logfile.write("number of input structures, resetting it to {ndims}\n\n".format(**dict))

        ntot = dict['ndims'] * dict['gridsize']
        if ntot < dict['npoints']:
            dict['gridsize'] = (dict['npoints']/dict['ndims']) + 1
            if rank == 0:
                log.info("Warning - resetting gridsize to {gridsize}".format(**dict))
                if dict['logfile'] is not None:
                    logfile.write('Warning - gridsize too small for number of\n')
                    logfile.write("output structures, resetting it to {gridsize}\n\n".format(**dict))
       
        if dict['logfile'] is not None and rank == 0:
            logfile.write("Total variance in trajectory data: {0:.2f}\n\n".format(p.totvar))
            logfile.write("Conformational sampling map will be generated in\n")
            logfile.write("{ndims} dimensions at a resolution of {gridsize} points\n".format(**dict))
            logfile.write("in each dimension.\n\n")
            logfile.write("{npoints} complementary structures will be generated.\n\n".format(**dict))
        dim = dict['ndims']
        nskip = args.skip
        projsSel = p.projs[nskip:dim + nskip].T
                    
        if args.currentpoints is not None and rank == 0:
            np.savetxt(args.currentpoints, projsSel)

        # Build a map from the projection data.
        m = mapping.Map(projsSel, resolution=dict['gridsize'], boundary=1)
        # Report on characteristics of the COCO map:
        
        if dict['logfile'] is not None and rank == 0:
            logfile.write("Sampled volume: {0} Ang.^{1}.\n".format(m.volume, dim))
        # Find the COCO points.
        nreps = int(dict['npoints'])
        if rank == 0:
            log.info('generating new points...')
        cp = new_points(m, npoints=nreps)
        
        if args.newpoints is not None and rank == 0:
            np.savetxt(args.newpoints, cp)

        if dict['logfile'] is not None and rank == 0:
            logfile.write("\nCoordinates of new structures in PC space:\n")
            for i in range(nreps):
                logfile.write( '{:4d}'.format(i))
                for j in cp[i]:
                    logfile.write(' {:6.2f}'.format(j))
                logfile.write('\n')

        outlist = []
            
        regularize = False
        if args.regularize:
            regularize = True

        for rep in range(nreps):
            dict['rep'] = rep
            # add zeros to start of cp if we are skipping over top EVs
            stmp = [0.0] * nskip + list(cp[rep])
            # Convert the point to a crude structure.
            e = p.scores(p.closest(stmp))
            e[:len(stmp)] = stmp
            crude = p.unmap(e, regularize=regularize)

            outlist.append(crude)
            # merge the optimised subset into the full coordinates array:
            xout = xref
            if rank == 0:
                xout[selndx] = fitted(crude, xout[selndx])
            
            if rank == 0:
                cref.write(dict['outnames'][rep], xout)
                
        if dict['logfile'] is not None and rank == 0:
            logfile.write("\nRMSD matrix for new structures:\n")
            for i in range(nreps):
                for j in range(nreps):
                    logfile.write("{0:6.2f}".format(rmsd(outlist[i],outlist[j])))
                logfile.write("\n")

        if args.cache:
            if len(cachelist) == 0:
                cachelist = [dict['cache'] + '/cache{}.dcd'.format(i) for i in range(size)] 
                chunksize = len(cf)/size
            else:
                chunksize = len(cf)/len(cachelist)
            for i in range(size):
                temptrj = dict['cache'] + '/tmp.dcd'
                start = i*chunksize
                end = min((i+1)*chunksize, len(cf))
                cf.write(temptrj, cf[start:end])
                if rank == 0:
                    os.rename(temptrj, cachelist[i])
            
            if dict['logfile'] is not None and rank == 0:
                logfile.write("Cache {} updated.".format(args.cache))

    if dict['logfile'] is not None and rank == 0:
        logfile.close()
################################################################################
#                                                                              #
#                                    ENTRY POINT                               #
#                                                                              #
################################################################################

if __name__ == '__main__':
    parser=argparse.ArgumentParser()
    parser.add_argument('-g','--grid', type=int, default=10, help="Number of points along each dimension of the CoCo histogram")
    parser.add_argument('-d','--dims', type=int, default=3, help='The number of projections to consider from the input pcz file in CoCo; this will also correspond to the number of dimensions of the histogram.')
    parser.add_argument('-n','--frontpoints', type=int, default=1, help="The number of new frontier points to select through CoCo.")
    parser.add_argument('-i','--mdfile', type=str, nargs='*', help='The MD files to process.', required=True)
    parser.add_argument('-o','--output', type=str, nargs='*', help='Basename of the pdb files that will be produced.', required=True)
    parser.add_argument('-t','--topfile', type=str, help='Topology file.', required=True)
    parser.add_argument('-v','--verbosity', action="count", help="Increase output verbosity.")
    parser.add_argument('-l','--logfile', type=str, default=None, help='Optional log file.')
    parser.add_argument('-c','--cache', type=str, default=None, help='Optional cache directory.')
    parser.add_argument('-s','--selection', type=str, default='all', help='Optional atom selection string.')
    parser.add_argument('--nompi', action='store_true', help='Disables any attempt to use MPI.')
    parser.add_argument('-r', '--regularize', action='store_true', help='Regularize structures.')
    parser.add_argument('-V','--version', action='version', version=__version__)
    parser.add_argument('-f','--fmt', type=str, default=None, help='Optional output format.')
    parser.add_argument('--currentpoints', type=str, default=None,
    help='Optional file with coordinates of current points.')
    parser.add_argument('--newpoints', type=str, default=None,
    help='Optional file with coordinates of CoCo-generated points.')
    parser.add_argument('--skip', type=int, default=0, help='The number of top eigenvectors to skip over (default=0)')
    
    args=parser.parse_args()
    coco_ui(args)
