"""OmnibusX SDK for Python."""

import base64
import json
import os
import shutil
import tempfile
import time
import uuid
import zipfile
from collections.abc import Callable
from pathlib import Path

import httpx
import pandas as pd
from tqdm import tqdm

from .api_client import ApiClient
from .models import (
    AddFileUploadPayload,
    AddTaskParams,
    AddTaskResponse,
    BatchInfo,
    DataFormat,
    FileUploadChunk,
    ImportOmnibusXFileParams,
    PreprocessDatasetParams,
    SequencingPlatform,
    SequencingTechnology,
    Species,
    TaskLog,
    TaskStatus,
    TaskType,
    UploadFilesResponse,
    UploadProgress,
    UserGroup,
)

AUTH0_DOMAIN = "omnibusx.us.auth0.com"
CLIENT_ID = "695G9N2XeZRjme75lAbjqC80yq28cpUn"
API_AUDIENCE = "https://api-prod.omnibusx.com"
SCOPES = "openid profile email"
CHUNK_SIZE = 5 * 1024 * 1024  # 5MB in bytes


def _decode_jwt_payload(token: str) -> dict:
    """Decode JWT token payload to extract user information.

    Args:
        token: The JWT token string

    Returns:
        Dictionary containing the decoded payload

    """
    try:
        # JWT has 3 parts: header.payload.signature
        payload_part = token.split(".")[1]
        # Add padding if needed for base64 decoding
        padding = 4 - len(payload_part) % 4
        if padding != 4:
            payload_part += "=" * padding
        # Decode base64
        decoded = base64.urlsafe_b64decode(payload_part)
        return json.loads(decoded)
    except Exception as e:
        print(f"Error decoding JWT token: {e}")
        return {}


class SDKClient:
    """Base class for OmnibusX SDK client."""

    def __init__(self, server_url: str, enable_https: bool = True) -> None:
        """Initialize the SDK client with base URL and authentication token."""
        self._server_url = server_url
        self._enable_https = enable_https
        self._access_token = None
        self._user_email = None
        self._cache_path = Path.cwd().joinpath(".omnibusx_token_cache.json")
        self._client = None

    def _start_device_authorization(self) -> None:
        """Request a device code from Auth0 for user authentication."""
        url = f"https://{AUTH0_DOMAIN}/oauth/device/code"
        payload = {
            "client_id": CLIENT_ID,
            "scope": SCOPES,
            "audience": API_AUDIENCE,
        }
        try:
            response = httpx.post(url, data=payload)
            response.raise_for_status()
            return response.json()
        except httpx.HTTPStatusError as e:
            print(f"Error starting device authorization: {e}")
            return None

    def _poll_for_token(self, device_code_info: dict) -> dict | None:
        """Poll the token endpoint until the user completes authentication."""
        url = f"https://{AUTH0_DOMAIN}/oauth/token"
        payload = {
            "client_id": CLIENT_ID,
            "device_code": device_code_info["device_code"],
            "grant_type": "urn:ietf:params:oauth:grant-type:device_code",
        }
        interval = device_code_info.get("interval", 5)
        print("Please complete the authentication in your browser.")
        while True:
            time.sleep(interval)
            try:
                response = httpx.post(url, data=payload)
                data = response.json()

                if response.status_code == 200:
                    print("Authentication successful.")
                    return data
                if data.get("error") == "authorization_pending":
                    print(".", end="", flush=True)
                    continue
                if data.get("error") == "slow_down":
                    interval += 5
                    print(f"Slowing down polling to {interval} seconds.")
                    continue
                print(
                    f"Error during authentication: {data.get('error_description', 'Unknown error')}"
                )
            except httpx.HTTPStatusError as e:
                print(f"HTTP error during token polling: {e}")
                return None
            else:
                return None

    def _load_token_from_cache(self) -> bool:
        """Load the access token from cache if it exists."""
        if self._cache_path.exists():
            try:
                with open(self._cache_path, "r") as file:
                    cached_data = json.load(file)
                    # Check if token is expired
                    if cached_data.get("expires_at", 0) > time.time() + 60:
                        self._access_token = cached_data.get("access_token")
                        self._user_email = cached_data.get("user_email")
                        return True
            except Exception as e:
                print(f"Error reading token cache: {e}")
        return False

    def _save_token_to_cache(self, token_data: dict) -> None:
        expires_at = token_data.get("expires_in", 0) + int(time.time())
        cache_content = {
            "access_token": token_data.get("access_token"),
            "user_email": self._user_email,
            "expires_at": expires_at,
        }
        with open(self._cache_path, "w") as file:
            json.dump(cache_content, file)

    def authenticate(self, cache_token: bool = True) -> bool:
        if cache_token and self._load_token_from_cache():
            self._client = ApiClient(
                base_url=self._server_url,
                token=self._access_token,
                user_email=self._user_email,
                enable_https=self._enable_https,
            )
            return True

        device_code_data = self._start_device_authorization()
        if not device_code_data:
            return False

        print("=== ACTION REQUIRED ===")
        print(
            f"1. Open this URL in your browser: {device_code_data['verification_uri_complete']}"
        )
        print(
            f"2. Make sure the code displayed in the browser matches: {device_code_data['user_code']}"
        )
        print("3. Follow the instructions to complete authentication.")

        token_data = self._poll_for_token(device_code_data)
        if token_data and "access_token" in token_data:
            self._access_token = token_data["access_token"]

            # Extract user email from ID token
            if "id_token" in token_data:
                id_payload = _decode_jwt_payload(token_data["id_token"])
                self._user_email = id_payload.get("email", "")
                if self._user_email:
                    print(f"Authenticated as: {self._user_email}")
            else:
                print("Warning: No ID token received, email may not be available")
                self._user_email = ""

            self._client = ApiClient(
                base_url=self._server_url,
                token=self._access_token,
                user_email=self._user_email,
                enable_https=self._enable_https,
            )
            if cache_token:
                self._save_token_to_cache(token_data)
            return True
        return False

    def clear_token_cache(self) -> None:
        """Clear the cached access token."""
        if self._cache_path.exists():
            try:
                self._cache_path.unlink()
                print("Token cache cleared.")
            except Exception as e:
                print(f"Error clearing token cache: {e}")
        else:
            print("No token cache found to clear.")

    def _add_task(self, task_params: AddTaskParams) -> AddTaskResponse:
        """Add a new task to the OmnibusX API."""
        response = self._client.post("/api/tasks/add", data=task_params.model_dump())
        return AddTaskResponse(**response)

    def _commit_task(self, task_id: str) -> None:
        """Commit a task by its ID."""
        self._client.post("/api/tasks/commit-result", data={"task_id": task_id})

    def _get_task(self, task_id: str) -> TaskLog:
        """Get detailed information about a specific task."""
        response = self._client.get("/api/tasks/get", params={"task_id": task_id})
        return TaskLog(**response)

    def test_connection(self) -> bool:
        """Test the connection to the OmnibusX API."""
        try:
            response = self._client.get("/health-check")
            if response.get("status") != "ok":
                raise ValueError(response.get("message", "Connection failed"))
        except Exception as e:
            print(f"Connection test failed: {e}")
            return False
        else:
            print("Connection successful")
            return True

    def get_available_groups(self) -> list[UserGroup]:
        """Get a list of available user groups."""
        response = self._client.get("/api/user-groups/get")
        return [
            UserGroup(
                user_group_id=group["id"],
                name=group["name"],
                description=group["description"],
            )
            for group in response
        ]

    def import_omnibusx_file(self, omnibusx_file_path: str, group_id: str) -> str:
        """Import an OmnibusX file."""
        params = ImportOmnibusXFileParams(
            omnibusx_file_path=omnibusx_file_path, group_id=group_id
        )
        task_params = AddTaskParams(
            task_type=TaskType.IMPORT_OMNIBUSX_FILE, params=params
        )
        add_task_response = self._add_task(task_params)
        return add_task_response.task_id

    def preprocess_dataset(self, params: PreprocessDatasetParams, group_id: str) -> str:
        """Preprocess a dataset with validated parameters.

        Args:
            params: PreprocessDatasetParams object with dataset configuration
            group_id: The group ID to associate with the task

        Returns:
            Task ID for monitoring the preprocessing job

        Note:
            This method expects file paths in params.batches to be server-side paths.
            If you have local files, use upload_and_preprocess_dataset() instead.
            For spatial data with metadata files, use upload_file_from_meta() first.

        Example (scRNA-seq):
            from omnibusx_sdk import (
                SDKClient, PreprocessDatasetParams, BatchInfo,
                Species, SequencingTechnology, SequencingPlatform, DataFormat
            )

            client = SDKClient(server_url="https://api-prod.omnibusx.com")
            client.authenticate()

            params = PreprocessDatasetParams(
                name="My scRNA-seq Dataset",
                description="Dataset description",
                batches=[
                    BatchInfo(
                        file_path="/server/path/to/data.h5ad",
                        batch_name="Batch 1"
                    )
                ],
                gene_reference_version=111,
                gene_reference_id=Species.HUMAN,
                technology=SequencingTechnology.SC_RNA_SEQ,
                platform=SequencingPlatform.ScRnaSeq.CHROMIUM_10X,
                data_format=DataFormat.SCANPY,
            )

            task_id = client.preprocess_dataset(params, group_id="your-group-id")

        Example (Xenium spatial data):
            params = PreprocessDatasetParams(
                name="Xenium Spatial Dataset",
                description="Spatial transcriptomics data",
                batches=[
                    BatchInfo(
                        file_path="/server/path/to/meta_xenium.tsv",
                        batch_name="Spatial Data"
                    )
                ],
                gene_reference_version=111,
                gene_reference_id=Species.HUMAN,
                technology=SequencingTechnology.WELL_BASED_SPATIAL,
                platform=SequencingPlatform.WellBasedSpatial.XENIUM,
                data_format=DataFormat.XENIUM,
            )

            task_id = client.preprocess_dataset(params, group_id="your-group-id")

        Example (Visium HD spatial data):
            params = PreprocessDatasetParams(
                name="Visium HD Spatial Dataset",
                description="Visium HD spatial data",
                batches=[
                    BatchInfo(
                        file_path="/server/path/to/meta_visium_hd.tsv",
                        batch_name="Visium HD Data"
                    )
                ],
                gene_reference_version=111,
                gene_reference_id=Species.HUMAN,
                technology=SequencingTechnology.WELL_BASED_SPATIAL,
                platform=SequencingPlatform.WellBasedSpatial.VISIUM_HD_10X,
                data_format=DataFormat.VISIUM_HD,
            )

            task_id = client.preprocess_dataset(params, group_id="your-group-id")

        """
        # Convert params to dict and add group_id
        params_dict = params.model_dump()
        params_dict["group_id"] = group_id

        task_params = AddTaskParams(
            task_type=TaskType.PREPROCESS_DATASET, params=params_dict
        )
        add_task_response = self._add_task(task_params)
        return add_task_response.task_id

    def upload_and_preprocess_dataset(
        self,
        params: PreprocessDatasetParams,
        group_id: str,
        progress_callback: Callable[[UploadProgress], None] | None = None,
        show_progress: bool = True,
        metadata_sep: str = ",",
    ) -> str:
        """Upload local files and preprocess dataset in one step.

        This is a convenience method that handles different upload strategies based on
        the sequencing technology:
        - For scRNA-seq: Uploads data files directly
        - For spatial data: Uses metadata files to orchestrate uploads

        Workflow:
        1. Uploads all local files specified in params.batches (or metadata files for spatial)
        2. Updates the file paths to point to uploaded server locations
        3. Submits the preprocessing task

        Args:
            params: PreprocessDatasetParams with LOCAL file paths
            group_id: The group ID to associate with the task
            progress_callback: Optional callback for upload progress tracking
            show_progress: Whether to display upload progress (default: True)
            metadata_sep: Separator used in metadata files for spatial data (default: ',')

        Returns:
            Task ID for monitoring the preprocessing job

        Example (scRNA-seq):
            from omnibusx_sdk import (
                SDKClient, PreprocessDatasetParams, BatchInfo,
                Species, SequencingTechnology, SequencingPlatform, DataFormat
            )

            client = SDKClient(server_url="https://api-prod.omnibusx.com")
            client.authenticate()

            # Specify LOCAL file paths
            params = PreprocessDatasetParams(
                name="My scRNA-seq Dataset",
                description="Dataset from local files",
                batches=[
                    BatchInfo(
                        file_path="/Users/me/data/sample1.h5ad",
                        batch_name="Sample 1"
                    ),
                    BatchInfo(
                        file_path="/Users/me/data/sample2.h5ad",
                        batch_name="Sample 2"
                    )
                ],
                gene_reference_version=111,
                gene_reference_id=Species.HUMAN,
                technology=SequencingTechnology.SC_RNA_SEQ,
                platform=SequencingPlatform.ScRnaSeq.CHROMIUM_10X,
                data_format=DataFormat.SCANPY,
            )

            task_id = client.upload_and_preprocess_dataset(
                params, group_id="your-group-id"
            )

        Example (Xenium spatial data):
            # Create a local metadata CSV file (samples.csv) with headers:
            # sample_id,data_path
            # sample1,/path/to/sample1_data.zarr.zip
            # sample2,/path/to/sample2_data.zarr.zip

            params = PreprocessDatasetParams(
                name="Xenium Spatial Dataset",
                description="Spatial transcriptomics",
                batches=[
                    BatchInfo(
                        file_path="/Users/me/metadata/samples.csv",
                        batch_name="Xenium Samples"
                    )
                ],
                gene_reference_version=111,
                gene_reference_id=Species.HUMAN,
                technology=SequencingTechnology.WELL_BASED_SPATIAL,
                platform=SequencingPlatform.WellBasedSpatial.XENIUM,
                data_format=DataFormat.XENIUM,
            )

            task_id = client.upload_and_preprocess_dataset(
                params, group_id="your-group-id", metadata_sep=","
            )

        """
        if params.technology == SequencingTechnology.SC_RNA_SEQ:
            # scRNA-seq workflow: Direct file upload
            local_file_paths = [batch.file_path for batch in params.batches]

            print(f"Uploading {len(local_file_paths)} file(s)...")

            upload_response = self.upload_files(
                file_paths=local_file_paths,
                group_id=group_id,
                progress_callback=progress_callback,
                show_progress=show_progress,
            )

            print(
                f"Upload complete! Files uploaded to: {upload_response.folder_path}\n"
            )

            # Create new batches with server paths
            updated_batches = []
            for batch, local_path in zip(params.batches, local_file_paths):
                filename = Path(local_path).name
                server_path = f"{upload_response.folder_path}/{filename}"
                updated_batches.append(
                    BatchInfo(file_path=server_path, batch_name=batch.batch_name)
                )

        elif params.technology == SequencingTechnology.WELL_BASED_SPATIAL:
            # Spatial workflow: Use metadata files to upload data
            print(
                f"Processing {len(params.batches)} metadata file(s) for spatial data...\n"
            )

            updated_batches = []
            for batch in params.batches:
                print(f"Processing metadata file for batch: {batch.batch_name}")

                # Upload files referenced in metadata and get server TSV path
                server_tsv_path = self.upload_file_from_meta(
                    file_path=batch.file_path,
                    group_id=group_id,
                    technology=params.technology,
                    platform=params.platform,
                    sep=metadata_sep,
                    show_progress=show_progress,
                )

                # The server_tsv_path is a local file, we need to upload it
                print(f"\nUploading metadata TSV file...")
                upload_response = self.upload_files(
                    file_paths=[server_tsv_path],
                    group_id=group_id,
                    show_progress=show_progress,
                )

                # Get the server path for the uploaded TSV
                tsv_filename = Path(server_tsv_path).name
                server_metadata_path = f"{upload_response.folder_path}/{tsv_filename}"

                updated_batches.append(
                    BatchInfo(
                        file_path=server_metadata_path, batch_name=batch.batch_name
                    )
                )

                # Clean up local TSV file
                Path(server_tsv_path).unlink(missing_ok=True)
                print()

        else:
            raise ValueError(
                f"Unsupported technology for upload_and_preprocess_dataset: {params.technology}"
            )

        # Create new params with updated server paths
        updated_params = PreprocessDatasetParams(
            name=params.name,
            description=params.description,
            batches=updated_batches,
            gene_reference_version=params.gene_reference_version,
            gene_reference_id=params.gene_reference_id,
            technology=params.technology,
            platform=params.platform,
            data_format=params.data_format,
        )

        # Submit preprocessing task
        print("Submitting preprocessing task...")
        task_id = self.preprocess_dataset(updated_params, group_id=group_id)

        print(f"Preprocessing task submitted! Task ID: {task_id}\n")
        return task_id

    def get_task_info(self, task_id: str, interval: int = 5) -> TaskLog:
        """Get information about a specific task."""
        while True:
            task_info = self._get_task(task_id)
            print(task_info.log, end="\r", flush=True)
            if task_info.status not in (TaskStatus.SUCCESS, TaskStatus.FAILED):
                time.sleep(interval)
            elif task_info.status == TaskStatus.SUCCESS and not task_info.is_committed:
                self._commit_task(task_id)
                break
            else:
                break
        return task_info

    def _upload_chunk_with_retry(
        self,
        chunk_data: bytes,
        chunk_id: str,
        folder_id: str,
        chunk_index: int,
        n_chunks: int,
        group_id: str = "",
        max_retries: int = 5,
    ) -> dict:
        """Upload a chunk with exponential backoff retry logic."""
        for attempt in range(max_retries):
            try:
                response = self._client.post_upload(
                    "/file-upload/upload-chunk",
                    file_data=chunk_data,
                    params={
                        "id": chunk_id,
                        "folder_id": folder_id,
                        "chunk": chunk_index,
                        "n_chunks": n_chunks,
                    },
                    group_id=group_id,
                )
                return response
            except Exception as e:
                if attempt == max_retries - 1:
                    raise
                delay = 2**attempt  # Exponential backoff: 1, 2, 4, 8, 16 seconds
                print(
                    f"Chunk upload failed (attempt {attempt + 1}/{max_retries}), "
                    f"retrying in {delay}s... Error: {e}"
                )
                time.sleep(delay)

    def upload_files(
        self,
        file_paths: list[str],
        group_id: str,
        progress_callback: Callable[[UploadProgress], None] | None = None,
        show_progress: bool = True,
    ) -> UploadFilesResponse:
        """Upload multiple files to the server using chunked upload.

        Args:
            file_paths: List of file paths to upload
            group_id: The group ID to associate with the uploaded files
            progress_callback: Optional callback function for custom progress handling
            show_progress: Whether to display built-in progress output (default: True)

        Returns:
            UploadFilesResponse containing folder_id and folder_path

        """
        folder_id = str(uuid.uuid4())

        # Calculate total size and chunks for progress tracking
        total_chunks = 0
        total_bytes = 0
        file_sizes = []
        for file_path in file_paths:
            path = Path(file_path)
            if not path.exists():
                raise FileNotFoundError(f"File not found: {file_path}")
            file_size = path.stat().st_size
            file_sizes.append(file_size)
            total_bytes += file_size
            total_chunks += (file_size + CHUNK_SIZE - 1) // CHUNK_SIZE

        total_files = len(file_paths)
        done_chunks = 0
        done_files = 0
        uploaded_bytes = 0

        upload_response = None

        # Create overall progress bar with bytes
        pbar = None
        if show_progress:
            pbar = tqdm(
                total=total_bytes,
                desc="Uploading files",
                unit="B",
                unit_scale=True,
                unit_divisor=1024,
                bar_format="{desc}: {percentage:3.0f}%|{bar}| {n_fmt}/{total_fmt} [{elapsed}<{remaining}, {rate_fmt}]",
            )

        # Upload each file
        for file_index, file_path in enumerate(file_paths):
            path = Path(file_path)
            file_size = file_sizes[file_index]
            n_chunks = (file_size + CHUNK_SIZE - 1) // CHUNK_SIZE

            # Update progress bar description with current file
            if pbar:
                pbar.set_description(f"[{done_files + 1}/{total_files}] {path.name}")

            # Create progress object
            progress = UploadProgress(
                total_files=total_files,
                total_chunks=total_chunks,
                done_files=done_files,
                done_chunks=done_chunks,
                current_file=path.name,
            )

            # Call custom callback if provided
            if progress_callback:
                progress_callback(progress)

            # Step 1: Add file upload (get chunk IDs)
            add_file_payload = AddFileUploadPayload(
                original_filename=path.name, n_chunks=n_chunks, folder_id=folder_id
            )
            chunks_response = self._client.post(
                "/file-upload/add",
                data=add_file_payload.model_dump(),
                group_id=group_id,
            )
            chunks = [FileUploadChunk(**chunk) for chunk in chunks_response]

            # Step 2: Upload each chunk
            with open(file_path, "rb") as file:
                start = 0
                for chunk_index, chunk_info in enumerate(chunks):
                    # Read chunk data
                    chunk_end = min(start + CHUNK_SIZE, file_size)
                    file.seek(start)
                    chunk_data = file.read(chunk_end - start)
                    chunk_size_bytes = len(chunk_data)
                    start = chunk_end

                    # Upload chunk with retry logic
                    upload_response = self._upload_chunk_with_retry(
                        chunk_data=chunk_data,
                        chunk_id=chunk_info.id,
                        folder_id=folder_id,
                        chunk_index=chunk_index,
                        n_chunks=n_chunks,
                        group_id=group_id,
                    )

                    # Update progress
                    done_chunks += 1
                    uploaded_bytes += chunk_size_bytes
                    if pbar:
                        pbar.update(chunk_size_bytes)

                    # Update progress object and call callback
                    progress = UploadProgress(
                        total_files=total_files,
                        total_chunks=total_chunks,
                        done_files=done_files,
                        done_chunks=done_chunks,
                        current_file=path.name,
                    )
                    if progress_callback:
                        progress_callback(progress)

            # File upload complete
            done_files += 1

            # Final progress for this file
            progress = UploadProgress(
                total_files=total_files,
                total_chunks=total_chunks,
                done_files=done_files,
                done_chunks=done_chunks,
                current_file=path.name,
            )
            if progress_callback:
                progress_callback(progress)

        # Close progress bar
        if pbar:
            pbar.close()
            print(
                f"✓ Upload complete! All {total_files} file(s) uploaded successfully."
            )

        # Return the final response
        if upload_response is None:
            raise ValueError("No files were uploaded")

        return UploadFilesResponse(**upload_response)

    def upload_file_from_meta(
        self,
        file_path: str,
        group_id: str,
        technology: str,
        platform: str,
        sep: str = ",",
        show_progress: bool = True,
    ) -> str:
        """Upload files from a metadata file and generate updated metadata with server paths.

        This function reads a metadata CSV/TSV file containing sample information and file paths,
        uploads the referenced files to the server, and generates a new metadata file with
        server-side paths.

        Args:
            file_path: Path to the metadata file (CSV or TSV with headers)
            group_id: The group ID to associate with the uploaded files
            technology: Sequencing technology (e.g., SequencingTechnology.WELL_BASED_SPATIAL)
            platform: Sequencing platform (e.g., SequencingPlatform.WellBasedSpatial.XENIUM)
            sep: Separator used in the metadata file (default: ',')
            show_progress: Whether to display upload progress (default: True)

        Returns:
            Path to the generated TSV file with updated server paths (includes headers)

        Metadata file format (with headers):
            - Files must include a header row (column names can vary)
            - Columns are accessed by position, not by name
            - For XENIUM: column1=sample_name, column2=file_path
            - For VISIUM_10X/VISIUM_HD_10X: column1=sample_name, column2=expression_matrix_path,
              column3=spatial_folder_path

        Example:
            # Create samples.csv with headers:
            # sample_id,data_path
            # sample1,/path/to/sample1.zarr.zip
            # sample2,/path/to/sample2.zarr.zip

            client = SDKClient(server_url="https://api-prod.omnibusx.com")
            client.authenticate()

            output_file = client.upload_file_from_meta(
                file_path="samples.csv",
                group_id="your-group-id",
                technology=SequencingTechnology.WELL_BASED_SPATIAL,
                platform=SequencingPlatform.WellBasedSpatial.XENIUM,
                sep=","
            )

        """
        # Read the metadata file with headers
        df = pd.read_csv(file_path, sep=sep, header=0)

        # Determine columns based on technology and platform
        if technology == SequencingTechnology.WELL_BASED_SPATIAL:
            if platform == SequencingPlatform.WellBasedSpatial.XENIUM:
                # Format: sample_name, file_path
                sample_col = 0
                file_cols = [1]
            elif platform in (
                SequencingPlatform.WellBasedSpatial.VISIUM_10X,
                SequencingPlatform.WellBasedSpatial.VISIUM_HD_10X,
            ):
                # Format: sample_name, expression_matrix, spatial_folder
                sample_col = 0
                file_cols = [1, 2]
            else:
                raise ValueError(
                    f"Unsupported platform for metadata upload: {platform}"
                )
        else:
            raise ValueError(
                f"Unsupported technology for metadata upload: {technology}"
            )

        # Create a temporary directory for zipped folders
        temp_dir = Path(tempfile.mkdtemp())

        try:
            # Track uploads: mapping from original path to file-to-upload path
            # For files: original_path -> original_path
            # For folders: original_path -> zipped_file_path
            path_to_upload_mapping = {}

            # Process each row to collect files to upload
            for idx, row in df.iterrows():
                for col_idx in file_cols:
                    # Access column by position
                    file_location = os.path.normpath(str(row.iloc[col_idx]))
                    file_path_obj = Path(file_location)

                    if file_location in path_to_upload_mapping:
                        # Already processed this file
                        continue

                    if file_path_obj.is_file():
                        # File exists, map to itself for upload
                        path_to_upload_mapping[file_location] = file_location
                    elif file_path_obj.is_dir():
                        # Directory - need to zip it first
                        zip_filename = f"{file_path_obj.name}.zip"
                        zip_path = temp_dir / zip_filename

                        print(f"Zipping folder: {file_location}")
                        self._zip_folder(file_path_obj, zip_path)

                        # Map original folder path to zipped file path
                        path_to_upload_mapping[file_location] = str(zip_path)
                    else:
                        raise FileNotFoundError(
                            f"File or folder not found: {file_location}"
                        )

            # Upload all files
            if path_to_upload_mapping:
                files_to_upload = list(path_to_upload_mapping.values())
                print(f"\nUploading {len(files_to_upload)} file(s)...")
                upload_response = self.upload_files(
                    file_paths=files_to_upload,
                    group_id=group_id,
                    show_progress=show_progress,
                )

                # Build mapping from original paths to server paths
                # Now we explicitly map each original path to its server path
                original_to_server_mapping = {}
                for original_path, upload_path in path_to_upload_mapping.items():
                    filename = Path(upload_path).name
                    server_path = f"{upload_response.folder_path}/{filename}"
                    original_to_server_mapping[original_path] = server_path
            else:
                raise ValueError("No files found to upload in metadata file")

            # Create new dataframe with server paths
            new_df = df.copy()
            for idx, row in new_df.iterrows():
                for col_idx in file_cols:
                    # Access column by position
                    original_path = os.path.normpath(str(row.iloc[col_idx]))
                    # Update using iloc for position-based assignment
                    new_df.iloc[idx, col_idx] = original_to_server_mapping[original_path]

            # Write to TSV file in current directory with headers
            input_filename = Path(file_path).stem
            output_filename = f"{input_filename}_server_paths.tsv"
            output_path = Path.cwd() / output_filename

            new_df.to_csv(output_path, sep="\t", header=True, index=False)
            print(f"\n✓ Metadata file with server paths saved to: {output_path}")

            return str(output_path)

        finally:
            # Clean up temporary directory
            if temp_dir.exists():
                shutil.rmtree(temp_dir)

    def _zip_folder(self, source_folder: Path, destination_zip: Path) -> None:
        """Zip a folder to a destination path.

        Args:
            source_folder: Path to the folder to zip
            destination_zip: Path where the zip file should be created

        """
        with zipfile.ZipFile(destination_zip, "w", zipfile.ZIP_DEFLATED) as zipf:
            # Walk through the directory
            for root, _, files in os.walk(source_folder):
                for file in files:
                    file_path = os.path.join(root, file)
                    # Add file to zip with relative path
                    arcname = os.path.relpath(file_path, source_folder)
                    zipf.write(file_path, arcname)
