from collections.abc import Sequence
from pathlib import Path

import numpy as np
import torch

from eir.data_load.data_preparation_modules.common import (
    process_tensor_to_length,
)
from eir.setup.input_setup_modules.setup_sequence import (
    ComputedSequenceInputInfo,
    al_encode_funcs,
    get_sequence_split_function,
)
from eir.setup.schemas import SequenceInputDataConfig


def sequence_load_wrapper(
    data_pointer: Path | int | np.ndarray,
    input_source: str,
    split_on: str | None,
    encode_func: al_encode_funcs,
) -> np.ndarray:
    """
    In the case of .csv input sources, we have already loaded and tokenized the data.
    """

    split_func = get_sequence_split_function(split_on=split_on)
    if input_source.endswith(".csv"):
        assert isinstance(data_pointer, np.ndarray)
        return data_pointer
    else:
        assert isinstance(data_pointer, str | Path)
        content = load_sequence_from_disk(sequence_file_path=data_pointer)

    file_content_split = split_func(content)
    file_content_encoded = encode_func(file_content_split)
    sequence_tokenized = np.array(file_content_encoded)

    return sequence_tokenized


def load_sequence_from_disk(sequence_file_path: Path | str) -> str:
    with open(sequence_file_path) as infile:
        return infile.read().strip()


def prepare_sequence_data(
    sequence_input_object: "ComputedSequenceInputInfo",
    cur_file_content_tokenized: np.ndarray,
    test_mode: bool,
) -> torch.Tensor:
    """
    We use clone here to copy the original data, vs. using from_numpy
    which shares memory, causing us to modify the original data.
    """

    sio = sequence_input_object
    input_type_info = sio.input_config.input_type_info
    assert isinstance(input_type_info, SequenceInputDataConfig)

    cur_arr = cur_file_content_tokenized.copy()
    cur_tokens_as_tensor = torch.LongTensor(cur_arr).detach().clone()

    sampling_strategy = input_type_info.sampling_strategy_if_longer
    if test_mode:
        sampling_strategy = "from_start"

    padding_token = getattr(sio.tokenizer, "pad_token", "<pad>")
    padding_token_parsed = parse_padding_token_encode_func_input(
        split_on=input_type_info.split_on, padding_token=padding_token
    )
    padding_value = sio.encode_func(padding_token_parsed)[0]
    cur_tokens_padded = process_tensor_to_length(
        tensor=cur_tokens_as_tensor,
        max_length=sio.computed_max_length,
        sampling_strategy_if_longer=sampling_strategy,
        padding_value=padding_value,
    )

    return cur_tokens_padded


def parse_padding_token_encode_func_input(
    split_on: str | None, padding_token: str
) -> Sequence[str] | str:
    parsed_token: Sequence[str] | str

    parsed_token = padding_token if split_on is None else [padding_token]

    return parsed_token
