Coverage for src/cc_liquid/backtester.py: 53%
383 statements
« prev ^ index » next coverage.py v7.10.3, created at 2025-10-13 20:16 +0000
« prev ^ index » next coverage.py v7.10.3, created at 2025-10-13 20:16 +0000
1"""Backtesting engine for cc-liquid.
3This module provides pure backtesting logic without any UI dependencies.
4All display/visualization is handled by the CLI layer.
6⚠️ IMPORTANT DISCLAIMER:
7Backtesting has inherent limitations. Past performance does not predict future results.
8Results are hypothetical and do not account for all real-world factors including:
9- Market impact and slippage beyond modeled estimates
10- Changing market conditions and liquidity
11- Technical failures and execution delays
12- Regulatory changes and exchange rules
13Always validate strategies with out-of-sample data and paper trading before live deployment.
14"""
16from __future__ import annotations
18import math
19from dataclasses import dataclass, field
20from datetime import datetime, timedelta
21from typing import Any, Literal
23import polars as pl
25from .portfolio import weights_from_ranks
28@dataclass
29class BacktestConfig:
30 """Configuration for backtesting."""
32 # Data paths
33 prices_path: str = "raw_data.parquet"
34 predictions_path: str = "predictions.parquet"
36 # Price data columns (defaults match raw_data.parquet)
37 price_date_column: str = "date"
38 price_id_column: str = "id"
39 price_close_column: str = "close"
41 # Prediction columns (will be taken from DataSourceConfig in CLI)
42 pred_date_column: str = "release_date"
43 pred_id_column: str = "id"
44 pred_value_column: str = "pred_10d"
46 # Data provider (e.g., crowdcent, numerai, local)
47 data_provider: str | None = None
49 # Date range (None = use all available overlapping data)
50 start_date: datetime | None = None
51 end_date: datetime | None = None
53 # Strategy parameters (match PortfolioConfig)
54 num_long: int = 60
55 num_short: int = 50
56 target_leverage: float = 3.0 # Sum of abs(weights), matching trader.py
57 # Weighting (opt-in non-EW)
58 weighting_scheme: Literal["equal", "rank_power"] = "equal"
59 rank_power: float = 0.0
61 # Rebalancing
62 rebalance_every_n_days: int = 10
63 prediction_lag_days: int = 1 # Use T-lag signals to trade at T
65 # Costs (in basis points)
66 fee_bps: float = 4.0 # Trading fee
67 slippage_bps: float = 50.0 # Slippage cost
69 # Initial capital
70 start_capital: float = 100_000.0
72 # Options
73 verbose: bool = False
76@dataclass
77class BacktestResult:
78 """Results from a backtest run."""
80 # Daily time series
81 daily: pl.DataFrame # columns: date, returns, equity, drawdown, turnover
83 # Position snapshots at rebalance dates
84 rebalance_positions: pl.DataFrame # columns: date, id, weight
86 # Summary statistics
87 stats: dict[str, float] = field(default_factory=dict)
89 # Config used
90 config: BacktestConfig | None = None
93class Backtester:
94 """Core backtesting engine."""
96 def __init__(self, config: BacktestConfig):
97 self.config = config
99 def run(self) -> BacktestResult:
100 """Run the backtest and return results."""
102 # 1. Load and prepare data
103 prices_long = self._load_prices()
104 predictions_long = self._load_predictions()
106 # 2. Compute returns matrix (use ALL available prices first)
107 returns_wide_all = self._compute_returns_wide(prices_long)
109 # 3. Determine valid trading dates from returns + lagged predictions
110 valid_dates = self._get_valid_trading_dates_from_returns(
111 returns_wide_all, predictions_long
112 )
113 if len(valid_dates) == 0:
114 raise ValueError(
115 "No valid trading dates given returns and prediction coverage"
116 )
118 # 4. Filter returns to valid dates only (predictions are filtered by cutoff at selection time)
119 returns_wide = returns_wide_all.filter(pl.col("date").is_in(valid_dates))
121 # Scope predictions to up-to-last cutoff for efficiency (do NOT drop early dates needed for lag)
122 last_cutoff = max(valid_dates) - timedelta(days=self.config.prediction_lag_days)
123 predictions_long = predictions_long.filter(pl.col("pred_date") <= last_cutoff)
125 # 5. Determine rebalance schedule
126 rebalance_dates = self._compute_rebalance_dates(valid_dates)
128 # 6. Run simulation
129 result = self._simulate(
130 returns_wide=returns_wide,
131 predictions_long=predictions_long,
132 rebalance_dates=rebalance_dates,
133 )
135 # 7. Compute statistics
136 stats = self._compute_stats(result["daily"])
138 return BacktestResult(
139 daily=result["daily"],
140 rebalance_positions=result["positions"],
141 stats=stats,
142 config=self.config,
143 )
145 def _get_valid_trading_dates_from_returns(
146 self, returns_wide_all: pl.DataFrame, predictions_long: pl.DataFrame
147 ) -> list[datetime]:
148 """Compute valid trading dates based on returns dates and lagged prediction availability.
150 Logic:
151 - Use ALL dates present in the returns matrix as candidate trading dates
152 - A date D is tradable if there exists any prediction with pred_date <= D - lag
153 - Apply optional user-specified start/end bounds at the end
154 """
155 if "date" not in returns_wide_all.columns or len(returns_wide_all) == 0:
156 return []
157 all_trade_dates = sorted(returns_wide_all["date"].to_list())
158 if not all_trade_dates:
159 return []
160 pred_dates = set(predictions_long["pred_date"].unique().to_list())
161 lag_td = timedelta(days=self.config.prediction_lag_days)
162 valid_dates: list[datetime] = []
163 for d in all_trade_dates:
164 cutoff = d - lag_td
165 # We need at least one prediction available on or before cutoff
166 if any(pd <= cutoff for pd in pred_dates):
167 valid_dates.append(d)
168 # Apply explicit user bounds last
169 if self.config.start_date:
170 valid_dates = [d for d in valid_dates if d >= self.config.start_date]
171 if self.config.end_date:
172 valid_dates = [d for d in valid_dates if d <= self.config.end_date]
173 return valid_dates
175 def _load_prices(self) -> pl.DataFrame:
176 """Load price data in long format."""
177 df = pl.read_parquet(self.config.prices_path)
179 # Select and rename columns
180 df = df.select(
181 [
182 pl.col(self.config.price_date_column).alias("date"),
183 pl.col(self.config.price_id_column).alias("id"),
184 pl.col(self.config.price_close_column).alias("close"),
185 ]
186 )
188 # Ensure date is datetime
189 if df["date"].dtype != pl.Datetime:
190 df = df.with_columns(pl.col("date").cast(pl.Date).cast(pl.Datetime))
192 # Drop nulls and sort
193 df = df.drop_nulls().sort(["date", "id"])
195 if self.config.verbose:
196 print(f"Loaded {len(df):,} price records for {df['id'].n_unique()} assets")
197 print(f"Price date range: {df['date'].min()} to {df['date'].max()}")
199 return df
201 def _load_predictions(self) -> pl.DataFrame:
202 """Load prediction data in long format."""
203 df = pl.read_parquet(self.config.predictions_path)
205 # Select and rename columns
206 df = df.select(
207 [
208 pl.col(self.config.pred_date_column).alias("pred_date"),
209 pl.col(self.config.pred_id_column).alias("id"),
210 pl.col(self.config.pred_value_column).alias("pred"),
211 ]
212 )
214 # Ensure date is datetime
215 if df["pred_date"].dtype != pl.Datetime:
216 df = df.with_columns(pl.col("pred_date").cast(pl.Date).cast(pl.Datetime))
218 # Drop nulls and sort
219 df = df.drop_nulls().sort(["pred_date", "id"])
221 if self.config.verbose:
222 print(
223 f"Loaded {len(df):,} prediction records for {df['id'].n_unique()} assets"
224 )
225 print(
226 f"Prediction date range: {df['pred_date'].min()} to {df['pred_date'].max()}"
227 )
229 return df
231 def _get_overlapping_dates(
232 self, prices: pl.DataFrame, predictions: pl.DataFrame
233 ) -> list[datetime]:
234 """Find dates where we have both prices and valid predictions (considering lag)."""
236 # Get unique dates from each dataset
237 price_dates = set(prices["date"].unique().to_list())
238 pred_dates = set(predictions["pred_date"].unique().to_list())
240 # For each price date, check if we have predictions from T-lag days before
241 valid_dates = []
242 lag_td = timedelta(days=self.config.prediction_lag_days)
244 for price_date in sorted(price_dates):
245 # We need predictions from this date or earlier (up to lag days before)
246 required_pred_date = price_date - lag_td
248 # Check if we have predictions on or before the required date
249 has_valid_pred = any(pd <= required_pred_date for pd in pred_dates)
251 if has_valid_pred:
252 valid_dates.append(price_date)
254 # Apply user-specified date bounds if any
255 if self.config.start_date:
256 valid_dates = [d for d in valid_dates if d >= self.config.start_date]
257 if self.config.end_date:
258 valid_dates = [d for d in valid_dates if d <= self.config.end_date]
260 if self.config.verbose:
261 print(f"Found {len(valid_dates)} valid trading dates with overlapping data")
262 if valid_dates:
263 print(f"Trading date range: {min(valid_dates)} to {max(valid_dates)}")
265 return valid_dates
267 def _compute_returns_wide(self, prices_long: pl.DataFrame) -> pl.DataFrame:
268 """Compute returns matrix in wide format (dates as rows, assets as columns)."""
270 # Calculate returns for each asset
271 prices_long = prices_long.sort(["id", "date"])
272 prices_long = prices_long.with_columns(
273 pl.col("close").pct_change().over("id").alias("return")
274 )
276 # Pivot to wide format
277 returns_wide = prices_long.pivot(index="date", on="id", values="return").sort(
278 "date"
279 )
281 if self.config.verbose:
282 n_assets = len(returns_wide.columns) - 1 # Exclude date column
283 print(f"Computed returns for {n_assets} assets")
285 return returns_wide
287 def _compute_rebalance_dates(self, valid_dates: list[datetime]) -> list[datetime]:
288 """Determine rebalance dates based on schedule."""
289 if not valid_dates:
290 return []
292 rebalance_dates = []
293 current_date = min(valid_dates)
295 while current_date <= max(valid_dates):
296 if current_date in valid_dates:
297 rebalance_dates.append(current_date)
298 current_date += timedelta(days=self.config.rebalance_every_n_days)
300 if self.config.verbose:
301 print(f"Scheduled {len(rebalance_dates)} rebalance dates")
303 return rebalance_dates
305 def _select_assets(
306 self,
307 predictions: pl.DataFrame,
308 cutoff_date: datetime,
309 available_assets: set[str],
310 ) -> tuple[list[str], list[str], pl.DataFrame]:
311 """Select assets and return latest predictions DataFrame for sizing."""
313 latest_preds = (
314 predictions.filter(pl.col("pred_date") <= cutoff_date)
315 .filter(pl.col("id").is_in(available_assets))
316 .sort("pred_date", descending=True)
317 .group_by("id")
318 .first()
319 )
321 if len(latest_preds) == 0:
322 empty = pl.DataFrame({"id": [], "pred": []})
323 return [], [], empty
325 latest_sorted = latest_preds.sort("pred", descending=True)
326 all_ids = latest_sorted["id"].to_list()
328 num_long = min(self.config.num_long, len(all_ids))
329 num_short = min(self.config.num_short, len(all_ids) - num_long)
331 long_assets = all_ids[:num_long]
332 short_assets = all_ids[-num_short:] if num_short > 0 else []
334 return long_assets, short_assets, latest_sorted.select(["id", "pred"])
336 def _simulate(
337 self,
338 returns_wide: pl.DataFrame,
339 predictions_long: pl.DataFrame,
340 rebalance_dates: list[datetime],
341 ) -> dict:
342 """Run the backtest simulation."""
344 # Initialize tracking variables
345 equity = self.config.start_capital
346 peak_equity = equity
348 daily_results = []
349 position_snapshots = []
350 current_weights = {} # Asset -> weight
352 # Convert rebalance dates to set for fast lookup
353 rebalance_set = set(rebalance_dates)
355 # Get all dates from returns
356 all_dates = returns_wide["date"].to_list()
358 for i, date in enumerate(all_dates):
359 # Get today's returns
360 returns_row = returns_wide.filter(pl.col("date") == date)
362 # Check if we need to rebalance
363 if date in rebalance_set:
364 # Determine cutoff date for predictions (T - lag)
365 cutoff_date = date - timedelta(days=self.config.prediction_lag_days)
367 # Get available assets (those with returns today)
368 available_assets = set()
369 for col in returns_row.columns:
370 if col != "date":
371 val = returns_row[col][0]
372 if val is not None and not math.isnan(val):
373 available_assets.add(col)
375 # Determine selections and convert ranks to weights
376 long_assets, short_assets, latest_preds = self._select_assets(
377 predictions_long,
378 cutoff_date,
379 available_assets,
380 )
382 new_weights: dict[str, float] = {}
383 total_positions = len(long_assets) + len(short_assets)
385 if total_positions > 0 and len(latest_preds) > 0:
386 weights = weights_from_ranks(
387 latest_preds=latest_preds,
388 id_col="id",
389 pred_col="pred",
390 long_assets=long_assets,
391 short_assets=short_assets,
392 target_gross=self.config.target_leverage,
393 scheme=self.config.weighting_scheme,
394 power=self.config.rank_power,
395 )
396 new_weights = weights
398 # Calculate turnover (L1 norm of weight changes)
399 turnover = 0.0
400 all_assets = set(current_weights.keys()) | set(new_weights.keys())
401 for asset in all_assets:
402 old_w = current_weights.get(asset, 0.0)
403 new_w = new_weights.get(asset, 0.0)
404 turnover += abs(new_w - old_w)
406 # Apply trading costs
407 total_cost_bps = self.config.fee_bps + self.config.slippage_bps
408 cost = turnover * (total_cost_bps / 10_000)
410 # Update weights AFTER calculating returns (weights take effect next period)
411 # But deduct costs immediately
412 equity *= 1 - cost
414 # Store position snapshot
415 for asset, weight in new_weights.items():
416 position_snapshots.append(
417 {"date": date, "id": asset, "weight": weight}
418 )
420 # Update current weights for next period
421 current_weights = new_weights.copy()
422 else:
423 turnover = 0.0
425 # Calculate portfolio return (using current weights from previous rebalance)
426 portfolio_return = 0.0
428 if current_weights:
429 for asset, weight in current_weights.items():
430 if asset in returns_row.columns:
431 asset_return = returns_row[asset][0]
432 if asset_return is not None and not math.isnan(asset_return):
433 portfolio_return += weight * asset_return
435 # Update equity
436 equity *= 1 + portfolio_return
438 # Track peak and drawdown
439 if equity > peak_equity:
440 peak_equity = equity
441 drawdown = (equity - peak_equity) / peak_equity if peak_equity > 0 else 0
443 # Store daily results
444 daily_results.append(
445 {
446 "date": date,
447 "returns": portfolio_return,
448 "equity": equity,
449 "drawdown": drawdown,
450 "turnover": turnover,
451 }
452 )
454 # Convert to DataFrames
455 daily_df = pl.DataFrame(daily_results)
456 positions_df = (
457 pl.DataFrame(position_snapshots) if position_snapshots else pl.DataFrame()
458 )
460 return {"daily": daily_df, "positions": positions_df}
462 def _compute_stats(self, daily: pl.DataFrame) -> dict[str, float]:
463 """Compute summary statistics from daily results."""
465 if len(daily) == 0:
466 return {}
468 # Basic info
469 n_days = len(daily)
470 start_equity = self.config.start_capital
471 final_equity = daily["equity"][-1]
473 # Returns
474 total_return = (final_equity / start_equity) - 1
476 # Annualized metrics (assuming 365 days per year for crypto)
477 years = n_days / 365.0
478 # Handle negative equity (can't take fractional power of negative number)
479 if years > 0 and final_equity > 0:
480 cagr = (final_equity / start_equity) ** (1.0 / years) - 1
481 else:
482 cagr = total_return # Fallback to simple return if equity went negative
484 # Risk metrics
485 returns = daily["returns"]
486 daily_vol = returns.std()
487 annual_vol = float(daily_vol * math.sqrt(365)) if daily_vol is not None else 0.0
489 # Sharpe ratio (assuming 0 risk-free rate)
490 sharpe = cagr / annual_vol if annual_vol > 0 else 0.0
492 # Drawdown
493 max_drawdown = daily["drawdown"].min() # Most negative value
495 # Calmar ratio
496 calmar = (
497 cagr / abs(max_drawdown)
498 if max_drawdown is not None and max_drawdown < 0
499 else 0
500 )
502 # Win rate
503 positive_days = (returns > 0).sum()
504 win_rate = positive_days / n_days if n_days > 0 else 0
506 # Turnover
507 avg_turnover = daily.filter(pl.col("turnover") > 0)["turnover"].mean()
508 if avg_turnover is None:
509 avg_turnover = 0
511 # Sortino ratio (downside deviation)
512 negative_returns = returns.filter(returns < 0)
513 if len(negative_returns) > 0:
514 downside_vol = negative_returns.std()
515 annual_downside_vol = (
516 downside_vol * math.sqrt(365) if downside_vol is not None else 0
517 )
518 sortino = cagr / annual_downside_vol if annual_downside_vol > 0 else 0
519 else:
520 sortino = float("inf") if cagr > 0 else 0
522 return {
523 "days": n_days,
524 "total_return": total_return,
525 "cagr": cagr,
526 "annual_volatility": annual_vol,
527 "sharpe_ratio": sharpe,
528 "sortino_ratio": sortino,
529 "max_drawdown": max_drawdown,
530 "calmar_ratio": calmar,
531 "win_rate": win_rate,
532 "avg_turnover": avg_turnover,
533 "final_equity": final_equity,
534 }
537class BacktestOptimizer:
538 """Grid search optimizer for backtesting parameters with parallel execution and caching."""
540 def __init__(self, base_config: BacktestConfig):
541 self.base_config = base_config
542 self._prices_cache = None
543 self._predictions_cache = None
544 self._cache_file = ".cc_liquid_optimizer_cache.json"
546 def _get_cache_key(self, params: dict) -> str:
547 """Generate a unique cache key for a parameter combination."""
548 import hashlib
549 import json
551 # Include base config settings that affect results
552 cache_data = {
553 "params": params,
554 "config": {
555 "prices_path": self.base_config.prices_path,
556 "predictions_path": self.base_config.predictions_path,
557 "data_provider": self.base_config.data_provider,
558 "start_date": str(self.base_config.start_date)
559 if self.base_config.start_date
560 else None,
561 "end_date": str(self.base_config.end_date)
562 if self.base_config.end_date
563 else None,
564 "prediction_lag_days": self.base_config.prediction_lag_days,
565 "fee_bps": self.base_config.fee_bps,
566 "slippage_bps": self.base_config.slippage_bps,
567 "start_capital": self.base_config.start_capital,
568 "weighting_scheme": self.base_config.weighting_scheme,
569 "rank_power": self.base_config.rank_power,
570 },
571 }
573 cache_str = json.dumps(cache_data, sort_keys=True)
574 return hashlib.md5(cache_str.encode()).hexdigest()
576 def _load_cache(self) -> dict:
577 """Load cached results from disk."""
578 import json
579 import os
581 if not os.path.exists(self._cache_file):
582 return {}
584 try:
585 with open(self._cache_file, "r") as f:
586 return json.load(f)
587 except Exception:
588 return {}
590 def _save_cache(self, cache: dict) -> None:
591 """Save cache to disk."""
592 import json
594 try:
595 with open(self._cache_file, "w") as f:
596 json.dump(cache, f)
597 except Exception:
598 pass # Silently fail if can't write cache
600 def _run_single_backtest(self, params: dict) -> dict | None:
601 """Run a single backtest with given parameters. No cache IO here."""
602 # Create config for this combination
603 config = BacktestConfig(
604 prices_path=self.base_config.prices_path,
605 predictions_path=self.base_config.predictions_path,
606 price_date_column=self.base_config.price_date_column,
607 price_id_column=self.base_config.price_id_column,
608 price_close_column=self.base_config.price_close_column,
609 pred_date_column=self.base_config.pred_date_column,
610 pred_id_column=self.base_config.pred_id_column,
611 pred_value_column=self.base_config.pred_value_column,
612 data_provider=self.base_config.data_provider,
613 start_date=self.base_config.start_date,
614 end_date=self.base_config.end_date,
615 num_long=params["num_long"],
616 num_short=params["num_short"],
617 target_leverage=params["leverage"],
618 rebalance_every_n_days=params["rebalance_days"],
619 prediction_lag_days=self.base_config.prediction_lag_days,
620 fee_bps=self.base_config.fee_bps,
621 slippage_bps=self.base_config.slippage_bps,
622 start_capital=self.base_config.start_capital,
623 verbose=False,
624 weighting_scheme="rank_power", # Always use rank_power (power=0 is equal weight)
625 rank_power=params["rank_power"],
626 )
628 try:
629 # Run backtest
630 backtester = Backtester(config)
631 result = backtester.run()
633 # Store results
634 result_data = {
635 "num_long": params["num_long"],
636 "num_short": params["num_short"],
637 "leverage": params["leverage"],
638 "rebalance_days": params["rebalance_days"],
639 "rank_power": params["rank_power"],
640 "sharpe": result.stats["sharpe_ratio"],
641 "cagr": result.stats["cagr"],
642 "calmar": result.stats["calmar_ratio"],
643 "sortino": result.stats["sortino_ratio"],
644 "max_dd": result.stats["max_drawdown"],
645 "volatility": result.stats["annual_volatility"],
646 "win_rate": result.stats["win_rate"],
647 "final_equity": result.stats["final_equity"],
648 }
650 return result_data
652 except Exception:
653 return None
655 def grid_search_parallel(
656 self,
657 num_longs: list[int] | None = None,
658 num_shorts: list[int] | None = None,
659 leverages: list[float] | None = None,
660 rebalance_days: list[int] | None = None,
661 rank_powers: list[float] | None = None,
662 metric: Literal["sharpe", "cagr", "calmar"] = "sharpe",
663 max_drawdown_limit: float | None = None,
664 max_workers: int | None = None,
665 progress_callback: Any | None = None,
666 ) -> pl.DataFrame:
667 """Run grid search over parameter combinations in parallel.
669 Args:
670 num_longs: List of long position counts to test
671 num_shorts: List of short position counts to test
672 leverages: List of leverage values to test
673 rebalance_days: List of rebalance frequencies to test
674 rank_powers: List of rank power values to test (0=equal weight)
675 metric: Optimization metric
676 max_drawdown_limit: Maximum drawdown constraint
677 max_workers: Number of parallel workers (None = auto)
678 progress_callback: Rich Progress instance for updates
680 Returns:
681 DataFrame with all results sorted by metric.
682 """
683 from concurrent.futures import ProcessPoolExecutor, as_completed
684 import multiprocessing as mp
685 import time
687 # Default to single values from base config if not specified
688 if num_longs is None:
689 num_longs = [self.base_config.num_long]
690 if num_shorts is None:
691 num_shorts = [self.base_config.num_short]
692 if leverages is None:
693 leverages = [self.base_config.target_leverage]
694 if rebalance_days is None:
695 rebalance_days = [self.base_config.rebalance_every_n_days]
696 if rank_powers is None:
697 rank_powers = [self.base_config.rank_power]
699 # Generate all parameter combinations
700 param_combinations = []
701 for n_long in num_longs:
702 for n_short in num_shorts:
703 for leverage in leverages:
704 for rebal_days in rebalance_days:
705 for rank_pow in rank_powers:
706 param_combinations.append(
707 {
708 "num_long": n_long,
709 "num_short": n_short,
710 "leverage": leverage,
711 "rebalance_days": rebal_days,
712 "rank_power": rank_pow,
713 }
714 )
716 # Check cache to see which we already have
717 cache = self._load_cache()
718 cached_count = sum(
719 1 for p in param_combinations if self._get_cache_key(p) in cache
720 )
722 # Get cache metadata
723 cache_info = ""
724 if cached_count > 0:
725 import os
726 import time as time_module
728 if os.path.exists(self._cache_file):
729 cache_size = os.path.getsize(self._cache_file) / 1024 # KB
730 cache_age = time_module.time() - os.path.getmtime(self._cache_file)
731 if cache_age < 3600:
732 age_str = f"{int(cache_age / 60)} min"
733 elif cache_age < 86400:
734 age_str = f"{cache_age / 3600:.1f} hours"
735 else:
736 age_str = f"{cache_age / 86400:.1f} days"
737 cache_info = f" (cache: {cache_size:.1f}KB, {age_str} old)"
739 # Always surface cache info (via progress console when available)
740 if cached_count > 0:
741 msg = f"Found {cached_count}/{len(param_combinations)} results in cache{cache_info}"
742 if progress_callback is not None and hasattr(progress_callback, "console"):
743 progress_callback.console.print(msg)
744 else:
745 print(msg)
747 # Set up parallel execution
748 if max_workers is None:
749 max_workers = min(mp.cpu_count(), 24) # Cap at 24 for memory reasons
751 results = []
752 best_so_far = None
754 # Track rate (only for non-cached backtests)
755 start_time = None
756 non_cached_completed = 0
758 # Separate cached and to-run combinations (load cache instantly before starting progress)
759 to_run: list[dict] = []
760 cached_count = 0
761 for params in param_combinations:
762 key = self._get_cache_key(params)
763 cached = cache.get(key)
764 if cached is not None:
765 # Respect drawdown filter
766 if (
767 max_drawdown_limit is not None
768 and cached.get("max_dd", 0) < -max_drawdown_limit
769 ):
770 cached_count += 1
771 continue
772 results.append(cached)
773 if best_so_far is None or cached[metric] > best_so_far[metric]:
774 best_so_far = cached.copy()
775 cached_count += 1
776 else:
777 to_run.append(params)
779 # Create task AFTER loading cached results, so timer starts fresh
780 task_id = None
781 if progress_callback:
782 task_id = progress_callback.add_task(
783 "[cyan]Running backtests...",
784 total=len(param_combinations),
785 completed=cached_count # Already completed from cache
786 )
788 # Run remaining in parallel
789 if to_run:
790 with ProcessPoolExecutor(max_workers=max_workers) as executor:
791 future_to_params = {
792 executor.submit(self._run_single_backtest, params): params
793 for params in to_run
794 }
795 for future in as_completed(future_to_params):
796 params = future_to_params[future]
798 # Start timer on first actual backtest completion
799 if start_time is None:
800 start_time = time.time()
802 try:
803 result = future.result()
804 if result is not None:
805 # Respect drawdown filter
806 if (
807 max_drawdown_limit is not None
808 and result["max_dd"] < -max_drawdown_limit
809 ):
810 if progress_callback and task_id is not None:
811 progress_callback.update(task_id, advance=1)
812 continue
813 results.append(result)
814 # Update best so far
815 if (
816 best_so_far is None
817 or result[metric] > best_so_far[metric]
818 ):
819 best_so_far = result.copy()
820 # Save to cache (main process only)
821 key = self._get_cache_key(params)
822 cache[key] = result
823 self._save_cache(cache)
824 # Progress update
825 if progress_callback and task_id is not None:
826 non_cached_completed += 1
827 elapsed = max(time.time() - start_time, 1e-6)
828 rate = non_cached_completed / elapsed
829 rate_str = "instant" if rate > 999 else f"{rate:.1f}/s"
830 progress_callback.update(
831 task_id,
832 advance=1,
833 description=f"[cyan]Backtests[/cyan] [dim]│[/dim] Best {metric}: {best_so_far[metric]:.3f} [dim]| {rate_str}[/dim]",
834 )
835 else:
836 if progress_callback and task_id is not None:
837 non_cached_completed += 1
838 elapsed = max(time.time() - start_time, 1e-6)
839 rate = non_cached_completed / elapsed
840 rate_str = "instant" if rate > 999 else f"{rate:.1f}/s"
841 progress_callback.update(
842 task_id,
843 advance=1,
844 description=f"[cyan]Backtests[/cyan] [dim]| {rate_str}[/dim]",
845 )
846 except Exception:
847 if progress_callback and task_id is not None:
848 non_cached_completed += 1
849 elapsed = max(time.time() - start_time, 1e-6)
850 rate = non_cached_completed / elapsed
851 rate_str = "instant" if rate > 999 else f"{rate:.1f}/s"
852 progress_callback.update(
853 task_id,
854 advance=1,
855 description=f"[cyan]Backtests[/cyan] [dim]| {rate_str}[/dim]",
856 )
858 # Convert to DataFrame and sort by metric
859 if not results:
860 return pl.DataFrame()
862 df = pl.DataFrame(results)
864 # Cast numeric columns to proper float types (handles complex numbers, NaN, inf)
865 numeric_cols = ["sharpe", "cagr", "calmar", "sortino", "max_dd", "volatility", "win_rate", "final_equity"]
866 df = df.with_columns([
867 pl.col(col).cast(pl.Float64, strict=False).fill_nan(0.0)
868 for col in numeric_cols if col in df.columns
869 ])
871 df = df.sort(metric, descending=True)
873 return df
875 def clear_cache(self) -> None:
876 """Clear the optimization cache."""
877 import os
879 if os.path.exists(self._cache_file):
880 os.remove(self._cache_file)
882 def grid_search(
883 self,
884 num_longs: list[int] | None = None,
885 num_shorts: list[int] | None = None,
886 leverages: list[float] | None = None,
887 rebalance_days: list[int] | None = None,
888 metric: Literal["sharpe", "cagr", "calmar"] = "sharpe",
889 max_drawdown_limit: float | None = None,
890 ) -> pl.DataFrame:
891 """Legacy sequential grid search (kept for backwards compatibility)."""
892 # Just call the parallel version with max_workers=1
893 return self.grid_search_parallel(
894 num_longs=num_longs,
895 num_shorts=num_shorts,
896 leverages=leverages,
897 rebalance_days=rebalance_days,
898 metric=metric,
899 max_drawdown_limit=max_drawdown_limit,
900 max_workers=1,
901 progress_callback=None,
902 )
904 def get_best_params(
905 self,
906 results_df: pl.DataFrame,
907 metric: Literal["sharpe", "cagr", "calmar"] = "sharpe",
908 ) -> dict[str, Any] | None:
909 """Extract best parameters from results DataFrame."""
911 if len(results_df) == 0:
912 return None
914 # Get best row
915 best_row = results_df.sort(metric, descending=True).head(1)
917 return {
918 "num_long": int(best_row["num_long"][0]),
919 "num_short": int(best_row["num_short"][0]),
920 "target_leverage": float(best_row["leverage"][0]),
921 "rebalance_every_n_days": int(best_row["rebalance_days"][0]),
922 "rank_power": float(best_row["rank_power"][0]),
923 }