Coverage for src/alprina_cli/agents/web3_auditor/cross_contract_analyzer.py: 21%
182 statements
« prev ^ index » next coverage.py v7.11.3, created at 2025-11-14 11:27 +0100
« prev ^ index » next coverage.py v7.11.3, created at 2025-11-14 11:27 +0100
1"""
2Cross-Contract Analysis Engine
4WEEK 3 DAY 4: Cross-Contract Analysis
5======================================
7Analyzes vulnerabilities across multiple interacting contracts:
8- Dependency graph construction
9- Cross-contract reentrancy detection
10- Upgrade pattern vulnerabilities
11- Attack chain identification
12- Interface trust issues
14Background:
15- Modern DeFi uses complex contract interactions
16- Vulnerabilities often exist at contract boundaries
17- Upgradeable contracts introduce proxy risks
19Author: Alprina Development Team
20Date: 2025-11-12
22References:
23- The DAO attack (2016): Cross-contract reentrancy
24- Parity Wallet (2017): Delegatecall vulnerability
25- Poly Network (2021): Cross-chain trust issues
26"""
28import re
29from typing import List, Dict, Any, Optional, Set, Tuple
30from dataclasses import dataclass, field
31from enum import Enum
32import networkx as nx
34try:
35 from .solidity_analyzer import SolidityVulnerability, VulnerabilityType
36except ImportError:
37 import sys
38 from pathlib import Path
39 sys.path.insert(0, str(Path(__file__).parent))
40 from solidity_analyzer import SolidityVulnerability, VulnerabilityType
43class CrossContractVulnType(Enum):
44 """Types of cross-contract vulnerabilities"""
45 CROSS_CONTRACT_REENTRANCY = "cross_contract_reentrancy"
46 UPGRADE_VULNERABILITY = "upgrade_vulnerability"
47 DELEGATECALL_INJECTION = "delegatecall_injection"
48 INTERFACE_TRUST = "interface_trust"
49 ACCESS_CONTROL_BREACH = "access_control_breach"
50 ATTACK_CHAIN = "attack_chain"
53@dataclass
54class ContractDependency:
55 """Represents a dependency between contracts"""
56 from_contract: str
57 to_contract: str
58 function_name: str
59 call_type: str # "call", "delegatecall", "staticcall", "interface"
60 line_number: int
63@dataclass
64class AttackChain:
65 """Represents a multi-step attack sequence"""
66 steps: List[Dict[str, Any]]
67 total_impact: str
68 complexity: str
69 description: str
72@dataclass
73class CrossContractVulnerability:
74 """Cross-contract vulnerability"""
75 vuln_type: CrossContractVulnType
76 severity: str
77 title: str
78 description: str
79 contracts_involved: List[str]
80 attack_chain: Optional[AttackChain]
81 estimated_loss: Tuple[float, float]
82 confidence: int = 90
85class CrossContractAnalyzer:
86 """
87 Analyze vulnerabilities across multiple contracts
89 Week 3 Day 4 Implementation:
90 1. Build dependency graphs between contracts
91 2. Detect cross-contract reentrancy
92 3. Identify upgrade vulnerabilities
93 4. Find attack chains
95 Features:
96 - Dependency graph construction using NetworkX
97 - Multi-contract reentrancy detection
98 - Proxy pattern analysis
99 - Attack path discovery
100 """
102 def __init__(self):
103 self.dependency_graph = nx.DiGraph()
104 self.contracts: Dict[str, Dict[str, Any]] = {}
105 self.vulnerabilities: List[CrossContractVulnerability] = []
107 def analyze_contracts(
108 self,
109 contracts: Dict[str, str], # {filename: source_code}
110 file_path: str = "multi-contract"
111 ) -> List[SolidityVulnerability]:
112 """
113 Analyze multiple contracts for cross-contract vulnerabilities
115 Args:
116 contracts: Dictionary of contract names to source code
118 Returns:
119 List of vulnerabilities in standard format
120 """
121 self.vulnerabilities = []
122 self.contracts = {}
123 self.dependency_graph = nx.DiGraph()
125 # Step 1: Parse all contracts
126 for contract_name, source_code in contracts.items():
127 self.contracts[contract_name] = self._parse_contract(contract_name, source_code)
129 # Step 2: Build dependency graph
130 self._build_dependency_graph()
132 # Step 3: Detect vulnerabilities
133 self._detect_cross_contract_reentrancy()
134 self._detect_upgrade_vulnerabilities()
135 self._detect_delegatecall_issues()
136 self._detect_interface_trust_issues()
137 self._identify_attack_chains()
139 # Convert to standard format
140 return self._convert_to_standard_format(file_path)
142 def _parse_contract(self, name: str, source_code: str) -> Dict[str, Any]:
143 """Parse a contract and extract relevant information"""
144 contract_info = {
145 'name': name,
146 'source': source_code,
147 'functions': [],
148 'external_calls': [],
149 'state_variables': [],
150 'is_proxy': False,
151 'is_upgradeable': False
152 }
154 lines = source_code.split('\n')
156 # Extract contract name from code
157 for line in lines:
158 contract_match = re.search(r'contract\s+(\w+)', line)
159 if contract_match:
160 contract_info['actual_name'] = contract_match.group(1)
161 break
163 # Detect proxy pattern
164 if any(keyword in source_code.lower() for keyword in ['delegatecall', 'proxy', 'implementation']):
165 contract_info['is_proxy'] = True
167 # Detect upgradeable pattern
168 if any(keyword in source_code.lower() for keyword in ['upgrade', 'initialize', 'upgradeable']):
169 contract_info['is_upgradeable'] = True
171 # Extract functions
172 contract_info['functions'] = self._extract_functions(source_code)
174 # Extract external calls
175 contract_info['external_calls'] = self._extract_external_calls(source_code)
177 return contract_info
179 def _extract_functions(self, source_code: str) -> List[Dict[str, Any]]:
180 """Extract function definitions"""
181 functions = []
182 lines = source_code.split('\n')
184 for i, line in enumerate(lines):
185 func_match = re.match(
186 r'\s*function\s+(\w+)\s*\([^)]*\)\s*(public|external|internal|private)?',
187 line
188 )
189 if func_match:
190 functions.append({
191 'name': func_match.group(1),
192 'visibility': func_match.group(2) or 'internal',
193 'line': i + 1
194 })
196 return functions
198 def _extract_external_calls(self, source_code: str) -> List[Dict[str, Any]]:
199 """Extract external contract calls"""
200 external_calls = []
201 lines = source_code.split('\n')
203 for i, line in enumerate(lines):
204 # Pattern: ContractName.function() or contract.function()
205 call_match = re.search(r'(\w+)\.(\w+)\s*\(', line)
206 if call_match:
207 target = call_match.group(1)
208 function = call_match.group(2)
210 # Detect call type
211 call_type = "call"
212 if 'delegatecall' in line:
213 call_type = "delegatecall"
214 elif 'staticcall' in line:
215 call_type = "staticcall"
217 external_calls.append({
218 'target': target,
219 'function': function,
220 'call_type': call_type,
221 'line': i + 1
222 })
224 return external_calls
226 def _build_dependency_graph(self):
227 """Build dependency graph between contracts"""
228 # Add all contracts as nodes
229 for contract_name in self.contracts.keys():
230 self.dependency_graph.add_node(contract_name)
232 # Add edges for external calls
233 for contract_name, contract_info in self.contracts.items():
234 for call in contract_info['external_calls']:
235 target = call['target']
237 # Check if target is another contract in our set
238 if target in self.contracts or target.endswith('Interface'):
239 self.dependency_graph.add_edge(
240 contract_name,
241 target,
242 call_type=call['call_type'],
243 function=call['function']
244 )
246 def _detect_cross_contract_reentrancy(self):
247 """
248 Detect reentrancy vulnerabilities across contracts
250 Pattern:
251 1. Contract A calls Contract B (external call)
252 2. Contract B calls back to Contract A before state update
253 3. State inconsistency allows exploitation
254 """
256 for contract_name, contract_info in self.contracts.items():
257 source = contract_info['source']
259 # Check for external calls before state changes
260 for func in contract_info['functions']:
261 if func['visibility'] not in ['public', 'external']:
262 continue
264 # Check if function has external calls
265 has_external_call = any(
266 call['target'] != contract_name
267 for call in contract_info['external_calls']
268 )
270 if has_external_call:
271 # Check if state updates happen after external calls
272 has_state_change_after = self._has_state_change_pattern(source)
274 if has_state_change_after:
275 # Potential cross-contract reentrancy
276 self.vulnerabilities.append(CrossContractVulnerability(
277 vuln_type=CrossContractVulnType.CROSS_CONTRACT_REENTRANCY,
278 severity="critical",
279 title=f"Cross-Contract Reentrancy in {contract_name}",
280 description=(
281 f"Contract {contract_name} makes external calls before updating state. "
282 f"This allows reentrancy attacks across contract boundaries."
283 ),
284 contracts_involved=[contract_name],
285 attack_chain=None,
286 estimated_loss=(100_000, 10_000_000),
287 confidence=85
288 ))
290 def _detect_upgrade_vulnerabilities(self):
291 """
292 Detect vulnerabilities in upgradeable contracts
294 Patterns:
295 1. Proxy without access control on upgrade function
296 2. Missing initialization in upgradeable contract
297 3. Storage collision in proxy pattern
298 4. Delegatecall to untrusted address
299 """
301 for contract_name, contract_info in self.contracts.items():
302 if not contract_info['is_proxy'] and not contract_info['is_upgradeable']:
303 continue
305 source = contract_info['source']
307 # Pattern 1: Upgrade without access control
308 if 'function upgrade' in source.lower() or 'function setimplementation' in source.lower():
309 has_access_control = any(
310 keyword in source.lower()
311 for keyword in ['onlyowner', 'require(msg.sender', 'onlyadmin']
312 )
314 if not has_access_control:
315 self.vulnerabilities.append(CrossContractVulnerability(
316 vuln_type=CrossContractVulnType.UPGRADE_VULNERABILITY,
317 severity="critical",
318 title=f"Unprotected Upgrade in {contract_name}",
319 description=(
320 f"Contract {contract_name} has upgrade functionality without access control. "
321 f"Anyone can upgrade to a malicious implementation and steal all funds."
322 ),
323 contracts_involved=[contract_name],
324 attack_chain=None,
325 estimated_loss=(1_000_000, 100_000_000),
326 confidence=95
327 ))
329 # Pattern 2: Missing initialize function
330 if contract_info['is_upgradeable']:
331 has_initialize = 'function initialize' in source.lower()
332 has_constructor = 'constructor(' in source
334 if has_constructor and not has_initialize:
335 self.vulnerabilities.append(CrossContractVulnerability(
336 vuln_type=CrossContractVulnType.UPGRADE_VULNERABILITY,
337 severity="high",
338 title=f"Constructor in Upgradeable Contract: {contract_name}",
339 description=(
340 f"Upgradeable contract {contract_name} uses constructor instead of initialize(). "
341 f"Constructor code won't execute in proxy context, leaving contract uninitialized."
342 ),
343 contracts_involved=[contract_name],
344 attack_chain=None,
345 estimated_loss=(50_000, 5_000_000),
346 confidence=90
347 ))
349 def _detect_delegatecall_issues(self):
350 """
351 Detect delegatecall vulnerabilities
353 Patterns:
354 1. Delegatecall to user-controlled address
355 2. Delegatecall without proper validation
356 3. Storage collision via delegatecall
357 """
359 for contract_name, contract_info in self.contracts.items():
360 source = contract_info['source']
362 if 'delegatecall' not in source.lower():
363 continue
365 # Check for delegatecall to variable address
366 lines = source.split('\n')
367 for i, line in enumerate(lines):
368 if 'delegatecall' in line.lower():
369 # Check if address is validated
370 has_validation = any(
371 keyword in source[:source.index(line)]
372 for keyword in ['require(', 'if (', 'whitelist', 'approved']
373 )
375 if not has_validation:
376 self.vulnerabilities.append(CrossContractVulnerability(
377 vuln_type=CrossContractVulnType.DELEGATECALL_INJECTION,
378 severity="critical",
379 title=f"Unsafe Delegatecall in {contract_name}",
380 description=(
381 f"Contract {contract_name} uses delegatecall without proper validation. "
382 f"Attacker can execute arbitrary code in contract context, "
383 f"leading to complete contract takeover."
384 ),
385 contracts_involved=[contract_name],
386 attack_chain=None,
387 estimated_loss=(500_000, 50_000_000),
388 confidence=95
389 ))
390 break
392 def _detect_interface_trust_issues(self):
393 """
394 Detect trust issues with external contract interfaces
396 Patterns:
397 1. Trusting external price oracles without validation
398 2. Calling unknown contracts without checks
399 3. Assuming external contract behavior
400 """
402 for contract_name, contract_info in self.contracts.items():
403 # Check for calls to external contracts
404 for call in contract_info['external_calls']:
405 # If calling an interface or external contract
406 if call['target'].endswith('Interface') or call['target'][0].isupper():
407 # Check if return value is validated
408 source = contract_info['source']
410 # This is a simplified check
411 if 'require(' not in source or 'revert(' not in source:
412 self.vulnerabilities.append(CrossContractVulnerability(
413 vuln_type=CrossContractVulnType.INTERFACE_TRUST,
414 severity="medium",
415 title=f"Untrusted External Call in {contract_name}",
416 description=(
417 f"Contract {contract_name} calls external contract {call['target']} "
418 f"without validating return values. Malicious contract can "
419 f"return unexpected values."
420 ),
421 contracts_involved=[contract_name, call['target']],
422 attack_chain=None,
423 estimated_loss=(10_000, 1_000_000),
424 confidence=70
425 ))
427 def _identify_attack_chains(self):
428 """
429 Identify multi-step attack chains across contracts
431 Example:
432 1. Exploit Contract A → manipulate state
433 2. Call Contract B → uses manipulated state
434 3. Profit from price difference
435 """
437 # Find cycles in dependency graph (potential reentrancy chains)
438 try:
439 cycles = list(nx.simple_cycles(self.dependency_graph))
441 for cycle in cycles:
442 if len(cycle) > 1:
443 # Potential attack chain through cycle
444 self.vulnerabilities.append(CrossContractVulnerability(
445 vuln_type=CrossContractVulnType.ATTACK_CHAIN,
446 severity="high",
447 title=f"Potential Attack Chain: {' → '.join(cycle)}",
448 description=(
449 f"Circular dependency detected: {' → '.join(cycle)}. "
450 f"This creates potential for complex reentrancy or state manipulation attacks."
451 ),
452 contracts_involved=cycle,
453 attack_chain=AttackChain(
454 steps=[{'contract': c, 'action': 'call'} for c in cycle],
455 total_impact="high",
456 complexity="medium",
457 description=f"Attack path through: {' → '.join(cycle)}"
458 ),
459 estimated_loss=(50_000, 5_000_000),
460 confidence=75
461 ))
462 except:
463 pass # No cycles found
465 def _has_state_change_pattern(self, source: str) -> bool:
466 """Check if contract has state changes after external calls"""
467 # Simplified check for state changes
468 state_change_patterns = [
469 r'=\s*\w+', # Assignment
470 r'balances\[',
471 r'\.transfer\(',
472 r'\.send\(',
473 ]
475 return any(re.search(pattern, source) for pattern in state_change_patterns)
477 def _convert_to_standard_format(self, file_path: str) -> List[SolidityVulnerability]:
478 """Convert cross-contract vulnerabilities to standard format"""
479 standard_vulns = []
481 for vuln in self.vulnerabilities:
482 # Map to standard vulnerability types
483 vuln_type_map = {
484 CrossContractVulnType.CROSS_CONTRACT_REENTRANCY: VulnerabilityType.REENTRANCY,
485 CrossContractVulnType.UPGRADE_VULNERABILITY: VulnerabilityType.ACCESS_CONTROL,
486 CrossContractVulnType.DELEGATECALL_INJECTION: VulnerabilityType.LOGIC_ERROR,
487 CrossContractVulnType.INTERFACE_TRUST: VulnerabilityType.LOGIC_ERROR,
488 CrossContractVulnType.ACCESS_CONTROL_BREACH: VulnerabilityType.ACCESS_CONTROL,
489 CrossContractVulnType.ATTACK_CHAIN: VulnerabilityType.LOGIC_ERROR,
490 }
492 vuln_type = vuln_type_map.get(vuln.vuln_type, VulnerabilityType.LOGIC_ERROR)
494 # Create code snippet with cross-contract details
495 code_snippet = (
496 f"Contracts Involved: {', '.join(vuln.contracts_involved)}\n"
497 f"Estimated Loss: ${vuln.estimated_loss[0]:,} - ${vuln.estimated_loss[1]:,}\n"
498 )
500 if vuln.attack_chain:
501 code_snippet += f"Attack Chain: {vuln.attack_chain.description}\n"
503 # Remediation
504 remediation = self._get_remediation(vuln.vuln_type)
506 standard_vuln = SolidityVulnerability(
507 vulnerability_type=vuln_type,
508 severity=vuln.severity,
509 title=f"[Cross-Contract] {vuln.title}",
510 description=vuln.description,
511 file_path=file_path,
512 line_number=1,
513 function_name="multiple",
514 contract_name=', '.join(vuln.contracts_involved),
515 code_snippet=code_snippet,
516 remediation=remediation,
517 confidence=vuln.confidence
518 )
520 standard_vulns.append(standard_vuln)
522 return standard_vulns
524 def _get_remediation(self, vuln_type: CrossContractVulnType) -> str:
525 """Get remediation advice"""
526 remediation_map = {
527 CrossContractVulnType.CROSS_CONTRACT_REENTRANCY: (
528 "Follow Checks-Effects-Interactions pattern: update state before external calls. "
529 "Use ReentrancyGuard from OpenZeppelin. Consider using pull over push for payments."
530 ),
531 CrossContractVulnType.UPGRADE_VULNERABILITY: (
532 "Add access control to upgrade functions (onlyOwner/onlyAdmin). "
533 "Use OpenZeppelin's UUPSUpgradeable or TransparentUpgradeableProxy. "
534 "Implement timelock for upgrades. Use initialize() instead of constructor."
535 ),
536 CrossContractVulnType.DELEGATECALL_INJECTION: (
537 "Never delegatecall to user-controlled addresses. "
538 "Maintain a whitelist of approved implementation contracts. "
539 "Use library patterns instead of delegatecall where possible."
540 ),
541 CrossContractVulnType.INTERFACE_TRUST: (
542 "Validate all return values from external contracts. "
543 "Use try/catch for external calls. "
544 "Implement circuit breakers for critical operations. "
545 "Verify contract addresses and implementations."
546 ),
547 CrossContractVulnType.ATTACK_CHAIN: (
548 "Break circular dependencies where possible. "
549 "Add reentrancy guards across the call chain. "
550 "Validate state consistency at each step. "
551 "Consider using commit-reveal schemes."
552 )
553 }
555 return remediation_map.get(vuln_type, "Review cross-contract interactions carefully.")
557 def visualize_dependency_graph(self) -> str:
558 """Generate text representation of dependency graph"""
559 if not self.dependency_graph.nodes():
560 return "No dependencies found"
562 output = "Contract Dependency Graph:\n"
563 output += "=" * 50 + "\n\n"
565 for contract in self.dependency_graph.nodes():
566 output += f"{contract}\n"
568 # Outgoing dependencies
569 for target in self.dependency_graph.successors(contract):
570 edge_data = self.dependency_graph[contract][target]
571 output += f" └─> {target} ({edge_data.get('call_type', 'call')})\n"
573 return output
576# Example usage
577if __name__ == "__main__":
578 analyzer = CrossContractAnalyzer()
580 # Test with multiple contracts
581 contracts = {
582 "Vault": """
583 contract Vault {
584 mapping(address => uint256) public balances;
586 function withdraw() external {
587 uint256 amount = balances[msg.sender];
588 (bool success, ) = msg.sender.call{value: amount}(""); // External call!
589 balances[msg.sender] = 0; // State change AFTER call - VULNERABLE!
590 }
591 }
592 """,
593 "Proxy": """
594 contract Proxy {
595 address public implementation;
597 function upgrade(address newImpl) external {
598 // NO ACCESS CONTROL - VULNERABLE!
599 implementation = newImpl;
600 }
602 fallback() external payable {
603 address impl = implementation;
604 assembly {
605 delegatecall(gas(), impl, 0, calldatasize(), 0, 0)
606 }
607 }
608 }
609 """
610 }
612 vulns = analyzer.analyze_contracts(contracts)
614 print(f"Found {len(vulns)} cross-contract vulnerabilities:\n")
615 for vuln in vulns:
616 print(f"{vuln.severity.upper()}: {vuln.title}")
617 print(f" {vuln.description}")
618 print(f" {vuln.code_snippet}")
619 print()
621 print(analyzer.visualize_dependency_graph())