#!/usr/bin/env python
# coding: utf-8
# Copyright (c) atomate Development Team.

from __future__ import division, unicode_literals, print_function

import argparse
import os
import sys
import yaml
import ast
from datetime import datetime

from monty.serialization import loadfn

from fireworks import LaunchPad

from atomate.utils.utils import get_wf_from_spec_dict, load_class
from atomate.vasp.powerups import add_namefile, add_tags
from atomate.vasp.workflows.presets import core

from pymatgen import Structure, Lattice, MPRester
from pymatgen.util.testing import PymatgenTest

default_yaml = """fireworks:
- fw: atomate.vasp.fireworks.core.OptimizeFW
- fw: atomate.vasp.fireworks.core.StaticFW
  params:
    parents: 0
- fw: atomate.vasp.fireworks.core.NonSCFFW
  params:
    parents: 1
    mode: uniform
- fw: atomate.vasp.fireworks.core.NonSCFFW
  params:
    parents: 1
    mode: line
common_params:
  db_file: None
"""

presets_dir = os.path.join(os.path.dirname(os.path.abspath(core.__file__)))
lpad = LaunchPad.auto_load()


def add_to_lpad(workflow, write_namefile=False):
    """
    Add the workflow to the launchpad

    Args:
        workflow (Workflow): workflow for db insertion
        write_namefile (bool): If set an empty file with the name
            "FW--<fw.name>" will be written to the launch directory
    """
    workflow = add_namefile(workflow) if write_namefile else workflow
    lpad.add_wf(workflow)


def _get_wf(args, structure):
    if args.spec_file:
        spec_path = args.spec_file
        if args.library:
            if args.library.lower() == "vasp":
                spec_path = os.path.join(presets_dir, "..", "base",
                                         "library", spec_path)
            else:
                raise ValueError("Unknown library: {}".format(args.library))
        d = loadfn(spec_path)
        return get_wf_from_spec_dict(structure, d, args.common_param_updates)

    elif args.preset:
        if args.library:
            if args.library.lower() == "vasp":
                modname = "atomate.vasp.workflows.presets.core"
                funcname = args.preset
            else:
                modname, funcname = args.preset.rsplit(".", 1)

        mod = __import__(modname, globals(), locals(), [str(funcname)], 0)
        func = getattr(mod, funcname)
        return func(structure)

    else:
        d = yaml.load(default_yaml)
        return get_wf_from_spec_dict(structure, d, args.common_param_updates)


def add_wf(args):
    for f in args.files:
        if not args.mp:
            s = Structure.from_file(f)
        else:
            s = MPRester().get_structure_by_material_id(f)
        wf = _get_wf(args, s)
        add_to_lpad(wf, write_namefile=False)


def submit_test_suite(args):
    """
    Creates a test suite of workflows and adds it to the LaunchPad
    """
    dt = datetime.utcnow()
    if args.reset:
        lpad.reset(password='', require_password=False)
    
    # Structures for standard workflow
    compounds = ["Si", "CsCl"]
    structs = [PymatgenTest.get_structure(c) for c in compounds]
    structs.append(Structure.from_spacegroup("Fm-3m", Lattice.cubic(4.204), ["Ni","O"],
                                             [[0, 0, 0], [0.5, 0.5, 0.5]]))
    structs = [s.get_primitive_structure() for s in structs]
    scopy = structs[0].copy()
    scopy.perturb(0.1)
    # Compounds for specific workflows, if not specified uses Si
    custom_args = {"wf_elastic_constant":[PymatgenTest.get_structure("Sn")],
                   "wf_piezoelectric_constant":[PymatgenTest.get_structure("SrTiO3")],
                   "wf_nudged_elastic_band": [[structs[0], scopy], structs[0]]}

    # Add default workflows
    d = yaml.load(default_yaml)
    wfs = [get_wf_from_spec_dict(s, d) for s in structs]
    # Add preset workflows
    for name, wf_func in core.__dict__.items():
        if callable(wf_func) and name[:2]=="wf":
            if name in custom_args:
                wfs.append(wf_func(*custom_args[name]))
            else:
                wfs.append(wf_func(structs[0]))
    wfs = [add_tags(wf, "test set {}".format(dt)) for wf in wfs]
    for wf in wfs:
        add_to_lpad(wf, write_namefile=False)


def verify_test_suite(args):
    tags = lpad.fireworks.distinct("spec.tags", {"spec.tags":{"$regex":"test set"}})
    for tag in tags:
        pipeline = [{"$match":{"spec.tags":tag}}, {"$project":{"state":1}},
                    {"$group":{"_id":"$state", "count":{"$sum":1}}}]
        rep_dict = {k["_id"]:k["count"] for k in lpad.fireworks.aggregate(pipeline)}
        dt = ''.join(tag.split()[2:])
        print("Test set submitted at {}".format(dt))
        if args.report:
            print("   Test fw state summary:")
            for state, count in rep_dict.items():
                print("   -{}: {}".format(state, count))
        count = sum(rep_dict.values())
        if rep_dict.get('COMPLETED', None)==count:
            print("   All test workflows successfully completed")
        elif not rep_dict.get('FIZZLED', None):
            print("   No fizzled workflows, testing not yet complete.")
        else:
            print("   {} fireworks have have fizzled.".format(rep_dict["FIZZLED"]))


def powerup_workflow(args):
    """
    Simple function that will powerup a workflow in the database
    """
    if args.wf_id and args.query:
        raise ValueError("Only one of --wf_id and --query may be specified")
    elif args.wf_id:
        wfs = [lpad.get_wf_by_fw_id(int(args.wf_id))]
    elif args.query:
        query = ast.literal_eval(args.query)
        wf_ids = lpad.get_wf_ids(query)
        wfs = [lpad.get_wf_by_fw_id(wf_id) for wf_id in wf_ids]
    else:
        raise ValueError("At least one of --wf_id or --query must be specified")
    powerup_fn = load_class(args.module, args.name)
    powerup_kwargs = ast.literal_eval(args.powerup_kwargs)
    for wf in wfs:
        wf = powerup_fn(wf, **powerup_kwargs)
        for fw in wf.fws:
            lpad.update_spec([fw.fw_id], {"_tasks": fw.as_dict()['spec']['_tasks']})


if __name__ == "__main__":
    parser = argparse.ArgumentParser(
        description="atwf is a convenient script to add workflows using a "
                    "simple YAML spec.",
        epilog="Author: Shyue Ping Ong")

    subparsers = parser.add_subparsers()

    padd = subparsers.add_parser("add", help="Add workflows.")
    padd.add_argument("-l", "--library", dest="library", type=str,
                      help="If this option is set, the path to the spec file"
                           "is taken with respect to the atomate base library."
                           "Use 'vasp' for the VASP library of workflows")
    padd.add_argument("-s", "--spec",
                      dest="spec_file", type=str,
                      help="Specify workflow type using YAML/JSON spec file.")
    padd.add_argument("-p", "--preset",
                      dest="preset", type=str,
                      help="Specify workflow type using preset function")
    padd.add_argument("-m", "--mp", dest="mp", action='store_true',
                      help="If this option is set, the files argument is "
                           "interpreted as a list of Materials Project IDS. "
                           "Note that your MAPI_KEY environment variable must "
                           "be set to get structures from the Materials "
                           "Project.")
    padd.add_argument("-c", "--common_params", dest="common_param_updates",
                      help="Set to a dict-like string, e.g. '{\"a\":\"b\"}', to set common params")
    padd.add_argument("files", metavar="files", type=str, nargs="+",
                      help="Structures to add workflows for.")
    padd.set_defaults(func=add_wf,common_param_updates="{}")

    ptest = subparsers.add_parser("test", help="Add test suite.")
    ptest.add_argument("-r", "--reset", dest="reset", action='store_true',
                       help="If this option is set, launchpad will be reset.")
    ptest.set_defaults(func=submit_test_suite)

    pverify = subparsers.add_parser("verify", help="verify test results.")
    
    pverify.add_argument("-r", "--report", dest="report", action="store_true",
                         help="Use this option to print summary of test workflow "
                              "states.")
    pverify.set_defaults(func=verify_test_suite)

    ppowerup = subparsers.add_parser("powerup", help="Apply a powerup to a workflow")
    ppowerup.add_argument("-i", "--wf_id", help="Workflow id to powerup")
    ppowerup.add_argument("-q", "--query", help="Query for workflows")
    ppowerup.add_argument("-n", "--name", help="Name of powerup to apply, "
                                               "e.g. add_modify_incar")
    ppowerup.add_argument("-m", "--module", default="atomate.vasp.powerups",
                          help="Module containing powerup")
    ppowerup.add_argument("-pk", "--powerup_kwargs", default='{}',
                          help="powerup keyword arguments, e.g. "
                               "'{\"incar_update\": {\"ENCUT\": 700}}'")
    ppowerup.set_defaults(func=powerup_workflow)

    args = parser.parse_args()

    if hasattr(args,"common_param_updates"):
        args.common_param_updates = ast.literal_eval(args.common_param_updates)  # str->dict

    try:
        a = getattr(args, "func")
    except AttributeError:
        parser.print_help()
        sys.exit(0)
    args.func(args)

