#
# MUSIC𝄞NTWRK
#
# A python library for pitch class set and rhythmic sequences classification and manipulation,
# the generation of networks in generalized music and sound spaces, and the sonification of arbitrary data
#
# Copyright (C) 2018 Marco Buongiorno Nardelli
# http://www.materialssoundmusic.com, mbn@unt.edu
#
# This file is distributed under the terms of the
# GNU General Public License. See the file `License'
# in the root directory of the present distribution,
# or http://www.gnu.org/copyleft/gpl.txt .
#

import sys,re,os
import pandas as pd
import numpy as np
import music21 as m21

from ..musicntwrk import PCSet
from ..utils.communications import *
from ..utils.load_balancing import *

try:
    from mpi4py import MPI
    # initialize parallel execution
    comm=MPI.COMM_WORLD
    rank = comm.Get_rank()
    size = comm.Get_size()
    para = True
except:
    rank = 0
    size = 1
    para = False

def pcsNetwork(dictionary,thup,thdw,distance,prob,write,pcslabel,TET):
    
    '''
    •	generate the network of pcs based on distances between interval vectors
    •	dictionary (str)– dictionary generated by pcsNetwork
    •	thup, thdw (float)– upper and lower thresholds for edge creation
    •	distance (str)– choice of norm in the musical space, default is 'euclidean'
    •	col = 2 – metric based on interval vector
    •	prob (float)– if ≠ 1, defines the probability of acceptance of any given edge
    •	in output it writes the nodes.csv and edges.csv as separate files in csv format
    '''

    col = 2
    df = np.asarray(dictionary)
    
    dim = np.asarray(list(map(int,re.findall('\d+',df[0,col])))).shape[0]
    
    # write csv for nodes
    if pcslabel:
        dnodes = pd.DataFrame(None,columns=['Label'])
        for n in range(len(df)):
            p = PCSet(np.asarray(list(map(int,re.findall('\d+',df[n,1])))))
            if p.pcs.shape[0] == 1:
                nn = ''.join(m21.chord.Chord(p.pcs.tolist()).pitchNames)
            else:
                nn = ''.join(m21.chord.Chord(p.normalOrder().tolist()).pitchNames)
            nameseq = pd.DataFrame([[str(nn)]],columns=['Label'])
            dnodes = pd.concat([dnodes,nameseq],ignore_index=True)
    else:
        dnodes = pd.DataFrame(df[:,0],columns=['Label'])
    if write: dnodes.to_csv('nodes.csv',index=False)

    if para: comm.Barrier()
    
    # find edges according to a metric
    
    vector = np.zeros((df[:,col].shape[0],dim))
    for i in range(df[:,col].shape[0]):
        vector[i]  = np.asarray(list(map(int,re.findall('\d+',df[i,col]))))
    N = vector.shape[0]
    index = np.linspace(0,vector.shape[0]-1,vector.shape[0],dtype=int)
    # parallelize over interval vector to optimize the vectorization in sklm.pairwise_distances
    ini,end = load_balancing(size, rank, N)
    nsize = end-ini
    vaux = scatter_array(vector)
    #pair = sklm.pairwise_distances(vaux,vector,metric=distance)
    dedges = pd.DataFrame(None,columns=['Source','Target','Weight'])
    for i in range(nsize):
        tmp = pd.DataFrame(None,columns=['Source','Target','Weight'])
        tmp['Source'] = (i+ini)*np.ones(vector.shape[0],dtype=int)[:]
        tmp['Target'] = index[:]
        tmp['Weight'] = np.sqrt(np.sum((vaux[i,:]-vector[:,:])**2,axis=1))
        tmp = tmp.query('Weight<='+str(thup)).query('Weight>='+str(thdw))
        if prob == 1:
            dedges = pd.concat([dedges,tmp],ignore_index=True)
        else:
            np.random.seed(int(time.time()))
            if np.random.rand() >= prob:
                dedges = pd.concat([dedges,tmp],ignore_index=True)
            else:
                pass
            
    dedges = dedges.query('Weight<='+str(thup)).query('Weight>='+str(thdw))
    dedges['Weight'] = dedges['Weight'].apply(lambda x: 1/x)
    # do some cleaning
    cond = dedges.Source > dedges.Target
    dedges.loc[cond, ['Source', 'Target']] = dedges.loc[cond, ['Target', 'Source']].values
    dedges = dedges.drop_duplicates(subset=['Source', 'Target'])

    # write csv for partial edges
    dedges.to_csv('edges'+str(rank)+'.csv',index=False)
    if para: comm.Barrier()
    
    if size != 1 and rank == 0:
        dedges = pd.DataFrame(None,columns=['Source','Target','Weight'])
        for i in range(size):
            tmp = pd.read_csv('edges'+str(i)+'.csv')
            dedges = pd.concat([dedges,tmp],ignore_index=True)
            os.remove('edges'+str(i)+'.csv')
        # do some cleaning
        cond = dedges.Source > dedges.Target
        dedges.loc[cond, ['Source', 'Target']] = dedges.loc[cond, ['Target', 'Source']].values
        dedges = dedges.drop_duplicates(subset=['Source', 'Target'])
        # write csv for edges
        if write: dedges.to_csv('edges.csv',index=False)
    elif size == 1:
        os.rename('edges'+str(rank)+'.csv','edges.csv')
        if not write: os.remove('edges.csv')

    return(dnodes,dedges)
