#!/usr/bin/env python

import numpy as np
import dynaphopy.analysis.fitting  as fitting
import dynaphopy.interface.iofile as reading
import matplotlib.pyplot as plt

import argparse


# Arguments definition
parser = argparse.ArgumentParser(description='fitdata')
parser.add_argument('file', metavar='file', type=str,
                   help='file containing power spectra information')

parser.add_argument('-bi', metavar='indices', type=str, nargs='+', default=None,
                   help='band indices (ex: "1,2 3 4, 5")')

parser.add_argument('-sm', '--smooth', action='store_true',
                   help='smooth before fitting (LOESS method)')

parser.add_argument('-f', metavar='N', type=float, default=0.02,
                   help='smooth factor [Default: 0.02]')

parser.add_argument('-s', metavar='file', type=str,
                   help='save spectra to file')

parser.add_argument('--silent', action='store_true',
                   help='do not show plots')

parser.add_argument('--fitting_function', type=int, default=0, metavar='N',
                   help='Select fitting function')

args = parser.parse_args()


# LOESS method definition
def loess(x, y, f=2./3.):

    n = len(x)
    r = int(np.ceil(f*n))
    h = [np.sort(np.abs(x - x[i]))[r] for i in range(n)]
    w = np.clip(np.abs((x[:, None] - x[None, :]) / h), 0.0, 1.0)
    w = (1.0 - w**3)**3
    y_smooth = np.zeros(n)
    delta = np.ones(n)
    for i in range(n):
            weights = delta * w[:, i]
            b = np.array([np.sum(weights*y), np.sum(weights*y*x)])
            a = np.array([[np.sum(weights), np.sum(weights*x)],
                          [np.sum(weights*x), np.sum(weights*x*x)]])
            beta = np.linalg.solve(a, b)
            y_smooth[i] = beta[0] + beta[1]*x[i]

    return y_smooth

#Process input data
try:
    input_file = open(args.file, "r")
except IOError:
    print("File not found")
    exit()

if args.bi is None:
    data_num = len(input_file.readline().split())
    degeneracy = [[i] for i in range(1, data_num)]
else:
    degeneracy = [[int(j) for j in i.split()] for i in ' '.join(args.bi).split(',')]


initial_data = []
for line in input_file:
    initial_data.append(line.split())
initial_data = np.array(initial_data,dtype=float)


test_frequencies_range = np.array(initial_data[:,0])
data = []
for phonon in degeneracy:
    data_temp = np.zeros_like(test_frequencies_range)
    for degenerate in phonon:
        data_temp += initial_data[:, degenerate]/len(phonon)
    data.append(data_temp)



# Smoothing using LOESS method
if args.smooth:

    y_smooth = []
    for i, datum in enumerate(data):

        y2 = loess(test_frequencies_range, datum, f=args.f)

        if not args.silent:
            plt.figure(i+1)
            plt.title('Band '+str(i+1))
            plt.suptitle('Smoothing LOESS')

            plt.plot(test_frequencies_range, datum, label='Original')
            plt.plot(test_frequencies_range, y2, label='Smooth', linewidth=3)
            plt.legend()

        y_smooth.append(y2)

    if not args.silent:
        plt.show()

    data = np.array(y_smooth).T
else:
    data = np.array(data).T

# Analysis using Dynaphopy functions
fitting.phonon_fitting_analysis(data, test_frequencies_range,
                                show_plots=not args.silent,
                                fitting_function_type=args.fitting_function,
                                use_degeneracy=None)

if args.s:
    reading.write_curve_to_file(test_frequencies_range, data, args.s)
