"""
Pather with batched (multi-step) rendering
"""
from typing import Self
from collections.abc import Sequence, Mapping, MutableMapping
import copy
import logging
from collections import defaultdict
from pprint import pformat

import numpy
from numpy import pi
from numpy.typing import ArrayLike

from ..pattern import Pattern
from ..library import ILibrary
from ..error import PortError, BuildError
from ..ports import PortList, Port
from ..abstract import Abstract
from ..utils import SupportsBool
from .tools import Tool, RenderStep
from .utils import ell


logger = logging.getLogger(__name__)


class RenderPather(PortList):
    """
      `RenderPather` is an alternative to `Pather` which uses the `path`/`path_to`/`mpath`
    functions to plan out wire paths without incrementally generating the layout. Instead,
    it waits until `render` is called, at which point it draws all the planned segments
    simultaneously. This allows it to e.g. draw each wire using a single `Path` or
    `Polygon` shape instead of multiple rectangles.

      `RenderPather` calls out to `Tool.planL` and `Tool.render` to provide tool-specific
    dimensions and build the final geometry for each wire. `Tool.planL` provides the
    output port data (relative to the input) for each segment. The tool, input and output
    ports are placed into a `RenderStep`, and a sequence of `RenderStep`s is stored for
    each port. When `render` is called, it bundles `RenderStep`s into batches which use
    the same `Tool`, and passes each batch to the relevant tool's `Tool.render` to build
    the geometry.

    See `Pather` for routing examples. After routing is complete, `render` must be called
    to generate the final geometry.
    """
    __slots__ = ('pattern', 'library', 'paths', 'tools', '_dead', )

    pattern: Pattern
    """ Layout of this device """

    library: ILibrary
    """ Library from which patterns should be referenced """

    _dead: bool
    """ If True, plug()/place() are skipped (for debugging) """

    paths: defaultdict[str, list[RenderStep]]
    """ Per-port list of operations, to be used by `render` """

    tools: dict[str | None, Tool]
    """
    Tool objects are used to dynamically generate new single-use Devices
    (e.g wires or waveguides) to be plugged into this device.
    """

    @property
    def ports(self) -> dict[str, Port]:
        return self.pattern.ports

    @ports.setter
    def ports(self, value: dict[str, Port]) -> None:
        self.pattern.ports = value

    def __init__(
            self,
            library: ILibrary,
            *,
            pattern: Pattern | None = None,
            ports: str | Mapping[str, Port] | None = None,
            tools: Tool | MutableMapping[str | None, Tool] | None = None,
            name: str | None = None,
            ) -> None:
        """
        Args:
            library: The library from which referenced patterns will be taken,
                and where new patterns (e.g. generated by the `tools`) will be placed.
            pattern: The pattern which will be modified by subsequent operations.
                If `None` (default), a new pattern is created.
            ports: Allows specifying the initial set of ports, if `pattern` does
                not already have any ports (or is not provided). May be a string,
                in which case it is interpreted as a name in `library`.
                Default `None` (no ports).
            tools: A mapping of {port: tool} which specifies what `Tool` should be used
                to generate waveguide or wire segments when `path`/`path_to`/`mpath`
                are called. Relies on `Tool.planL` and `Tool.render` implementations.
            name: If specified, `library[name]` is set to `self.pattern`.
        """
        self._dead = False
        self.paths = defaultdict(list)
        self.library = library
        if pattern is not None:
            self.pattern = pattern
        else:
            self.pattern = Pattern()

        if ports is not None:
            if self.pattern.ports:
                raise BuildError('Ports supplied for pattern with pre-existing ports!')
            if isinstance(ports, str):
                if library is None:
                    raise BuildError('Ports given as a string, but `library` was `None`!')
                ports = library.abstract(ports).ports

            self.pattern.ports.update(copy.deepcopy(dict(ports)))

        if name is not None:
            if library is None:
                raise BuildError('Name was supplied, but no library was given!')
            library[name] = self.pattern

        if tools is None:
            self.tools = {}
        elif isinstance(tools, Tool):
            self.tools = {None: tools}
        else:
            self.tools = dict(tools)

    @classmethod
    def interface(
            cls: type['RenderPather'],
            source: PortList | Mapping[str, Port] | str,
            *,
            library: ILibrary | None = None,
            tools: Tool | MutableMapping[str | None, Tool] | None = None,
            in_prefix: str = 'in_',
            out_prefix: str = '',
            port_map: dict[str, str] | Sequence[str] | None = None,
            name: str | None = None,
            ) -> 'RenderPather':
        """
        Wrapper for `Pattern.interface()`, which returns a RenderPather instead.

        Args:
            source: A collection of ports (e.g. Pattern, Builder, or dict)
                from which to create the interface. May be a pattern name if
                `library` is provided.
            library: Library from which existing patterns should be referenced,
                and to which the new one should be added (if named). If not provided,
                `source.library` must exist and will be used.
            tools: `Tool`s which will be used by the pather for generating new wires
                or waveguides (via `path`/`path_to`/`mpath`).
            in_prefix: Prepended to port names for newly-created ports with
                reversed directions compared to the current device.
            out_prefix: Prepended to port names for ports which are directly
                copied from the current device.
            port_map: Specification for ports to copy into the new device:
                - If `None`, all ports are copied.
                - If a sequence, only the listed ports are copied
                - If a mapping, the listed ports (keys) are copied and
                    renamed (to the values).

        Returns:
            The new `RenderPather`, with an empty pattern and 2x as many ports as
              listed in port_map.

        Raises:
            `PortError` if `port_map` contains port names not present in the
                current device.
            `PortError` if applying the prefixes results in duplicate port
                names.
        """
        if library is None:
            if hasattr(source, 'library') and isinstance(source.library, ILibrary):
                library = source.library
            else:
                raise BuildError('No library provided (and not present in `source.library`')

        if tools is None and hasattr(source, 'tools') and isinstance(source.tools, dict):
            tools = source.tools

        if isinstance(source, str):
            source = library.abstract(source).ports

        pat = Pattern.interface(source, in_prefix=in_prefix, out_prefix=out_prefix, port_map=port_map)
        new = RenderPather(library=library, pattern=pat, name=name, tools=tools)
        return new

    def plug(
            self,
            other: Abstract | str,
            map_in: dict[str, str],
            map_out: dict[str, str | None] | None = None,
            *,
            mirrored: bool = False,
            inherit_name: bool = True,
            set_rotation: bool | None = None,
            append: bool = False,
            ) -> Self:
        """
          Wrapper for `Pattern.plug` which adds a `RenderStep` with opcode 'P'
        for any affected ports. This separates any future `RenderStep`s on the
        same port into a new batch, since the plugged device interferes with drawing.

        Args:
            other: An `Abstract`, string, or `Pattern` describing the device to be instatiated.
            map_in: dict of `{'self_port': 'other_port'}` mappings, specifying
                port connections between the two devices.
            map_out: dict of `{'old_name': 'new_name'}` mappings, specifying
                new names for ports in `other`.
            mirrored: Enables mirroring `other` across the x axis prior to
                connecting any ports.
            inherit_name: If `True`, and `map_in` specifies only a single port,
                and `map_out` is `None`, and `other` has only two ports total,
                then automatically renames the output port of `other` to the
                name of the port from `self` that appears in `map_in`. This
                makes it easy to extend a device with simple 2-port devices
                (e.g. wires) without providing `map_out` each time `plug` is
                called. See "Examples" above for more info. Default `True`.
            set_rotation: If the necessary rotation cannot be determined from
                the ports being connected (i.e. all pairs have at least one
                port with `rotation=None`), `set_rotation` must be provided
                to indicate how much `other` should be rotated. Otherwise,
                `set_rotation` must remain `None`.
            append: If `True`, `other` is appended instead of being referenced.
                Note that this does not flatten  `other`, so its refs will still
                be refs (now inside `self`).

        Returns:
            self

        Raises:
            `PortError` if any ports specified in `map_in` or `map_out` do not
                exist in `self.ports` or `other_names`.
            `PortError` if there are any duplicate names after `map_in` and `map_out`
                are applied.
            `PortError` if the specified port mapping is not achieveable (the ports
                do not line up)
        """
        if self._dead:
            logger.error('Skipping plug() since device is dead')
            return self

        other_tgt: Pattern | Abstract
        if isinstance(other, str):
            other_tgt = self.library.abstract(other)
        if append and isinstance(other, Abstract):
            other_tgt = self.library[other.name]

        # get rid of plugged ports
        for kk in map_in:
            if kk in self.paths:
                self.paths[kk].append(RenderStep('P', None, self.ports[kk].copy(), self.ports[kk].copy(), None))

        plugged = map_in.values()
        for name, port in other_tgt.ports.items():
            if name in plugged:
                continue
            new_name = map_out.get(name, name) if map_out is not None else name
            if new_name is not None and new_name in self.paths:
                self.paths[new_name].append(RenderStep('P', None, port.copy(), port.copy(), None))

        self.pattern.plug(
            other=other_tgt,
            map_in=map_in,
            map_out=map_out,
            mirrored=mirrored,
            inherit_name=inherit_name,
            set_rotation=set_rotation,
            append=append,
            )

        return self

    def place(
            self,
            other: Abstract | str,
            *,
            offset: ArrayLike = (0, 0),
            rotation: float = 0,
            pivot: ArrayLike = (0, 0),
            mirrored: bool = False,
            port_map: dict[str, str | None] | None = None,
            skip_port_check: bool = False,
            append: bool = False,
            ) -> Self:
        """
          Wrapper for `Pattern.place` which adds a `RenderStep` with opcode 'P'
        for any affected ports. This separates any future `RenderStep`s on the
        same port into a new batch, since the placed device interferes with drawing.

        Note that mirroring is applied before rotation; translation (`offset`) is applied last.

        Args:
            other: An `Abstract` or `Pattern` describing the device to be instatiated.
            offset: Offset at which to place the instance. Default (0, 0).
            rotation: Rotation applied to the instance before placement. Default 0.
            pivot: Rotation is applied around this pivot point (default (0, 0)).
                Rotation is applied prior to translation (`offset`).
            mirrored: Whether theinstance should be mirrored across the x axis.
                Mirroring is applied before translation and rotation.
            port_map: dict of `{'old_name': 'new_name'}` mappings, specifying
                new names for ports in the instantiated pattern. New names can be
                `None`, which will delete those ports.
            skip_port_check: Can be used to skip the internal call to `check_ports`,
                in case it has already been performed elsewhere.
            append: If `True`, `other` is appended instead of being referenced.
                Note that this does not flatten  `other`, so its refs will still
                be refs (now inside `self`).

        Returns:
            self

        Raises:
            `PortError` if any ports specified in `map_in` or `map_out` do not
                exist in `self.ports` or `other.ports`.
            `PortError` if there are any duplicate names after `map_in` and `map_out`
                are applied.
        """
        if self._dead:
            logger.error('Skipping place() since device is dead')
            return self

        other_tgt: Pattern | Abstract
        if isinstance(other, str):
            other_tgt = self.library.abstract(other)
        if append and isinstance(other, Abstract):
            other_tgt = self.library[other.name]

        for name, port in other_tgt.ports.items():
            new_name = port_map.get(name, name) if port_map is not None else name
            if new_name is not None and new_name in self.paths:
                self.paths[new_name].append(RenderStep('P', None, port.copy(), port.copy(), None))

        self.pattern.place(
            other=other_tgt,
            offset=offset,
            rotation=rotation,
            pivot=pivot,
            mirrored=mirrored,
            port_map=port_map,
            skip_port_check=skip_port_check,
            append=append,
            )

        return self

    def retool(
            self,
            tool: Tool,
            keys: str | Sequence[str | None] | None = None,
            ) -> Self:
        """
        Update the `Tool` which will be used when generating `Pattern`s for the ports
        given by `keys`.

        Args:
            tool: The new `Tool` to use for the given ports.
            keys: Which ports the tool should apply to. `None` indicates the default tool,
                used when there is no matching entry in `self.tools` for the port in question.

        Returns:
            self
        """
        if keys is None or isinstance(keys, str):
            self.tools[keys] = tool
        else:
            for key in keys:
                self.tools[key] = tool
        return self

    def path(
            self,
            portspec: str,
            ccw: SupportsBool | None,
            length: float,
            **kwargs,
            ) -> Self:
        """
        Plan a "wire"/"waveguide" extending from the port `portspec`, with the aim
        of traveling exactly `length` distance.

        The wire will travel `length` distance along the port's axis, an an unspecified
        (tool-dependent) distance in the perpendicular direction. The output port will
        be rotated (or not) based on the `ccw` parameter.

        `RenderPather.render` must be called after all paths have been fully planned.

        Args:
            portspec: The name of the port into which the wire will be plugged.
            ccw: If `None`, the output should be along the same axis as the input.
                Otherwise, cast to bool and turn counterclockwise if True
                and clockwise otherwise.
            length: The total distance from input to output, along the input's axis only.
                (There may be a tool-dependent offset along the other axis.)

        Returns:
            self

        Raises:
            BuildError if `distance` is too small to fit the bend (if a bend is present).
            LibraryError if no valid name could be picked for the pattern.
        """
        if self._dead:
            logger.error('Skipping path() since device is dead')
            return self

        port = self.pattern[portspec]
        in_ptype = port.ptype
        port_rot = port.rotation
        assert port_rot is not None         # TODO allow manually setting rotation for RenderPather.path()?

        tool = self.tools.get(portspec, self.tools[None])
        # ask the tool for bend size (fill missing dx or dy), check feasibility, and get out_ptype
        out_port, data = tool.planL(ccw, length, in_ptype=in_ptype, **kwargs)

        # Update port
        out_port.rotate_around((0, 0), pi + port_rot)
        out_port.translate(port.offset)

        step = RenderStep('L', tool, port.copy(), out_port.copy(), data)
        self.paths[portspec].append(step)

        self.pattern.ports[portspec] = out_port.copy()

        return self

    def path_to(
            self,
            portspec: str,
            ccw: SupportsBool | None,
            position: float | None = None,
            *,
            x: float | None = None,
            y: float | None = None,
            **kwargs,
            ) -> Self:
        """
        Plan a "wire"/"waveguide" extending from the port `portspec`, with the aim
        of ending exactly at a target position.

        The wire will travel so that the output port will be placed at exactly the target
        position along the input port's axis. There can be an unspecified (tool-dependent)
        offset in the perpendicular direction. The output port will be rotated (or not)
        based on the `ccw` parameter.

        `RenderPather.render` must be called after all paths have been fully planned.

        Args:
            portspec: The name of the port into which the wire will be plugged.
            ccw: If `None`, the output should be along the same axis as the input.
                Otherwise, cast to bool and turn counterclockwise if True
                and clockwise otherwise.
            position: The final port position, along the input's axis only.
                (There may be a tool-dependent offset along the other axis.)
                Only one of `position`, `x`, and `y` may be specified.
            x: The final port position along the x axis.
                `portspec` must refer to a horizontal port if `x` is passed, otherwise a
                BuildError will be raised.
            y: The final port position along the y axis.
                `portspec` must refer to a vertical port if `y` is passed, otherwise a
                BuildError will be raised.

        Returns:
            self

        Raises:
            BuildError if `position`, `x`, or `y` is too close to fit the bend (if a bend
                is present).
            BuildError if `x` or `y` is specified but does not match the axis of `portspec`.
            BuildError if more than one of `x`, `y`, and `position` is specified.
        """
        if self._dead:
            logger.error('Skipping path_to() since device is dead')
            return self

        pos_count = sum(vv is not None for vv in (position, x, y))
        if pos_count > 1:
            raise BuildError('Only one of `position`, `x`, and `y` may be specified at once')
        if pos_count < 1:
            raise BuildError('One of `position`, `x`, and `y` must be specified')

        port = self.pattern[portspec]
        if port.rotation is None:
            raise PortError(f'Port {portspec} has no rotation and cannot be used for path_to()')

        if not numpy.isclose(port.rotation % (pi / 2), 0):
            raise BuildError('path_to was asked to route from non-manhattan port')

        is_horizontal = numpy.isclose(port.rotation % pi, 0)
        if is_horizontal:
            if y is not None:
                raise BuildError('Asked to path to y-coordinate, but port is horizontal')
            if position is None:
                position = x
        else:
            if x is not None:
                raise BuildError('Asked to path to x-coordinate, but port is vertical')
            if position is None:
                position = y

        x0, y0 = port.offset
        if is_horizontal:
            if numpy.sign(numpy.cos(port.rotation)) == numpy.sign(position - x0):
                raise BuildError(f'path_to routing to behind source port: x0={x0:g} to {position:g}')
            length = numpy.abs(position - x0)
        else:
            if numpy.sign(numpy.sin(port.rotation)) == numpy.sign(position - y0):
                raise BuildError(f'path_to routing to behind source port: y0={y0:g} to {position:g}')
            length = numpy.abs(position - y0)

        return self.path(portspec, ccw, length, **kwargs)

    def mpath(
            self,
            portspec: str | Sequence[str],
            ccw: SupportsBool | None,
            *,
            spacing: float | ArrayLike | None = None,
            set_rotation: float | None = None,
            **kwargs,
            ) -> Self:
        """
        `mpath` is a superset of `path` and `path_to` which can act on bundles or buses
        of "wires or "waveguides".

        See `Pather.mpath` for details.

        Args:
            portspec: The names of the ports which are to be routed.
            ccw: If `None`, the outputs should be along the same axis as the inputs.
                Otherwise, cast to bool and turn 90 degrees counterclockwise if `True`
                and clockwise otherwise.
            spacing: Center-to-center distance between output ports along the input port's axis.
                Must be provided if (and only if) `ccw` is not `None`.
            set_rotation: If the provided ports have `rotation=None`, this can be used
                to set a rotation for them.

        Returns:
            self

        Raises:
            BuildError if the implied length for any wire is too close to fit the bend
                (if a bend is requested).
            BuildError if `xmin`/`xmax` or `ymin`/`ymax` is specified but does not
                match the axis of `portspec`.
            BuildError if an incorrect bound type or spacing is specified.
        """
        if self._dead:
            logger.error('Skipping mpath() since device is dead')
            return self

        bound_types = set()
        if 'bound_type' in kwargs:
            bound_types.add(kwargs['bound_type'])
            bound = kwargs['bound']
        for bt in ('emin', 'emax', 'pmin', 'pmax', 'xmin', 'xmax', 'ymin', 'ymax', 'min_past_furthest'):
            if bt in kwargs:
                bound_types.add(bt)
                bound = kwargs[bt]

        if not bound_types:
            raise BuildError('No bound type specified for mpath')
        if len(bound_types) > 1:
            raise BuildError(f'Too many bound types specified for mpath: {bound_types}')
        bound_type = tuple(bound_types)[0]

        if isinstance(portspec, str):
            portspec = [portspec]
        ports = self.pattern[tuple(portspec)]

        extensions = ell(ports, ccw, spacing=spacing, bound=bound, bound_type=bound_type, set_rotation=set_rotation)

        if len(ports) == 1:
            # Not a bus, so having a container just adds noise to the layout
            port_name = tuple(portspec)[0]
            self.path(port_name, ccw, extensions[port_name])
        else:
            for port_name, length in extensions.items():
                self.path(port_name, ccw, length)
        return self

    def render(
            self,
            append: bool = True,
            ) -> Self:
        """
        Generate the geometry which has been planned out with `path`/`path_to`/etc.

        Args:
            append: If `True`, the rendered geometry will be directly appended to
                `self.pattern`. Note that it will not be flattened, so if only one
                layer of hierarchy is eliminated.

        Returns:
            self
        """
        lib = self.library
        tool_port_names = ('A', 'B')
        pat = Pattern()

        def render_batch(portspec: str, batch: list[RenderStep], append: bool) -> None:
            assert batch[0].tool is not None
            name = lib << batch[0].tool.render(batch, port_names=tool_port_names)
            pat.ports[portspec] = batch[0].start_port.copy()
            if append:
                pat.plug(lib[name], {portspec: tool_port_names[0]}, append=append)
                del lib[name]       # NOTE if the rendered pattern has refs, those are now in `pat` but not flattened
            else:
                pat.plug(lib.abstract(name), {portspec: tool_port_names[0]}, append=append)

        for portspec, steps in self.paths.items():
            batch: list[RenderStep] = []
            for step in steps:
                appendable_op = step.opcode in ('L', 'S', 'U')
                same_tool = batch and step.tool == batch[0].tool

                # If we can't continue a batch, render it
                if batch and (not appendable_op or not same_tool):
                    render_batch(portspec, batch, append)
                    batch = []

                # batch is emptied already if we couldn't continue it
                if appendable_op:
                    batch.append(step)

                # Opcodes which break the batch go below this line
                if not appendable_op and portspec in pat.ports:
                    del pat.ports[portspec]

            #If the last batch didn't end yet
            if batch:
                render_batch(portspec, batch, append)

        self.paths.clear()
        pat.ports.clear()
        self.pattern.append(pat)

        return self

    def translate(self, offset: ArrayLike) -> Self:
        """
        Translate the pattern and all ports.

        Args:
            offset: (x, y) distance to translate by

        Returns:
            self
        """
        self.pattern.translate_elements(offset)
        return self

    def rotate_around(self, pivot: ArrayLike, angle: float) -> Self:
        """
        Rotate the pattern and all ports.

        Args:
            angle: angle (radians, counterclockwise) to rotate by
            pivot: location to rotate around

        Returns:
            self
        """
        self.pattern.rotate_around(pivot, angle)
        return self

    def mirror(self, axis: int) -> Self:
        """
        Mirror the pattern and all ports across the specified axis.

        Args:
            axis: Axis to mirror across (x=0, y=1)

        Returns:
            self
        """
        self.pattern.mirror(axis)
        return self

    def set_dead(self) -> Self:
        """
        Disallows further changes through `plug()` or `place()`.
        This is meant for debugging:
        ```
            dev.plug(a, ...)
            dev.set_dead()      # added for debug purposes
            dev.plug(b, ...)    # usually raises an error, but now skipped
            dev.plug(c, ...)    # also skipped
            dev.pattern.visualize()     # shows the device as of the set_dead() call
        ```

        Returns:
            self
        """
        self._dead = True
        return self

    def __repr__(self) -> str:
        s = f'<Pather {self.pattern} L({len(self.library)}) {pformat(self.tools)}>'
        return s


