# SPDX-License-Identifier: GPL-3.0-or-later
from typing import Dict, List

from sympy import symbols
from symfem.functions import VectorFunction, MatrixFunction

from topoptlab.symbolic.parametric_map import jacobian

def stifftens_isotropic(ndim: int, plane_stress: bool = True) -> MatrixFunction:
    """
    stiffness tensor for isotropic material expressed in Terms of Young's
    modulus E and Poisson's ratio v.

    Parameters
    ----------
    ndim : int
        number of dimensions
    plane_stress : bool
        if True, return stiffness tensor for plane stress, otherwise return
        stiffness tensor for plane strain

    Returns
    -------
    c : symfem.functions.MatrixFunction
        stiffness tensor.
    """
    E,nu = symbols("E nu")
    if ndim == 1:
        return MatrixFunction([[E]])
    elif ndim == 2:
        if plane_stress:
            return E/(1-nu**2)*MatrixFunction([[1,nu,0],
                                               [nu,1,0],
                                               [0,0,(1-nu)/2]])
        else:
            return E/((1+nu)*(1-2*nu))*MatrixFunction([[1-nu,nu,0],
                                                       [nu,1-nu,0],
                                                       [0,0,(1-nu)/2]])
    elif ndim == 3:
        return E/((1+nu)*(1-2*nu))*MatrixFunction([[1-nu,nu,nu,0,0,0],
                                                   [nu,1-nu,nu,0,0,0],
                                                   [nu,nu,1-nu,0,0,0],
                                                   [0,0,0,(1-nu)/2,0,0],
                                                   [0,0,0,0,(1-nu)/2,0],
                                                   [0,0,0,0,0,(1-nu)/2]])

def small_strain_matrix(ndim: int, nd_inds: list,
                        basis: List,
                        isoparam_kws: Dict) -> MatrixFunction:
    """
    Create the small strain matrix commonly referred to as B matrix.

    Parameters
    ----------
    ndim : int
        number of spatial dimensions.
    nd_inds : list
        node indices.
    basis : list
        list of basis functions as generated by base_cell.
    isoparam_kws : dictionary
        keywords for the isoparametric mapping.

    Returns
    -------
    bmatrix : symfem.functions.Matrixfunction
        small displacement matrix.
    """

    nrows = int((ndim**2 + ndim) /2)
    ncols = int(ndim * len(nd_inds))
    # compute gradients of basis functions
    Jinv = jacobian(ndim=ndim,
                    return_J=False, return_inv=True, return_det=False,
                    **isoparam_kws)
    gradN_T = (VectorFunction(basis).grad(ndim)@Jinv.transpose()).transpose()
    #
    bmatrix = [[0 for j in range(ncols)] for i in range(nrows)]
    # tension
    for i in range(ndim):
        bmatrix[i][i::ndim] = gradN_T[i]
    # shear
    i,j = ndim-2,ndim-1
    for k in range(nrows-ndim):
        #
        bmatrix[ndim+k][i::ndim] = gradN_T[j]
        bmatrix[ndim+k][j::ndim] = gradN_T[i]
        #
        i,j = (i+1)%ndim , (j+1)%ndim
    return MatrixFunction(bmatrix)

