"""
Defines the WeakForwardSimulator calculator class
"""
#***************************************************************************************************
# Copyright 2015, 2019, 2025 National Technology & Engineering Solutions of Sandia, LLC (NTESS).
# Under the terms of Contract DE-NA0003525 with NTESS, the U.S. Government retains certain rights
# in this software.
# Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except
# in compliance with the License.  You may obtain a copy of the License at
# http://www.apache.org/licenses/LICENSE-2.0 or in the LICENSE file in the root pyGSTi directory.
#***************************************************************************************************

import numpy as _np
import time as _time

from pygsti.forwardsims.forwardsim import ForwardSimulator as _ForwardSimulator
from pygsti.baseobjs import outcomelabeldict as _ld


class WeakForwardSimulator(_ForwardSimulator):
    """
    A calculator of circuit outcome probabilities from a "weak" forward simulator
    (i.e. probabilites taken as average frequencies over a number of "shots").

    Due to their ability to only sample outcome probabilities, WeakForwardSimulators
    rely heavily on implementing the _compute_sparse_circuit_outcome_probabilities
    function of ForwardSimulators.
    """

    def __init__(self, shots, model=None, base_seed=None):
        """
        Construct a new WeakForwardSimulator object.

        Parameters
        ----------
        shots: int
            Number of times to run each circuit to obtain an approximate probability
        model : Model
            Optional parent Model to be stored with the Simulator
        base_seed: int, optional
            Base seed for RNG of probabilitic operations during circuit simulation.
            Incremented for every shot such that deterministic seeding behavior can be
            carried out with both serial or MPI execution.
            If not provided, falls back to using time.time() to get a valid seed.
        """
        self.shots = shots
        self.base_seed = base_seed if base_seed is not None else int(_time.time())
        super().__init__(model)

    def _compute_circuit_outcome_for_shot(self, circuit, resource_alloc, time=None, rand_state=None):
        """Compute outcome for a single shot of a circuit.

        Parameters
        ----------
        circuit : Circuit
            A tuple-like object of *simplified* gates (e.g. may include
            instrument elements like 'Imyinst_0') generated by
            OpModel.expand_instruments_and_separate_povm()

        resource_alloc: ResourceAlloc
            Currently not used

        time : float, optional
            The *start* time at which `circuit` is evaluated.

        rand_state: RandomState, optional
            RNG object to use for probabilistic operations

        Returns
        -------
        outcome_label: tuple
            An outcome label for the single shot sampled
        """

        #Default implementation assumes evotype can propagate state vectors, like 'statevec'.
        assert(time is None), "WeakForwardSimulator cannot be used to simulate time-dependent circuits yet"
        assert(resource_alloc is None), "WeakForwardSimulator cannot use a resource_alloc for one shot."

        #prep_label, op_labels, povm_label = self.model.split_circuit(spc_circuit)
        spc_dict = self.model.expand_instruments_and_separate_povm(circuit,
                                                                observed_outcomes=None)  # FUTURE: observed outcomes?
        assert(len(spc_dict) == 1), "Circuits with instruments are not supported by weak forward simulator (yet)"
        spc = next(iter(spc_dict.keys()))  # first & only SeparatePOVMCircuit

        prep_label = spc.circuit_without_povm[0]  # I think this is always present in SeparatePOVMCircuit objects
        op_labels = spc.circuit_without_povm[1:]

        st = self.model.circuit_layer_operator(prep_label, 'prep')._rep.actionable_staterep()
        for op_label in op_labels:
            oprep = self.model._circuit_layer_operator(op_label, 'op')._rep
            st = oprep.acton_random(st, rand_state)

        povmrep = self.model.circuit_layer_operator(spc.povm_label, 'povm')._rep
        if povmrep is None:
            r = rand_state.rand()  # random number in [0,1]
            x = 0
            for elbl, full_elbl in zip(spc.effect_labels, spc.full_effect_labels):
                erep = self.model._circuit_layer_operator(full_elbl, 'povm')._rep
                x += erep.probability(st)  # outcome probability
                if r <= x:
                    return elbl  # outcome label
            raise ValueError("WeakForwardSimulator failure because probabilties add to %f < 1!" % x)
        else:
            return povmrep.sample_outcome(st, rand_state)

        #raise NotImplementedError("WeakForwardSimulator-derived classes should implement this!")

    def _compute_sparse_circuit_outcome_probabilities(self, circuit, resource_alloc, time=None):
        probs = _ld.OutcomeLabelDict()

        comm = None if resource_alloc is None else resource_alloc.comm
        if comm is None:
            # No MPI, just serial execution
            for i in range(self.shots):
                rand_state = _np.random.RandomState(self.base_seed + i)

                outcome = self._compute_circuit_outcome_for_shot(circuit, None, time, rand_state)
                if outcome in probs:
                    probs[outcome] += 1.0 / self.shots
                else:
                    probs[outcome] = 1.0 / self.shots
        else:
            # Have a comm, so use MPI to parallelize over shots
            rank = comm.Get_rank()
            size = comm.Get_size()

            shots_per_rank = self.shots // size
            remainder = self.shots % size
            # Take care of any leftover shots
            if rank < remainder:
                shots_per_rank += 1

            # Calculate how many shots other ranks are computing so we get the correct seed offset
            seed_offset = shots_per_rank * rank + min(rank, remainder)

            # Each rank runs their set of shots (with no resource alloc since we are using it at this level)
            outcomes_per_rank = []
            for i in range(shots_per_rank):
                rand_state = _np.random.RandomState(self.base_seed + seed_offset + i)
                outcomes_per_rank.append(self._compute_circuit_outcome_for_shot(circuit, None, time, rand_state))

            # Collect all outcomes
            outcomes = comm.gather(outcomes_per_rank, root=0)
            if rank == 0:
                for opr in outcomes:
                    for outcome in opr:
                        if outcome in probs:
                            probs[outcome] += 1.0 / self.shots
                        else:
                            probs[outcome] = 1.0 / self.shots

            # Distribute final probabilities
            probs = comm.bcast(probs, root=0)

        # Update seed so subsequent circuits have different RNG
        self.base_seed += self.shots

        return probs

    # For WeakForwardSimulator, provide "bulk" interface based on the sparse interface
    # This will be highly inefficient for large numbers of qubits due to the dense storage of outcome probabilities
    # Anything expanding out all effects or creating a COPALayout will be expensive
    def _compute_circuit_outcome_probabilities(self, array_to_fill, circuit, outcomes, resource_alloc, time=None):
        # TODO: Other forward sims have expand_outcomes here, check how to fit that in
        sparse_probs = self._compute_sparse_circuit_outcome_probabilities(circuit, resource_alloc, time)

        for i, outcome in enumerate(outcomes):
            array_to_fill[i] = sparse_probs[outcome]

    def bulk_probs(self, circuits, clip_to=None, resource_alloc=None, smartc=None):
        """
        Construct a dictionary containing the probabilities for an entire list of circuits.

        Parameters
        ----------
        circuits : list of Circuits
            The list of circuits.  May also be a :class:`CircuitOutcomeProbabilityArrayLayout`
            object containing pre-computed quantities that make this function run faster.

        clip_to : 2-tuple, optional
            (min,max) to clip return value if not None.

        resource_alloc : ResourceAllocation, optional
            A resource allocation object describing the available resources and a strategy
            for partitioning them.

        smartc : SmartCache, optional
            A cache object to cache & use previously cached values inside this
            function.

        Returns
        -------
        probs : dictionary
            A dictionary such that `probs[circuit]` is an ordered dictionary of
            outcome probabilities whose keys are outcome labels.
        """
        return {circ: self._compute_sparse_circuit_outcome_probabilities(circ, resource_alloc) for circ in circuits}
