# Neural Piano main Python module

import os
import io
from pathlib import Path
import importlib

import librosa
import soundfile as sf

import midirenderer

from .music2latent.inference import EncoderDecoder

from .master import master_mono_piano

def neuralpiano(input_midi_file,
                output_audio_file,
                sample_rate=48000,
                denoising_steps=10,
                load_multi_instrumental_model=False,
                return_audio=False
                ):
    
    home_root = '~/'
    models_dir = os.path.join(home_root, "models")
    sf2_name = "SGM-v2.01-YamahaGrand-Guit-Bass-v2.7.sf2"
    sf2_path = os.path.join(models_dir, sf2_name)

    print('=' * 70)
    print('Neural Piano')
    print('=' * 70)
    
    print('Prepping model...')
    encdec = EncoderDecoder(load_multi_instrumental_model=load_multi_instrumental_model)

    print('Reading and rendering MIDI file...')
    wav_data = midirenderer.render_wave_from(
        Path(sf2_path).read_bytes(),
        Path(input_midi_file).read_bytes()
    )
    
    print('Loading rendered MIDI...')
    with io.BytesIO(wav_data) as byte_stream:
        wv, sr = librosa.load(byte_stream, sr=sample_rate)
    
    print('Encoding...')
    latent = encdec.encode(wv)
    
    print('Rendering...')
    wv_rec = encdec.decode(latent, denoising_steps=denoising_steps)
    
    print('Mastering...')
    stereo, diag = master_mono_piano(wv_rec)
    
    print('Saving final audio...')
    sf.write(output_audio_file, stereo.squeeze().T, samplerate=sr)
    
    print('=' * 70)
    print('Done!')
    print('=' * 70)
    
    if return_audio:
        return stereo