import inspect
import io
import json
import os
from contextlib import asynccontextmanager
from typing import Awaitable, Callable

from fastapi import Depends, FastAPI, HTTPException, Request
from fastapi.responses import HTMLResponse, StreamingResponse
from fastapi.staticfiles import StaticFiles
from fastapi.templating import Jinja2Templates
from jinja2 import Environment, TemplateNotFound, select_autoescape
from sqlalchemy import and_, insert, select
from sqlalchemy.ext.asyncio import AsyncSession
from sqlalchemy.orm import Session
from starlette.routing import _DefaultLifespan

from .api.model_rest_api import ModelRestApi
from .apis import (
    AuthApi,
    InfoApi,
    PermissionsApi,
    PermissionViewApi,
    RolesApi,
    UsersApi,
    ViewsMenusApi,
)
from .auth import AuthDict
from .cli.commands import upgrade
from .const import BASE_APIS, DEFAULT_STATIC_FOLDER, DEFAULT_TEMPLATE_FOLDER, logger
from .db import session_manager
from .dependencies import set_global_user
from .globals import GlobalsMiddleware, g
from .models import Api, Permission, PermissionApi, Role
from .security import SecurityManager
from .utils import safe_call, smart_run
from .version import __version__

__all__ = ["FastAPIReactToolkit"]


class FastAPIReactToolkit:
    """
    The main class for the `FastAPIReactToolkit` library.

    This class provides a set of methods to initialize a FastAPI application, add APIs, manage permissions and roles,
    and initialize the database with permissions, APIs, roles, and their relationships.

    Args:
        `app` (FastAPI | None, optional): The FastAPI application instance. If set, the `initialize` method will be called with this instance. Defaults to None.
        `auth` (AuthDict | None, optional): The authentication configuration. Set this if you want to customize the authentication system. See the `AuthDict` type for more details. Defaults to None.
        `create_tables` (bool, optional): Whether to create tables in the database. Not needed if you use a migration system like Alembic. Defaults to False.
        `exclude_apis` (list[BASE_APIS] | None, optional): List of APIs to exclude from the initialization. Defaults to None.
        `global_user_dependency` (bool, optional): Whether to add the `set_global_user` dependency to the FastAPI application. This allows you to access the current user with the `g.user` object. Defaults to True.
        `upgrade_db` (bool, optional): Whether to upgrade the database automatically with the `upgrade` command. Same as running `fastapi-rtk db upgrade`. Defaults to False.
        `on_startup` (Callable[[FastAPI], None] | Awaitable[Callable[[FastAPI], None]], optional): Function to call when the app is starting up. Defaults to None.
        `on_shutdown` (Callable[[FastAPI], None] | Awaitable[Callable[[FastAPI], None]], optional): Function to call when the app is shutting down. Defaults to None.

    ## Example:

    ```python
    import logging

    from fastapi import FastAPI, Request, Response
    from fastapi.middleware.cors import CORSMiddleware
    from fastapi_rtk import FastAPIReactToolkit, User
    from fastapi_rtk.manager import UserManager

    from .base_data import add_base_data

    logging.basicConfig(format="%(asctime)s:%(levelname)s:%(name)s:%(message)s")
    logging.getLogger().setLevel(logging.INFO)


    class CustomUserManager(UserManager):
        async def on_after_login(
            self,
            user: User,
            request: Request | None = None,
            response: Response | None = None,
        ) -> None:
            await super().on_after_login(user, request, response)
            print("User logged in: ", user)


    async def on_startup(app: FastAPI):
        await add_base_data()
        print("base data added")
        pass


    app = FastAPI(docs_url="/openapi/v1")
    app.add_middleware(
        CORSMiddleware,
        allow_origins=["http://localhost:6006"],
        allow_credentials=True,
        allow_methods=["*"],
        allow_headers=["*"],
    )

    toolkit = FastAPIReactToolkit(
        auth={
            "user_manager": CustomUserManager,
            # "password_helper": FABPasswordHelper(),  #! Add this line to use old password hash
        },
        create_tables=True,
        on_startup=on_startup,
    )
    toolkit.config.from_pyfile("./app/config.py")
    toolkit.initialize(app)

    from .apis import *
    ```
    """

    app: FastAPI = None
    apis: list[ModelRestApi] = None
    initialized: bool = False
    create_tables: bool = False
    exclude_apis: list[BASE_APIS] = None
    security: SecurityManager = None
    global_user_dependency = True
    upgrade_db = False
    on_startup: (
        Callable[[FastAPI], None] | Awaitable[Callable[[FastAPI], None]] | None
    ) = None
    on_shutdown: (
        Callable[[FastAPI], None] | Awaitable[Callable[[FastAPI], None]] | None
    ) = None
    _mounted = False

    def __init__(
        self,
        *,
        app: FastAPI | None = None,
        auth: AuthDict | None = None,
        create_tables: bool = False,
        exclude_apis: list[BASE_APIS] | None = None,
        global_user_dependency: bool = True,
        upgrade_db: bool = False,
        on_startup: (
            Callable[[FastAPI], None] | Awaitable[Callable[[FastAPI], None]] | None
        ) = None,
        on_shutdown: (
            Callable[[FastAPI], None] | Awaitable[Callable[[FastAPI], None]] | None
        ) = None,
    ) -> None:
        g.current_app = self
        self.apis = []
        self.exclude_apis = exclude_apis or []
        self.global_user_dependency = global_user_dependency
        self.upgrade_db = upgrade_db
        self.security = SecurityManager(self)
        self.create_tables = create_tables
        self.on_startup = on_startup
        self.on_shutdown = on_shutdown

        if auth:
            for key, value in auth.items():
                setattr(g.auth, key, value)

        if app:
            self.initialize(app)

    def initialize(self, app: FastAPI) -> None:
        """
        Initializes the FastAPI application.

        Args:
            app (FastAPI): The FastAPI application instance.

        Returns:
            None
        """
        if self.initialized:
            return

        self.initialized = True
        self.app = app

        # Read config once
        self._init_config()

        # Initialize the lifespan
        self._init_lifespan()

        # Add the GlobalsMiddleware
        self.app.add_middleware(GlobalsMiddleware)
        if self.global_user_dependency:
            self.app.router.dependencies.append(Depends(set_global_user()))

        # Add the APIs
        self._init_basic_apis()

    def add_api(self, api: ModelRestApi | type[ModelRestApi]):
        """
        Adds the specified API to the FastAPI application.

        Parameters:
        - api (ModelRestApi | type[ModelRestApi]): The API to add to the FastAPI application.

        Returns:
        - None

        Raises:
        - ValueError: If the API is added after the `mount()` method is called.
        """
        if self._mounted:
            raise ValueError(
                "API Mounted after mount() was called, please add APIs before calling mount()"
            )

        api = api if isinstance(api, ModelRestApi) else api()
        previous_api = next(
            (a for a in self.apis if a.resource_name == api.resource_name), None
        )
        if previous_api:
            logger.warn(
                f"API {api.resource_name} already exists, replacing with new API"
            )
            self.apis.remove(previous_api)
        self.apis.append(api)
        api.toolkit = self

    def total_permissions(self) -> list[str]:
        """
        Returns the total list of permissions required by all APIs.

        Returns:
        - list[str]: The total list of permissions.
        """
        permissions = []
        for api in self.apis:
            permissions.extend(getattr(api, "permissions", []))
        return list(set(permissions))

    def connect_to_database(self):
        """
        Connects to the database using the configured SQLAlchemy database URI.

        This method initializes the database session maker with the SQLAlchemy
        database URI specified in the configuration.

        Raises:
            ValueError: If the `SQLALCHEMY_DATABASE_URI` is not set in the configuration.
        """
        uri = g.config.get("SQLALCHEMY_DATABASE_URI")
        if not uri:
            logger.warn(
                "SQLALCHEMY_DATABASE_URI is not set in the configuration, skipping database connection, any database related operation will fail"
            )
            return

        binds = g.config.get("SQLALCHEMY_BINDS")
        session_manager.init_db(uri, binds)
        logger.info("Connected to database")
        logger.info(f"URI: {uri}")
        logger.info(f"Binds: {binds}")

    async def init_database(self):
        """
        Initializes the database by inserting permissions, APIs, roles, and their relationships.

        The initialization process is as follows:
        1. Inserts permissions into the database.
        2. Inserts APIs into the database.
        3. Inserts roles into the database.
        4. Inserts the relationship between permissions and APIs into the database.
        5. Inserts the relationship between permissions, APIs, and roles into the database.

        Returns:
            None
        """
        if not session_manager._engine:
            logger.warn(
                "Database not connected, skipping database initialization, any database related operation will fail"
            )
            return

        async with session_manager.session() as db:
            logger.info("INITIALIZING DATABASE")
            await self._insert_permissions(db)
            await self._insert_apis(db)
            await self._insert_roles(db)
            await self._associate_permission_with_api(db)
            await self._associate_permission_api_with_role(db)
            logger.info("DATABASE INITIALIZED")

    async def _insert_permissions(self, db: AsyncSession | Session):
        new_permissions = self.total_permissions()
        stmt = select(Permission).where(Permission.name.in_(new_permissions))
        result = await safe_call(db.scalars(stmt))
        existing_permissions = [permission.name for permission in result.all()]
        if len(new_permissions) == len(existing_permissions):
            return

        permission_objs = [
            Permission(name=permission)
            for permission in new_permissions
            if permission not in existing_permissions
        ]
        for permission in permission_objs:
            logger.info(f"ADDING PERMISSION {permission}")
            db.add(permission)
        await safe_call(db.commit())

    async def _insert_apis(self, db: AsyncSession | Session):
        new_apis = [api.__class__.__name__ for api in self.apis]
        stmt = select(Api).where(Api.name.in_(new_apis))
        result = await safe_call(db.scalars(stmt))
        existing_apis = [api.name for api in result.all()]
        if len(new_apis) == len(existing_apis):
            return

        api_objs = [Api(name=api) for api in new_apis if api not in existing_apis]
        for api in api_objs:
            logger.info(f"ADDING API {api}")
            db.add(api)
        await safe_call(db.commit())

    async def _insert_roles(self, db: AsyncSession | Session):
        new_roles = [g.admin_role, g.public_role]
        stmt = select(Role).where(Role.name.in_(new_roles))
        result = await safe_call(db.scalars(stmt))
        existing_roles = [role.name for role in result.all()]
        if len(new_roles) == len(existing_roles):
            return

        role_objs = [
            Role(name=role) for role in new_roles if role not in existing_roles
        ]
        for role in role_objs:
            logger.info(f"ADDING ROLE {role}")
            db.add(role)
        await safe_call(db.commit())

    async def _associate_permission_with_api(self, db: AsyncSession | Session):
        for api in self.apis:
            new_permissions = getattr(api, "permissions", [])
            if not new_permissions:
                continue

            # Get the api object
            api_stmt = select(Api).where(Api.name == api.__class__.__name__)
            api_obj = await safe_call(db.scalar(api_stmt))

            if not api_obj:
                raise ValueError(f"API {api.__class__.__name__} not found")

            permission_stmt = select(Permission).where(
                and_(
                    Permission.name.in_(new_permissions),
                    ~Permission.id.in_([p.permission_id for p in api_obj.permissions]),
                )
            )
            permission_result = await safe_call(db.scalars(permission_stmt))
            new_permissions = permission_result.all()

            if not new_permissions:
                continue

            for permission in new_permissions:
                permission_api_stmt = insert(PermissionApi).values(
                    permission_id=permission.id, api_id=api_obj.id
                )
                await safe_call(db.execute(permission_api_stmt))
                logger.info(f"ASSOCIATING PERMISSION {permission} WITH API {api_obj}")
            await safe_call(db.commit())

    async def _associate_permission_api_with_role(self, db: AsyncSession | Session):
        # Read config based roles
        roles_dict = g.config.get("ROLES") or g.config.get("FAB_ROLES", {})
        admin_ignored_apis: list[str] = []

        for role_name, role_permissions in roles_dict.items():
            role_stmt = select(Role).where(Role.name == role_name)
            role_result = await safe_call(db.scalars(role_stmt))
            role = role_result.first()
            if not role:
                role = Role(name=role_name)
                logger.info(f"ADDING ROLE {role}")
                db.add(role)

            for api_name, permission_name in role_permissions:
                admin_ignored_apis.append(api_name)
                permission_api_stmt = (
                    select(PermissionApi)
                    .where(
                        and_(Api.name == api_name, Permission.name == permission_name)
                    )
                    .join(Permission)
                    .join(Api)
                )
                permission_api = await safe_call(db.scalar(permission_api_stmt))
                if not permission_api:
                    permission_stmt = select(Permission).where(
                        Permission.name == permission_name
                    )
                    permission = await safe_call(db.scalar(permission_stmt))
                    if not permission:
                        permission = Permission(name=permission_name)
                        logger.info(f"ADDING PERMISSION {permission}")
                        db.add(permission)

                    stmt = select(Api).where(Api.name == api_name)
                    api = await safe_call(db.scalar(stmt))
                    if not api:
                        api = Api(name=api_name)
                        logger.info(f"ADDING API {api}")
                        db.add(api)

                    permission_api = PermissionApi(permission=permission, api=api)
                    logger.info(f"ADDING PERMISSION-API {permission_api}")
                    db.add(permission_api)

                # Associate role with permission-api
                if role not in permission_api.roles:
                    permission_api.roles.append(role)
                    logger.info(
                        f"ASSOCIATING {role} WITH PERMISSION-API {permission_api}"
                    )

                await safe_call(db.commit())

        # Get admin role
        admin_role_stmt = select(Role).where(Role.name == g.admin_role)
        admin_role = await safe_call(db.scalar(admin_role_stmt))

        if admin_role:
            # Get list of permission-api.assoc_permission_api_id of the admin role
            stmt = (
                select(PermissionApi)
                .where(
                    and_(
                        ~PermissionApi.roles.contains(admin_role),
                        ~Api.name.in_(admin_ignored_apis),
                    )
                )
                .join(Api)
            )
            result = await safe_call(db.scalars(stmt))
            existing_assoc_permission_api_roles = result.all()

            # Add admin role to all permission-api objects
            for permission_api in existing_assoc_permission_api_roles:
                permission_api.roles.append(admin_role)
                logger.info(
                    f"ASSOCIATING {admin_role} WITH PERMISSION-API {permission_api}"
                )
            await safe_call(db.commit())

    def _mount_static_folder(self):
        """
        Mounts the static folder specified in the configuration.

        Returns:
            None
        """
        # If the folder does not exist, create it
        os.makedirs(g.config.get("STATIC_FOLDER", DEFAULT_STATIC_FOLDER), exist_ok=True)

        static_folder = g.config.get("STATIC_FOLDER", DEFAULT_STATIC_FOLDER)
        self.app.mount("/static", StaticFiles(directory=static_folder), name="static")

    def _mount_template_folder(self):
        """
        Mounts the template folder specified in the configuration.

        Returns:
            None
        """
        # If the folder does not exist, create it
        os.makedirs(
            g.config.get("TEMPLATE_FOLDER", DEFAULT_TEMPLATE_FOLDER), exist_ok=True
        )

        templates = Jinja2Templates(
            directory=g.config.get("TEMPLATE_FOLDER", DEFAULT_TEMPLATE_FOLDER)
        )

        @self.app.get("/{full_path:path}", response_class=HTMLResponse)
        def index(request: Request):
            try:
                return templates.TemplateResponse(
                    request=request,
                    name="index.html",
                    context={"base_path": g.config.get("BASE_PATH", "/")},
                )
            except TemplateNotFound:
                raise HTTPException(status_code=404, detail="Not Found")

    """
    -----------------------------------------
         INIT FUNCTIONS
    -----------------------------------------
    """

    def _init_config(self):
        """
        Initializes the configuration for the FastAPI application.

        This method reads the configuration values from the `g.config` dictionary and sets the corresponding attributes
        of the FastAPI application.
        """
        if self.app:
            self.app.title = g.config.get("APP_NAME", "FastAPI React Toolkit")
            self.app.summary = g.config.get("APP_SUMMARY", self.app.summary)
            self.app.description = g.config.get("APP_DESCRIPTION", self.app.description)
            self.app.version = g.config.get("APP_VERSION", __version__)
            self.app.openapi_url = g.config.get("APP_OPENAPI_URL", self.app.openapi_url)

    def _init_lifespan(self):
        if g.is_migrate:
            return

        @asynccontextmanager
        async def lifespan(app: FastAPI):
            # Initialize the database connection
            self.connect_to_database()

            # Integrate the router for each API
            for api in self.apis:
                api.integrate_router(app)

            # Add the JS manifest route
            self._init_js_manifest()

            # Mount the static and template folders
            self._mounted = True
            self._mount_static_folder()
            self._mount_template_folder()

            await session_manager.init_fastapi_rtk_tables()

            if self.upgrade_db:
                await smart_run(upgrade)

            if self.create_tables and session_manager._engine:
                await session_manager.create_all()

            # Creating permission, apis, roles, and connecting them
            await self.init_database()

            # On startup
            if self.on_startup:
                parameter_length = len(
                    inspect.signature(self.on_startup).parameters.values()
                )
                if parameter_length > 0:
                    await safe_call(self.on_startup(app))
                else:
                    await safe_call(self.on_startup())

            yield

            # On shutdown
            if self.on_shutdown:
                parameter_length = len(
                    inspect.signature(self.on_shutdown).parameters.values()
                )
                if parameter_length > 0:
                    await safe_call(self.on_shutdown(app))
                else:
                    await safe_call(self.on_shutdown())

            # Run when the app is shutting down
            await session_manager.close()

        # Check whether lifespan is already set
        if not isinstance(self.app.router.lifespan_context, _DefaultLifespan):
            raise ValueError(
                "Lifespan already set, please do not set lifespan directly in the FastAPI app"
            )

        self.app.router.lifespan_context = lifespan

    def _init_basic_apis(self):

        apis = [
            AuthApi,
            InfoApi,
            PermissionsApi,
            PermissionViewApi,
            RolesApi,
            UsersApi,
            ViewsMenusApi,
        ]
        for api in apis:
            if api.__name__ in self.exclude_apis:
                continue
            self.add_api(api)

    def _init_js_manifest(self):
        @self.app.get("/server-config.js", response_class=StreamingResponse)
        def js_manifest():
            env = Environment(autoescape=select_autoescape(["html", "xml"]))
            template_string = "window.fab_react_config = {{ react_vars |tojson }}"
            template = env.from_string(template_string)
            rendered_string = template.render(
                react_vars=json.dumps(g.config.get("FAB_REACT_CONFIG", {}))
            )
            content = rendered_string.encode("utf-8")
            scriptfile = io.BytesIO(content)
            return StreamingResponse(
                scriptfile,
                media_type="application/javascript",
            )

    """
    -----------------------------------------
         PROPERTY FUNCTIONS
    -----------------------------------------
    """

    @property
    def config(self):
        return g.config
