Coverage for src/cc_liquid/backtester.py: 53%

383 statements  

« prev     ^ index     » next       coverage.py v7.10.3, created at 2025-10-13 20:16 +0000

1"""Backtesting engine for cc-liquid. 

2 

3This module provides pure backtesting logic without any UI dependencies. 

4All display/visualization is handled by the CLI layer. 

5 

6⚠️ IMPORTANT DISCLAIMER: 

7Backtesting has inherent limitations. Past performance does not predict future results. 

8Results are hypothetical and do not account for all real-world factors including: 

9- Market impact and slippage beyond modeled estimates 

10- Changing market conditions and liquidity 

11- Technical failures and execution delays 

12- Regulatory changes and exchange rules 

13Always validate strategies with out-of-sample data and paper trading before live deployment. 

14""" 

15 

16from __future__ import annotations 

17 

18import math 

19from dataclasses import dataclass, field 

20from datetime import datetime, timedelta 

21from typing import Any, Literal 

22 

23import polars as pl 

24 

25from .portfolio import weights_from_ranks 

26 

27 

28@dataclass 

29class BacktestConfig: 

30 """Configuration for backtesting.""" 

31 

32 # Data paths 

33 prices_path: str = "raw_data.parquet" 

34 predictions_path: str = "predictions.parquet" 

35 

36 # Price data columns (defaults match raw_data.parquet) 

37 price_date_column: str = "date" 

38 price_id_column: str = "id" 

39 price_close_column: str = "close" 

40 

41 # Prediction columns (will be taken from DataSourceConfig in CLI) 

42 pred_date_column: str = "release_date" 

43 pred_id_column: str = "id" 

44 pred_value_column: str = "pred_10d" 

45 

46 # Data provider (e.g., crowdcent, numerai, local) 

47 data_provider: str | None = None 

48 

49 # Date range (None = use all available overlapping data) 

50 start_date: datetime | None = None 

51 end_date: datetime | None = None 

52 

53 # Strategy parameters (match PortfolioConfig) 

54 num_long: int = 60 

55 num_short: int = 50 

56 target_leverage: float = 3.0 # Sum of abs(weights), matching trader.py 

57 # Weighting (opt-in non-EW) 

58 weighting_scheme: Literal["equal", "rank_power"] = "equal" 

59 rank_power: float = 0.0 

60 

61 # Rebalancing 

62 rebalance_every_n_days: int = 10 

63 prediction_lag_days: int = 1 # Use T-lag signals to trade at T 

64 

65 # Costs (in basis points) 

66 fee_bps: float = 4.0 # Trading fee 

67 slippage_bps: float = 50.0 # Slippage cost 

68 

69 # Initial capital 

70 start_capital: float = 100_000.0 

71 

72 # Options 

73 verbose: bool = False 

74 

75 

76@dataclass 

77class BacktestResult: 

78 """Results from a backtest run.""" 

79 

80 # Daily time series 

81 daily: pl.DataFrame # columns: date, returns, equity, drawdown, turnover 

82 

83 # Position snapshots at rebalance dates 

84 rebalance_positions: pl.DataFrame # columns: date, id, weight 

85 

86 # Summary statistics 

87 stats: dict[str, float] = field(default_factory=dict) 

88 

89 # Config used 

90 config: BacktestConfig | None = None 

91 

92 

93class Backtester: 

94 """Core backtesting engine.""" 

95 

96 def __init__(self, config: BacktestConfig): 

97 self.config = config 

98 

99 def run(self) -> BacktestResult: 

100 """Run the backtest and return results.""" 

101 

102 # 1. Load and prepare data 

103 prices_long = self._load_prices() 

104 predictions_long = self._load_predictions() 

105 

106 # 2. Compute returns matrix (use ALL available prices first) 

107 returns_wide_all = self._compute_returns_wide(prices_long) 

108 

109 # 3. Determine valid trading dates from returns + lagged predictions 

110 valid_dates = self._get_valid_trading_dates_from_returns( 

111 returns_wide_all, predictions_long 

112 ) 

113 if len(valid_dates) == 0: 

114 raise ValueError( 

115 "No valid trading dates given returns and prediction coverage" 

116 ) 

117 

118 # 4. Filter returns to valid dates only (predictions are filtered by cutoff at selection time) 

119 returns_wide = returns_wide_all.filter(pl.col("date").is_in(valid_dates)) 

120 

121 # Scope predictions to up-to-last cutoff for efficiency (do NOT drop early dates needed for lag) 

122 last_cutoff = max(valid_dates) - timedelta(days=self.config.prediction_lag_days) 

123 predictions_long = predictions_long.filter(pl.col("pred_date") <= last_cutoff) 

124 

125 # 5. Determine rebalance schedule 

126 rebalance_dates = self._compute_rebalance_dates(valid_dates) 

127 

128 # 6. Run simulation 

129 result = self._simulate( 

130 returns_wide=returns_wide, 

131 predictions_long=predictions_long, 

132 rebalance_dates=rebalance_dates, 

133 ) 

134 

135 # 7. Compute statistics 

136 stats = self._compute_stats(result["daily"]) 

137 

138 return BacktestResult( 

139 daily=result["daily"], 

140 rebalance_positions=result["positions"], 

141 stats=stats, 

142 config=self.config, 

143 ) 

144 

145 def _get_valid_trading_dates_from_returns( 

146 self, returns_wide_all: pl.DataFrame, predictions_long: pl.DataFrame 

147 ) -> list[datetime]: 

148 """Compute valid trading dates based on returns dates and lagged prediction availability. 

149 

150 Logic: 

151 - Use ALL dates present in the returns matrix as candidate trading dates 

152 - A date D is tradable if there exists any prediction with pred_date <= D - lag 

153 - Apply optional user-specified start/end bounds at the end 

154 """ 

155 if "date" not in returns_wide_all.columns or len(returns_wide_all) == 0: 

156 return [] 

157 all_trade_dates = sorted(returns_wide_all["date"].to_list()) 

158 if not all_trade_dates: 

159 return [] 

160 pred_dates = set(predictions_long["pred_date"].unique().to_list()) 

161 lag_td = timedelta(days=self.config.prediction_lag_days) 

162 valid_dates: list[datetime] = [] 

163 for d in all_trade_dates: 

164 cutoff = d - lag_td 

165 # We need at least one prediction available on or before cutoff 

166 if any(pd <= cutoff for pd in pred_dates): 

167 valid_dates.append(d) 

168 # Apply explicit user bounds last 

169 if self.config.start_date: 

170 valid_dates = [d for d in valid_dates if d >= self.config.start_date] 

171 if self.config.end_date: 

172 valid_dates = [d for d in valid_dates if d <= self.config.end_date] 

173 return valid_dates 

174 

175 def _load_prices(self) -> pl.DataFrame: 

176 """Load price data in long format.""" 

177 df = pl.read_parquet(self.config.prices_path) 

178 

179 # Select and rename columns 

180 df = df.select( 

181 [ 

182 pl.col(self.config.price_date_column).alias("date"), 

183 pl.col(self.config.price_id_column).alias("id"), 

184 pl.col(self.config.price_close_column).alias("close"), 

185 ] 

186 ) 

187 

188 # Ensure date is datetime 

189 if df["date"].dtype != pl.Datetime: 

190 df = df.with_columns(pl.col("date").cast(pl.Date).cast(pl.Datetime)) 

191 

192 # Drop nulls and sort 

193 df = df.drop_nulls().sort(["date", "id"]) 

194 

195 if self.config.verbose: 

196 print(f"Loaded {len(df):,} price records for {df['id'].n_unique()} assets") 

197 print(f"Price date range: {df['date'].min()} to {df['date'].max()}") 

198 

199 return df 

200 

201 def _load_predictions(self) -> pl.DataFrame: 

202 """Load prediction data in long format.""" 

203 df = pl.read_parquet(self.config.predictions_path) 

204 

205 # Select and rename columns 

206 df = df.select( 

207 [ 

208 pl.col(self.config.pred_date_column).alias("pred_date"), 

209 pl.col(self.config.pred_id_column).alias("id"), 

210 pl.col(self.config.pred_value_column).alias("pred"), 

211 ] 

212 ) 

213 

214 # Ensure date is datetime 

215 if df["pred_date"].dtype != pl.Datetime: 

216 df = df.with_columns(pl.col("pred_date").cast(pl.Date).cast(pl.Datetime)) 

217 

218 # Drop nulls and sort 

219 df = df.drop_nulls().sort(["pred_date", "id"]) 

220 

221 if self.config.verbose: 

222 print( 

223 f"Loaded {len(df):,} prediction records for {df['id'].n_unique()} assets" 

224 ) 

225 print( 

226 f"Prediction date range: {df['pred_date'].min()} to {df['pred_date'].max()}" 

227 ) 

228 

229 return df 

230 

231 def _get_overlapping_dates( 

232 self, prices: pl.DataFrame, predictions: pl.DataFrame 

233 ) -> list[datetime]: 

234 """Find dates where we have both prices and valid predictions (considering lag).""" 

235 

236 # Get unique dates from each dataset 

237 price_dates = set(prices["date"].unique().to_list()) 

238 pred_dates = set(predictions["pred_date"].unique().to_list()) 

239 

240 # For each price date, check if we have predictions from T-lag days before 

241 valid_dates = [] 

242 lag_td = timedelta(days=self.config.prediction_lag_days) 

243 

244 for price_date in sorted(price_dates): 

245 # We need predictions from this date or earlier (up to lag days before) 

246 required_pred_date = price_date - lag_td 

247 

248 # Check if we have predictions on or before the required date 

249 has_valid_pred = any(pd <= required_pred_date for pd in pred_dates) 

250 

251 if has_valid_pred: 

252 valid_dates.append(price_date) 

253 

254 # Apply user-specified date bounds if any 

255 if self.config.start_date: 

256 valid_dates = [d for d in valid_dates if d >= self.config.start_date] 

257 if self.config.end_date: 

258 valid_dates = [d for d in valid_dates if d <= self.config.end_date] 

259 

260 if self.config.verbose: 

261 print(f"Found {len(valid_dates)} valid trading dates with overlapping data") 

262 if valid_dates: 

263 print(f"Trading date range: {min(valid_dates)} to {max(valid_dates)}") 

264 

265 return valid_dates 

266 

267 def _compute_returns_wide(self, prices_long: pl.DataFrame) -> pl.DataFrame: 

268 """Compute returns matrix in wide format (dates as rows, assets as columns).""" 

269 

270 # Calculate returns for each asset 

271 prices_long = prices_long.sort(["id", "date"]) 

272 prices_long = prices_long.with_columns( 

273 pl.col("close").pct_change().over("id").alias("return") 

274 ) 

275 

276 # Pivot to wide format 

277 returns_wide = prices_long.pivot(index="date", on="id", values="return").sort( 

278 "date" 

279 ) 

280 

281 if self.config.verbose: 

282 n_assets = len(returns_wide.columns) - 1 # Exclude date column 

283 print(f"Computed returns for {n_assets} assets") 

284 

285 return returns_wide 

286 

287 def _compute_rebalance_dates(self, valid_dates: list[datetime]) -> list[datetime]: 

288 """Determine rebalance dates based on schedule.""" 

289 if not valid_dates: 

290 return [] 

291 

292 rebalance_dates = [] 

293 current_date = min(valid_dates) 

294 

295 while current_date <= max(valid_dates): 

296 if current_date in valid_dates: 

297 rebalance_dates.append(current_date) 

298 current_date += timedelta(days=self.config.rebalance_every_n_days) 

299 

300 if self.config.verbose: 

301 print(f"Scheduled {len(rebalance_dates)} rebalance dates") 

302 

303 return rebalance_dates 

304 

305 def _select_assets( 

306 self, 

307 predictions: pl.DataFrame, 

308 cutoff_date: datetime, 

309 available_assets: set[str], 

310 ) -> tuple[list[str], list[str], pl.DataFrame]: 

311 """Select assets and return latest predictions DataFrame for sizing.""" 

312 

313 latest_preds = ( 

314 predictions.filter(pl.col("pred_date") <= cutoff_date) 

315 .filter(pl.col("id").is_in(available_assets)) 

316 .sort("pred_date", descending=True) 

317 .group_by("id") 

318 .first() 

319 ) 

320 

321 if len(latest_preds) == 0: 

322 empty = pl.DataFrame({"id": [], "pred": []}) 

323 return [], [], empty 

324 

325 latest_sorted = latest_preds.sort("pred", descending=True) 

326 all_ids = latest_sorted["id"].to_list() 

327 

328 num_long = min(self.config.num_long, len(all_ids)) 

329 num_short = min(self.config.num_short, len(all_ids) - num_long) 

330 

331 long_assets = all_ids[:num_long] 

332 short_assets = all_ids[-num_short:] if num_short > 0 else [] 

333 

334 return long_assets, short_assets, latest_sorted.select(["id", "pred"]) 

335 

336 def _simulate( 

337 self, 

338 returns_wide: pl.DataFrame, 

339 predictions_long: pl.DataFrame, 

340 rebalance_dates: list[datetime], 

341 ) -> dict: 

342 """Run the backtest simulation.""" 

343 

344 # Initialize tracking variables 

345 equity = self.config.start_capital 

346 peak_equity = equity 

347 

348 daily_results = [] 

349 position_snapshots = [] 

350 current_weights = {} # Asset -> weight 

351 

352 # Convert rebalance dates to set for fast lookup 

353 rebalance_set = set(rebalance_dates) 

354 

355 # Get all dates from returns 

356 all_dates = returns_wide["date"].to_list() 

357 

358 for i, date in enumerate(all_dates): 

359 # Get today's returns 

360 returns_row = returns_wide.filter(pl.col("date") == date) 

361 

362 # Check if we need to rebalance 

363 if date in rebalance_set: 

364 # Determine cutoff date for predictions (T - lag) 

365 cutoff_date = date - timedelta(days=self.config.prediction_lag_days) 

366 

367 # Get available assets (those with returns today) 

368 available_assets = set() 

369 for col in returns_row.columns: 

370 if col != "date": 

371 val = returns_row[col][0] 

372 if val is not None and not math.isnan(val): 

373 available_assets.add(col) 

374 

375 # Determine selections and convert ranks to weights 

376 long_assets, short_assets, latest_preds = self._select_assets( 

377 predictions_long, 

378 cutoff_date, 

379 available_assets, 

380 ) 

381 

382 new_weights: dict[str, float] = {} 

383 total_positions = len(long_assets) + len(short_assets) 

384 

385 if total_positions > 0 and len(latest_preds) > 0: 

386 weights = weights_from_ranks( 

387 latest_preds=latest_preds, 

388 id_col="id", 

389 pred_col="pred", 

390 long_assets=long_assets, 

391 short_assets=short_assets, 

392 target_gross=self.config.target_leverage, 

393 scheme=self.config.weighting_scheme, 

394 power=self.config.rank_power, 

395 ) 

396 new_weights = weights 

397 

398 # Calculate turnover (L1 norm of weight changes) 

399 turnover = 0.0 

400 all_assets = set(current_weights.keys()) | set(new_weights.keys()) 

401 for asset in all_assets: 

402 old_w = current_weights.get(asset, 0.0) 

403 new_w = new_weights.get(asset, 0.0) 

404 turnover += abs(new_w - old_w) 

405 

406 # Apply trading costs 

407 total_cost_bps = self.config.fee_bps + self.config.slippage_bps 

408 cost = turnover * (total_cost_bps / 10_000) 

409 

410 # Update weights AFTER calculating returns (weights take effect next period) 

411 # But deduct costs immediately 

412 equity *= 1 - cost 

413 

414 # Store position snapshot 

415 for asset, weight in new_weights.items(): 

416 position_snapshots.append( 

417 {"date": date, "id": asset, "weight": weight} 

418 ) 

419 

420 # Update current weights for next period 

421 current_weights = new_weights.copy() 

422 else: 

423 turnover = 0.0 

424 

425 # Calculate portfolio return (using current weights from previous rebalance) 

426 portfolio_return = 0.0 

427 

428 if current_weights: 

429 for asset, weight in current_weights.items(): 

430 if asset in returns_row.columns: 

431 asset_return = returns_row[asset][0] 

432 if asset_return is not None and not math.isnan(asset_return): 

433 portfolio_return += weight * asset_return 

434 

435 # Update equity 

436 equity *= 1 + portfolio_return 

437 

438 # Track peak and drawdown 

439 if equity > peak_equity: 

440 peak_equity = equity 

441 drawdown = (equity - peak_equity) / peak_equity if peak_equity > 0 else 0 

442 

443 # Store daily results 

444 daily_results.append( 

445 { 

446 "date": date, 

447 "returns": portfolio_return, 

448 "equity": equity, 

449 "drawdown": drawdown, 

450 "turnover": turnover, 

451 } 

452 ) 

453 

454 # Convert to DataFrames 

455 daily_df = pl.DataFrame(daily_results) 

456 positions_df = ( 

457 pl.DataFrame(position_snapshots) if position_snapshots else pl.DataFrame() 

458 ) 

459 

460 return {"daily": daily_df, "positions": positions_df} 

461 

462 def _compute_stats(self, daily: pl.DataFrame) -> dict[str, float]: 

463 """Compute summary statistics from daily results.""" 

464 

465 if len(daily) == 0: 

466 return {} 

467 

468 # Basic info 

469 n_days = len(daily) 

470 start_equity = self.config.start_capital 

471 final_equity = daily["equity"][-1] 

472 

473 # Returns 

474 total_return = (final_equity / start_equity) - 1 

475 

476 # Annualized metrics (assuming 365 days per year for crypto) 

477 years = n_days / 365.0 

478 # Handle negative equity (can't take fractional power of negative number) 

479 if years > 0 and final_equity > 0: 

480 cagr = (final_equity / start_equity) ** (1.0 / years) - 1 

481 else: 

482 cagr = total_return # Fallback to simple return if equity went negative 

483 

484 # Risk metrics 

485 returns = daily["returns"] 

486 daily_vol = returns.std() 

487 annual_vol = float(daily_vol * math.sqrt(365)) if daily_vol is not None else 0.0 

488 

489 # Sharpe ratio (assuming 0 risk-free rate) 

490 sharpe = cagr / annual_vol if annual_vol > 0 else 0.0 

491 

492 # Drawdown 

493 max_drawdown = daily["drawdown"].min() # Most negative value 

494 

495 # Calmar ratio 

496 calmar = ( 

497 cagr / abs(max_drawdown) 

498 if max_drawdown is not None and max_drawdown < 0 

499 else 0 

500 ) 

501 

502 # Win rate 

503 positive_days = (returns > 0).sum() 

504 win_rate = positive_days / n_days if n_days > 0 else 0 

505 

506 # Turnover 

507 avg_turnover = daily.filter(pl.col("turnover") > 0)["turnover"].mean() 

508 if avg_turnover is None: 

509 avg_turnover = 0 

510 

511 # Sortino ratio (downside deviation) 

512 negative_returns = returns.filter(returns < 0) 

513 if len(negative_returns) > 0: 

514 downside_vol = negative_returns.std() 

515 annual_downside_vol = ( 

516 downside_vol * math.sqrt(365) if downside_vol is not None else 0 

517 ) 

518 sortino = cagr / annual_downside_vol if annual_downside_vol > 0 else 0 

519 else: 

520 sortino = float("inf") if cagr > 0 else 0 

521 

522 return { 

523 "days": n_days, 

524 "total_return": total_return, 

525 "cagr": cagr, 

526 "annual_volatility": annual_vol, 

527 "sharpe_ratio": sharpe, 

528 "sortino_ratio": sortino, 

529 "max_drawdown": max_drawdown, 

530 "calmar_ratio": calmar, 

531 "win_rate": win_rate, 

532 "avg_turnover": avg_turnover, 

533 "final_equity": final_equity, 

534 } 

535 

536 

537class BacktestOptimizer: 

538 """Grid search optimizer for backtesting parameters with parallel execution and caching.""" 

539 

540 def __init__(self, base_config: BacktestConfig): 

541 self.base_config = base_config 

542 self._prices_cache = None 

543 self._predictions_cache = None 

544 self._cache_file = ".cc_liquid_optimizer_cache.json" 

545 

546 def _get_cache_key(self, params: dict) -> str: 

547 """Generate a unique cache key for a parameter combination.""" 

548 import hashlib 

549 import json 

550 

551 # Include base config settings that affect results 

552 cache_data = { 

553 "params": params, 

554 "config": { 

555 "prices_path": self.base_config.prices_path, 

556 "predictions_path": self.base_config.predictions_path, 

557 "data_provider": self.base_config.data_provider, 

558 "start_date": str(self.base_config.start_date) 

559 if self.base_config.start_date 

560 else None, 

561 "end_date": str(self.base_config.end_date) 

562 if self.base_config.end_date 

563 else None, 

564 "prediction_lag_days": self.base_config.prediction_lag_days, 

565 "fee_bps": self.base_config.fee_bps, 

566 "slippage_bps": self.base_config.slippage_bps, 

567 "start_capital": self.base_config.start_capital, 

568 "weighting_scheme": self.base_config.weighting_scheme, 

569 "rank_power": self.base_config.rank_power, 

570 }, 

571 } 

572 

573 cache_str = json.dumps(cache_data, sort_keys=True) 

574 return hashlib.md5(cache_str.encode()).hexdigest() 

575 

576 def _load_cache(self) -> dict: 

577 """Load cached results from disk.""" 

578 import json 

579 import os 

580 

581 if not os.path.exists(self._cache_file): 

582 return {} 

583 

584 try: 

585 with open(self._cache_file, "r") as f: 

586 return json.load(f) 

587 except Exception: 

588 return {} 

589 

590 def _save_cache(self, cache: dict) -> None: 

591 """Save cache to disk.""" 

592 import json 

593 

594 try: 

595 with open(self._cache_file, "w") as f: 

596 json.dump(cache, f) 

597 except Exception: 

598 pass # Silently fail if can't write cache 

599 

600 def _run_single_backtest(self, params: dict) -> dict | None: 

601 """Run a single backtest with given parameters. No cache IO here.""" 

602 # Create config for this combination 

603 config = BacktestConfig( 

604 prices_path=self.base_config.prices_path, 

605 predictions_path=self.base_config.predictions_path, 

606 price_date_column=self.base_config.price_date_column, 

607 price_id_column=self.base_config.price_id_column, 

608 price_close_column=self.base_config.price_close_column, 

609 pred_date_column=self.base_config.pred_date_column, 

610 pred_id_column=self.base_config.pred_id_column, 

611 pred_value_column=self.base_config.pred_value_column, 

612 data_provider=self.base_config.data_provider, 

613 start_date=self.base_config.start_date, 

614 end_date=self.base_config.end_date, 

615 num_long=params["num_long"], 

616 num_short=params["num_short"], 

617 target_leverage=params["leverage"], 

618 rebalance_every_n_days=params["rebalance_days"], 

619 prediction_lag_days=self.base_config.prediction_lag_days, 

620 fee_bps=self.base_config.fee_bps, 

621 slippage_bps=self.base_config.slippage_bps, 

622 start_capital=self.base_config.start_capital, 

623 verbose=False, 

624 weighting_scheme="rank_power", # Always use rank_power (power=0 is equal weight) 

625 rank_power=params["rank_power"], 

626 ) 

627 

628 try: 

629 # Run backtest 

630 backtester = Backtester(config) 

631 result = backtester.run() 

632 

633 # Store results 

634 result_data = { 

635 "num_long": params["num_long"], 

636 "num_short": params["num_short"], 

637 "leverage": params["leverage"], 

638 "rebalance_days": params["rebalance_days"], 

639 "rank_power": params["rank_power"], 

640 "sharpe": result.stats["sharpe_ratio"], 

641 "cagr": result.stats["cagr"], 

642 "calmar": result.stats["calmar_ratio"], 

643 "sortino": result.stats["sortino_ratio"], 

644 "max_dd": result.stats["max_drawdown"], 

645 "volatility": result.stats["annual_volatility"], 

646 "win_rate": result.stats["win_rate"], 

647 "final_equity": result.stats["final_equity"], 

648 } 

649 

650 return result_data 

651 

652 except Exception: 

653 return None 

654 

655 def grid_search_parallel( 

656 self, 

657 num_longs: list[int] | None = None, 

658 num_shorts: list[int] | None = None, 

659 leverages: list[float] | None = None, 

660 rebalance_days: list[int] | None = None, 

661 rank_powers: list[float] | None = None, 

662 metric: Literal["sharpe", "cagr", "calmar"] = "sharpe", 

663 max_drawdown_limit: float | None = None, 

664 max_workers: int | None = None, 

665 progress_callback: Any | None = None, 

666 ) -> pl.DataFrame: 

667 """Run grid search over parameter combinations in parallel. 

668 

669 Args: 

670 num_longs: List of long position counts to test 

671 num_shorts: List of short position counts to test 

672 leverages: List of leverage values to test 

673 rebalance_days: List of rebalance frequencies to test 

674 rank_powers: List of rank power values to test (0=equal weight) 

675 metric: Optimization metric 

676 max_drawdown_limit: Maximum drawdown constraint 

677 max_workers: Number of parallel workers (None = auto) 

678 progress_callback: Rich Progress instance for updates 

679 

680 Returns: 

681 DataFrame with all results sorted by metric. 

682 """ 

683 from concurrent.futures import ProcessPoolExecutor, as_completed 

684 import multiprocessing as mp 

685 import time 

686 

687 # Default to single values from base config if not specified 

688 if num_longs is None: 

689 num_longs = [self.base_config.num_long] 

690 if num_shorts is None: 

691 num_shorts = [self.base_config.num_short] 

692 if leverages is None: 

693 leverages = [self.base_config.target_leverage] 

694 if rebalance_days is None: 

695 rebalance_days = [self.base_config.rebalance_every_n_days] 

696 if rank_powers is None: 

697 rank_powers = [self.base_config.rank_power] 

698 

699 # Generate all parameter combinations 

700 param_combinations = [] 

701 for n_long in num_longs: 

702 for n_short in num_shorts: 

703 for leverage in leverages: 

704 for rebal_days in rebalance_days: 

705 for rank_pow in rank_powers: 

706 param_combinations.append( 

707 { 

708 "num_long": n_long, 

709 "num_short": n_short, 

710 "leverage": leverage, 

711 "rebalance_days": rebal_days, 

712 "rank_power": rank_pow, 

713 } 

714 ) 

715 

716 # Check cache to see which we already have 

717 cache = self._load_cache() 

718 cached_count = sum( 

719 1 for p in param_combinations if self._get_cache_key(p) in cache 

720 ) 

721 

722 # Get cache metadata 

723 cache_info = "" 

724 if cached_count > 0: 

725 import os 

726 import time as time_module 

727 

728 if os.path.exists(self._cache_file): 

729 cache_size = os.path.getsize(self._cache_file) / 1024 # KB 

730 cache_age = time_module.time() - os.path.getmtime(self._cache_file) 

731 if cache_age < 3600: 

732 age_str = f"{int(cache_age / 60)} min" 

733 elif cache_age < 86400: 

734 age_str = f"{cache_age / 3600:.1f} hours" 

735 else: 

736 age_str = f"{cache_age / 86400:.1f} days" 

737 cache_info = f" (cache: {cache_size:.1f}KB, {age_str} old)" 

738 

739 # Always surface cache info (via progress console when available) 

740 if cached_count > 0: 

741 msg = f"Found {cached_count}/{len(param_combinations)} results in cache{cache_info}" 

742 if progress_callback is not None and hasattr(progress_callback, "console"): 

743 progress_callback.console.print(msg) 

744 else: 

745 print(msg) 

746 

747 # Set up parallel execution 

748 if max_workers is None: 

749 max_workers = min(mp.cpu_count(), 24) # Cap at 24 for memory reasons 

750 

751 results = [] 

752 best_so_far = None 

753 

754 # Track rate (only for non-cached backtests) 

755 start_time = None 

756 non_cached_completed = 0 

757 

758 # Separate cached and to-run combinations (load cache instantly before starting progress) 

759 to_run: list[dict] = [] 

760 cached_count = 0 

761 for params in param_combinations: 

762 key = self._get_cache_key(params) 

763 cached = cache.get(key) 

764 if cached is not None: 

765 # Respect drawdown filter 

766 if ( 

767 max_drawdown_limit is not None 

768 and cached.get("max_dd", 0) < -max_drawdown_limit 

769 ): 

770 cached_count += 1 

771 continue 

772 results.append(cached) 

773 if best_so_far is None or cached[metric] > best_so_far[metric]: 

774 best_so_far = cached.copy() 

775 cached_count += 1 

776 else: 

777 to_run.append(params) 

778 

779 # Create task AFTER loading cached results, so timer starts fresh 

780 task_id = None 

781 if progress_callback: 

782 task_id = progress_callback.add_task( 

783 "[cyan]Running backtests...", 

784 total=len(param_combinations), 

785 completed=cached_count # Already completed from cache 

786 ) 

787 

788 # Run remaining in parallel 

789 if to_run: 

790 with ProcessPoolExecutor(max_workers=max_workers) as executor: 

791 future_to_params = { 

792 executor.submit(self._run_single_backtest, params): params 

793 for params in to_run 

794 } 

795 for future in as_completed(future_to_params): 

796 params = future_to_params[future] 

797 

798 # Start timer on first actual backtest completion 

799 if start_time is None: 

800 start_time = time.time() 

801 

802 try: 

803 result = future.result() 

804 if result is not None: 

805 # Respect drawdown filter 

806 if ( 

807 max_drawdown_limit is not None 

808 and result["max_dd"] < -max_drawdown_limit 

809 ): 

810 if progress_callback and task_id is not None: 

811 progress_callback.update(task_id, advance=1) 

812 continue 

813 results.append(result) 

814 # Update best so far 

815 if ( 

816 best_so_far is None 

817 or result[metric] > best_so_far[metric] 

818 ): 

819 best_so_far = result.copy() 

820 # Save to cache (main process only) 

821 key = self._get_cache_key(params) 

822 cache[key] = result 

823 self._save_cache(cache) 

824 # Progress update 

825 if progress_callback and task_id is not None: 

826 non_cached_completed += 1 

827 elapsed = max(time.time() - start_time, 1e-6) 

828 rate = non_cached_completed / elapsed 

829 rate_str = "instant" if rate > 999 else f"{rate:.1f}/s" 

830 progress_callback.update( 

831 task_id, 

832 advance=1, 

833 description=f"[cyan]Backtests[/cyan] [dim]│[/dim] Best {metric}: {best_so_far[metric]:.3f} [dim]| {rate_str}[/dim]", 

834 ) 

835 else: 

836 if progress_callback and task_id is not None: 

837 non_cached_completed += 1 

838 elapsed = max(time.time() - start_time, 1e-6) 

839 rate = non_cached_completed / elapsed 

840 rate_str = "instant" if rate > 999 else f"{rate:.1f}/s" 

841 progress_callback.update( 

842 task_id, 

843 advance=1, 

844 description=f"[cyan]Backtests[/cyan] [dim]| {rate_str}[/dim]", 

845 ) 

846 except Exception: 

847 if progress_callback and task_id is not None: 

848 non_cached_completed += 1 

849 elapsed = max(time.time() - start_time, 1e-6) 

850 rate = non_cached_completed / elapsed 

851 rate_str = "instant" if rate > 999 else f"{rate:.1f}/s" 

852 progress_callback.update( 

853 task_id, 

854 advance=1, 

855 description=f"[cyan]Backtests[/cyan] [dim]| {rate_str}[/dim]", 

856 ) 

857 

858 # Convert to DataFrame and sort by metric 

859 if not results: 

860 return pl.DataFrame() 

861 

862 df = pl.DataFrame(results) 

863 

864 # Cast numeric columns to proper float types (handles complex numbers, NaN, inf) 

865 numeric_cols = ["sharpe", "cagr", "calmar", "sortino", "max_dd", "volatility", "win_rate", "final_equity"] 

866 df = df.with_columns([ 

867 pl.col(col).cast(pl.Float64, strict=False).fill_nan(0.0) 

868 for col in numeric_cols if col in df.columns 

869 ]) 

870 

871 df = df.sort(metric, descending=True) 

872 

873 return df 

874 

875 def clear_cache(self) -> None: 

876 """Clear the optimization cache.""" 

877 import os 

878 

879 if os.path.exists(self._cache_file): 

880 os.remove(self._cache_file) 

881 

882 def grid_search( 

883 self, 

884 num_longs: list[int] | None = None, 

885 num_shorts: list[int] | None = None, 

886 leverages: list[float] | None = None, 

887 rebalance_days: list[int] | None = None, 

888 metric: Literal["sharpe", "cagr", "calmar"] = "sharpe", 

889 max_drawdown_limit: float | None = None, 

890 ) -> pl.DataFrame: 

891 """Legacy sequential grid search (kept for backwards compatibility).""" 

892 # Just call the parallel version with max_workers=1 

893 return self.grid_search_parallel( 

894 num_longs=num_longs, 

895 num_shorts=num_shorts, 

896 leverages=leverages, 

897 rebalance_days=rebalance_days, 

898 metric=metric, 

899 max_drawdown_limit=max_drawdown_limit, 

900 max_workers=1, 

901 progress_callback=None, 

902 ) 

903 

904 def get_best_params( 

905 self, 

906 results_df: pl.DataFrame, 

907 metric: Literal["sharpe", "cagr", "calmar"] = "sharpe", 

908 ) -> dict[str, Any] | None: 

909 """Extract best parameters from results DataFrame.""" 

910 

911 if len(results_df) == 0: 

912 return None 

913 

914 # Get best row 

915 best_row = results_df.sort(metric, descending=True).head(1) 

916 

917 return { 

918 "num_long": int(best_row["num_long"][0]), 

919 "num_short": int(best_row["num_short"][0]), 

920 "target_leverage": float(best_row["leverage"][0]), 

921 "rebalance_every_n_days": int(best_row["rebalance_days"][0]), 

922 "rank_power": float(best_row["rank_power"][0]), 

923 }