from typing import List, Optional, Union

from psynet.trial.chain import ChainNetwork, ChainNode, ChainTrial, ChainTrialMaker

from ..utils import NoArgumentProvided, deep_copy, get_logger
from .main import Trial

logger = get_logger()


class StaticTrial(ChainTrial):
    """
    A Trial class for static experiments.

    The user must override the ``time_estimate`` class attribute,
    providing the estimated duration of the trial in seconds.
    This is used for predicting the participant's reward
    and for constructing the progress bar.

    Attributes
    ----------

    time_estimate : numeric
        The estimated duration of the trial (including any feedback), in seconds.
        This should generally correspond to the (sum of the) ``time_estimate`` parameters in
        the page(s) generated by ``show_trial``, plus the ``time_estimate`` parameter in
        the page generated by ``show_feedback`` (if defined).
        This is used for predicting the participant's reward
        and for constructing the progress bar.

    participant_id : int
        The ID of the associated participant.
        The user should not typically change this directly.
        Stored in ``property1`` in the database.

    complete : bool
        Whether the trial has been completed (i.e. received a response
        from the participant). The user should not typically change this directly.
        Stored in ``property2`` in the database.

    answer : Object
        The response returned by the participant. This is serialised
        to JSON, so it shouldn't be too big.
        The user should not typically change this directly.
        Stored in ``details`` in the database.

    earliest_async_process_start_time : Optional[datetime]
        Time at which the earliest pending async process was called.

    definition
        A dictionary of parameters defining the trial,
        inherited from the respective :class:`~psynet.trial.static.Stimulus` object.

    participant_group
        The associated participant group.

    block
        The block in which the trial is situated.
    """

    __extra_vars__ = Trial.__extra_vars__.copy()

    def generate_asset_key(self, asset):
        return f"{self.trial_maker_id}/block_{self.block}__node_{self.node_id}__trial_{self.id}__{asset.local_key}{asset.extension}"

    def show_trial(self, experiment, participant):
        raise NotImplementedError

    def make_definition(self, experiment, participant):
        for k, v in self.node.assets.items():
            self.assets[k] = v
        return deep_copy(self.node.definition)


class StaticTrialMaker(ChainTrialMaker):
    """
    Administers a sequence of trials in a static experiment.
    The class is intended for use with the
    :class:`~psynet.trial.static.StaticTrial` helper class.
    which should be customised to show the relevant node
    for the experimental paradigm.

    The user may also override the following methods, if desired:

    * :meth:`~psynet.trial.static.StaticTrialMaker.choose_block_order`;
      chooses the order of blocks in the experiment. By default the blocks
      are ordered randomly.

    * :meth:`~psynet.trial.static.StaticTrialMaker.choose_participant_group`;
        Only relevant if the trial maker uses nodes with non-default participant groups.
        In this case the experimenter is expected to supply a function that takes participant as an argument
        and returns the chosen participant group for that trial maker.

    * :meth:`~psynet.trial.main.TrialMaker.on_complete`,
      run once the sequence of trials is complete.

    * :meth:`~psynet.trial.main.TrialMaker.performance_check`;
      checks the performance of the participant
      with a view to rejecting poor-performing participants.

    * :meth:`~psynet.trial.main.TrialMaker.compute_performance_reward`;
      computes the final performance reward to assign to the participant.

    Further customisable options are available in the constructor's parameter list,
    documented below.

    Parameters
    ----------

    trial_class
        The class object for trials administered by this maker
        (should subclass :class:`~psynet.trial.static.StaticTrial`).

    nodes
        The nodes to be administered to the participants. This can be provided as
        a list of :class:`~psynet.trial.static.StaticNode` objects,
        or as a function (taking no arguments) that can be called to generate such
        a list. The latter is useful for generating nodes based on local files
        (e.g. large media assets) that are not available on the deployed server.

    expected_trials_per_participant
        Expected number of trials that each participant will complete.
        This is used for timeline/progress estimation purposes.

    max_trials_per_participant
        Maximum number of trials that each participant may complete;
        once this number is reached, the participant will move on
        to the next stage in the timeline.

    recruit_mode
        Selects a recruitment criterion for determining whether to recruit
        another participant. The built-in criteria are ``"n_participants"``
        and ``"n_trials"``.

    target_n_participants
        Target number of participants to recruit for the experiment. All
        participants must successfully finish the experiment to count
        towards this quota. This target is only relevant if
        ``recruit_mode="n_participants"``.

    target_trials_per_node
        Target number of trials to recruit for each node in the experiment. This target is only relevant if
        ``recruit_mode="n_trials"``.

    max_trials_per_block
        Determines the maximum number of trials that a participant will be allowed to experience in each block,
        including failed trials. Note that this number does not include repeat trials.

    allow_repeated_nodes
        Determines whether the participant can be administered the same node more than once.

    max_unique_nodes_per_block
        Determines the maximum number of unique nodes that a participant will be allowed to experience
        in each block. Once this quota is reached, the participant will be forced to repeat
        previously experienced nodes.

    balance_across_nodes
        If ``True`` (default), active balancing across participants is enabled, meaning that
        node selection favours nodes that have been presented fewest times to any participant
        in the experiment, excluding failed trials.

    check_performance_at_end
        If ``True``, the participant's performance
        is evaluated at the end of the series of trials.
        Defaults to ``False``.
        See :meth:`~psynet.trial.main.TrialMaker.performance_check`
        for implementing performance checks.

    check_performance_every_trial
        If ``True``, the participant's performance
        is evaluated after each trial.
        Defaults to ``False``.
        See :meth:`~psynet.trial.main.TrialMaker.performance_check`
        for implementing performance checks.

    fail_trials_on_premature_exit
        If ``True``, a participant's trials are marked as failed
        if they leave the experiment prematurely.
        Defaults to ``True``.

    fail_trials_on_participant_performance_check
        If ``True``, a participant's trials are marked as failed
        if the participant fails a performance check.
        Defaults to ``True``.

    n_repeat_trials
        Number of repeat trials to present to the participant. These trials
        are typically used to estimate the reliability of the participant's
        responses. Repeat trials are presented at the end of the trial maker,
        after all blocks have been completed.
        Defaults to 0.

    choose_participant_group
        Only relevant if the trial maker uses nodes with non-default participant groups.
        In this case the experimenter is expected to supply a function that takes participant as an argument
        and returns the chosen participant group for that trial maker.

    sync_group_type
        Optional SyncGroup type to use for synchronizing participant allocation to nodes.
        When this is set, then the ordinary node allocation logic will only apply to the 'leader'
        of each SyncGroup. The other members of this SyncGroup will follow that leader around,
        so that in every given trial the SyncGroup works on the same node together.

    Attributes
    ----------

    check_timeout_interval_sec : float
        How often to check for trials that have timed out, in seconds (default = 30).
        Users are invited to override this.

    response_timeout_sec : float
        How long until a trial's response times out, in seconds (default = 60)
        (i.e. how long PsyNet will wait for the participant's response to a trial).
        This is a lower bound on the actual timeout
        time, which depends on when the timeout daemon next runs,
        which in turn depends on :attr:`~psynet.trial.main.TrialMaker.check_timeout_interval_sec`.
        Users are invited to override this.

    async_timeout_sec : float
        How long until an async process times out, in seconds (default = 300).
        This is a lower bound on the actual timeout
        time, which depends on when the timeout daemon next runs,
        which in turn depends on :attr:`~psynet.trial.main.TrialMaker.check_timeout_interval_sec`.
        Users are invited to override this.

    network_query : sqlalchemy.orm.Query
        An SQLAlchemy query for retrieving all networks owned by the current trial maker.
        Can be used for operations such as the following: ``self.network_query.count()``.

    n_networks : int
        Returns the number of networks owned by the trial maker.

    networks : list
        Returns the networks owned by the trial maker.

    performance_threshold : float
        Score threshold used by the default performance check method, defaults to 0.0.
        By default, corresponds to the minimum proportion of non-failed trials that
        the participant must achieve to pass the performance check.

    end_performance_check_waits : bool
        If ``True`` (default), then the final performance check waits until all trials no
        longer have any pending asynchronous processes.
    """

    def __init__(
        self,
        *,
        id_: str,
        trial_class,
        nodes: Optional[Union[callable, List["StaticNode"]]],
        expected_trials_per_participant: int,
        max_trials_per_participant: Optional[int] = NoArgumentProvided,
        recruit_mode: Optional[str] = None,
        target_n_participants: Optional[int] = None,
        target_trials_per_node: Optional[int] = None,
        max_trials_per_block: Optional[int] = None,
        allow_repeated_nodes: bool = False,
        balance_across_nodes: bool = True,
        check_performance_at_end: bool = False,
        check_performance_every_trial: bool = False,
        fail_trials_on_premature_exit: bool = True,
        fail_trials_on_participant_performance_check: bool = True,
        n_repeat_trials: int = 0,
        assets=None,
        choose_participant_group: Optional[callable] = None,
        sync_group_type: Optional[str] = None,
    ):
        # balance_across_chains = (
        #     active_balancing_across_participants or active_balancing_within_participants
        # )
        # balance_strategy = set()
        # if active_balancing_within_participants:
        #     balance_strategy.add("within")
        # if active_balancing_across_participants:
        #     balance_strategy.add("across")

        if callable(nodes):
            if expected_trials_per_participant is None:
                raise ValueError(
                    "If nodes is a function, expected_trials_per_participant must be explicitly provided."
                )
            chains_per_experiment = None
        else:
            assert isinstance(nodes, list)
            if (
                expected_trials_per_participant > len(nodes)
                and not allow_repeated_nodes
            ):
                raise ValueError(
                    f"expected_trials_per_participant ({expected_trials_per_participant}) "
                    f"may not exceed len(nodes) ({len(nodes)}) "
                    "unless allow_repeated_nodes = True."
                )
            chains_per_experiment = len(nodes)

        if allow_repeated_nodes:
            assert (
                max_trials_per_participant is not None
                or max_trials_per_block is not None
            )

        super().__init__(
            id_=id_,
            start_nodes=nodes,
            trial_class=trial_class,
            network_class=StaticNetwork,
            node_class=StaticNode,
            recruit_mode=recruit_mode,
            target_n_participants=target_n_participants,
            expected_trials_per_participant=expected_trials_per_participant,
            max_trials_per_participant=max_trials_per_participant,
            max_trials_per_block=max_trials_per_block,
            chain_type="across",
            chains_per_participant=None,
            chains_per_experiment=chains_per_experiment,
            max_nodes_per_chain=1,
            trials_per_node=target_trials_per_node if target_trials_per_node else 1e6,
            balance_across_chains=balance_across_nodes,
            # balance_strategy=balance_strategy,
            allow_revisiting_networks_in_across_chains=allow_repeated_nodes,
            check_performance_at_end=check_performance_at_end,
            check_performance_every_trial=check_performance_every_trial,
            fail_trials_on_premature_exit=fail_trials_on_premature_exit,
            fail_trials_on_participant_performance_check=fail_trials_on_participant_performance_check,
            n_repeat_trials=n_repeat_trials,
            assets=assets,
            choose_participant_group=choose_participant_group,
            sync_group_type=sync_group_type,
        )


class StaticNetwork(ChainNetwork):
    pass


class StaticNode(ChainNode):
    def summarize_trials(self, trials: list, experiment, participant):
        return None

    def create_definition_from_seed(self, seed, experiment, participant):
        return None
