# spiderforce4ai/__init__.py

import asyncio
import aiohttp
import json
import logging
from typing import List, Dict, Union, Optional, Tuple
from dataclasses import dataclass, asdict
from urllib.parse import urljoin, urlparse
from pathlib import Path
import time
import xml.etree.ElementTree as ET
from concurrent.futures import ThreadPoolExecutor
from datetime import datetime
import re
from rich.progress import Progress, SpinnerColumn, TextColumn, BarColumn, TaskProgressColumn
from rich.console import Console
import aiofiles
import httpx
import requests
from multiprocessing import Pool

console = Console()

def slugify(url: str) -> str:
    """Convert URL to a valid filename."""
    parsed = urlparse(url)
    # Combine domain and path, remove scheme and special characters
    slug = f"{parsed.netloc}{parsed.path}"
    slug = re.sub(r'[^\w\-]', '_', slug)
    slug = re.sub(r'_+', '_', slug)  # Replace multiple underscores with single
    return slug.strip('_')

@dataclass
class CrawlResult:
    """Store results of a crawl operation."""
    url: str
    status: str  # 'success' or 'failed'
    markdown: Optional[str] = None
    error: Optional[str] = None
    timestamp: str = None
    config: Dict = None
    
    def __post_init__(self):
        if not self.timestamp:
            self.timestamp = datetime.now().isoformat()

@dataclass
class CrawlConfig:
    """Configuration for crawling settings."""
    target_selector: Optional[str] = None  # Optional - specific element to target
    remove_selectors: Optional[List[str]] = None  # Optional - elements to remove
    remove_selectors_regex: Optional[List[str]] = None  # Optional - regex patterns for removal
    max_concurrent_requests: int = 1  # Default to single thread
    request_delay: float = 0.5  # Delay between requests
    timeout: int = 30  # Request timeout
    output_dir: Path = Path("spiderforce_reports")  # Default to spiderforce_reports in current directory
    webhook_url: Optional[str] = None  # Optional webhook endpoint
    webhook_timeout: int = 10  # Webhook timeout
    webhook_headers: Optional[Dict[str, str]] = None  # Optional webhook headers
    webhook_payload_template: Optional[str] = None  # Optional custom webhook payload template
    save_reports: bool = False  # Whether to save crawl reports
    report_file: Optional[Path] = None  # Optional report file location (used only if save_reports is True)

    def __post_init__(self):
        # Initialize empty lists/dicts for None values
        self.remove_selectors = self.remove_selectors or []
        self.remove_selectors_regex = self.remove_selectors_regex or []
        self.webhook_headers = self.webhook_headers or {}
        
        # Ensure output_dir is a Path and exists
        self.output_dir = Path(self.output_dir)
        self.output_dir.mkdir(parents=True, exist_ok=True)
        
        # Only setup report file if save_reports is True
        if self.save_reports:
            if self.report_file is None:
                self.report_file = self.output_dir / "crawl_report.json"
            else:
                self.report_file = Path(self.report_file)

    def to_dict(self) -> Dict:
        """Convert config to dictionary for API requests."""
        payload = {}
        # Only include selectors if they are set
        if self.target_selector:
            payload["target_selector"] = self.target_selector
        if self.remove_selectors:
            payload["remove_selectors"] = self.remove_selectors
        if self.remove_selectors_regex:
            payload["remove_selectors_regex"] = self.remove_selectors_regex
        return payload


def _send_webhook_sync(result: CrawlResult, config: CrawlConfig) -> None:
    """Synchronous version of webhook sender for parallel processing."""
    if not config.webhook_url:
        return

    # Use custom payload template if provided, otherwise use default
    if config.webhook_payload_template:
        # Replace variables in the template
        payload_str = config.webhook_payload_template.format(
            url=result.url,
            status=result.status,
            markdown=result.markdown if result.status == "success" else None,
            error=result.error if result.status == "failed" else None,
            timestamp=result.timestamp,
            config=config.to_dict()
        )
        payload = json.loads(payload_str)  # Parse the formatted JSON string
    else:
        # Use default payload format
        payload = {
            "url": result.url,
            "status": result.status,
            "markdown": result.markdown if result.status == "success" else None,
            "error": result.error if result.status == "failed" else None,
            "timestamp": result.timestamp,
            "config": config.to_dict()
        }

    try:
        response = requests.post(
            config.webhook_url,
            json=payload,
            headers=config.webhook_headers,
            timeout=config.webhook_timeout
        )
        response.raise_for_status()
    except Exception as e:
        print(f"Warning: Failed to send webhook for {result.url}: {str(e)}")

# Module level function for multiprocessing
def _process_url_parallel(args: Tuple[str, str, CrawlConfig]) -> CrawlResult:
    """Process a single URL for parallel processing."""
    url, base_url, config = args
    try:
        endpoint = f"{base_url}/convert"
        payload = {
            "url": url,
            **config.to_dict()
        }
        
        response = requests.post(endpoint, json=payload, timeout=config.timeout)
        if response.status_code != 200:
            result = CrawlResult(
                url=url,
                status="failed",
                error=f"HTTP {response.status_code}: {response.text}",
                config=config.to_dict()
            )
            # Send webhook for failed result
            _send_webhook_sync(result, config)
            return result
        
        markdown = response.text
        
        # Save markdown if output directory is configured
        if config.output_dir:
            filepath = config.output_dir / f"{slugify(url)}.md"
            with open(filepath, 'w', encoding='utf-8') as f:
                f.write(markdown)
        
        result = CrawlResult(
            url=url,
            status="success",
            markdown=markdown,
            config=config.to_dict()
        )
        
        # Send webhook for successful result
        _send_webhook_sync(result, config)
        
        # Add delay if configured
        if config.request_delay:
            time.sleep(config.request_delay)
        
        return result
            
    except Exception as e:
        result = CrawlResult(
            url=url,
            status="failed",
            error=str(e),
            config=config.to_dict()
        )
        # Send webhook for error result
        _send_webhook_sync(result, config)
        return result

class SpiderForce4AI:
    """Main class for interacting with SpiderForce4AI service."""

    def __init__(self, base_url: str):
        self.base_url = base_url.rstrip('/')
        self.session = None
        self._executor = ThreadPoolExecutor()
        self.crawl_results: List[CrawlResult] = []

    async def _ensure_session(self):
        """Ensure aiohttp session exists."""
        if self.session is None or self.session.closed:
            self.session = aiohttp.ClientSession()

    async def _close_session(self):
        """Close aiohttp session."""
        if self.session and not self.session.closed:
            await self.session.close()

    async def _save_markdown(self, url: str, markdown: str, output_dir: Path):
        """Save markdown content to file."""
        filename = f"{slugify(url)}.md"
        filepath = output_dir / filename
        async with aiofiles.open(filepath, 'w', encoding='utf-8') as f:
            await f.write(markdown)
        return filepath



    def crawl_sitemap_server_parallel(self, sitemap_url: str, config: CrawlConfig) -> List[CrawlResult]:
        """
        Crawl sitemap URLs using server-side parallel processing.
        """
        print(f"Fetching sitemap from {sitemap_url}...")
        
        # Fetch sitemap
        try:
            response = requests.get(sitemap_url, timeout=config.timeout)
            response.raise_for_status()
            sitemap_text = response.text
        except Exception as e:
            print(f"Error fetching sitemap: {str(e)}")
            raise

        # Parse sitemap
        try:
            root = ET.fromstring(sitemap_text)
            namespace = {'ns': root.tag.split('}')[0].strip('{')}
            urls = [loc.text for loc in root.findall('.//ns:loc', namespace)]
            print(f"Found {len(urls)} URLs in sitemap")
        except Exception as e:
            print(f"Error parsing sitemap: {str(e)}")
            raise

        # Process URLs using server-side parallel endpoint
        return self.crawl_urls_server_parallel(urls, config)


    def crawl_urls_server_parallel(self, urls: List[str], config: CrawlConfig) -> List[CrawlResult]:
        """
        Crawl multiple URLs using server-side parallel processing.
        This uses the /convert_parallel endpoint which handles parallelization on the server.
        """
        print(f"Sending {len(urls)} URLs for parallel processing...")
        
        try:
            endpoint = f"{self.base_url}/convert_parallel"
            
            # Prepare payload
            payload = {
                "urls": urls,
                **config.to_dict()
            }
            
            # Send request
            response = requests.post(
                endpoint, 
                json=payload, 
                timeout=config.timeout
            )
            response.raise_for_status()
            
            # Process results
            results = []
            server_results = response.json()  # Assuming server returns JSON array of results
            
            for url_result in server_results:
                result = CrawlResult(
                    url=url_result["url"],
                    status=url_result.get("status", "failed"),
                    markdown=url_result.get("markdown"),
                    error=url_result.get("error"),
                    config=config.to_dict()
                )
                
                # Save markdown if successful and output dir is configured
                if result.status == "success" and config.output_dir and result.markdown:
                    filepath = config.output_dir / f"{slugify(result.url)}.md"
                    with open(filepath, 'w', encoding='utf-8') as f:
                        f.write(result.markdown)
                
                # Send webhook if configured
                if config.webhook_url:
                    _send_webhook_sync(result, config)
                    
                results.append(result)
            
            # Save report if enabled
            if config.save_reports:
                self._save_report_sync(results, config)
                print(f"\nReport saved to: {config.report_file}")
            
            # Print summary
            successful = len([r for r in results if r.status == "success"])
            failed = len([r for r in results if r.status == "failed"])
            print(f"\nParallel processing completed:")
            print(f"✓ Successful: {successful}")
            print(f"✗ Failed: {failed}")
            
            return results
                
        except Exception as e:
            print(f"Error during parallel processing: {str(e)}")
            # Create failed results for all URLs
            return [
                CrawlResult(
                    url=url,
                    status="failed",
                    error=str(e),
                    config=config.to_dict()
                ) for url in urls
            ]


    async def _send_webhook(self, result: CrawlResult, config: CrawlConfig):
        """Send webhook with crawl results."""
        if not config.webhook_url:
            return

        payload = {
            "url": result.url,
            "status": result.status,
            "markdown": result.markdown if result.status == "success" else None,
            "error": result.error if result.status == "failed" else None,
            "timestamp": result.timestamp,
            "config": config.to_dict()
        }

        try:
            async with httpx.AsyncClient() as client:
                response = await client.post(
                    config.webhook_url,
                    json=payload,
                    timeout=config.webhook_timeout
                )
                response.raise_for_status()
        except Exception as e:
            console.print(f"[yellow]Warning: Failed to send webhook for {result.url}: {str(e)}[/yellow]")

    def _save_report_sync(self, results: List[CrawlResult], config: CrawlConfig) -> None:
        """Save crawl report synchronously."""
        report = {
            "timestamp": datetime.now().isoformat(),
            "config": config.to_dict(),
            "results": {
                "successful": [asdict(r) for r in results if r.status == "success"],
                "failed": [asdict(r) for r in results if r.status == "failed"]
            },
            "summary": {
                "total": len(results),
                "successful": len([r for r in results if r.status == "success"]),
                "failed": len([r for r in results if r.status == "failed"])
            }
        }

        with open(config.report_file, 'w', encoding='utf-8') as f:
            json.dump(report, f, indent=2)

    async def _save_report(self, config: CrawlConfig):
        """Save crawl report to JSON file."""
        if not config.report_file:
            return

        report = {
            "timestamp": datetime.now().isoformat(),
            "config": config.to_dict(),
            "results": {
                "successful": [asdict(r) for r in self.crawl_results if r.status == "success"],
                "failed": [asdict(r) for r in self.crawl_results if r.status == "failed"]
            },
            "summary": {
                "total": len(self.crawl_results),
                "successful": len([r for r in self.crawl_results if r.status == "success"]),
                "failed": len([r for r in self.crawl_results if r.status == "failed"])
            }
        }

        async with aiofiles.open(config.report_file, 'w', encoding='utf-8') as f:
            await f.write(json.dumps(report, indent=2))

    async def crawl_url_async(self, url: str, config: CrawlConfig) -> CrawlResult:
        """Crawl a single URL asynchronously."""
        await self._ensure_session()
        
        try:
            endpoint = f"{self.base_url}/convert"
            payload = {
                "url": url,
                **config.to_dict()
            }
            
            async with self.session.post(endpoint, json=payload, timeout=config.timeout) as response:
                if response.status != 200:
                    error_text = await response.text()
                    result = CrawlResult(
                        url=url,
                        status="failed",
                        error=f"HTTP {response.status}: {error_text}",
                        config=config.to_dict()
                    )
                else:
                    markdown = await response.text()
                    result = CrawlResult(
                        url=url,
                        status="success",
                        markdown=markdown,
                        config=config.to_dict()
                    )

                    if config.output_dir:
                        await self._save_markdown(url, markdown, config.output_dir)
                    
                    await self._send_webhook(result, config)
                
                self.crawl_results.append(result)
                return result
                
        except Exception as e:
            result = CrawlResult(
                url=url,
                status="failed",
                error=str(e),
                config=config.to_dict()
            )
            self.crawl_results.append(result)
            return result

    def crawl_url(self, url: str, config: CrawlConfig) -> CrawlResult:
        """Synchronous version of crawl_url_async."""
        return asyncio.run(self.crawl_url_async(url, config))

    async def _retry_failed_urls(self, failed_results: List[CrawlResult], config: CrawlConfig, progress=None) -> List[CrawlResult]:
        """Retry failed URLs once."""
        if not failed_results:
            return []

        failed_count = len(failed_results)
        total_count = len([r for r in self.crawl_results])
        failure_ratio = (failed_count / total_count) * 100
        
        console.print(f"\n[yellow]Retrying failed URLs: {failed_count} ({failure_ratio:.1f}% failed)[/yellow]")
        retry_results = []
        
        # Create a new progress bar if one wasn't provided
        should_close_progress = progress is None
        if progress is None:
            progress = Progress(
                SpinnerColumn(),
                TextColumn("[progress.description]{task.description}"),
                BarColumn(),
                TaskProgressColumn(),
                console=console
            )
            progress.start()

        retry_task = progress.add_task("[yellow]Retrying failed URLs...", total=len(failed_results))

        for result in failed_results:
            progress.update(retry_task, description=f"[yellow]Retrying: {result.url}")
            
            try:
                new_result = await self.crawl_url_async(result.url, config)
                if new_result.status == "success":
                    console.print(f"[green]✓ Retry successful: {result.url}[/green]")
                else:
                    console.print(f"[red]✗ Retry failed: {result.url} - {new_result.error}[/red]")
                retry_results.append(new_result)
            except Exception as e:
                console.print(f"[red]✗ Retry error: {result.url} - {str(e)}[/red]")
                retry_results.append(CrawlResult(
                    url=result.url,
                    status="failed",
                    error=f"Retry error: {str(e)}",
                    config=config.to_dict()
                ))
            
            progress.update(retry_task, advance=1)
            await asyncio.sleep(config.request_delay)

        if should_close_progress:
            progress.stop()

        return retry_results

    async def crawl_urls_async(self, urls: List[str], config: CrawlConfig) -> List[CrawlResult]:
        """Crawl multiple URLs asynchronously with progress bar."""
        await self._ensure_session()
        
        with Progress(
            SpinnerColumn(),
            TextColumn("[progress.description]{task.description}"),
            BarColumn(),
            TaskProgressColumn(),
            console=console
        ) as progress:
            task = progress.add_task("[cyan]Crawling URLs...", total=len(urls))
            
            async def crawl_with_progress(url):
                result = await self.crawl_url_async(url, config)
                progress.update(task, advance=1, description=f"[cyan]Crawled: {url}")
                return result

            semaphore = asyncio.Semaphore(config.max_concurrent_requests)
            async def crawl_with_semaphore(url):
                async with semaphore:
                    result = await crawl_with_progress(url)
                    await asyncio.sleep(config.request_delay)
                    return result

            initial_results = await asyncio.gather(*[crawl_with_semaphore(url) for url in urls])
            
            # Identify failed URLs
            failed_results = [r for r in initial_results if r.status == "failed"]
            
            # Calculate initial failure ratio
            initial_failed = len(failed_results)
            total_urls = len(urls)
            failure_ratio = (initial_failed / total_urls) * 100

            # Retry failed URLs if ratio is acceptable
            if failed_results:
                if failure_ratio > 20:
                    console.print(f"\n[red]Failure ratio too high ({failure_ratio:.1f}%) - aborting retry due to possible server overload[/red]")
                    results = initial_results
                else:
                    retry_results = await self._retry_failed_urls(failed_results, config, progress)
                    # Update results list by replacing failed results with successful retries
                    results = initial_results.copy()
                    for retry_result in retry_results:
                        for i, result in enumerate(results):
                            if result.url == retry_result.url:
                                results[i] = retry_result
                                break
            else:
                results = initial_results
            
            # Save final report
            await self._save_report(config)
            
            # Calculate final statistics
            final_successful = len([r for r in results if r.status == "success"])
            final_failed = len([r for r in results if r.status == "failed"])
            
            # Print detailed summary
            console.print(f"\n[green]Crawling Summary:[/green]")
            console.print(f"Total URLs processed: {total_urls}")
            console.print(f"Initial failures: {initial_failed} ({failure_ratio:.1f}%)")
            console.print(f"Final results:")
            console.print(f"  ✓ Successful: {final_successful}")
            console.print(f"  ✗ Failed: {final_failed}")
            
            if initial_failed > 0:
                retry_successful = initial_failed - final_failed
                console.print(f"Retry success rate: {retry_successful}/{initial_failed} ({(retry_successful/initial_failed)*100:.1f}%)")
            
            if config.report_file:
                console.print(f"📊 Report saved to: {config.report_file}")
            
            return results

    def crawl_urls(self, urls: List[str], config: CrawlConfig) -> List[CrawlResult]:
        """Synchronous version of crawl_urls_async."""
        return asyncio.run(self.crawl_urls_async(urls, config))

    async def crawl_sitemap_async(self, sitemap_url: str, config: CrawlConfig) -> List[CrawlResult]:
        """Crawl URLs from a sitemap asynchronously."""
        await self._ensure_session()
        
        try:
            console.print(f"[cyan]Fetching sitemap from {sitemap_url}...[/cyan]")
            async with self.session.get(sitemap_url, timeout=config.timeout) as response:
                sitemap_text = await response.text()
        except Exception as e:
            console.print(f"[red]Error fetching sitemap: {str(e)}[/red]")
            raise

        try:
            root = ET.fromstring(sitemap_text)
            namespace = {'ns': root.tag.split('}')[0].strip('{')}
            urls = [loc.text for loc in root.findall('.//ns:loc', namespace)]
            console.print(f"[green]Found {len(urls)} URLs in sitemap[/green]")
        except Exception as e:
            console.print(f"[red]Error parsing sitemap: {str(e)}[/red]")
            raise

        return await self.crawl_urls_async(urls, config)

    def crawl_sitemap(self, sitemap_url: str, config: CrawlConfig) -> List[CrawlResult]:
        """Synchronous version of crawl_sitemap_async."""
        return asyncio.run(self.crawl_sitemap_async(sitemap_url, config))

    def crawl_sitemap_parallel(self, sitemap_url: str, config: CrawlConfig) -> List[CrawlResult]:
        """Crawl sitemap URLs in parallel using multiprocessing (no asyncio required)."""
        print(f"Fetching sitemap from {sitemap_url}...")
        
        # Fetch sitemap
        try:
            response = requests.get(sitemap_url, timeout=config.timeout)
            response.raise_for_status()
            sitemap_text = response.text
        except Exception as e:
            print(f"Error fetching sitemap: {str(e)}")
            raise

        # Parse sitemap
        try:
            root = ET.fromstring(sitemap_text)
            namespace = {'ns': root.tag.split('}')[0].strip('{')}
            urls = [loc.text for loc in root.findall('.//ns:loc', namespace)]
            print(f"Found {len(urls)} URLs in sitemap")
        except Exception as e:
            print(f"Error parsing sitemap: {str(e)}")
            raise

        # Prepare arguments for parallel processing
        process_args = [(url, self.base_url, config) for url in urls]

        # Create process pool and execute crawls
        results = []

        with Pool(processes=config.max_concurrent_requests) as pool:
            with Progress(
                SpinnerColumn(),
                TextColumn("[progress.description]{task.description}"),
                BarColumn(),
                TextColumn("[progress.percentage]{task.percentage:>3.0f}%"),
                TextColumn("({task.completed}/{task.total})"),
            ) as progress:
                task = progress.add_task("Crawling URLs...", total=len(urls))
                
                for result in pool.imap_unordered(_process_url_parallel, process_args):
                    results.append(result)
                    progress.update(task, advance=1)
                    status = "✓" if result.status == "success" else "✗"
                    progress.description = f"Last: {status} {result.url}"

        # Save final report
        if config.report_file:
            self._save_report_sync(results, config)
            print(f"\nReport saved to: {config.report_file}")

        # Calculate initial failure statistics
        failed_results = [r for r in results if r.status == "failed"]
        initial_failed = len(failed_results)
        total_urls = len(urls)
        failure_ratio = (initial_failed / total_urls) * 100

        # Retry failed URLs if ratio is acceptable
        if failed_results:
            if failure_ratio > 20:
                console.print(f"\n[red]Failure ratio too high ({failure_ratio:.1f}%) - aborting retry due to possible server overload[/red]")
            else:
                failed_count = len(failed_results)
                failure_ratio = (failed_count / total_urls) * 100
                console.print(f"\n[yellow]Retrying failed URLs: {failed_count} ({failure_ratio:.1f}% failed)[/yellow]")
                for result in failed_results:
                    new_result = _process_url_parallel((result.url, self.base_url, config))
                    
                    # Save markdown and trigger webhook for successful retries
                    if new_result.status == "success":
                        console.print(f"[green]✓ Retry successful: {result.url}[/green]")
                        # Save markdown if output directory is configured
                        if config.output_dir and new_result.markdown:
                            filepath = config.output_dir / f"{slugify(new_result.url)}.md"
                            with open(filepath, 'w', encoding='utf-8') as f:
                                f.write(new_result.markdown)
                        # Send webhook for successful retry
                        _send_webhook_sync(new_result, config)
                    else:
                        console.print(f"[red]✗ Retry failed: {result.url} - {new_result.error}[/red]")
                        # Send webhook for failed retry
                        _send_webhook_sync(new_result, config)
                    
                    # Update results list
                    for i, r in enumerate(results):
                        if r.url == new_result.url:
                            results[i] = new_result
                            break

        # Calculate final statistics
        final_successful = len([r for r in results if r.status == "success"])
        final_failed = len([r for r in results if r.status == "failed"])

        # Print detailed summary
        console.print(f"\n[green]Crawling Summary:[/green]")
        console.print(f"Total URLs processed: {total_urls}")
        console.print(f"Initial failures: {initial_failed} ({failure_ratio:.1f}%)")
        console.print(f"Final results:")
        console.print(f"  ✓ Successful: {final_successful}")
        console.print(f"  ✗ Failed: {final_failed}")

        if initial_failed > 0:
            retry_successful = initial_failed - final_failed
            console.print(f"Retry success rate: {retry_successful}/{initial_failed} ({(retry_successful/initial_failed)*100:.1f}%)")

        return results

    async def __aenter__(self):
        """Async context manager entry."""
        await self._ensure_session()
        return self

    async def __aexit__(self, exc_type, exc_val, exc_tb):
        """Async context manager exit."""
        await self._close_session()

    def __enter__(self):
        """Sync context manager entry."""
        return self

    def __exit__(self, exc_type, exc_val, exc_tb):
        """Sync context manager exit."""
        self._executor.shutdown(wait=True)

