Coverage for src/alprina_cli/agent_coordinator.py: 0%

151 statements  

« prev     ^ index     » next       coverage.py v7.11.3, created at 2025-11-14 11:27 +0100

1""" 

2Alprina Agent Coordinator - Intelligent agent chaining and vulnerability lifecycle management. 

3 

4Coordinates multiple security agents to work together in sophisticated workflows: 

5- Red Team → Blue Team → DFIR chains 

6- Vulnerability discovery → Validation → Remediation tracking 

7- Automated follow-up scans 

8- Agent-to-agent communication 

9 

10Reference: AI SDK Agent Coordination Patterns 

11""" 

12 

13from typing import Dict, List, Any, Optional, Callable 

14from enum import Enum 

15from loguru import logger 

16from datetime import datetime, timedelta 

17import asyncio 

18 

19 

20class ChainType(Enum): 

21 """Types of agent coordination chains.""" 

22 ATTACK_DEFENSE = "attack_defense" # Red Team → Blue Team 

23 INVESTIGATION = "investigation" # Discovery → DFIR → Analysis 

24 FULL_LIFECYCLE = "full_lifecycle" # Red Team → Blue Team → DFIR → Remediation 

25 VALIDATION = "validation" # Scan → Retest → Verify 

26 CONTINUOUS = "continuous" # Ongoing monitoring and response 

27 

28 

29class VulnerabilityState(Enum): 

30 """Vulnerability lifecycle states.""" 

31 DISCOVERED = "discovered" # Initial discovery 

32 VALIDATED = "validated" # Confirmed by retesting 

33 TRIAGED = "triaged" # Prioritized and assigned 

34 IN_REMEDIATION = "in_remediation" # Being fixed 

35 FIXED = "fixed" # Fix applied 

36 VERIFIED = "verified" # Fix verified by retest 

37 FALSE_POSITIVE = "false_positive" # Marked as false positive 

38 ACCEPTED_RISK = "accepted_risk" # Risk accepted by stakeholders 

39 

40 

41class AgentCoordinator: 

42 """ 

43 Coordinates multiple security agents for sophisticated workflows. 

44 

45 Capabilities: 

46 - Agent chaining (A → B → C) 

47 - Vulnerability lifecycle tracking 

48 - Automated follow-up scans 

49 - Agent-to-agent communication 

50 - Intelligent chain selection 

51 """ 

52 

53 def __init__(self): 

54 """Initialize agent coordinator.""" 

55 self.vulnerability_registry: Dict[str, Dict[str, Any]] = {} 

56 self.chain_history: List[Dict[str, Any]] = [] 

57 logger.info("Agent Coordinator initialized") 

58 

59 async def execute_chain( 

60 self, 

61 chain_type: ChainType, 

62 target: str, 

63 initial_findings: Optional[Dict[str, Any]] = None, 

64 safe_only: bool = True 

65 ) -> Dict[str, Any]: 

66 """ 

67 Execute a coordinated agent chain. 

68 

69 Args: 

70 chain_type: Type of coordination chain to execute 

71 target: Target to scan 

72 initial_findings: Optional initial findings to start with 

73 safe_only: Only run safe checks 

74 

75 Returns: 

76 Aggregated results from all agents in chain 

77 """ 

78 logger.info(f"Executing {chain_type.value} chain on {target}") 

79 

80 chain_def = self._get_chain_definition(chain_type) 

81 

82 if not chain_def: 

83 logger.error(f"Unknown chain type: {chain_type}") 

84 return {"error": f"Unknown chain type: {chain_type}"} 

85 

86 # Execute chain 

87 results = { 

88 "chain_type": chain_type.value, 

89 "target": target, 

90 "start_time": datetime.now().isoformat(), 

91 "steps": [], 

92 "vulnerabilities": [] 

93 } 

94 

95 context = initial_findings or {} 

96 

97 for step in chain_def["steps"]: 

98 try: 

99 logger.info(f"Chain step: {step['name']} using {step['agent']}") 

100 

101 # Execute agent with context from previous steps 

102 step_result = await self._execute_agent_step( 

103 agent=step["agent"], 

104 task=step["task"], 

105 target=target, 

106 context=context, 

107 safe_only=safe_only 

108 ) 

109 

110 # Record step 

111 results["steps"].append({ 

112 "name": step["name"], 

113 "agent": step["agent"], 

114 "timestamp": datetime.now().isoformat(), 

115 "findings_count": len(step_result.get("findings", [])), 

116 "output": step_result 

117 }) 

118 

119 # Update context for next step 

120 context = self._merge_context(context, step_result) 

121 

122 # Track vulnerabilities 

123 if step_result.get("findings"): 

124 for finding in step_result["findings"]: 

125 self._track_vulnerability(finding, step["agent"], target) 

126 results["vulnerabilities"].append(finding) 

127 

128 except Exception as e: 

129 logger.error(f"Chain step {step['name']} failed: {e}") 

130 results["steps"].append({ 

131 "name": step["name"], 

132 "agent": step["agent"], 

133 "error": str(e), 

134 "timestamp": datetime.now().isoformat() 

135 }) 

136 

137 results["end_time"] = datetime.now().isoformat() 

138 results["total_vulnerabilities"] = len(results["vulnerabilities"]) 

139 

140 # Record chain execution 

141 self.chain_history.append({ 

142 "chain_type": chain_type.value, 

143 "target": target, 

144 "timestamp": datetime.now().isoformat(), 

145 "vulnerabilities_found": results["total_vulnerabilities"] 

146 }) 

147 

148 return results 

149 

150 def _get_chain_definition(self, chain_type: ChainType) -> Optional[Dict[str, Any]]: 

151 """ 

152 Get chain definition for a given chain type. 

153 

154 Args: 

155 chain_type: Type of chain 

156 

157 Returns: 

158 Chain definition with steps 

159 """ 

160 chains = { 

161 ChainType.ATTACK_DEFENSE: { 

162 "name": "Attack & Defense Chain", 

163 "description": "Red team attacks, blue team defends", 

164 "steps": [ 

165 { 

166 "name": "Red Team Assessment", 

167 "agent": "red_teamer", 

168 "task": "offensive-security", 

169 "description": "Identify exploitable vulnerabilities" 

170 }, 

171 { 

172 "name": "Blue Team Response", 

173 "agent": "blue_teamer", 

174 "task": "defensive-security", 

175 "description": "Evaluate defenses against discovered attacks" 

176 } 

177 ] 

178 }, 

179 

180 ChainType.INVESTIGATION: { 

181 "name": "Security Investigation Chain", 

182 "description": "Comprehensive security investigation", 

183 "steps": [ 

184 { 

185 "name": "Initial Scan", 

186 "agent": "codeagent", 

187 "task": "code-audit", 

188 "description": "Static code analysis" 

189 }, 

190 { 

191 "name": "Deep Analysis", 

192 "agent": "secret_detection", 

193 "task": "secret-detection", 

194 "description": "Credential and secret scanning" 

195 }, 

196 { 

197 "name": "Forensic Analysis", 

198 "agent": "dfir", 

199 "task": "forensics", 

200 "description": "DFIR investigation of findings" 

201 } 

202 ] 

203 }, 

204 

205 ChainType.FULL_LIFECYCLE: { 

206 "name": "Full Security Lifecycle", 

207 "description": "Complete security assessment lifecycle", 

208 "steps": [ 

209 { 

210 "name": "Red Team Attack", 

211 "agent": "red_teamer", 

212 "task": "offensive-security", 

213 "description": "Offensive security testing" 

214 }, 

215 { 

216 "name": "Blue Team Defense", 

217 "agent": "blue_teamer", 

218 "task": "defensive-security", 

219 "description": "Defensive posture evaluation" 

220 }, 

221 { 

222 "name": "Forensic Investigation", 

223 "agent": "dfir", 

224 "task": "forensics", 

225 "description": "Evidence collection and analysis" 

226 }, 

227 { 

228 "name": "Remediation Planning", 

229 "agent": "codeagent", 

230 "task": "code-audit", 

231 "description": "Code-level remediation guidance" 

232 } 

233 ] 

234 }, 

235 

236 ChainType.VALIDATION: { 

237 "name": "Vulnerability Validation Chain", 

238 "description": "Discover, validate, and verify fixes", 

239 "steps": [ 

240 { 

241 "name": "Initial Discovery", 

242 "agent": "codeagent", 

243 "task": "code-audit", 

244 "description": "Find vulnerabilities" 

245 }, 

246 { 

247 "name": "Validation", 

248 "agent": "bug_bounty", 

249 "task": "vuln-scan", 

250 "description": "Validate findings" 

251 }, 

252 { 

253 "name": "Retest", 

254 "agent": "retester", 

255 "task": "retest", 

256 "description": "Verify fixes" 

257 } 

258 ] 

259 }, 

260 

261 ChainType.CONTINUOUS: { 

262 "name": "Continuous Monitoring Chain", 

263 "description": "Ongoing security monitoring", 

264 "steps": [ 

265 { 

266 "name": "Baseline Scan", 

267 "agent": "codeagent", 

268 "task": "code-audit", 

269 "description": "Establish security baseline" 

270 }, 

271 { 

272 "name": "Secret Monitoring", 

273 "agent": "secret_detection", 

274 "task": "secret-detection", 

275 "description": "Monitor for credential leaks" 

276 }, 

277 { 

278 "name": "Configuration Audit", 

279 "agent": "config_audit", 

280 "task": "config-audit", 

281 "description": "Check configuration drift" 

282 } 

283 ] 

284 } 

285 } 

286 

287 return chains.get(chain_type) 

288 

289 async def _execute_agent_step( 

290 self, 

291 agent: str, 

292 task: str, 

293 target: str, 

294 context: Dict[str, Any], 

295 safe_only: bool 

296 ) -> Dict[str, Any]: 

297 """ 

298 Execute a single agent in the chain. 

299 

300 Args: 

301 agent: Agent to execute 

302 task: Task for agent 

303 target: Scan target 

304 context: Context from previous steps 

305 safe_only: Only safe checks 

306 

307 Returns: 

308 Agent execution results 

309 """ 

310 from .security_engine import run_local_scan, run_remote_scan 

311 from pathlib import Path 

312 

313 # Determine if local or remote 

314 is_local = Path(target).exists() if target else False 

315 

316 try: 

317 if is_local: 

318 result = await asyncio.get_event_loop().run_in_executor( 

319 None, 

320 run_local_scan, 

321 target, 

322 task, 

323 safe_only 

324 ) 

325 else: 

326 result = await asyncio.get_event_loop().run_in_executor( 

327 None, 

328 run_remote_scan, 

329 target, 

330 task, 

331 safe_only 

332 ) 

333 

334 # Enhance result with context 

335 if isinstance(result, dict): 

336 result["context_from_previous_steps"] = context 

337 result["chain_position"] = len(context.get("chain_history", [])) + 1 

338 

339 return result 

340 

341 except Exception as e: 

342 logger.error(f"Agent {agent} execution failed: {e}") 

343 return { 

344 "error": str(e), 

345 "agent": agent, 

346 "task": task, 

347 "findings": [] 

348 } 

349 

350 def _merge_context( 

351 self, 

352 existing_context: Dict[str, Any], 

353 new_results: Dict[str, Any] 

354 ) -> Dict[str, Any]: 

355 """ 

356 Merge context from previous steps with new results. 

357 

358 Args: 

359 existing_context: Existing context 

360 new_results: New results to merge 

361 

362 Returns: 

363 Merged context 

364 """ 

365 merged = existing_context.copy() 

366 

367 # Add findings to context 

368 if "all_findings" not in merged: 

369 merged["all_findings"] = [] 

370 

371 if new_results.get("findings"): 

372 merged["all_findings"].extend(new_results["findings"]) 

373 

374 # Track chain history 

375 if "chain_history" not in merged: 

376 merged["chain_history"] = [] 

377 

378 merged["chain_history"].append({ 

379 "agent": new_results.get("agent"), 

380 "findings_count": len(new_results.get("findings", [])), 

381 "timestamp": datetime.now().isoformat() 

382 }) 

383 

384 # Aggregate severity counts 

385 if "total_severity_counts" not in merged: 

386 merged["total_severity_counts"] = { 

387 "CRITICAL": 0, 

388 "HIGH": 0, 

389 "MEDIUM": 0, 

390 "LOW": 0, 

391 "INFO": 0 

392 } 

393 

394 for finding in new_results.get("findings", []): 

395 severity = finding.get("severity", "INFO") 

396 if severity in merged["total_severity_counts"]: 

397 merged["total_severity_counts"][severity] += 1 

398 

399 return merged 

400 

401 def _track_vulnerability( 

402 self, 

403 finding: Dict[str, Any], 

404 discovered_by: str, 

405 target: str 

406 ): 

407 """ 

408 Track vulnerability in lifecycle registry. 

409 

410 Args: 

411 finding: Vulnerability finding 

412 discovered_by: Agent that discovered it 

413 target: Scan target 

414 """ 

415 vuln_id = self._generate_vuln_id(finding, target) 

416 

417 if vuln_id not in self.vulnerability_registry: 

418 # New vulnerability 

419 self.vulnerability_registry[vuln_id] = { 

420 "id": vuln_id, 

421 "state": VulnerabilityState.DISCOVERED.value, 

422 "finding": finding, 

423 "target": target, 

424 "discovered_by": discovered_by, 

425 "discovered_at": datetime.now().isoformat(), 

426 "history": [{ 

427 "state": VulnerabilityState.DISCOVERED.value, 

428 "timestamp": datetime.now().isoformat(), 

429 "agent": discovered_by 

430 }] 

431 } 

432 logger.info(f"New vulnerability tracked: {vuln_id}") 

433 else: 

434 # Update existing vulnerability 

435 self.vulnerability_registry[vuln_id]["history"].append({ 

436 "state": "rediscovered", 

437 "timestamp": datetime.now().isoformat(), 

438 "agent": discovered_by 

439 }) 

440 logger.info(f"Vulnerability rediscovered: {vuln_id}") 

441 

442 def _generate_vuln_id(self, finding: Dict[str, Any], target: str) -> str: 

443 """ 

444 Generate unique ID for vulnerability. 

445 

446 Args: 

447 finding: Vulnerability finding 

448 target: Scan target 

449 

450 Returns: 

451 Unique vulnerability ID 

452 """ 

453 import hashlib 

454 

455 # Create ID from type, location, and target 

456 vuln_string = f"{finding.get('type')}:{finding.get('location')}:{target}" 

457 return f"vuln_{hashlib.md5(vuln_string.encode()).hexdigest()[:12]}" 

458 

459 def update_vulnerability_state( 

460 self, 

461 vuln_id: str, 

462 new_state: VulnerabilityState, 

463 notes: str = "" 

464 ) -> bool: 

465 """ 

466 Update vulnerability lifecycle state. 

467 

468 Args: 

469 vuln_id: Vulnerability ID 

470 new_state: New state 

471 notes: Optional notes 

472 

473 Returns: 

474 True if updated successfully 

475 """ 

476 if vuln_id not in self.vulnerability_registry: 

477 logger.warning(f"Vulnerability {vuln_id} not found in registry") 

478 return False 

479 

480 vuln = self.vulnerability_registry[vuln_id] 

481 old_state = vuln["state"] 

482 

483 # Update state 

484 vuln["state"] = new_state.value 

485 vuln["history"].append({ 

486 "state": new_state.value, 

487 "timestamp": datetime.now().isoformat(), 

488 "notes": notes, 

489 "previous_state": old_state 

490 }) 

491 

492 logger.info(f"Vulnerability {vuln_id}: {old_state}{new_state.value}") 

493 return True 

494 

495 def get_vulnerabilities_by_state( 

496 self, 

497 state: VulnerabilityState 

498 ) -> List[Dict[str, Any]]: 

499 """ 

500 Get all vulnerabilities in a specific state. 

501 

502 Args: 

503 state: Vulnerability state to filter by 

504 

505 Returns: 

506 List of vulnerabilities in that state 

507 """ 

508 return [ 

509 vuln for vuln in self.vulnerability_registry.values() 

510 if vuln["state"] == state.value 

511 ] 

512 

513 def get_vulnerability_metrics(self) -> Dict[str, Any]: 

514 """ 

515 Get vulnerability lifecycle metrics. 

516 

517 Returns: 

518 Metrics dictionary 

519 """ 

520 metrics = { 

521 "total_vulnerabilities": len(self.vulnerability_registry), 

522 "by_state": {}, 

523 "by_severity": {"CRITICAL": 0, "HIGH": 0, "MEDIUM": 0, "LOW": 0, "INFO": 0}, 

524 "average_time_to_fix": None, 

525 "chains_executed": len(self.chain_history) 

526 } 

527 

528 # Count by state 

529 for state in VulnerabilityState: 

530 metrics["by_state"][state.value] = len( 

531 self.get_vulnerabilities_by_state(state) 

532 ) 

533 

534 # Count by severity 

535 for vuln in self.vulnerability_registry.values(): 

536 severity = vuln["finding"].get("severity", "INFO") 

537 if severity in metrics["by_severity"]: 

538 metrics["by_severity"][severity] += 1 

539 

540 # Calculate average time to fix 

541 fixed_vulns = self.get_vulnerabilities_by_state(VulnerabilityState.VERIFIED) 

542 if fixed_vulns: 

543 fix_times = [] 

544 for vuln in fixed_vulns: 

545 discovered = datetime.fromisoformat(vuln["discovered_at"]) 

546 verified = next( 

547 (datetime.fromisoformat(h["timestamp"]) 

548 for h in vuln["history"] 

549 if h["state"] == VulnerabilityState.VERIFIED.value), 

550 None 

551 ) 

552 if verified: 

553 fix_times.append((verified - discovered).total_seconds() / 3600) # hours 

554 

555 if fix_times: 

556 metrics["average_time_to_fix"] = sum(fix_times) / len(fix_times) 

557 

558 return metrics 

559 

560 async def schedule_follow_up_scan( 

561 self, 

562 vuln_id: str, 

563 delay_hours: int = 24 

564 ) -> bool: 

565 """ 

566 Schedule a follow-up scan to verify vulnerability fix. 

567 

568 Args: 

569 vuln_id: Vulnerability to retest 

570 delay_hours: Hours to wait before rescanning 

571 

572 Returns: 

573 True if scheduled successfully 

574 """ 

575 if vuln_id not in self.vulnerability_registry: 

576 logger.warning(f"Cannot schedule follow-up: {vuln_id} not found") 

577 return False 

578 

579 vuln = self.vulnerability_registry[vuln_id] 

580 

581 # In production, this would integrate with a job scheduler 

582 # For now, we just log the intent 

583 rescan_time = datetime.now() + timedelta(hours=delay_hours) 

584 

585 logger.info( 

586 f"Follow-up scan scheduled for {vuln_id} at {rescan_time.isoformat()}" 

587 ) 

588 

589 vuln["follow_up_scheduled"] = { 

590 "scheduled_at": datetime.now().isoformat(), 

591 "rescan_at": rescan_time.isoformat(), 

592 "delay_hours": delay_hours 

593 } 

594 

595 return True 

596 

597 def recommend_chain( 

598 self, 

599 findings: Dict[str, Any], 

600 target_type: str = "code" 

601 ) -> ChainType: 

602 """ 

603 Recommend appropriate agent chain based on initial findings. 

604 

605 Args: 

606 findings: Initial scan findings 

607 target_type: Type of target (code, network, etc.) 

608 

609 Returns: 

610 Recommended chain type 

611 """ 

612 findings_list = findings.get("findings", []) 

613 

614 # Count critical/high findings 

615 critical_high = len([ 

616 f for f in findings_list 

617 if f.get("severity") in ["CRITICAL", "HIGH"] 

618 ]) 

619 

620 # If critical/high findings, recommend full lifecycle 

621 if critical_high > 0: 

622 logger.info("Critical findings detected → Recommending FULL_LIFECYCLE chain") 

623 return ChainType.FULL_LIFECYCLE 

624 

625 # If moderate findings, recommend attack/defense 

626 if len(findings_list) > 5: 

627 logger.info("Multiple findings detected → Recommending ATTACK_DEFENSE chain") 

628 return ChainType.ATTACK_DEFENSE 

629 

630 # Otherwise, validation chain 

631 logger.info("Standard findings → Recommending VALIDATION chain") 

632 return ChainType.VALIDATION 

633 

634 

635# Global coordinator instance 

636_coordinator = None 

637 

638 

639def get_coordinator() -> AgentCoordinator: 

640 """Get global agent coordinator instance.""" 

641 global _coordinator 

642 if _coordinator is None: 

643 _coordinator = AgentCoordinator() 

644 return _coordinator