//! This module contains routines needed for normalizing a circuit
//! into a form that can be encoded as a pytket legacy circuit.
//!
//! This is a best-effort attempt, and may not always succeed.

use derive_more::{Display, Error, From};
use hugr::{HugrView, Node};
use itertools::Itertools;

use crate::serialize::pytket::OpConvertError;
use crate::Circuit;

use super::find_tuple_unpack_rewrites;

/// Try to lower a circuit to a form that can be encoded as a pytket legacy circuit.
pub fn lower_to_pytket<T: HugrView<Node = Node>>(
    circ: &Circuit<T>,
) -> Result<Circuit, PytketLoweringError> {
    let mut circ = circ
        .extract_dfg()
        .map_err(|_| PytketLoweringError::NonLocalOperations)?;

    // Remove sequences of tuple pack-unpack operations,
    // typically generated by guppy.
    let rewrites = find_tuple_unpack_rewrites(&circ).collect_vec();
    for rewrite in rewrites {
        rewrite.apply(&mut circ).unwrap();
    }

    Ok(circ)
}

/// Errors that can occur during the lowering process.
#[derive(Debug, Display, Error, From)]
#[non_exhaustive]
pub enum PytketLoweringError {
    /// An error occurred during the conversion of an operation.
    #[display("operation conversion error: {_0}")]
    #[from]
    OpConversionError(OpConvertError),
    /// The circuit is not fully-contained in a region.
    /// Function calls are not supported.
    #[display("Non-local operations found. Function calls are not supported.")]
    NonLocalOperations,
}

#[cfg(test)]
mod test {

    use crate::TketOp;

    use super::*;
    use hugr::builder::{CFGBuilder, Dataflow, HugrBuilder};
    use hugr::extension::prelude::{qb_t, MakeTuple, UnpackTuple};

    use hugr::hugr::hugrmut::HugrMut;
    use hugr::ops::handle::NodeHandle;
    use hugr::ops::{OpTag, OpTrait, OpType, Tag};
    use hugr::types::{Signature, TypeRow};
    use hugr::HugrView;
    use rstest::{fixture, rstest};

    /// Builds a circuit in the style of guppy's output.
    ///
    /// This is composed of a `Module`, containing a `FuncDefn`, containing a
    /// `CFG`, containing an `Exit` and a `DataflowBlock` with the actual
    /// circuit.
    #[fixture]
    fn guppy_like_circuit() -> Circuit {
        fn build() -> Result<Circuit, hugr::builder::BuildError> {
            let two_qbs = TypeRow::from(vec![qb_t(), qb_t()]);
            let circ_signature = Signature::new_endo(two_qbs.clone());
            let mut cfg = CFGBuilder::new(circ_signature)?;
            let circ = {
                let mut dfg = cfg.simple_entry_builder(two_qbs.clone(), 1)?;
                let [q1, q2] = dfg.input_wires_arr();

                let [q1] = dfg.add_dataflow_op(TketOp::H, [q1])?.outputs_arr();
                let [q1, q2] = dfg.add_dataflow_op(TketOp::CX, [q1, q2])?.outputs_arr();

                let [tup] = dfg
                    .add_dataflow_op(MakeTuple::new(two_qbs.clone()), [q1, q2])?
                    .outputs_arr();
                let [q1, q2] = dfg
                    .add_dataflow_op(UnpackTuple::new(two_qbs), [tup])?
                    .outputs_arr();

                // Adds an empty Unit branch.
                let [branch] = dfg
                    .add_dataflow_op(Tag::new(0, vec![TypeRow::new()]), [])?
                    .outputs_arr();

                dfg.finish_with_outputs(branch, [q1, q2])?
            };
            cfg.branch(&circ, 0, &cfg.exit_block())?;

            let mut hugr = cfg.finish_hugr()?;
            hugr.set_entrypoint(circ.node());
            Ok(Circuit::new(hugr))
        }
        build().unwrap()
    }

    #[rstest]
    #[case::guppy_like_circuit(guppy_like_circuit())]
    fn test_pytket_lowering(#[case] circ: Circuit) {
        use cool_asserts::assert_matches;

        let lowered_circ = lower_to_pytket(&circ).unwrap();
        lowered_circ.hugr().validate().unwrap();

        let parent_tag = lowered_circ.hugr().entrypoint_optype().tag();
        assert!(OpTag::DataflowParent.is_superset(parent_tag));
        assert_matches!(
            lowered_circ.hugr().get_optype(lowered_circ.input_node()),
            OpType::Input(_)
        );
        assert_matches!(
            lowered_circ.hugr().get_optype(lowered_circ.output_node()),
            OpType::Output(_)
        );
        assert_eq!(lowered_circ.num_operations(), circ.num_operations());

        // Check that the circuit signature is preserved.
        let original_sig = circ.circuit_signature();
        let lowered_sig = lowered_circ.circuit_signature();
        assert_eq!(lowered_sig.input(), original_sig.input());

        // The output signature may have changed due CFG branch tag removal.
        let output_count_diff =
            original_sig.output().len() as isize - lowered_sig.output().len() as isize;
        assert!(
            output_count_diff == 0 || output_count_diff == 1,
            "Output count mismatch. Original: {original_sig}, Lowered: {lowered_sig}"
        );
        assert_eq!(
            lowered_sig.output()[..],
            original_sig.output()[output_count_diff as usize..]
        );

        // Check that the output node was successfully updated
        let output_sig = lowered_circ
            .hugr()
            .signature(lowered_circ.output_node())
            .unwrap();
        assert_eq!(lowered_sig.output(), output_sig.input());
    }
}
