#
# Finite Volume discretisation class
#
import numpy as np
from scipy.sparse import (
    coo_matrix,
    csr_matrix,
    diags,
    eye,
    hstack,
    kron,
    lil_matrix,
    spdiags,
    vstack,
)

import pybamm


class FiniteVolume(pybamm.SpatialMethod):
    """
    A class which implements the steps specific to the finite volume method during
    discretisation.

    For broadcast and mass_matrix, we follow the default behaviour from SpatialMethod.

    Parameters
    ----------
    options : dict-like, optional
        A dictionary of options to be passed to the spatial method. The only option
        currently available is "extrapolation", which has options for "order" and "use_bcs".
        It sets the order separately for `pybamm.BoundaryValue` and `pybamm.BoundaryGradient`.
        Default is "linear" for the value and quadratic for the gradient.
    """

    def __init__(self, options=None):
        super().__init__(options)

    def build(self, mesh):
        super().build(mesh)

        # add npts_for_broadcast to mesh domains for this particular discretisation
        for dom in mesh.keys():
            mesh[dom].npts_for_broadcast_to_nodes = mesh[dom].npts

    def spatial_variable(self, symbol):
        """
        Creates a discretised spatial variable compatible with
        the FiniteVolume method.

        Parameters
        ----------
        symbol : :class:`pybamm.SpatialVariable`
            The spatial variable to be discretised.

        Returns
        -------
        :class:`pybamm.Vector`
            Contains the discretised spatial variable
        """
        symbol_mesh = self.mesh[symbol.domain]
        repeats = self._get_auxiliary_domain_repeats(symbol.domains)
        if symbol.evaluates_on_edges("primary"):
            if hasattr(symbol_mesh, "length"):
                edges = self._get_edges_symbolic_mesh(symbol.domains["primary"])
                entries = pybamm.kronecker_product(
                    pybamm.Matrix(np.ones(repeats)), edges
                )
                entries.domains = symbol.domains
            else:
                entries = pybamm.Vector(
                    np.tile(symbol_mesh.edges, repeats), domains=symbol.domains
                )
        else:
            if hasattr(symbol_mesh, "length"):
                nodes = self._get_nodes_symbolic_mesh(symbol.domains["primary"])
                entries = pybamm.kronecker_product(
                    pybamm.Matrix(np.ones(repeats)), nodes
                )
                entries.domains = symbol.domains
            else:
                entries = pybamm.Vector(
                    np.tile(symbol_mesh.nodes, repeats), domains=symbol.domains
                )

        return entries

    def gradient(self, symbol, discretised_symbol, boundary_conditions):
        """Matrix-vector multiplication to implement the gradient operator.
        See :meth:`pybamm.SpatialMethod.gradient`
        """
        # Discretise symbol
        domain = symbol.domain

        # Add Dirichlet boundary conditions, if defined
        if symbol in boundary_conditions:
            bcs = boundary_conditions[symbol]
            if any(bc[1] == "Dirichlet" for bc in bcs.values()):
                # add ghost nodes and update domain
                discretised_symbol, domain = self.add_ghost_nodes(
                    symbol, discretised_symbol, bcs
                )

        # note in 1D cartesian, cylindrical and spherical grad are the same
        gradient_matrix = self.gradient_matrix(domain, symbol.domains)

        # Multiply by gradient matrix
        out = gradient_matrix @ discretised_symbol

        # Add Neumann boundary conditions, if defined
        if symbol in boundary_conditions:
            bcs = boundary_conditions[symbol]
            if any(bc[1] == "Neumann" for bc in bcs.values()):
                out = self.add_neumann_values(symbol, out, bcs, domain)

        return out

    def _get_edges_symbolic_mesh(self, domains: list[str]):
        submeshes = [self.mesh[domain] for domain in domains]
        edges_list = []
        for i, submesh_ in enumerate(submeshes):
            if i == 0:
                edges_ = pybamm.Vector(submesh_.edges)
            else:
                edges_ = pybamm.Vector(submesh_.edges[1:])
            # edges_ = submesh_.edges
            if hasattr(submesh_, "length"):
                edges_ = edges_ * submesh_.length + submesh_.min
            edges_list.append(edges_)
        edges = pybamm.numpy_concatenation(*edges_list)
        return edges

    def _get_nodes_symbolic_mesh(self, domains: list[str]):
        submeshes = [self.mesh[domain_] for domain_ in domains]
        nodes_list = []
        for submesh_ in submeshes:
            nodes_ = pybamm.Vector(submesh_.nodes)
            if hasattr(submesh_, "length"):
                nodes_ = nodes_ * submesh_.length + submesh_.min
            nodes_list.append(nodes_)
        nodes = pybamm.numpy_concatenation(*nodes_list)
        return nodes

    def _get_d_nodes_symbolic_mesh(self, domains: list[str]):
        nodes = self._get_nodes_symbolic_mesh(domains)
        diff_matrix = diags([-1, 1], [0, 1], shape=(nodes.size - 1, nodes.size))
        d_nodes = pybamm.Matrix(diff_matrix) @ nodes
        return d_nodes

    def _get_d_edges_symbolic_mesh(self, domains: list[str]):
        edges = self._get_edges_symbolic_mesh(domains)
        diff_matrix = diags([-1, 1], [0, 1], shape=(edges.size - 1, edges.size))
        d_edges = pybamm.Matrix(diff_matrix) @ edges
        return d_edges

    def _get_first_node(self, domains: list[str]):
        submesh = self.mesh[domains[0]]
        if hasattr(submesh, "length"):
            first_node = pybamm.Scalar(submesh.nodes[0]) * submesh.length + submesh.min
        else:
            first_node = pybamm.Scalar(submesh.nodes[0])
        return first_node

    def _get_last_node(self, domains: list[str]):
        submesh = self.mesh[domains[-1]]
        if hasattr(submesh, "length"):
            last_node = pybamm.Scalar(submesh.nodes[-1]) * submesh.length + submesh.min
        else:
            last_node = pybamm.Scalar(submesh.nodes[-1])
        return last_node

    def _get_edges_left_right_symbolic_mesh(self, domains: list[str]):
        submeshes = [self.mesh[domain] for domain in domains]
        edges_left_list = []
        edges_right_list = []
        for i, submesh_ in enumerate(submeshes):
            if i == 0:
                edges_left = submesh_.edges
                edges_right = submesh_.edges[1:]
            else:
                edges_left = submesh_.edges[1:]
                edges_right = submesh_.edges[1:]
            if i == len(submeshes) - 1:
                edges_left = edges_left[:-1]
            edges_left = pybamm.Vector(edges_left)
            edges_right = pybamm.Vector(edges_right)
            if hasattr(submesh_, "length"):
                edges_left = edges_left * submesh_.length + submesh_.min
                edges_right = edges_right * submesh_.length + submesh_.min
            edges_left_list.append(edges_left)
            edges_right_list.append(edges_right)
        edges_left = pybamm.numpy_concatenation(*edges_left_list)
        edges_right = pybamm.numpy_concatenation(*edges_right_list)
        return edges_left, edges_right

    def gradient_matrix(self, domain, domains):
        """
        Gradient matrix for finite volumes in the appropriate domain.
        Equivalent to grad(y) = (y[1:] - y[:-1])/dx

        Parameters
        ----------
        domains : list
            The domain in which to compute the gradient matrix, including ghost nodes

        Returns
        -------
        :class:`pybamm.Matrix`
            The (sparse) finite volume gradient matrix for the domain
        """
        # Create appropriate submesh by combining submeshes in primary domain
        submesh = self.mesh[domain]
        if hasattr(submesh, "length"):
            d_nodes = self._get_d_nodes_symbolic_mesh(domain)
            e = 1 / d_nodes
        else:
            e = 1 / submesh.d_nodes

        # Create 1D matrix using submesh
        n = submesh.npts
        sub_matrix_minus = pybamm.Matrix(diags([-1], [0], shape=(n - 1, n)))
        sub_matrix_plus = pybamm.Matrix(diags([1], [1], shape=(n - 1, n)))
        sub_matrix = (sub_matrix_minus + sub_matrix_plus) * e

        # number of repeats
        second_dim_repeats = self._get_auxiliary_domain_repeats(domains)

        # generate full matrix from the submatrix
        # Convert to csr_matrix so that we can take the index (row-slicing), which is
        # not supported by the default kron format
        # Note that this makes column-slicing inefficient, but this should not be an
        # issue
        matrix = pybamm.kronecker_product(
            pybamm.Matrix(eye(second_dim_repeats)), sub_matrix
        )
        return matrix

    def divergence(self, symbol, discretised_symbol, boundary_conditions):
        """Matrix-vector multiplication to implement the divergence operator.
        See :meth:`pybamm.SpatialMethod.divergence`
        """
        submesh = self.mesh[symbol.domain]

        divergence_matrix = self.divergence_matrix(symbol.domains)

        # check coordinate system
        if submesh.coord_sys in ["cylindrical polar", "spherical polar"]:
            second_dim_repeats = self._get_auxiliary_domain_repeats(symbol.domains)
            # create np.array of repeated submesh.edges
            if hasattr(submesh, "length"):
                edges = self._get_edges_symbolic_mesh(symbol.domains["primary"])
            else:
                edges = submesh.edges

            r_edges = pybamm.kronecker_product(
                pybamm.Matrix(np.ones(second_dim_repeats)), edges
            )
            if submesh.coord_sys == "spherical polar":
                out = divergence_matrix @ ((r_edges**2) * discretised_symbol)
            elif submesh.coord_sys == "cylindrical polar":
                out = divergence_matrix @ (r_edges * discretised_symbol)
        else:
            out = divergence_matrix @ discretised_symbol

        return out

    def divergence_matrix(self, domains):
        """
        Divergence matrix for finite volumes in the appropriate domain.
        Equivalent to div(N) = (N[1:] - N[:-1])/dx

        Parameters
        ----------
        domains : dict
            The domain(s) and auxiliary domain in which to compute the divergence matrix

        Returns
        -------
        :class:`pybamm.Matrix`
            The (sparse) finite volume divergence matrix for the domain
        """
        # Create appropriate submesh by combining submeshes in domain
        submesh = self.mesh[domains["primary"]]
        if hasattr(submesh, "length"):
            d_edges = self._get_d_edges_symbolic_mesh(domains["primary"])
        else:
            d_edges = pybamm.Vector(submesh.d_edges)

        # check coordinate system
        if submesh.coord_sys in ["cylindrical polar", "spherical polar"]:
            if hasattr(submesh, "length"):
                r_edges_left, r_edges_right = self._get_edges_left_right_symbolic_mesh(
                    domains["primary"]
                )
            else:
                r_edges_left = submesh.edges[:-1]
                r_edges_right = submesh.edges[1:]
            if submesh.coord_sys == "spherical polar":
                d_edges = (r_edges_right**3 - r_edges_left**3) / 3
            elif submesh.coord_sys == "cylindrical polar":
                d_edges = (r_edges_right**2 - r_edges_left**2) / 2
        else:
            d_edges = d_edges
        e = 1 / d_edges

        # Create matrix using submesh
        n = submesh.npts + 1
        sub_matrix_minus = pybamm.Matrix(diags([-1], [0], shape=(n - 1, n)))
        sub_matrix_plus = pybamm.Matrix(diags([1], [1], shape=(n - 1, n)))
        sub_matrix = (sub_matrix_minus + sub_matrix_plus) * e

        # repeat matrix for each node in secondary dimensions
        second_dim_repeats = self._get_auxiliary_domain_repeats(domains)
        # generate full matrix from the submatrix
        matrix = pybamm.kronecker_product(
            pybamm.Matrix(eye(second_dim_repeats)), sub_matrix
        )
        return matrix

    def laplacian(self, symbol, discretised_symbol, boundary_conditions):
        """
        Laplacian operator, implemented as div(grad(.))
        See :meth:`pybamm.SpatialMethod.laplacian`
        """
        grad = self.gradient(symbol, discretised_symbol, boundary_conditions)
        return self.divergence(grad, grad, boundary_conditions)

    def integral(
        self, child, discretised_child, integration_dimension, integration_variable
    ):
        """Vector-vector dot product to implement the integral operator."""
        integration_vector = self.definite_integral_matrix(
            child, integration_dimension=integration_dimension
        )
        out = integration_vector @ discretised_child

        return out

    def definite_integral_matrix(
        self, child, vector_type="row", integration_dimension="primary"
    ):
        """
        Matrix for finite-volume implementation of the definite integral in the
        primary dimension

        .. math::
            I = \\int_{a}^{b}\\!f(s)\\,ds

        for where :math:`a` and :math:`b` are the left-hand and right-hand boundaries of
        the domain respectively

        Parameters
        ----------
        child : :class:`pybamm.Symbol`
            The symbol being integrated
        vector_type : str, optional
            Whether to return a row or column vector in the primary dimension
            (default is row)
        integration_dimension : str, optional
            The dimension in which to integrate (default is "primary")

        Returns
        -------
        :class:`pybamm.Matrix`
            The finite volume integral matrix for the domain
        """
        domains = child.domains
        if vector_type != "row" and integration_dimension != "primary":
            raise NotImplementedError(
                f"Integral in {integration_dimension} vector only implemented in 'row' form"
            )

        domain = child.domains[integration_dimension]
        submesh = self.mesh[domain]

        # check coordinate system
        if submesh.coord_sys in ["cylindrical polar", "spherical polar"]:
            if hasattr(submesh, "length"):
                r_edges_left, r_edges_right = self._get_edges_left_right_symbolic_mesh(
                    domains["primary"]
                )
            else:
                r_edges_left = pybamm.Vector(submesh.edges[:-1])
                r_edges_right = pybamm.Vector(submesh.edges[1:])
            if submesh.coord_sys == "spherical polar":
                d_edges = 4 * np.pi * (r_edges_right**3 - r_edges_left**3) / 3
            elif submesh.coord_sys == "cylindrical polar":
                d_edges = 2 * np.pi * (r_edges_right**2 - r_edges_left**2) / 2
        else:
            if hasattr(submesh, "length"):
                d_edges = self._get_d_edges_symbolic_mesh(domains["primary"])
            else:
                d_edges = pybamm.Vector(submesh.d_edges)
        possible_dimensions = ["primary", "secondary", "tertiary", "quaternary"]
        if integration_dimension == "primary":
            # Create appropriate submesh by combining submeshes in domain
            submesh = self.mesh[domains["primary"]]

            # Create vector of ones for primary domain submesh

            if vector_type == "row":
                d_edges = pybamm.Transpose(d_edges)
            elif vector_type == "column":
                d_edges = d_edges

            # repeat matrix for each node in secondary dimensions
            second_dim_repeats = self._get_auxiliary_domain_repeats(domains)
            # generate full matrix from the submatrix
            matrix = pybamm.kronecker_product(
                pybamm.Matrix(eye(second_dim_repeats)), d_edges
            )
        elif integration_dimension in possible_dimensions[1:]:
            this_dimension_index = possible_dimensions.index(integration_dimension)
            # get lower dimensions and the corresponding domains, i.e. if integration_dimension is "secondary",
            # lower_dimensions is ["primary"] and lower_domains is [child.domains["primary"]]
            lower_dimensions = possible_dimensions[:this_dimension_index]
            lower_domains = [child.domains[dimension] for dimension in lower_dimensions]
            # get higher dimensions, i.e. if integration_dimension is "secondary",
            # higher_dimensions is ["tertiary", "quaternary"]
            higher_dimensions = possible_dimensions[this_dimension_index + 1 :]
            n_lower_pts = 1
            #  Lower dimensions should be repeated, so add them to the eye matrix
            for lower_domain, lower_dimension in zip(
                lower_domains, lower_dimensions, strict=False
            ):
                lower_submesh = self.mesh[lower_domain]
                if child.evaluates_on_edges(lower_dimension):
                    n_lower_pts *= lower_submesh.npts + 1
                else:
                    n_lower_pts *= lower_submesh.npts
            if d_edges.shape[0] == 1:
                int_matrix = pybamm.kronecker_product(
                    d_edges, pybamm.Matrix(eye(n_lower_pts))
                )
            else:
                int_matrix = pybamm.kronecker_product(
                    pybamm.Transpose(d_edges), pybamm.Matrix(eye(n_lower_pts))
                )

            # Higher dimensions should be tiled, so repeat the matrix for each higher dimension.
            higher_repeats = self._get_auxiliary_domain_repeats(
                {k: v for k, v in domains.items() if (k in higher_dimensions)}
            )
            matrix = pybamm.kronecker_product(
                pybamm.Matrix(eye(higher_repeats)), int_matrix
            )
        # generate full matrix from the submatrix
        # Convert to csr_matrix so that we can take the index (row-slicing), which is
        # not supported by the default kron format
        # Note that this makes column-slicing inefficient, but this should not be an
        # issue
        return matrix

    def indefinite_integral(self, child, discretised_child, direction):
        """Implementation of the indefinite integral operator."""

        # Different integral matrix depending on whether the integrand evaluates on
        # edges or nodes
        if child.evaluates_on_edges("primary"):
            integration_matrix = self.indefinite_integral_matrix_edges(
                child.domains, direction
            )
        else:
            # Check coordinate system is not cylindrical or spherical polar for
            # the case where child evaluates on edges
            # If it becomes necessary to implement this, will need to think about what
            # the cylindrical/spherical polar indefinite integral should be
            submesh = self.mesh[child.domain]
            if submesh.coord_sys in ["cylindrical polar", "spherical polar"]:
                raise NotImplementedError(
                    f"Indefinite integral on a {submesh.coord_sys} domain is not "
                    "implemented"
                )
            integration_matrix = self.indefinite_integral_matrix_nodes(
                child.domains, direction
            )

        # Don't need to check for cylindrical/spherical domains as we have ruled
        # these out in the case that involves integrating a divergence
        # (child evaluates on nodes)
        out = integration_matrix @ discretised_child

        out.copy_domains(child)

        return out

    @staticmethod
    def _get_integral_node_edge_matrix(vector, n):
        vec_transposed = pybamm.Transpose(vector)
        cols = []
        for _ in range(n):
            cols.append(vec_transposed)
        return pybamm.SparseStack(*cols)

    def indefinite_integral_matrix_edges(self, domains, direction):
        """
        Matrix for finite-volume implementation of the indefinite integral where the
        integrand is evaluated on mesh edges (shape (n+1, 1)).
        The integral will then be evaluated on mesh nodes (shape (n, 1)).

        Parameters
        ----------
        domains : dict
            The domain(s) and auxiliary domains of integration
        direction : str
            The direction of integration (forward or backward). See notes.

        Returns
        -------
        :class:`pybamm.Matrix`
            The finite volume integral matrix for the domain

        Notes
        -----

        **Forward integral**

        .. math::
            F(x) = \\int_0^x\\!f(u)\\,du

        The indefinite integral must satisfy the following conditions:

        - :math:`F(0) = 0`
        - :math:`f(x) = \\frac{dF}{dx}`

        or, in discrete form,

        - `BoundaryValue(F, "left") = 0`, i.e. :math:`3*F_0 - F_1 = 0`
        - :math:`f_{i+1/2} = (F_{i+1} - F_i) / dx_{i+1/2}`

        Hence we must have

        - :math:`F_0 = du_{1/2} * f_{1/2} / 2`
        - :math:`F_{i+1} = F_i + du_{i+1/2} * f_{i+1/2}`

        Note that :math:`f_{-1/2}` and :math:`f_{end+1/2}` are included in the discrete
        integrand vector `f`, so we add a column of zeros at each end of the
        indefinite integral matrix to ignore these.

        **Backward integral**

        .. math::
            F(x) = \\int_x^{end}\\!f(u)\\,du

        The indefinite integral must satisfy the following conditions:

        - :math:`F(end) = 0`
        - :math:`f(x) = -\\frac{dF}{dx}`

        or, in discrete form,

        - `BoundaryValue(F, "right") = 0`, i.e. :math:`3*F_{end} - F_{end-1} = 0`
        - :math:`f_{i+1/2} = -(F_{i+1} - F_i) / dx_{i+1/2}`

        Hence we must have

        - :math:`F_{end} = du_{end+1/2} * f_{end-1/2} / 2`
        - :math:`F_{i-1} = F_i + du_{i-1/2} * f_{i-1/2}`

        Note that :math:`f_{-1/2}` and :math:`f_{end+1/2}` are included in the discrete
        integrand vector `f`, so we add a column of zeros at each end of the
        indefinite integral matrix to ignore these.
        """

        # Create appropriate submesh by combining submeshes in domain
        submesh = self.mesh[domains["primary"]]
        n = submesh.npts
        second_dim_repeats = self._get_auxiliary_domain_repeats(domains)
        if hasattr(submesh, "length"):
            d_nodes = self._get_d_nodes_symbolic_mesh(domains["primary"])
        else:
            d_nodes = pybamm.Vector(submesh.d_nodes)
        d_nodes_matrix = self._get_integral_node_edge_matrix(d_nodes, n)
        if direction == "forward":
            du_entries = [np.ones(d_nodes.size)] * (n - 1)
            offset = -np.arange(1, n, 1)
            main_integral_matrix = d_nodes_matrix * pybamm.Matrix(
                spdiags(du_entries, offset, n, n - 1)
            )
            bc_offset_matrix = lil_matrix((n, n - 1))
            bc_offset_matrix[:, 0] = 1.0
            bc_offset_matrix = d_nodes_matrix * pybamm.Matrix(bc_offset_matrix) / 2
        elif direction == "backward":
            du_entries = [np.ones(d_nodes.size)] * (n + 1)
            offset = np.arange(n, -1, -1)
            main_integral_matrix = d_nodes_matrix * pybamm.Matrix(
                spdiags(du_entries, offset, n, n - 1)
            )
            bc_offset_matrix = lil_matrix((n, n - 1))
            bc_offset_matrix[:, -1] = 1.0
            bc_offset_matrix = d_nodes_matrix * pybamm.Matrix(bc_offset_matrix) / 2
        sub_matrix = main_integral_matrix + bc_offset_matrix
        # add a column of zeros at each end
        zero_col = pybamm.Transpose(pybamm.Matrix(csr_matrix((n, 1))))
        sub_matrix_transposed = pybamm.Transpose(sub_matrix)
        sub_matrix = pybamm.SparseStack(zero_col, sub_matrix_transposed, zero_col)
        sub_matrix = pybamm.Transpose(sub_matrix)
        # Convert to csr_matrix so that we can take the index (row-slicing), which is
        # not supported by the default kron format
        # Note that this makes column-slicing inefficient, but this should not be an
        # issue
        matrix = pybamm.kronecker_product(
            pybamm.Matrix(eye(second_dim_repeats)), sub_matrix
        )
        return matrix

    def indefinite_integral_matrix_nodes(self, domains, direction):
        """
        Matrix for finite-volume implementation of the (backward) indefinite integral
        where the integrand is evaluated on mesh nodes (shape (n, 1)).
        The integral will then be evaluated on mesh edges (shape (n+1, 1)).
        This is just a straightforward (backward) cumulative sum of the integrand

        Parameters
        ----------
        domains : dict
            The domain(s) and auxiliary domains of integration
        direction : str
            The direction of integration (forward or backward)

        Returns
        -------
        :class:`pybamm.Matrix`
            The finite volume integral matrix for the domain
        """

        # Create appropriate submesh by combining submeshes in domain
        submesh = self.mesh[domains["primary"]]
        n = submesh.npts
        second_dim_repeats = self._get_auxiliary_domain_repeats(domains)
        if hasattr(submesh, "length"):
            d_edges = self._get_d_edges_symbolic_mesh(domains["primary"])
        else:
            d_edges = pybamm.Vector(submesh.d_edges)
        d_edges = self._get_d_edges_symbolic_mesh(domains["primary"])
        d_edges_matrix = self._get_integral_node_edge_matrix(d_edges, n + 1)
        du_entries = [np.ones(d_edges.size)] * n
        if direction == "forward":
            offset = -np.arange(1, n + 1, 1)  # from -1 down to -n
        elif direction == "backward":
            offset = np.arange(n - 1, -1, -1)  # from n-1 down to 0
        sub_matrix = d_edges_matrix * pybamm.Matrix(
            spdiags(du_entries, offset, n + 1, n)
        )
        # Convert to csr_matrix so that we can take the index (row-slicing), which is
        # not supported by the default kron format
        # Note that this makes column-slicing inefficient, but this should not be an
        # issue
        matrix = pybamm.kronecker_product(
            pybamm.Matrix(eye(second_dim_repeats)), sub_matrix
        )
        return matrix

    def delta_function(self, symbol, discretised_symbol):
        """
        Delta function. Implemented as a vector whose only non-zero element is the
        first (if symbol.side = "left") or last (if symbol.side = "right"), with
        appropriate value so that the integral of the delta function across the whole
        domain is the same as the integral of the discretised symbol across the whole
        domain.

        See :meth:`pybamm.SpatialMethod.delta_function`
        """
        # Find the number of submeshes
        submesh = self.mesh[symbol.domain]

        prim_pts = submesh.npts
        second_dim_repeats = self._get_auxiliary_domain_repeats(symbol.domains)

        # Create submatrix to compute delta function as a flux
        if hasattr(submesh, "length"):
            d_nodes = self._get_d_nodes_symbolic_mesh(symbol.domain)
        else:
            d_nodes = pybamm.Vector(submesh.d_nodes)
        if symbol.side == "left":
            dx_sub_matrix = pybamm.Vector(
                csr_matrix(([1], ([0], [0])), shape=(d_nodes.size, 1)).toarray()
            )
            dx = pybamm.Transpose(dx_sub_matrix) @ d_nodes
            sub_matrix = pybamm.Matrix(
                csr_matrix(([1], ([0], [0])), shape=(prim_pts, 1)).toarray()
            )
        elif symbol.side == "right":
            dx_sub_matrix = pybamm.Vector(
                csr_matrix(
                    ([1], ([d_nodes.size - 1], [0])), shape=(d_nodes.size, 1)
                ).toarray()
            )
            dx = pybamm.Transpose(dx_sub_matrix) @ d_nodes
            sub_matrix = pybamm.Matrix(
                csr_matrix(([1], ([prim_pts - 1], [0])), shape=(prim_pts, 1)).toarray()
            )

        # Calculate domain width, to make sure that the integral of the delta function
        # is the same as the integral of the child
        if hasattr(submesh, "length"):
            domain_width = (
                pybamm.Scalar(submesh.edges[-1] - submesh.edges[0]) * submesh.length
            )
        else:
            domain_width = pybamm.Scalar(submesh.edges[-1] - submesh.edges[0])
        # Generate full matrix from the submatrix
        # Convert to csr_matrix so that we can take the index (row-slicing), which is
        # not supported by the default kron format
        # Note that this makes column-slicing inefficient, but this should not be an
        # issue
        matrix = pybamm.kronecker_product(
            pybamm.Matrix(eye(second_dim_repeats).toarray()), sub_matrix
        )

        # Return delta function, keep domains
        delta_fn = domain_width / dx * matrix * discretised_symbol
        delta_fn.copy_domains(symbol)

        return delta_fn

    def internal_neumann_condition(
        self, left_symbol_disc, right_symbol_disc, left_mesh, right_mesh
    ):
        """
        A method to find the internal Neumann conditions between two symbols
        on adjacent subdomains.

        Parameters
        ----------
        left_symbol_disc : :class:`pybamm.Symbol`
            The discretised symbol on the left subdomain
        right_symbol_disc : :class:`pybamm.Symbol`
            The discretised symbol on the right subdomain
        left_mesh : list
            The mesh on the left subdomain
        right_mesh : list
            The mesh on the right subdomain
        """

        left_npts = left_mesh.npts
        right_npts = right_mesh.npts

        second_dim_repeats = self._get_auxiliary_domain_repeats(
            left_symbol_disc.domains
        )

        if second_dim_repeats != self._get_auxiliary_domain_repeats(
            right_symbol_disc.domains
        ):
            raise pybamm.DomainError(
                """Number of secondary points in subdomains do not match"""
            )

        left_sub_matrix = np.zeros((1, left_npts))
        left_sub_matrix[0][left_npts - 1] = 1
        left_matrix = pybamm.Matrix(
            csr_matrix(kron(eye(second_dim_repeats), left_sub_matrix))
        )

        right_sub_matrix = np.zeros((1, right_npts))
        right_sub_matrix[0][0] = 1
        right_matrix = pybamm.Matrix(
            csr_matrix(kron(eye(second_dim_repeats), right_sub_matrix))
        )

        # Finite volume derivative
        # Remove domains to avoid clash
        right_mesh_x = self._get_first_node(right_symbol_disc.domain)
        left_mesh_x = self._get_last_node(left_symbol_disc.domain)
        dx = right_mesh_x - left_mesh_x
        dy_r = (right_matrix / dx) @ right_symbol_disc
        dy_r.clear_domains()
        dy_l = (left_matrix / dx) @ left_symbol_disc
        dy_l.clear_domains()

        return dy_r - dy_l

    def add_ghost_nodes(self, symbol, discretised_symbol, bcs):
        """
        Add ghost nodes to a symbol.

        For Dirichlet bcs, for a boundary condition "y = a at the left-hand boundary",
        we concatenate a ghost node to the start of the vector y with value "2*a - y1"
        where y1 is the value of the first node.
        Similarly for the right-hand boundary condition.

        For Neumann bcs no ghost nodes are added. Instead, the exact value provided
        by the boundary condition is used at the cell edge when calculating the
        gradient (see :meth:`pybamm.FiniteVolume.add_neumann_values`).

        Parameters
        ----------
        symbol : :class:`pybamm.SpatialVariable`
            The variable to be discretised
        discretised_symbol : :class:`pybamm.Vector`
            Contains the discretised variable
        bcs : dict of tuples (:class:`pybamm.Scalar`, str)
            Dictionary (with keys "left" and "right") of boundary conditions. Each
            boundary condition consists of a value and a flag indicating its type
            (e.g. "Dirichlet")

        Returns
        -------
        :class:`pybamm.Symbol`
            `Matrix @ discretised_symbol + bcs_vector`. When evaluated, this gives the
            discretised_symbol, with appropriate ghost nodes concatenated at each end.

        """
        # get relevant grid points
        domain = symbol.domain
        submesh = self.mesh[domain]

        # Prepare sizes and empty bcs_vector
        n = submesh.npts
        second_dim_repeats = self._get_auxiliary_domain_repeats(symbol.domains)

        # Catch if no boundary conditions are defined
        if "left" not in bcs.keys() and "right" not in bcs.keys():
            raise ValueError(f"No boundary conditions have been provided for {symbol}")

        # Allow to only pass one boundary condition (for upwind/downwind)
        lbc_value, lbc_type = bcs.get("left", (None, None))
        rbc_value, rbc_type = bcs.get("right", (None, None))

        # Add ghost node(s) to domain where necessary and count number of
        # Dirichlet boundary conditions
        n_bcs = 0
        if lbc_type == "Dirichlet":
            domain = [domain[0] + "_left ghost cell", *domain]
            n_bcs += 1
        if rbc_type == "Dirichlet":
            domain = [*domain, domain[-1] + "_right ghost cell"]
            n_bcs += 1

        # Calculate values for ghost nodes for any Dirichlet boundary conditions
        if lbc_type == "Dirichlet":
            lbc_sub_matrix = coo_matrix(([1], ([0], [0])), shape=(n + n_bcs, 1))
            lbc_matrix = csr_matrix(kron(eye(second_dim_repeats), lbc_sub_matrix))
            if lbc_value.evaluates_to_number():
                left_ghost_constant = (
                    2 * lbc_value * pybamm.Vector(np.ones(second_dim_repeats))
                )
            else:
                left_ghost_constant = 2 * lbc_value
            lbc_vector = pybamm.Matrix(lbc_matrix) @ left_ghost_constant
        elif lbc_type in ["Neumann", None]:
            lbc_vector = pybamm.Vector(np.zeros((n + n_bcs) * second_dim_repeats))
        else:
            raise ValueError(
                f"boundary condition must be Dirichlet or Neumann, not '{lbc_type}'"
            )

        if rbc_type == "Dirichlet":
            rbc_sub_matrix = coo_matrix(
                ([1], ([n + n_bcs - 1], [0])), shape=(n + n_bcs, 1)
            )
            rbc_matrix = csr_matrix(kron(eye(second_dim_repeats), rbc_sub_matrix))
            if rbc_value.evaluates_to_number():
                right_ghost_constant = (
                    2 * rbc_value * pybamm.Vector(np.ones(second_dim_repeats))
                )
            else:
                right_ghost_constant = 2 * rbc_value
            rbc_vector = pybamm.Matrix(rbc_matrix) @ right_ghost_constant
        elif rbc_type in ["Neumann", None]:
            rbc_vector = pybamm.Vector(np.zeros((n + n_bcs) * second_dim_repeats))
        else:
            raise ValueError(
                f"boundary condition must be Dirichlet or Neumann, not '{rbc_type}'"
            )

        bcs_vector = lbc_vector + rbc_vector
        # Need to match the domain. E.g. in the case of the boundary condition
        # on the particle, the gradient has domain particle but the bcs_vector
        # has domain electrode, since it is a function of the macroscopic variables
        bcs_vector.copy_domains(discretised_symbol)

        # Make matrix to calculate ghost nodes
        # coo_matrix takes inputs (data, (row, col)) and puts data[i] at the point
        # (row[i], col[i]) for each index of data.
        if lbc_type == "Dirichlet":
            left_ghost_vector = coo_matrix(([-1], ([0], [0])), shape=(1, n))
        else:
            left_ghost_vector = None
        if rbc_type == "Dirichlet":
            right_ghost_vector = coo_matrix(([-1], ([0], [n - 1])), shape=(1, n))
        else:
            right_ghost_vector = None
        sub_matrix = vstack([left_ghost_vector, eye(n), right_ghost_vector])

        # repeat matrix for secondary dimensions
        # Convert to csr_matrix so that we can take the index (row-slicing), which is
        # not supported by the default kron format
        # Note that this makes column-slicing inefficient, but this should not be an
        # issue
        matrix = csr_matrix(kron(eye(second_dim_repeats), sub_matrix))

        new_symbol = pybamm.Matrix(matrix) @ discretised_symbol + bcs_vector

        return new_symbol, domain

    def add_neumann_values(self, symbol, discretised_gradient, bcs, domain):
        """
        Add the known values of the gradient from Neumann boundary conditions to
        the discretised gradient.

        Dirichlet bcs are implemented using ghost nodes, see
        :meth:`pybamm.FiniteVolume.add_ghost_nodes`.

        Parameters
        ----------
        symbol : :class:`pybamm.SpatialVariable`
            The variable to be discretised
        discretised_gradient : :class:`pybamm.Vector`
            Contains the discretised gradient of symbol
        bcs : dict of tuples (:class:`pybamm.Scalar`, str)
            Dictionary (with keys "left" and "right") of boundary conditions. Each
            boundary condition consists of a value and a flag indicating its type
            (e.g. "Dirichlet")
        domain : list of strings
            The domain of the gradient of the symbol (may include ghost nodes)

        Returns
        -------
        :class:`pybamm.Symbol`
            `Matrix @ discretised_gradient + bcs_vector`. When evaluated, this gives the
            discretised_gradient, with the values of the Neumann boundary conditions
            concatenated at each end (if given).

        """
        # get relevant grid points
        submesh = self.mesh[domain]

        # Prepare sizes and empty bcs_vector
        n = submesh.npts - 1
        second_dim_repeats = self._get_auxiliary_domain_repeats(symbol.domains)

        lbc_value, lbc_type = bcs["left"]
        rbc_value, rbc_type = bcs["right"]

        # Count number of Neumann boundary conditions
        n_bcs = 0
        if lbc_type == "Neumann":
            n_bcs += 1
        if rbc_type == "Neumann":
            n_bcs += 1

        # Add any values from Neumann boundary conditions to the bcs vector
        if lbc_type == "Neumann" and lbc_value != 0:
            lbc_sub_matrix = coo_matrix(([1], ([0], [0])), shape=(n + n_bcs, 1))
            lbc_matrix = csr_matrix(kron(eye(second_dim_repeats), lbc_sub_matrix))
            if lbc_value.evaluates_to_number():
                left_bc = lbc_value * pybamm.Vector(np.ones(second_dim_repeats))
            else:
                left_bc = lbc_value
            lbc_vector = pybamm.Matrix(lbc_matrix) @ left_bc
        elif lbc_type == "Dirichlet" or (lbc_type == "Neumann" and lbc_value == 0):
            lbc_vector = pybamm.Vector(np.zeros((n + n_bcs) * second_dim_repeats))
        else:
            raise ValueError(
                f"boundary condition must be Dirichlet or Neumann, not '{rbc_type}'"
            )
        if rbc_type == "Neumann" and rbc_value != 0:
            rbc_sub_matrix = coo_matrix(
                ([1], ([n + n_bcs - 1], [0])), shape=(n + n_bcs, 1)
            )
            rbc_matrix = csr_matrix(kron(eye(second_dim_repeats), rbc_sub_matrix))
            if rbc_value.evaluates_to_number():
                right_bc = rbc_value * pybamm.Vector(np.ones(second_dim_repeats))
            else:
                right_bc = rbc_value
            rbc_vector = pybamm.Matrix(rbc_matrix) @ right_bc
        elif rbc_type == "Dirichlet" or (rbc_type == "Neumann" and rbc_value == 0):
            rbc_vector = pybamm.Vector(np.zeros((n + n_bcs) * second_dim_repeats))
        else:
            raise ValueError(
                f"boundary condition must be Dirichlet or Neumann, not '{rbc_type}'"
            )

        bcs_vector = lbc_vector + rbc_vector
        # Need to match the domain. E.g. in the case of the boundary condition
        # on the particle, the gradient has domain particle but the bcs_vector
        # has domain electrode, since it is a function of the macroscopic variables
        bcs_vector.copy_domains(discretised_gradient)

        # Make matrix which makes "gaps" in the the discretised gradient into
        # which the known Neumann values will be added. E.g. in 1D if the left
        # boundary condition is Dirichlet and the right Neumann, this matrix will
        # act to append a zero to the end of the discretised gradient
        if lbc_type == "Neumann":
            left_vector = csr_matrix((1, n))
        else:
            left_vector = None
        if rbc_type == "Neumann":
            right_vector = csr_matrix((1, n))
        else:
            right_vector = None
        sub_matrix = vstack([left_vector, eye(n), right_vector])

        # repeat matrix for secondary dimensions
        # Convert to csr_matrix so that we can take the index (row-slicing), which is
        # not supported by the default kron format
        # Note that this makes column-slicing inefficient, but this should not be an
        # issue
        matrix = csr_matrix(kron(eye(second_dim_repeats), sub_matrix))

        new_gradient = pybamm.Matrix(matrix) @ discretised_gradient + bcs_vector

        return new_gradient

    def _get_boundary_submesh_length(self, side: str, domains: list[str]):
        if side == "left":
            return self.mesh[domains[0]].length
        elif side == "right":
            return self.mesh[domains[-1]].length

    def _boundary_mesh_size(self, child, side):
        """
        Get the mesh size at the boundary of a variable's domain.
        """
        submesh = self.mesh[child.domain]
        if hasattr(submesh, "length"):
            length = self._get_boundary_submesh_length("left", child.domain)
        else:
            length = 1
        if side == "left":
            val = length * pybamm.Scalar(submesh.d_nodes[0])
        elif side == "right":
            val = length * pybamm.Scalar(submesh.d_nodes[-1])
        else:
            raise ValueError(f"Invalid side: {side}")
        return val

    def boundary_value_or_flux(self, symbol, discretised_child, bcs=None):
        """
        Uses extrapolation to get the boundary value or flux of a variable in the
        Finite Volume Method.

        See :meth:`pybamm.SpatialMethod.boundary_value`
        """

        # Find the number of submeshes
        submesh = self.mesh[discretised_child.domain]

        prim_pts = submesh.npts
        repeats = self._get_auxiliary_domain_repeats(discretised_child.domains)

        if bcs is None:
            bcs = {}

        extrap_order_gradient = (
            getattr(symbol, "order", None)
            or self.options["extrapolation"]["order"]["gradient"]
        )
        extrap_order_value = (
            getattr(symbol, "order", None)
            or self.options["extrapolation"]["order"]["value"]
        )
        use_bcs = self.options["extrapolation"]["use bcs"]

        nodes = submesh.nodes
        edges = submesh.edges

        dx0 = nodes[0] - edges[0]
        dx1 = submesh.d_nodes[0]
        dx2 = submesh.d_nodes[1]

        dxN = edges[-1] - nodes[-1]
        dxNm1 = submesh.d_nodes[-1]
        dxNm2 = submesh.d_nodes[-2]

        child = symbol.child

        # Create submatrix to compute boundary values or fluxes
        # Derivation of extrapolation formula can be found at:
        # https://github.com/Scottmar93/extrapolation-coefficents/tree/master
        if isinstance(symbol, pybamm.BoundaryMeshSize):
            return self._boundary_mesh_size(child, symbol.side)
        elif isinstance(symbol, pybamm.BoundaryValue):
            if use_bcs and pybamm.has_bc_of_form(child, symbol.side, bcs, "Dirichlet"):
                # just use the value from the bc: f(x*)
                sub_matrix = csr_matrix((1, prim_pts))
                additive = bcs[child][symbol.side][0]
                additive_multiplicative = pybamm.Scalar(1)
                multiplicative = pybamm.Scalar(1)

            elif symbol.side == "left":
                if extrap_order_value == "linear":
                    # to find value at x* use formula:
                    # f(x*) = f_1 - (dx0 / dx1) (f_2 - f_1)

                    if use_bcs and pybamm.has_bc_of_form(
                        child, symbol.side, bcs, "Neumann"
                    ):
                        sub_matrix = csr_matrix(([1], ([0], [0])), shape=(1, prim_pts))

                        additive = -dx0 * bcs[child][symbol.side][0]
                        if hasattr(submesh, "length"):
                            additive_multiplicative = self._get_boundary_submesh_length(
                                "left", child.domain
                            )
                        else:
                            additive_multiplicative = pybamm.Scalar(1)
                        multiplicative = pybamm.Scalar(1)

                    else:
                        sub_matrix = csr_matrix(
                            ([1 + (dx0 / dx1), -(dx0 / dx1)], ([0, 0], [0, 1])),
                            shape=(1, prim_pts),
                        )
                        additive = pybamm.Scalar(0)
                        additive_multiplicative = pybamm.Scalar(1)
                        multiplicative = pybamm.Scalar(1)

                elif extrap_order_value == "quadratic":
                    if use_bcs and pybamm.has_bc_of_form(
                        child, symbol.side, bcs, "Neumann"
                    ):
                        a = (dx0 + dx1) ** 2 / (dx1 * (2 * dx0 + dx1))
                        b = -(dx0**2) / (2 * dx0 * dx1 + dx1**2)
                        alpha = -(dx0 * (dx0 + dx1)) / (2 * dx0 + dx1)

                        sub_matrix = csr_matrix(
                            ([a, b], ([0, 0], [0, 1])), shape=(1, prim_pts)
                        )
                        additive = alpha * bcs[child][symbol.side][0]
                        if hasattr(submesh, "length"):
                            additive_multiplicative = self._get_boundary_submesh_length(
                                "left", child.domain
                            )
                        else:
                            additive_multiplicative = pybamm.Scalar(1)
                        multiplicative = pybamm.Scalar(1)

                    else:
                        a = (dx0 + dx1) * (dx0 + dx1 + dx2) / (dx1 * (dx1 + dx2))
                        b = -dx0 * (dx0 + dx1 + dx2) / (dx1 * dx2)
                        c = dx0 * (dx0 + dx1) / (dx2 * (dx1 + dx2))

                        sub_matrix = csr_matrix(
                            ([a, b, c], ([0, 0, 0], [0, 1, 2])), shape=(1, prim_pts)
                        )

                        additive = pybamm.Scalar(0)
                        additive_multiplicative = pybamm.Scalar(1)
                        multiplicative = pybamm.Scalar(1)

                elif extrap_order_value == "constant":
                    sub_matrix = csr_matrix(
                        ([1], ([0], [0])),
                        shape=(1, prim_pts),
                    )
                    additive = pybamm.Scalar(0)
                    additive_multiplicative = pybamm.Scalar(1)
                    multiplicative = pybamm.Scalar(1)
                else:
                    raise NotImplementedError

            elif symbol.side == "right":
                if extrap_order_value == "linear":
                    if use_bcs and pybamm.has_bc_of_form(
                        child, symbol.side, bcs, "Neumann"
                    ):
                        # use formula:
                        # f(x*) = fN + dxN * f'(x*)
                        sub_matrix = csr_matrix(
                            ([1], ([0], [prim_pts - 1])), shape=(1, prim_pts)
                        )
                        additive = dxN * bcs[child][symbol.side][0]
                        if hasattr(submesh, "length"):
                            multiplicative = self._get_boundary_submesh_length(
                                "right", child.domain
                            )
                            additive_multiplicative = self._get_boundary_submesh_length(
                                "right", child.domain
                            )
                        else:
                            multiplicative = pybamm.Scalar(1)
                            additive_multiplicative = pybamm.Scalar(1)
                    else:
                        # to find value at x* use formula:
                        # f(x*) = f_N - (dxN / dxNm1) (f_N - f_Nm1)
                        sub_matrix = csr_matrix(
                            (
                                [-(dxN / dxNm1), 1 + (dxN / dxNm1)],
                                ([0, 0], [prim_pts - 2, prim_pts - 1]),
                            ),
                            shape=(1, prim_pts),
                        )
                        additive = pybamm.Scalar(0)
                        additive_multiplicative = pybamm.Scalar(1)
                        multiplicative = pybamm.Scalar(1)
                elif extrap_order_value == "quadratic":
                    if use_bcs and pybamm.has_bc_of_form(
                        child, symbol.side, bcs, "Neumann"
                    ):
                        a = (dxN + dxNm1) ** 2 / (dxNm1 * (2 * dxN + dxNm1))
                        b = -(dxN**2) / (2 * dxN * dxNm1 + dxNm1**2)
                        alpha = dxN * (dxN + dxNm1) / (2 * dxN + dxNm1)
                        sub_matrix = csr_matrix(
                            ([b, a], ([0, 0], [prim_pts - 2, prim_pts - 1])),
                            shape=(1, prim_pts),
                        )

                        additive = alpha * bcs[child][symbol.side][0]
                        if hasattr(submesh, "length"):
                            additive_multiplicative = self._get_boundary_submesh_length(
                                "right", child.domain
                            )
                        else:
                            additive_multiplicative = pybamm.Scalar(1)
                        multiplicative = pybamm.Scalar(1)
                    else:
                        a = (
                            (dxN + dxNm1)
                            * (dxN + dxNm1 + dxNm2)
                            / (dxNm1 * (dxNm1 + dxNm2))
                        )
                        b = -dxN * (dxN + dxNm1 + dxNm2) / (dxNm1 * dxNm2)
                        c = dxN * (dxN + dxNm1) / (dxNm2 * (dxNm1 + dxNm2))

                        sub_matrix = csr_matrix(
                            (
                                [c, b, a],
                                ([0, 0, 0], [prim_pts - 3, prim_pts - 2, prim_pts - 1]),
                            ),
                            shape=(1, prim_pts),
                        )
                        additive = pybamm.Scalar(0)
                        additive_multiplicative = pybamm.Scalar(1)
                        multiplicative = pybamm.Scalar(1)
                elif extrap_order_value == "constant":
                    sub_matrix = csr_matrix(
                        ([1], ([0], [prim_pts - 1])),
                        shape=(1, prim_pts),
                    )
                    additive = pybamm.Scalar(0)
                    additive_multiplicative = pybamm.Scalar(1)
                    multiplicative = pybamm.Scalar(1)
                else:
                    raise NotImplementedError

        elif isinstance(symbol, pybamm.BoundaryGradient):
            if use_bcs and pybamm.has_bc_of_form(child, symbol.side, bcs, "Neumann"):
                # just use the value from the bc: f'(x*)
                sub_matrix = csr_matrix((1, prim_pts))
                additive = bcs[child][symbol.side][0]
                additive_multiplicative = pybamm.Scalar(1)
                multiplicative = pybamm.Scalar(1)

            elif symbol.side == "left":
                if extrap_order_gradient == "linear":
                    # f'(x*) = (f_2 - f_1) / dx1
                    sub_matrix = (1 / dx1) * csr_matrix(
                        ([-1, 1], ([0, 0], [0, 1])), shape=(1, prim_pts)
                    )
                    additive = pybamm.Scalar(0)
                    additive_multiplicative = pybamm.Scalar(1)
                    if hasattr(submesh, "length"):
                        multiplicative = 1 / self._get_boundary_submesh_length(
                            "left", child.domain
                        )
                    else:
                        multiplicative = pybamm.Scalar(1)

                elif extrap_order_gradient == "quadratic":
                    a = -(2 * dx0 + 2 * dx1 + dx2) / (dx1**2 + dx1 * dx2)
                    b = (2 * dx0 + dx1 + dx2) / (dx1 * dx2)
                    c = -(2 * dx0 + dx1) / (dx1 * dx2 + dx2**2)

                    sub_matrix = csr_matrix(
                        ([a, b, c], ([0, 0, 0], [0, 1, 2])), shape=(1, prim_pts)
                    )
                    additive = pybamm.Scalar(0)
                    additive_multiplicative = pybamm.Scalar(1)
                    if hasattr(submesh, "length"):
                        multiplicative = 1 / self._get_boundary_submesh_length(
                            "left", child.domain
                        )
                    else:
                        multiplicative = pybamm.Scalar(1)

                else:
                    raise NotImplementedError

            elif symbol.side == "right":
                if extrap_order_gradient == "linear":
                    # use formula:
                    # f'(x*) = (f_N - f_Nm1) / dxNm1
                    sub_matrix = (1 / dxNm1) * csr_matrix(
                        ([-1, 1], ([0, 0], [prim_pts - 2, prim_pts - 1])),
                        shape=(1, prim_pts),
                    )
                    additive = pybamm.Scalar(0)
                    additive_multiplicative = pybamm.Scalar(1)
                    if hasattr(submesh, "length"):
                        multiplicative = 1 / self._get_boundary_submesh_length(
                            "right", child.domain
                        )
                    else:
                        multiplicative = pybamm.Scalar(1)

                elif extrap_order_gradient == "quadratic":
                    a = (2 * dxN + 2 * dxNm1 + dxNm2) / (dxNm1**2 + dxNm1 * dxNm2)
                    b = -(2 * dxN + dxNm1 + dxNm2) / (dxNm1 * dxNm2)
                    c = (2 * dxN + dxNm1) / (dxNm1 * dxNm2 + dxNm2**2)

                    sub_matrix = csr_matrix(
                        (
                            [c, b, a],
                            ([0, 0, 0], [prim_pts - 3, prim_pts - 2, prim_pts - 1]),
                        ),
                        shape=(1, prim_pts),
                    )
                    additive = pybamm.Scalar(0)
                    additive_multiplicative = pybamm.Scalar(1)
                    if hasattr(submesh, "length"):
                        multiplicative = 1 / self._get_boundary_submesh_length(
                            "right", child.domain
                        )
                    else:
                        multiplicative = pybamm.Scalar(1)
                else:
                    raise NotImplementedError

        # Generate full matrix from the submatrix
        # Convert to csr_matrix so that we can take the index (row-slicing), which is
        # not supported by the default kron format
        # Note that this makes column-slicing inefficient, but this should not be an
        # issue
        matrix = csr_matrix(kron(eye(repeats), sub_matrix))

        # Return boundary value with domain given by symbol
        matrix = pybamm.Matrix(matrix) * multiplicative
        boundary_value = matrix @ discretised_child
        boundary_value.copy_domains(symbol)

        additive.copy_domains(symbol)
        boundary_value += additive * additive_multiplicative

        return boundary_value

    def evaluate_at(self, symbol, discretised_child, position):
        """
        Returns the symbol evaluated at a given position in space.

        Parameters
        ----------
        symbol: :class:`pybamm.Symbol`
            The boundary value or flux symbol
        discretised_child : :class:`pybamm.StateVector`
            The discretised variable from which to calculate the boundary value
        position : :class:`pybamm.Scalar`
            The point in one-dimensional space at which to evaluate the symbol.

        Returns
        -------
        :class:`pybamm.MatrixMultiplication`
            The variable representing the value at the given point.
        """
        # Get mesh nodes
        domain = discretised_child.domain
        mesh = self.mesh[domain]
        if symbol.children[0].evaluates_on_edges("primary"):
            nodes = mesh.edges
        else:
            nodes = mesh.nodes
        if hasattr(mesh, "length"):
            domain = discretised_child.domain
            raise NotImplementedError(
                f"The symbolic submesh does not support `EvaluateAt` because we are unable to find the position of the node in the symbolic submesh. Please use one of the other submeshes for domain {domain}"
            )
        repeats = self._get_auxiliary_domain_repeats(discretised_child.domains)

        # Find the index of the node closest to the value
        index = np.argmin(np.abs(nodes - position.value))

        # Create a sparse matrix with a 1 at the index
        sub_matrix = csr_matrix(([1], ([0], [index])), shape=(1, len(nodes)))
        # repeat across auxiliary domains
        matrix = csr_matrix(kron(eye(repeats), sub_matrix))

        # Index into the discretised child
        out = pybamm.Matrix(matrix) @ discretised_child

        # `EvaluateAt` removes domain
        out.clear_domains()

        return out

    def process_binary_operators(self, bin_op, left, right, disc_left, disc_right):
        """Discretise binary operators in model equations.  Performs appropriate
        averaging of diffusivities if one of the children is a gradient operator, so
        that discretised sizes match up. For this averaging we use the harmonic
        mean [1].

        [1] Recktenwald, Gerald. "The control-volume finite-difference approximation to
        the diffusion equation." (2012).

        Parameters
        ----------
        bin_op : :class:`pybamm.BinaryOperator`
            Binary operator to discretise
        left : :class:`pybamm.Symbol`
            The left child of `bin_op`
        right : :class:`pybamm.Symbol`
            The right child of `bin_op`
        disc_left : :class:`pybamm.Symbol`
            The discretised left child of `bin_op`
        disc_right : :class:`pybamm.Symbol`
            The discretised right child of `bin_op`
        Returns
        -------
        :class:`pybamm.BinaryOperator`
            Discretised binary operator

        """
        # Post-processing to make sure discretised dimensions match
        left_evaluates_on_edges = left.evaluates_on_edges("primary")
        right_evaluates_on_edges = right.evaluates_on_edges("primary")

        # inner product takes fluxes from edges to nodes
        if isinstance(bin_op, pybamm.Inner):
            if left_evaluates_on_edges:
                disc_left = self.edge_to_node(disc_left)
            if right_evaluates_on_edges:
                disc_right = self.edge_to_node(disc_right)

        # If neither child evaluates on edges, or both children have gradients,
        # no need to do any averaging
        elif left_evaluates_on_edges == right_evaluates_on_edges:
            pass
        # If only left child evaluates on edges, map right child onto edges
        # using the harmonic mean if the left child is a gradient (i.e. this
        # binary operator represents a flux)
        elif left_evaluates_on_edges and not right_evaluates_on_edges:
            if isinstance(left, pybamm.Gradient):
                method = "harmonic"
            else:
                method = "arithmetic"
            disc_right = self.node_to_edge(disc_right, method=method)
        # If only right child evaluates on edges, map left child onto edges
        # using the harmonic mean if the right child is a gradient (i.e. this
        # binary operator represents a flux)
        elif right_evaluates_on_edges and not left_evaluates_on_edges:
            if isinstance(right, pybamm.Gradient):
                method = "harmonic"
            else:
                method = "arithmetic"
            disc_left = self.node_to_edge(disc_left, method=method)
        # Return new binary operator with appropriate class
        out = pybamm.simplify_if_constant(bin_op.create_copy([disc_left, disc_right]))

        return out

    def concatenation(self, disc_children):
        """Discrete concatenation, taking `edge_to_node` for children that evaluate on
        edges.
        See :meth:`pybamm.SpatialMethod.concatenation`
        """
        for idx, child in enumerate(disc_children):
            submesh = self.mesh[child.domain]
            repeats = self._get_auxiliary_domain_repeats(child.domains)
            n_nodes = len(submesh.nodes) * repeats
            n_edges = len(submesh.edges) * repeats
            child_size = child.size
            if child_size != n_nodes:
                # Average any children that evaluate on the edges (size n_edges) to
                # evaluate on nodes instead, so that concatenation works properly
                if child_size == n_edges:
                    disc_children[idx] = self.edge_to_node(child)
                else:
                    raise pybamm.ShapeError(
                        "child must have size n_nodes (number of nodes in the mesh) "
                        "or n_edges (number of edges in the mesh)"
                    )
        return pybamm.domain_concatenation(disc_children, self.mesh)

    def edge_to_node(self, discretised_symbol, method="arithmetic"):
        """
        Convert a discretised symbol evaluated on the cell edges to a discretised symbol
        evaluated on the cell nodes.
        See :meth:`pybamm.FiniteVolume.shift`
        """
        return self.shift(discretised_symbol, "edge to node", method)

    def node_to_edge(self, discretised_symbol, method="arithmetic"):
        """
        Convert a discretised symbol evaluated on the cell nodes to a discretised symbol
        evaluated on the cell edges.
        See :meth:`pybamm.FiniteVolume.shift`
        """
        return self.shift(discretised_symbol, "node to edge", method)

    def shift(self, discretised_symbol, shift_key, method):
        """
        Convert a discretised symbol evaluated at edges/nodes, to a discretised symbol
        evaluated at nodes/edges. Can be the arithmetic mean or the harmonic mean.

        Note: when computing fluxes at cell edges it is better to take the
        harmonic mean based on [1].

        [1] Recktenwald, Gerald. "The control-volume finite-difference approximation to
        the diffusion equation." (2012).

        Parameters
        ----------
        discretised_symbol : :class:`pybamm.Symbol`
            Symbol to be averaged. When evaluated, this symbol returns either a scalar
            or an array of shape (n,) or (n+1,), where n is the number of points in the
            mesh for the symbol's domain (n = self.mesh[symbol.domain].npts)
        shift_key : str
            Whether to shift from nodes to edges ("node to edge"), or from edges to
            nodes ("edge to node")
        method : str
            Whether to use the "arithmetic" or "harmonic" mean

        Returns
        -------
        :class:`pybamm.Symbol`
            Averaged symbol. When evaluated, this returns either a scalar or an array of
            shape (n+1,) (if `shift_key = "node to edge"`) or (n,) (if
            `shift_key = "edge to node"`)
        """

        def arithmetic_mean(array):
            """Calculate the arithmetic mean of an array using matrix multiplication"""
            # Create appropriate submesh by combining submeshes in domain
            submesh = self.mesh[array.domain]

            # Create 1D matrix using submesh
            n = submesh.npts

            if shift_key == "node to edge":
                sub_matrix_left = csr_matrix(
                    ([1.5, -0.5], ([0, 0], [0, 1])), shape=(1, n)
                )
                sub_matrix_center = diags([0.5, 0.5], [0, 1], shape=(n - 1, n))
                sub_matrix_right = csr_matrix(
                    ([-0.5, 1.5], ([0, 0], [n - 2, n - 1])), shape=(1, n)
                )
                sub_matrix = vstack(
                    [sub_matrix_left, sub_matrix_center, sub_matrix_right]
                )
            elif shift_key == "edge to node":
                sub_matrix = diags([0.5, 0.5], [0, 1], shape=(n, n + 1))
            else:
                raise ValueError(f"shift key '{shift_key}' not recognised")
            # Second dimension length
            second_dim_repeats = self._get_auxiliary_domain_repeats(
                discretised_symbol.domains
            )

            # Generate full matrix from the submatrix
            # Convert to csr_matrix so that we can take the index (row-slicing), which
            # is not supported by the default kron format
            # Note that this makes column-slicing inefficient, but this should not be an
            # issue
            matrix = csr_matrix(kron(eye(second_dim_repeats), sub_matrix))

            return pybamm.Matrix(matrix) @ array

        def harmonic_mean(array):
            """
            Calculate the harmonic mean of an array using matrix multiplication.
            The harmonic mean is computed as

            .. math::
                D_{eff} = \\frac{1}{\\frac{\\beta}{D_1} + \\frac{1 - \\beta}{D_2}},

            where

            .. math::
                \\beta = \\frac{\\Delta x_1}{\\Delta x_2 + \\Delta x_1}

            accounts for the difference in the control volume widths. This is the
            definiton from [1], which is the same as that in [2] but with slightly
            different notation.

            [1] Torchio, M et al. "LIONSIMBA: A Matlab Framework Based on a Finite
            Volume Model Suitable for Li-Ion Battery Design, Simulation, and Control."
            (2016).
            [2] Recktenwald, Gerald. "The control-volume finite-difference
            approximation to the diffusion equation." (2012).
            """
            # Create appropriate submesh by combining submeshes in domain
            submesh = self.mesh[array.domain]

            # Get second dimension length for use later
            second_dim_repeats = self._get_auxiliary_domain_repeats(
                discretised_symbol.domains
            )

            # Create 1D matrix using submesh
            n = submesh.npts

            if shift_key == "node to edge":
                # Matrix to compute values at the exterior edges
                edges_sub_matrix_left = csr_matrix(
                    ([1.5, -0.5], ([0, 0], [0, 1])), shape=(1, n)
                )
                edges_sub_matrix_center = csr_matrix((n - 1, n))
                edges_sub_matrix_right = csr_matrix(
                    ([-0.5, 1.5], ([0, 0], [n - 2, n - 1])), shape=(1, n)
                )
                edges_sub_matrix = vstack(
                    [
                        edges_sub_matrix_left,
                        edges_sub_matrix_center,
                        edges_sub_matrix_right,
                    ]
                )

                # Generate full matrix from the submatrix
                # Convert to csr_matrix so that we can take the index (row-slicing),
                # which is not supported by the default kron format
                # Note that this makes column-slicing inefficient, but this should
                # not be an issue
                edges_matrix = csr_matrix(
                    kron(eye(second_dim_repeats), edges_sub_matrix)
                )

                # Matrix to extract the node values running from the first node
                # to the penultimate node in the primary dimension (D_1 in the
                # definiton of the harmonic mean)
                sub_matrix_D1 = hstack([eye(n - 1), csr_matrix((n - 1, 1))])
                matrix_D1 = csr_matrix(kron(eye(second_dim_repeats), sub_matrix_D1))
                D1 = pybamm.Matrix(matrix_D1) @ array

                # Matrix to extract the node values running from the second node
                # to the final node in the primary dimension  (D_2 in the
                # definiton of the harmonic mean)
                sub_matrix_D2 = hstack([csr_matrix((n - 1, 1)), eye(n - 1)])
                matrix_D2 = csr_matrix(kron(eye(second_dim_repeats), sub_matrix_D2))
                D2 = pybamm.Matrix(matrix_D2) @ array
                # Compute weight beta
                if hasattr(submesh, "length"):
                    d_edges = self._get_d_edges_symbolic_mesh(
                        discretised_symbol.domains["primary"]
                    )
                    left_index_matrix = diags(
                        [1], [0], shape=(d_edges.size - 1, d_edges.size)
                    )
                    left_dx = pybamm.Matrix(left_index_matrix) @ d_edges
                    right_index_matrix = diags(
                        [1], [1], shape=(d_edges.size - 1, d_edges.size)
                    )
                    right_dx = pybamm.Matrix(right_index_matrix) @ d_edges
                else:
                    dx = submesh.d_edges
                    left_dx = pybamm.Vector(dx[:-1])
                    right_dx = pybamm.Vector(dx[1:])
                sub_beta = left_dx / (left_dx + right_dx)
                beta = pybamm.kronecker_product(
                    pybamm.Matrix(np.ones((second_dim_repeats, 1))), sub_beta
                )

                # dx_real = dx * length, therefore, beta is unchanged
                # Compute harmonic mean on internal edges
                D_eff = 1 / (beta / D1 + (1 - beta) / D2)

                # Matrix to pad zeros at the beginning and end of the array where
                # the exterior edge values will be added
                sub_matrix = vstack(
                    [csr_matrix((1, n - 1)), eye(n - 1), csr_matrix((1, n - 1))]
                )

                # Generate full matrix from the submatrix
                # Convert to csr_matrix so that we can take the index (row-slicing),
                # which is not supported by the default kron format
                # Note that this makes column-slicing inefficient, but this should
                # not be an issue
                matrix = csr_matrix(kron(eye(second_dim_repeats), sub_matrix))

                return (
                    pybamm.Matrix(edges_matrix) @ array + pybamm.Matrix(matrix) @ D_eff
                )

            elif shift_key == "edge to node":
                # Matrix to extract the edge values running from the first edge
                # to the penultimate edge in the primary dimension (D_1 in the
                # definiton of the harmonic mean)
                sub_matrix_D1 = hstack([eye(n), csr_matrix((n, 1))])
                matrix_D1 = csr_matrix(kron(eye(second_dim_repeats), sub_matrix_D1))
                D1 = pybamm.Matrix(matrix_D1) @ array

                # Matrix to extract the edge values running from the second edge
                # to the final edge in the primary dimension  (D_2 in the
                # definiton of the harmonic mean)
                sub_matrix_D2 = hstack([csr_matrix((n, 1)), eye(n)])
                matrix_D2 = csr_matrix(kron(eye(second_dim_repeats), sub_matrix_D2))
                D2 = pybamm.Matrix(matrix_D2) @ array

                # Compute weight beta
                dx0 = pybamm.Scalar(
                    submesh.nodes[0] - submesh.edges[0]
                )  # first edge to node
                if hasattr(submesh, "length"):
                    dx0 = dx0 * self._get_boundary_submesh_length(
                        "left", discretised_symbol.domain
                    )
                dxN = pybamm.Scalar(
                    submesh.edges[-1] - submesh.nodes[-1]
                )  # last node to edge
                if hasattr(submesh, "length"):
                    dxN = dxN * self._get_boundary_submesh_length(
                        "right", discretised_symbol.domain
                    )
                if hasattr(submesh, "length"):
                    d_nodes = self._get_d_nodes_symbolic_mesh(discretised_symbol.domain)
                else:
                    d_nodes = pybamm.Vector(submesh.d_nodes)
                dx = pybamm.numpy_concatenation(dx0, d_nodes, dxN)
                left_dx_matrix = diags([1], [0], shape=(dx.size - 1, dx.size))
                right_dx_matrix = diags([1], [1], shape=(dx.size - 1, dx.size))
                left_dx = pybamm.Matrix(left_dx_matrix) @ dx
                right_dx = pybamm.Matrix(right_dx_matrix) @ dx
                sub_beta = left_dx / (left_dx + right_dx)
                beta = pybamm.kronecker_product(
                    pybamm.Matrix(np.ones((second_dim_repeats, 1))), sub_beta
                )

                # Compute harmonic mean on nodes
                D_eff = 1 / (beta / D1 + (1 - beta) / D2)

                return D_eff

            else:
                raise ValueError(f"shift key '{shift_key}' not recognised")

        # If discretised_symbol evaluates to number there is no need to average
        if discretised_symbol.size == 1:
            out = discretised_symbol
        elif method == "arithmetic":
            out = arithmetic_mean(discretised_symbol)
        elif method == "harmonic":
            out = harmonic_mean(discretised_symbol)
        else:
            raise ValueError(f"method '{method}' not recognised")
        return out

    def upwind_or_downwind(self, symbol, discretised_symbol, bcs, direction):
        """
        Implement an upwinding operator. Currently, this requires the symbol to have
        a Dirichlet boundary condition on the left side (for upwinding) or right side
        (for downwinding).

        Parameters
        ----------
        symbol : :class:`pybamm.SpatialVariable`
            The variable to be discretised
        discretised_gradient : :class:`pybamm.Vector`
            Contains the discretised gradient of symbol
        bcs : dict of tuples (:class:`pybamm.Scalar`, str)
            Dictionary (with keys "left" and "right") of boundary conditions. Each
            boundary condition consists of a value and a flag indicating its type
            (e.g. "Dirichlet")
        direction : str
            Direction in which to apply the operator (upwind or downwind)
        """
        if symbol not in bcs:
            raise pybamm.ModelError(
                f"Boundary conditions must be provided for {direction}ing '{symbol}'"
            )

        if direction == "upwind":
            bc_side = "left"
        elif direction == "downwind":
            bc_side = "right"

        if bcs[symbol][bc_side][1] != "Dirichlet":
            raise pybamm.ModelError(
                "Dirichlet boundary conditions must be provided for "
                f"{direction}ing '{symbol}'"
            )

        # Extract only the relevant boundary condition as the model might have both
        bc_subset = {bc_side: bcs[symbol][bc_side]}
        symbol_out, _ = self.add_ghost_nodes(symbol, discretised_symbol, bc_subset)
        return symbol_out
