#!/usr/bin/env python

"""
Parallel sync is a feature designed to optimize
throughput by utilizing available CPUs/threads, particularly beneficial
in environments experiencing high network latency.

Scenario & Challenge:
In instances where your PG database, Elasticsearch/OpenSearch, and PGSync
servers operate on divergent networks, a delay in request/response time is
noticeable. The primary constraint emerges from the database query's roundtrip,
which even server-side cursors can address only to a limited extent by fetching
a certain number of records at a time. The consequent delay in fetching the
next cursor significantly hampers the overall synchronization speed.

Solution:
To mitigate this, the strategy is to conduct an initial fast/parallel sync,
thereby populating Elasticsearch/OpenSearch in a single iteration.
Post this, the regular pgsync can continue running as a daemon.

Approach and Technical Implementation:
The approach centers around utilizing the Tuple identifier record of the table
columns. Every table incorporates a system column – "ctid" of type "tid,"
which helps identify the page record and the row number in each block.
This element facilitates the pagination of the sync process.

Technically, pagination implies dividing each paged record amongst the
available CPUs/threads. This division enables the parallel execution of
Elasticsearch/OpenSearch bulk inserts. The "ctid" serves as a tuple
(for instance, (1, 5)), pinpointing the row in a disk page.

By leveraging this method, all paged row records are retrieved upfront and
allocated as work units across the worker threads/CPUs.
Each work unit, defined by the BLOCK_SIZE, denotes the number of root node
records assigned for each worker to process.

Subsequently, the workers execute queries for each assigned chunk of work,
filtered based on the page number and row numbers.
This systematic and parallel approach optimizes the synchronization process,
especially in environments challenged by network latency.
"""
import asyncio
import multiprocessing
import os
import re
import sys
import typing as t
from concurrent.futures import ProcessPoolExecutor, ThreadPoolExecutor
from dataclasses import dataclass
from queue import Queue
from threading import Thread

import click
import sqlalchemy as sa

from pgsync.settings import (
    BLOCK_SIZE,
    CHECKPOINT_PATH,
    S3_SCHEMA_URL,
    SCHEMA,
    SCHEMA_URL,
)
from pgsync.sync import Sync
from pgsync.utils import (
    config_loader,
    MutuallyExclusiveOption,
    show_settings,
    timeit,
    validate_config,
)


def save_ctid(page: int, row: int, filename: str) -> None:
    """
    Save the checkpoint for a given page and row in a file with the given name.

    Args:
        page (int): The page number to save.
        row (int): The row number to save.
        filename (str): The name of the file to save the checkpoint in.
    """
    filepath: str = os.path.join(CHECKPOINT_PATH, f".{filename}.ctid")
    with open(filepath, "w+") as fp:
        fp.write(f"{page},{row}\n")


def read_ctid(filename: str) -> t.Tuple[t.Optional[int], t.Optional[int]]:
    """
    Reads the checkpoint file for the given name and returns the page and row numbers.

    Args:
        filename (str): The name of the checkpoint file.

    Returns:
        tuple: A tuple containing the page and row numbers. If the checkpoint file does not exist, returns (None, None).
    """
    filepath: str = os.path.join(CHECKPOINT_PATH, f".{filename}.ctid")
    if os.path.exists(filepath):
        with open(filepath, "r") as fp:
            pairs: str = fp.read().split()[0].split(",")
            page: int = int(pairs[0])
            row: int = int(pairs[1])
            return page, row
    return None, None


def logical_slot_changes(
    doc: dict, verbose: bool = False, validate: bool = False
) -> None:
    # now sync up to txmax to capture everything we may have missed
    sync: Sync = Sync(doc, verbose=verbose, validate=validate)
    txmin: int = sync.checkpoint
    txmax: int = sync.txid_current
    sync.logical_slot_changes(txmin=txmin, txmax=txmax)
    sync.checkpoint = txmax or sync.txid_current


@dataclass
class Task:
    doc: dict
    verbose: bool = False
    validate: bool = False

    def process(self, task: dict) -> None:
        sync: Sync = Sync(
            self.doc, verbose=self.verbose, validate=self.validate
        )
        txmin: int = sync.checkpoint
        txmax: int = sync.txid_current
        sync.search_client.bulk(
            sync.index,
            sync.sync(ctid=task, txmin=txmin, txmax=txmax),
        )
        sys.stdout.write(f"Process pid: {os.getpid()} complete.\n")
        sys.stdout.flush()


@timeit
def fetch_tasks(
    doc: dict,
    block_size: t.Optional[int] = None,
) -> t.Generator:
    block_size = block_size or BLOCK_SIZE
    pages: dict = {}
    sync: Sync = Sync(doc)
    page: t.Optional[int] = None
    row: t.Optional[int] = None
    filename: str = re.sub(
        "[^0-9a-zA-Z_]+", "", f"{sync.database.lower()}_{sync.index}"
    )
    page, row = read_ctid(filename)
    statement: sa.sql.Select = sa.select(
        *[
            sa.literal_column("1").label("x"),
            sa.literal_column("1").label("y"),
            sa.column("ctid"),
        ]
    ).select_from(sync.tree.root.model)

    # filter by Page
    if page:
        statement = statement.where(
            sa.cast(
                sa.func.SPLIT_PART(
                    sa.func.REPLACE(
                        sa.func.REPLACE(
                            sa.cast(sa.column("ctid"), sa.Text),
                            "(",
                            "",
                        ),
                        ")",
                        "",
                    ),
                    ",",
                    1,
                ),
                sa.Integer,
            )
            > page
        )

    # filter by Row
    if row:
        statement = statement.where(
            sa.cast(
                sa.func.SPLIT_PART(
                    sa.func.REPLACE(
                        sa.func.REPLACE(
                            sa.cast(sa.column("ctid"), sa.Text),
                            "(",
                            "",
                        ),
                        ")",
                        "",
                    ),
                    ",",
                    2,
                ),
                sa.Integer,
            )
            > row
        )

    i: int = 1
    for _, _, ctid in sync.fetchmany(statement):
        value: list = ctid[0].split(",")
        page: int = int(value[0].replace("(", ""))
        row: int = int(value[1].replace(")", ""))
        pages.setdefault(page, [])
        pages[page].append(row)
        if i % block_size == 0:
            yield pages
            pages = {}
        i += 1
    yield pages


@timeit
def synchronous(
    tasks: t.Generator,
    doc: dict,
    verbose: bool = False,
    validate: bool = False,
) -> None:
    sys.stdout.write("Synchronous\n")
    sync: Sync = Sync(doc, verbose=verbose, validate=validate)
    txmin: int = sync.checkpoint
    txmax: int = sync.txid_current
    index: str = sync.index
    for task in tasks:
        sync.search_client.bulk(
            index,
            sync.sync(ctid=task, txmin=txmin, txmax=txmax),
        )
    logical_slot_changes(doc, verbose=verbose, validate=validate)


@timeit
def multithreaded(
    tasks: t.Generator,
    doc: dict,
    nthreads: t.Optional[int] = None,
    verbose: bool = False,
    validate: bool = False,
) -> None:
    sys.stdout.write("Multithreaded\n")

    def worker(sync: Sync, queue: Queue) -> None:
        txmin: int = sync.checkpoint
        txmax: int = sync.txid_current
        while True:
            task: dict = queue.get()
            sync.search_client.bulk(
                sync.index,
                sync.sync(ctid=task, txmin=txmin, txmax=txmax),
            )
            queue.task_done()

    nthreads: int = nthreads or 1
    queue: Queue = Queue()
    sync: Sync = Sync(doc, verbose=verbose, validate=validate)

    for _ in range(nthreads):
        thread: Thread = Thread(
            target=worker,
            args=(
                sync,
                queue,
            ),
        )
        thread.daemon = True
        thread.start()
    for task in tasks:
        queue.put(task)

    queue.join()  # block until all tasks are done
    logical_slot_changes(doc, verbose=verbose, validate=validate)


@timeit
def multiprocess(
    tasks: t.Generator,
    doc: dict,
    ncpus: t.Optional[int] = None,
    verbose: bool = False,
    validate: bool = False,
) -> None:
    sys.stdout.write("Multiprocess\n")
    task: Task = Task(doc, verbose=verbose, validate=validate)
    with ProcessPoolExecutor(max_workers=ncpus) as executor:
        try:
            list(executor.map(task.process, tasks))
        except Exception as e:
            sys.stdout.write(f"Exception: {e}\n")
    logical_slot_changes(doc, verbose=verbose, validate=validate)


@timeit
def multithreaded_async(
    tasks: t.Generator,
    doc: dict,
    nthreads: t.Optional[int] = None,
    verbose: bool = False,
    validate: bool = False,
) -> None:
    sys.stdout.write("Multi-threaded async\n")
    executor: ThreadPoolExecutor = ThreadPoolExecutor(max_workers=nthreads)
    event_loop = asyncio.get_event_loop()
    event_loop.run_until_complete(
        run_tasks(executor, tasks, doc, verbose=verbose, validate=validate)
    )
    logical_slot_changes(doc, verbose=verbose, validate=validate)


@timeit
def multiprocess_async(
    tasks: t.Generator,
    doc: dict,
    ncpus: t.Optional[int] = None,
    verbose: bool = False,
    validate: bool = False,
) -> None:
    sys.stdout.write("Multi-process async\n")
    executor: ProcessPoolExecutor = ProcessPoolExecutor(max_workers=ncpus)
    event_loop = asyncio.get_event_loop()
    try:
        event_loop.run_until_complete(
            run_tasks(executor, tasks, doc, verbose=verbose, validate=validate)
        )
    except KeyboardInterrupt:
        pass
    logical_slot_changes(doc, verbose=verbose, validate=validate)


async def run_tasks(
    executor: t.Union[ThreadPoolExecutor, ProcessPoolExecutor],
    tasks: t.Generator,
    doc: dict,
    verbose: bool = False,
    validate: bool = False,
) -> None:
    sync: t.Optional[Sync] = None
    if isinstance(executor, ThreadPoolExecutor):
        # threads can share a common Sync object
        sync = Sync(doc, verbose=verbose, validate=validate)
    event_loop = asyncio.get_event_loop()
    completed, _ = await asyncio.wait(
        [
            event_loop.run_in_executor(
                executor, run_task, task, sync, doc, verbose, validate
            )
            for task in tasks
        ]
    )
    results: list = [task.result() for task in completed]
    print("results: {!r}".format(results))
    print("exiting")


def run_task(
    task: dict,
    sync: t.Optional[Sync] = None,
    doc: t.Optional[dict] = None,
    verbose: bool = False,
    validate: bool = False,
) -> int:
    if sync is None:
        sync: Sync = Sync(doc, verbose=verbose, validate=validate)
    txmin: int = sync.checkpoint
    txmax: int = sync.txid_current
    sync.search_client.bulk(
        sync.index,
        sync.sync(ctid=task, txmin=txmin, txmax=txmax),
    )
    if len(task) > 0:
        page: int = max(task.keys())
        row: int = max(task[page])
        filename: str = re.sub(
            "[^0-9a-zA-Z_]+", "", f"{sync.database.lower()}_{sync.index}"
        )
        save_ctid(page, row, filename)

    return 1


@click.command()
@click.option(
    "--config",
    "-c",
    help="Schema config",
    type=click.Path(exists=True),
    default=SCHEMA,
    show_default=True,
    cls=MutuallyExclusiveOption,
    mutually_exclusive=["s3_schema_url", "schema_url"],
)
@click.option(
    "--schema_url",
    help="URL for schema config",
    type=click.STRING,
    default=SCHEMA_URL,
    show_default=True,
    cls=MutuallyExclusiveOption,
    mutually_exclusive=["config", "s3_schema_url"],
)
@click.option(
    "--s3_schema_url",
    help="S3 URL for schema config",
    type=click.STRING,
    default=S3_SCHEMA_URL,
    show_default=True,
    cls=MutuallyExclusiveOption,
    mutually_exclusive=["config", "schema_url"],
)
@click.option(
    "--verbose",
    "-v",
    is_flag=True,
    default=False,
    help="Turn on verbosity",
)
@click.option(
    "--nprocs",
    "-n",
    help="Number of threads/process",
    type=int,
    default=multiprocessing.cpu_count() * 2,
)
@click.option(
    "--mode",
    "-m",
    help="Sync mode",
    type=click.Choice(
        [
            "synchronous",
            "multithreaded",
            "multiprocess",
            "multithreaded_async",
            "multiprocess_async",
        ],
        case_sensitive=False,
    ),
    default="multiprocess_async",
)
def main(
    config: str,
    schema_url: str,
    s3_schema_url: str,
    nprocs: int,
    mode: str,
    verbose: bool,
) -> None:
    """
    TODO:
    - Track progress across cpus/threads
    - Handle KeyboardInterrupt Exception
    """

    validate_config(
        config=config, schema_url=schema_url, s3_schema_url=s3_schema_url
    )

    show_settings(
        config=config, schema_url=schema_url, s3_schema_url=s3_schema_url
    )

    for doc in config_loader(
        config=config, schema_url=schema_url, s3_schema_url=s3_schema_url
    ):
        tasks: t.Generator = fetch_tasks(doc)
        if mode == "synchronous":
            synchronous(tasks, doc, verbose=verbose)
        elif mode == "multithreaded":
            multithreaded(tasks, doc, nthreads=nprocs, verbose=verbose)
        elif mode == "multiprocess":
            multiprocess(tasks, doc, ncpus=nprocs, verbose=verbose)
        elif mode == "multithreaded_async":
            multithreaded_async(tasks, doc, nthreads=nprocs, verbose=verbose)
        elif mode == "multiprocess_async":
            multiprocess_async(tasks, doc, ncpus=nprocs, verbose=verbose)


if __name__ == "__main__":
    main()
