"""Database executor helpers."""

from __future__ import annotations

from collections.abc import Callable
from concurrent.futures.thread import _threads_queues, _worker
import sys
import threading
from typing import Any
import weakref

from homeassistant.util.executor import InterruptibleThreadPoolExecutor


def _worker_with_shutdown_hook(
    shutdown_hook: Callable[[], None],
    recorder_and_worker_thread_ids: set[int],
    *args: Any,
    **kwargs: Any,
) -> None:
    """Create a worker that calls a function after its finished."""
    recorder_and_worker_thread_ids.add(threading.get_ident())
    _worker(*args, **kwargs)
    shutdown_hook()


class DBInterruptibleThreadPoolExecutor(InterruptibleThreadPoolExecutor):
    """A database instance that will not deadlock on shutdown."""

    def __init__(
        self, recorder_and_worker_thread_ids: set[int], *args: Any, **kwargs: Any
    ) -> None:
        """Init the executor with a shutdown hook support."""
        self._shutdown_hook: Callable[[], None] = kwargs.pop("shutdown_hook")
        self.recorder_and_worker_thread_ids = recorder_and_worker_thread_ids
        super().__init__(*args, **kwargs)

    def _adjust_thread_count(self) -> None:
        """Overridden to add support for shutdown hook.

        Based on the CPython 3.10 implementation.
        """
        # if idle threads are available, don't spin new threads
        if self._idle_semaphore.acquire(  # pylint: disable=consider-using-with
            timeout=0
        ):
            return

        # When the executor gets lost, the weakref callback will wake up
        # the worker threads.
        def weakref_cb(  # type: ignore[no-untyped-def]
            _: Any,
            q=self._work_queue,
        ) -> None:
            q.put(None)

        if sys.version_info >= (3, 14):
            additional_args = (
                self._create_worker_context(),
                self._work_queue,
            )
        else:
            additional_args = (
                self._work_queue,
                self._initializer,
                self._initargs,
            )

        num_threads = len(self._threads)
        if num_threads < self._max_workers:
            thread_name = f"{self._thread_name_prefix or self}_{num_threads}"
            executor_thread = threading.Thread(
                name=thread_name,
                target=_worker_with_shutdown_hook,
                args=(
                    self._shutdown_hook,
                    self.recorder_and_worker_thread_ids,
                    weakref.ref(self, weakref_cb),
                    *(additional_args),
                ),
            )
            executor_thread.start()
            self._threads.add(executor_thread)  # type: ignore[attr-defined]
            _threads_queues[executor_thread] = self._work_queue  # type: ignore[index]
