"""
Copyright (c) 2024 Massachusetts Institute of Technology 
SPDX-License-Identifier: BSD-2-Clause
"""
import  numpy  as  np
import  cirq   as  cirq
import sys

from pyLIQTR.clam.operator_strings import  op_strings
from pyLIQTR.ProblemInstances.ProblemInstance import  ProblemInstance
from openfermion import  jordan_wigner, get_interaction_operator
from openfermion import  InteractionOperator
from openfermion import  FermionOperator

from    functools   import   cache, cached_property

import juliapkg

class ChemicalHamiltonian(ProblemInstance):
    
    """
    This ProblemInstance generates information for the PauliLCU encoding
    of a molecular Hamiltonian.
    
    Args:
        - mol_ham: molecular Hamiltonian in the form of an openfermion
            InteractionOperator. This object can be generated by 
            molecule.get_molecular_hamiltonian().
        - mol_name (Optional): str, name or molecular formula.
        
    """

    def __init__(self, mol_ham: InteractionOperator, mol_name=None, **kwargs):
        # We need to setup Julia if this hasn't happened yet, hopefully this only runs once!
        juliapkg.require_julia("~1.8,~1.9")
        juliapkg.resolve()
        
        # Now start the real initialization
        self._mol_ham    =  mol_ham
        self._mol_name   =  mol_name
        self._df_cutoffs =  dict()

        super(ProblemInstance, self).__init__(**kwargs)


    def __str__(self):
        if self._mol_name == None:
            return "Molecule"
        else:
            return f"{self._mol_name}"
    


    def n_qubits(self):
        return self._mol_ham.n_qubits

    @cached_property
    def terms_jw(self):
        return jordan_wigner(self._mol_ham).terms

    @cached_property
    def _ops(self):
        _ops = op_strings( N_qb=self.n_qubits() )

        ct = 0
        for term in self.terms_jw:

            bits    =  []
            ops     =  []
            coeff   =  self.terms_jw[term]

            for p in term:
                op = p[1]
                qubit = p[0]
                bits.append(qubit)
                ops.append(op)

            bits = tuple(bits)

            op_str = ''
            for op in ops:
                op_str += op

            if (ct >= 1):
                _ops.append_tuple( (bits,op_str,coeff) )
            ct += 1
        return _ops

    @property
    @cache
    def lam(self):
        lam = 0
        for term in self.terms_jw:
            coeff  =  self.terms_jw[term]
            lam   +=  abs(coeff)
        return lam

    @cached_property
    def hamiltonian_tensors(self):
        from pyLIQTR.utils.df_utils import to_tensors
        H = 0
        for term in self._mol_ham:
            H += FermionOperator(term, self._mol_ham[term])

        h0, one_body_tensor, two_body_tensor = to_tensors(H)

        return {'h0':h0, 'one_body_tensor':one_body_tensor, 'two_body_tensor':two_body_tensor}

    @cached_property
    def obt_fragment(self):
        from pyLIQTR.utils.df_utils import to_OBF
        _, obt, _ = self.hamiltonian_tensors.values()
        obt_frag = to_OBF(obt)
        return obt_frag

    def DF_fragments(self,sf_error_threshold):
        from pyLIQTR.utils.df_utils import DF_decomposition
        h0, one_body_tensor, two_body_tensor = self.hamiltonian_tensors.values()            
        return DF_decomposition(h0, one_body_tensor, two_body_tensor,tol=sf_error_threshold)

    def df_cutoffs(self,df_error_threshold:float=1e-3,sf_error_threshold:float=1e-10):
        error_pair = (df_error_threshold,sf_error_threshold)
        if error_pair in self._df_cutoffs:
            return self._df_cutoffs[error_pair]
        else:
            __,__,__,__ = self.yield_DF_Info(df_error_threshold=df_error_threshold,sf_error_threshold=sf_error_threshold)
            return self._df_cutoffs[error_pair]

    def get_alpha(self,encoding:str='PauliLCU',sf_error_threshold:float=1e-10,df_error_threshold:float=None,df_cutoffs:list=None):
        if encoding == 'PauliLCU':
            return(self._ops.get_coeff_norm())
        elif encoding == 'DF':
            from pyLIQTR.utils.df_utils import to_OBF
            if 'sphinx' not in sys.modules:
                from juliacall import Main as jl
                jl.seval('import Pkg')
                jl.seval('Pkg.add("QuantumMAMBO")')
                jl.seval("using QuantumMAMBO")
                mambo = jl.QuantumMAMBO

            if df_cutoffs is None:
                if df_error_threshold is None:
                    df_cutoffs = self.df_cutoffs(df_error_threshold=1e-3,sf_error_threshold=sf_error_threshold)
                else:
                    df_cutoffs = self.df_cutoffs(df_error_threshold=df_error_threshold,sf_error_threshold=sf_error_threshold)
            else:
                if df_error_threshold is not None:
                    raise ValueError("provide only df_error_threshold or df_cutoffs")

            _, one_body_tensor, two_body_tensor = self.hamiltonian_tensors.values()            
            DF_frags = self.DF_fragments(sf_error_threshold)
            one_body_correction = 2*sum([two_body_tensor[:,:,r,r] for r in range(two_body_tensor.shape[0])])
            one_body_fragment = to_OBF(one_body_tensor + one_body_correction)
            lambdaTprime = mambo.OBF_L1(one_body_fragment)
            lambdaDF = 0.0
            for l,frag in enumerate(DF_frags):
                lambdaDF += 0.5 * abs(frag.coeff) * ((sum(np.abs(frag.C.λ[:df_cutoffs[l]])))**2)
            return lambdaTprime + lambdaDF

        
    def optimize(self, method='BLISS'):
        if method == 'BLISS':
            if 'sphinx' not in sys.modules:
                from juliacall import Main as jl
                jl.seval('import Pkg')
                jl.seval('Pkg.add("QuantumMAMBO")')
                jl.seval("using QuantumMAMBO")
                mambo = jl.QuantumMAMBO
            else:
                raise ValueError()
            def BLISS(H:FermionOperator, num_elecs, do_T = True, ret_mambo = False, verbose=True, do_save=False, **kwargs):
                Hmambo = mambo.OF_to_F_OP(H)
                H_new, _ = mambo.bliss_linprog(Hmambo, num_elecs)
                return mambo.to_OF(H_new)
            n = self.n_qubits() / 2
            ham_f = 0
            for term in self._mol_ham:
                ham_f += FermionOperator(term, self._mol_ham[term])
            bliss_ham_f = BLISS(H=ham_f, num_elecs=n)
            bliss_mol_ham = get_interaction_operator(bliss_ham_f)
            bliss_mol_instance    =   ChemicalHamiltonian(mol_ham=bliss_mol_ham,mol_name="H2 w/ BLISS")
            new_lam = bliss_mol_instance.lam
            norm_percent = (1 - (new_lam / self.lam)) * 100
            print(f"New encoding normalization: {new_lam}. Normalization reduced by {norm_percent}%.")
            return bliss_mol_instance
    
    def yield_PauliLCU_Info(self,return_as='arrays',do_pad=0,pad_value=1.0):

        if (return_as == 'arrays'):
            terms = self._ops.terms(do_pad=do_pad,pad_value=pad_value)
        elif (return_as == 'strings'):
            terms = self._ops.strings(do_pad=do_pad,pad_value=pad_value)

        for term in terms:
            yield term
            
    
    def yield_DF_Info(self, df_error_threshold: float,sf_error_threshold:float=1e-10):
        from pyLIQTR.utils.df_utils import U_to_Givens, calc_xi

        _, obt, tbt = self.hamiltonian_tensors.values()   

        DF_frags = self.DF_fragments(sf_error_threshold)

        num_frags = len(DF_frags)
        num_orbs = np.size(obt, 0)
        mus_mat = np.zeros(shape = (num_frags + 1, num_orbs))
        thetas_tsr = np.zeros(shape = (num_frags + 1, num_orbs, num_orbs - 1))
        
        obt_frag = self.obt_fragment
        for i in range(len(obt_frag.C.λ)):
            mus_mat[0][i] = obt_frag.C.λ[i]
    
        for k in range(0, num_orbs):
            U_to_G = U_to_Givens(obt_frag.U[0], k)
            for i in range(len(U_to_G)):
                thetas_tsr[0][k][i] = U_to_G[i]    

        for l in range(num_frags):
            for i in range(len(DF_frags[l].C.λ)):
                mus_mat[l+1][i] = DF_frags[l].C.λ[i]

            for k in range(num_orbs):
                U_to_G = U_to_Givens(DF_frags[l].U[0], k)
                for g in range(len(U_to_G)):
                    thetas_tsr[l+1, k, g] = U_to_G[g]

        T_prime = mus_mat[0]
    
        T_prime_signs = []
        T_prime_vals = []
        for i in T_prime:
            sign = (1-np.sign(i))/2 if np.sign(i) else 0
            T_prime_signs.append(int(sign))
            T_prime_vals.append(abs(i))
        T_prime_full = [T_prime_signs, T_prime_vals]
    
        f_p_abs = []
        f_p_full = []
        for i in range(1, np.size(mus_mat, 0)):
            f_p_i = mus_mat[i]
            f_p_abs.append(abs(f_p_i))
            f_p_signs = []
            f_p_vals = []
            for k in f_p_i:
                sign = (1-np.sign(k))/2 if np.sign(k) else 0
                f_p_signs.append(int(sign))
                f_p_vals.append(abs(k))
            f_p_full.append([f_p_signs, f_p_vals])

        self._df_cutoffs[(df_error_threshold,sf_error_threshold)] = calc_xi(f_p_abs, df_error_threshold)

        return T_prime_full, f_p_full, self._df_cutoffs[(df_error_threshold,sf_error_threshold)], thetas_tsr
    
    
        

        
