#!/usr/bin/env python
'''
Collect trajectory data from a tios simulation.
'''
from __future__ import print_function
import sys
import os
import datetime
from tios import utilities, receiver
from tios._version import __version__
import mdtraj as mdt
import curses
import re
import argparse


def set_erase_to_end_of_line():
    # see stackoverflow.com/questions/5419389
    result = ""
    if sys.stdout.isatty():
        curses.setupterm(fd=sys.stdout.fileno())
        result += re.sub(r'\$<\d+>[/*]?', '', curses.tigetstr('el').decode() or '')
    return result
    
el = set_erase_to_end_of_line()

parser  = argparse.ArgumentParser(description='Connect to a *tios* stream and download data')
parser.add_argument("id", help='tios id of the job to collect data from')
parser.add_argument('-o', '--trajfile', help='name of trajectory file')
parser.add_argument('-l', '--logfile', help='name of log file')
parser.add_argument('-p', '--pdb', 
                    help='write a .pdb file to complement the trajectory')
parser.add_argument('-v', '--verbose', action='store_true',
                    help='display info as collection progresses')
parser.add_argument('-t', '--time_period', default=0.0, 
                    help='time period over which to collect data (in ps or ns)')
parser.add_argument('-i', '--interval', default=0.0,
                    help='time interval between snapshots (in ps or ns)')
parser.add_argument('-s', '--selection', default='all',
                    help='include only these atoms in the trajectory file')
parser.add_argument('-V', '--version', action='version', version=__version__)
args = parser.parse_args()

tr = receiver.TiosReceiver(args.id, selection=args.selection)

if isinstance(args.time_period, str):
    if 'ps' in args.time_period:
        time_period = float(args.time_period[:-2])
    elif 'ns' in args.time_period:
        time_period = float(args.time_period[:-2]) * 1000
    else:
        time_period = float(args.time_period)
else:
    time_period = float(args.time_period)

if isinstance(args.interval, str):
    if 'ps' in args.interval:
        interval = float(args.interval[:-2])
    elif 'ns' in args.interval:
        interval = float(args.interval[:-2]) * 1000
    else:
        interval = float(args.interval)
else:
    interval = float(args.interval)

sample_interval = float(tr._te.trate)
if interval == 0.0:
    interval = sample_interval
if int(interval) % int(sample_interval) != 0:
    print('Error: your sampling interval must be equal to,', end='')
    print(' or a multiple of, {}ps'.format(sample_interval))

nskip = int(interval / sample_interval)
if args.pdb is not None:
    tr.trajectory.topology.create_standard_bonds()
    tr.trajectory.make_molecules_whole()
    tr.trajectory.save(args.pdb)

if args.trajfile is None:
    exit(0)

ext = os.path.splitext(args.trajfile)[1]
if not ext in ['.xtc', '.dcd']:
    print('Error = only .xtc and .dcd formats are supported')
    exit(1)
if ext == '.xtc':
    writer = mdt.formats.XTCTrajectoryFile
else:
    writer = mdt.formats.DCDTrajectoryFile
tt = tr.trajectory
with writer(args.trajfile, 'w') as f:
    if ext == '.xtc':
        f.write(xyz=tt.xyz, time=tt.time, box=tt.unitcell_vectors)
    else:
        # Bug in dcd writer: does not convert nm to angstroms
        f.write(xyz=tt.xyz*10.0, cell_lengths=tt.unitcell_lengths*10.0,
                cell_angles=tt.unitcell_angles)
    status = tr.status
    if status == 'Running':
        delta_t = datetime.datetime.utcnow() - tr._te.last_update
        if delta_t.total_seconds() > 10 * tr._te.frame_rate:
            status = '*Stalled*'
    timepoint = tr.timepoint
    n_frames = 1
    unlimited = time_period == 0
    final_timepoint = timepoint + time_period
    out = 'Frames so far: {} Time point (ps): {:.2f} status: {}'
    if args.logfile is not None:
        with open(args.logfile, 'w') as l:
            l.write(out.format(n_frames, timepoint, status) + '\n')
    if args.verbose:
        sys.stdout.write(out.format(n_frames, timepoint, status))
        sys.stdout.flush()
    killer = utilities.GracefulKiller()
    while not killer.kill_now:
        if status == 'Stopped':
            tr.step(killer=killer)
        else:
            tr.step(wait=False, killer=killer)
        if killer.kill_now:
            break
        timepoint = tr.timepoint    
        status = tr.status
        if status == 'Running':
            delta_t = datetime.datetime.utcnow() - tr._te.last_update
            if delta_t.total_seconds() > 10 * tr._te.frame_rate:
                status = '*Stalled*'
        n_frames += 1
    
        if not unlimited and timepoint > final_timepoint:
            break
        if n_frames % nskip == 0:
            tr.trajectory.make_molecules_whole()
            if ext == '.xtc':
                f.write(xyz=tt.xyz, time=tt.time, box=tt.unitcell_vectors)
            else:
                f.write(xyz=tt.xyz*10.0, cell_lengths=tt.unitcell_lengths*10.0,
                        cell_angles=tt.unitcell_angles)
            if args.logfile is not None:
                with open(args.logfile, 'w') as l:
                    l.write(out.format(n_frames / nskip, timepoint, status) + '\n')
            if args.verbose:
                sys.stdout.write('\r' + el + out.format(n_frames / nskip,
                                                        timepoint, status))
                sys.stdout.flush()
    if args.verbose:
        print('\n')

