Coverage for src/alprina_cli/agent_coordinator.py: 0%
151 statements
« prev ^ index » next coverage.py v7.11.3, created at 2025-11-14 11:27 +0100
« prev ^ index » next coverage.py v7.11.3, created at 2025-11-14 11:27 +0100
1"""
2Alprina Agent Coordinator - Intelligent agent chaining and vulnerability lifecycle management.
4Coordinates multiple security agents to work together in sophisticated workflows:
5- Red Team → Blue Team → DFIR chains
6- Vulnerability discovery → Validation → Remediation tracking
7- Automated follow-up scans
8- Agent-to-agent communication
10Reference: AI SDK Agent Coordination Patterns
11"""
13from typing import Dict, List, Any, Optional, Callable
14from enum import Enum
15from loguru import logger
16from datetime import datetime, timedelta
17import asyncio
20class ChainType(Enum):
21 """Types of agent coordination chains."""
22 ATTACK_DEFENSE = "attack_defense" # Red Team → Blue Team
23 INVESTIGATION = "investigation" # Discovery → DFIR → Analysis
24 FULL_LIFECYCLE = "full_lifecycle" # Red Team → Blue Team → DFIR → Remediation
25 VALIDATION = "validation" # Scan → Retest → Verify
26 CONTINUOUS = "continuous" # Ongoing monitoring and response
29class VulnerabilityState(Enum):
30 """Vulnerability lifecycle states."""
31 DISCOVERED = "discovered" # Initial discovery
32 VALIDATED = "validated" # Confirmed by retesting
33 TRIAGED = "triaged" # Prioritized and assigned
34 IN_REMEDIATION = "in_remediation" # Being fixed
35 FIXED = "fixed" # Fix applied
36 VERIFIED = "verified" # Fix verified by retest
37 FALSE_POSITIVE = "false_positive" # Marked as false positive
38 ACCEPTED_RISK = "accepted_risk" # Risk accepted by stakeholders
41class AgentCoordinator:
42 """
43 Coordinates multiple security agents for sophisticated workflows.
45 Capabilities:
46 - Agent chaining (A → B → C)
47 - Vulnerability lifecycle tracking
48 - Automated follow-up scans
49 - Agent-to-agent communication
50 - Intelligent chain selection
51 """
53 def __init__(self):
54 """Initialize agent coordinator."""
55 self.vulnerability_registry: Dict[str, Dict[str, Any]] = {}
56 self.chain_history: List[Dict[str, Any]] = []
57 logger.info("Agent Coordinator initialized")
59 async def execute_chain(
60 self,
61 chain_type: ChainType,
62 target: str,
63 initial_findings: Optional[Dict[str, Any]] = None,
64 safe_only: bool = True
65 ) -> Dict[str, Any]:
66 """
67 Execute a coordinated agent chain.
69 Args:
70 chain_type: Type of coordination chain to execute
71 target: Target to scan
72 initial_findings: Optional initial findings to start with
73 safe_only: Only run safe checks
75 Returns:
76 Aggregated results from all agents in chain
77 """
78 logger.info(f"Executing {chain_type.value} chain on {target}")
80 chain_def = self._get_chain_definition(chain_type)
82 if not chain_def:
83 logger.error(f"Unknown chain type: {chain_type}")
84 return {"error": f"Unknown chain type: {chain_type}"}
86 # Execute chain
87 results = {
88 "chain_type": chain_type.value,
89 "target": target,
90 "start_time": datetime.now().isoformat(),
91 "steps": [],
92 "vulnerabilities": []
93 }
95 context = initial_findings or {}
97 for step in chain_def["steps"]:
98 try:
99 logger.info(f"Chain step: {step['name']} using {step['agent']}")
101 # Execute agent with context from previous steps
102 step_result = await self._execute_agent_step(
103 agent=step["agent"],
104 task=step["task"],
105 target=target,
106 context=context,
107 safe_only=safe_only
108 )
110 # Record step
111 results["steps"].append({
112 "name": step["name"],
113 "agent": step["agent"],
114 "timestamp": datetime.now().isoformat(),
115 "findings_count": len(step_result.get("findings", [])),
116 "output": step_result
117 })
119 # Update context for next step
120 context = self._merge_context(context, step_result)
122 # Track vulnerabilities
123 if step_result.get("findings"):
124 for finding in step_result["findings"]:
125 self._track_vulnerability(finding, step["agent"], target)
126 results["vulnerabilities"].append(finding)
128 except Exception as e:
129 logger.error(f"Chain step {step['name']} failed: {e}")
130 results["steps"].append({
131 "name": step["name"],
132 "agent": step["agent"],
133 "error": str(e),
134 "timestamp": datetime.now().isoformat()
135 })
137 results["end_time"] = datetime.now().isoformat()
138 results["total_vulnerabilities"] = len(results["vulnerabilities"])
140 # Record chain execution
141 self.chain_history.append({
142 "chain_type": chain_type.value,
143 "target": target,
144 "timestamp": datetime.now().isoformat(),
145 "vulnerabilities_found": results["total_vulnerabilities"]
146 })
148 return results
150 def _get_chain_definition(self, chain_type: ChainType) -> Optional[Dict[str, Any]]:
151 """
152 Get chain definition for a given chain type.
154 Args:
155 chain_type: Type of chain
157 Returns:
158 Chain definition with steps
159 """
160 chains = {
161 ChainType.ATTACK_DEFENSE: {
162 "name": "Attack & Defense Chain",
163 "description": "Red team attacks, blue team defends",
164 "steps": [
165 {
166 "name": "Red Team Assessment",
167 "agent": "red_teamer",
168 "task": "offensive-security",
169 "description": "Identify exploitable vulnerabilities"
170 },
171 {
172 "name": "Blue Team Response",
173 "agent": "blue_teamer",
174 "task": "defensive-security",
175 "description": "Evaluate defenses against discovered attacks"
176 }
177 ]
178 },
180 ChainType.INVESTIGATION: {
181 "name": "Security Investigation Chain",
182 "description": "Comprehensive security investigation",
183 "steps": [
184 {
185 "name": "Initial Scan",
186 "agent": "codeagent",
187 "task": "code-audit",
188 "description": "Static code analysis"
189 },
190 {
191 "name": "Deep Analysis",
192 "agent": "secret_detection",
193 "task": "secret-detection",
194 "description": "Credential and secret scanning"
195 },
196 {
197 "name": "Forensic Analysis",
198 "agent": "dfir",
199 "task": "forensics",
200 "description": "DFIR investigation of findings"
201 }
202 ]
203 },
205 ChainType.FULL_LIFECYCLE: {
206 "name": "Full Security Lifecycle",
207 "description": "Complete security assessment lifecycle",
208 "steps": [
209 {
210 "name": "Red Team Attack",
211 "agent": "red_teamer",
212 "task": "offensive-security",
213 "description": "Offensive security testing"
214 },
215 {
216 "name": "Blue Team Defense",
217 "agent": "blue_teamer",
218 "task": "defensive-security",
219 "description": "Defensive posture evaluation"
220 },
221 {
222 "name": "Forensic Investigation",
223 "agent": "dfir",
224 "task": "forensics",
225 "description": "Evidence collection and analysis"
226 },
227 {
228 "name": "Remediation Planning",
229 "agent": "codeagent",
230 "task": "code-audit",
231 "description": "Code-level remediation guidance"
232 }
233 ]
234 },
236 ChainType.VALIDATION: {
237 "name": "Vulnerability Validation Chain",
238 "description": "Discover, validate, and verify fixes",
239 "steps": [
240 {
241 "name": "Initial Discovery",
242 "agent": "codeagent",
243 "task": "code-audit",
244 "description": "Find vulnerabilities"
245 },
246 {
247 "name": "Validation",
248 "agent": "bug_bounty",
249 "task": "vuln-scan",
250 "description": "Validate findings"
251 },
252 {
253 "name": "Retest",
254 "agent": "retester",
255 "task": "retest",
256 "description": "Verify fixes"
257 }
258 ]
259 },
261 ChainType.CONTINUOUS: {
262 "name": "Continuous Monitoring Chain",
263 "description": "Ongoing security monitoring",
264 "steps": [
265 {
266 "name": "Baseline Scan",
267 "agent": "codeagent",
268 "task": "code-audit",
269 "description": "Establish security baseline"
270 },
271 {
272 "name": "Secret Monitoring",
273 "agent": "secret_detection",
274 "task": "secret-detection",
275 "description": "Monitor for credential leaks"
276 },
277 {
278 "name": "Configuration Audit",
279 "agent": "config_audit",
280 "task": "config-audit",
281 "description": "Check configuration drift"
282 }
283 ]
284 }
285 }
287 return chains.get(chain_type)
289 async def _execute_agent_step(
290 self,
291 agent: str,
292 task: str,
293 target: str,
294 context: Dict[str, Any],
295 safe_only: bool
296 ) -> Dict[str, Any]:
297 """
298 Execute a single agent in the chain.
300 Args:
301 agent: Agent to execute
302 task: Task for agent
303 target: Scan target
304 context: Context from previous steps
305 safe_only: Only safe checks
307 Returns:
308 Agent execution results
309 """
310 from .security_engine import run_local_scan, run_remote_scan
311 from pathlib import Path
313 # Determine if local or remote
314 is_local = Path(target).exists() if target else False
316 try:
317 if is_local:
318 result = await asyncio.get_event_loop().run_in_executor(
319 None,
320 run_local_scan,
321 target,
322 task,
323 safe_only
324 )
325 else:
326 result = await asyncio.get_event_loop().run_in_executor(
327 None,
328 run_remote_scan,
329 target,
330 task,
331 safe_only
332 )
334 # Enhance result with context
335 if isinstance(result, dict):
336 result["context_from_previous_steps"] = context
337 result["chain_position"] = len(context.get("chain_history", [])) + 1
339 return result
341 except Exception as e:
342 logger.error(f"Agent {agent} execution failed: {e}")
343 return {
344 "error": str(e),
345 "agent": agent,
346 "task": task,
347 "findings": []
348 }
350 def _merge_context(
351 self,
352 existing_context: Dict[str, Any],
353 new_results: Dict[str, Any]
354 ) -> Dict[str, Any]:
355 """
356 Merge context from previous steps with new results.
358 Args:
359 existing_context: Existing context
360 new_results: New results to merge
362 Returns:
363 Merged context
364 """
365 merged = existing_context.copy()
367 # Add findings to context
368 if "all_findings" not in merged:
369 merged["all_findings"] = []
371 if new_results.get("findings"):
372 merged["all_findings"].extend(new_results["findings"])
374 # Track chain history
375 if "chain_history" not in merged:
376 merged["chain_history"] = []
378 merged["chain_history"].append({
379 "agent": new_results.get("agent"),
380 "findings_count": len(new_results.get("findings", [])),
381 "timestamp": datetime.now().isoformat()
382 })
384 # Aggregate severity counts
385 if "total_severity_counts" not in merged:
386 merged["total_severity_counts"] = {
387 "CRITICAL": 0,
388 "HIGH": 0,
389 "MEDIUM": 0,
390 "LOW": 0,
391 "INFO": 0
392 }
394 for finding in new_results.get("findings", []):
395 severity = finding.get("severity", "INFO")
396 if severity in merged["total_severity_counts"]:
397 merged["total_severity_counts"][severity] += 1
399 return merged
401 def _track_vulnerability(
402 self,
403 finding: Dict[str, Any],
404 discovered_by: str,
405 target: str
406 ):
407 """
408 Track vulnerability in lifecycle registry.
410 Args:
411 finding: Vulnerability finding
412 discovered_by: Agent that discovered it
413 target: Scan target
414 """
415 vuln_id = self._generate_vuln_id(finding, target)
417 if vuln_id not in self.vulnerability_registry:
418 # New vulnerability
419 self.vulnerability_registry[vuln_id] = {
420 "id": vuln_id,
421 "state": VulnerabilityState.DISCOVERED.value,
422 "finding": finding,
423 "target": target,
424 "discovered_by": discovered_by,
425 "discovered_at": datetime.now().isoformat(),
426 "history": [{
427 "state": VulnerabilityState.DISCOVERED.value,
428 "timestamp": datetime.now().isoformat(),
429 "agent": discovered_by
430 }]
431 }
432 logger.info(f"New vulnerability tracked: {vuln_id}")
433 else:
434 # Update existing vulnerability
435 self.vulnerability_registry[vuln_id]["history"].append({
436 "state": "rediscovered",
437 "timestamp": datetime.now().isoformat(),
438 "agent": discovered_by
439 })
440 logger.info(f"Vulnerability rediscovered: {vuln_id}")
442 def _generate_vuln_id(self, finding: Dict[str, Any], target: str) -> str:
443 """
444 Generate unique ID for vulnerability.
446 Args:
447 finding: Vulnerability finding
448 target: Scan target
450 Returns:
451 Unique vulnerability ID
452 """
453 import hashlib
455 # Create ID from type, location, and target
456 vuln_string = f"{finding.get('type')}:{finding.get('location')}:{target}"
457 return f"vuln_{hashlib.md5(vuln_string.encode()).hexdigest()[:12]}"
459 def update_vulnerability_state(
460 self,
461 vuln_id: str,
462 new_state: VulnerabilityState,
463 notes: str = ""
464 ) -> bool:
465 """
466 Update vulnerability lifecycle state.
468 Args:
469 vuln_id: Vulnerability ID
470 new_state: New state
471 notes: Optional notes
473 Returns:
474 True if updated successfully
475 """
476 if vuln_id not in self.vulnerability_registry:
477 logger.warning(f"Vulnerability {vuln_id} not found in registry")
478 return False
480 vuln = self.vulnerability_registry[vuln_id]
481 old_state = vuln["state"]
483 # Update state
484 vuln["state"] = new_state.value
485 vuln["history"].append({
486 "state": new_state.value,
487 "timestamp": datetime.now().isoformat(),
488 "notes": notes,
489 "previous_state": old_state
490 })
492 logger.info(f"Vulnerability {vuln_id}: {old_state} → {new_state.value}")
493 return True
495 def get_vulnerabilities_by_state(
496 self,
497 state: VulnerabilityState
498 ) -> List[Dict[str, Any]]:
499 """
500 Get all vulnerabilities in a specific state.
502 Args:
503 state: Vulnerability state to filter by
505 Returns:
506 List of vulnerabilities in that state
507 """
508 return [
509 vuln for vuln in self.vulnerability_registry.values()
510 if vuln["state"] == state.value
511 ]
513 def get_vulnerability_metrics(self) -> Dict[str, Any]:
514 """
515 Get vulnerability lifecycle metrics.
517 Returns:
518 Metrics dictionary
519 """
520 metrics = {
521 "total_vulnerabilities": len(self.vulnerability_registry),
522 "by_state": {},
523 "by_severity": {"CRITICAL": 0, "HIGH": 0, "MEDIUM": 0, "LOW": 0, "INFO": 0},
524 "average_time_to_fix": None,
525 "chains_executed": len(self.chain_history)
526 }
528 # Count by state
529 for state in VulnerabilityState:
530 metrics["by_state"][state.value] = len(
531 self.get_vulnerabilities_by_state(state)
532 )
534 # Count by severity
535 for vuln in self.vulnerability_registry.values():
536 severity = vuln["finding"].get("severity", "INFO")
537 if severity in metrics["by_severity"]:
538 metrics["by_severity"][severity] += 1
540 # Calculate average time to fix
541 fixed_vulns = self.get_vulnerabilities_by_state(VulnerabilityState.VERIFIED)
542 if fixed_vulns:
543 fix_times = []
544 for vuln in fixed_vulns:
545 discovered = datetime.fromisoformat(vuln["discovered_at"])
546 verified = next(
547 (datetime.fromisoformat(h["timestamp"])
548 for h in vuln["history"]
549 if h["state"] == VulnerabilityState.VERIFIED.value),
550 None
551 )
552 if verified:
553 fix_times.append((verified - discovered).total_seconds() / 3600) # hours
555 if fix_times:
556 metrics["average_time_to_fix"] = sum(fix_times) / len(fix_times)
558 return metrics
560 async def schedule_follow_up_scan(
561 self,
562 vuln_id: str,
563 delay_hours: int = 24
564 ) -> bool:
565 """
566 Schedule a follow-up scan to verify vulnerability fix.
568 Args:
569 vuln_id: Vulnerability to retest
570 delay_hours: Hours to wait before rescanning
572 Returns:
573 True if scheduled successfully
574 """
575 if vuln_id not in self.vulnerability_registry:
576 logger.warning(f"Cannot schedule follow-up: {vuln_id} not found")
577 return False
579 vuln = self.vulnerability_registry[vuln_id]
581 # In production, this would integrate with a job scheduler
582 # For now, we just log the intent
583 rescan_time = datetime.now() + timedelta(hours=delay_hours)
585 logger.info(
586 f"Follow-up scan scheduled for {vuln_id} at {rescan_time.isoformat()}"
587 )
589 vuln["follow_up_scheduled"] = {
590 "scheduled_at": datetime.now().isoformat(),
591 "rescan_at": rescan_time.isoformat(),
592 "delay_hours": delay_hours
593 }
595 return True
597 def recommend_chain(
598 self,
599 findings: Dict[str, Any],
600 target_type: str = "code"
601 ) -> ChainType:
602 """
603 Recommend appropriate agent chain based on initial findings.
605 Args:
606 findings: Initial scan findings
607 target_type: Type of target (code, network, etc.)
609 Returns:
610 Recommended chain type
611 """
612 findings_list = findings.get("findings", [])
614 # Count critical/high findings
615 critical_high = len([
616 f for f in findings_list
617 if f.get("severity") in ["CRITICAL", "HIGH"]
618 ])
620 # If critical/high findings, recommend full lifecycle
621 if critical_high > 0:
622 logger.info("Critical findings detected → Recommending FULL_LIFECYCLE chain")
623 return ChainType.FULL_LIFECYCLE
625 # If moderate findings, recommend attack/defense
626 if len(findings_list) > 5:
627 logger.info("Multiple findings detected → Recommending ATTACK_DEFENSE chain")
628 return ChainType.ATTACK_DEFENSE
630 # Otherwise, validation chain
631 logger.info("Standard findings → Recommending VALIDATION chain")
632 return ChainType.VALIDATION
635# Global coordinator instance
636_coordinator = None
639def get_coordinator() -> AgentCoordinator:
640 """Get global agent coordinator instance."""
641 global _coordinator
642 if _coordinator is None:
643 _coordinator = AgentCoordinator()
644 return _coordinator