Coverage for src/alprina_cli/agents/web3_auditor/cross_contract_analyzer.py: 21%

182 statements  

« prev     ^ index     » next       coverage.py v7.11.3, created at 2025-11-14 11:27 +0100

1""" 

2Cross-Contract Analysis Engine 

3 

4WEEK 3 DAY 4: Cross-Contract Analysis 

5====================================== 

6 

7Analyzes vulnerabilities across multiple interacting contracts: 

8- Dependency graph construction 

9- Cross-contract reentrancy detection 

10- Upgrade pattern vulnerabilities 

11- Attack chain identification 

12- Interface trust issues 

13 

14Background: 

15- Modern DeFi uses complex contract interactions 

16- Vulnerabilities often exist at contract boundaries 

17- Upgradeable contracts introduce proxy risks 

18 

19Author: Alprina Development Team 

20Date: 2025-11-12 

21 

22References: 

23- The DAO attack (2016): Cross-contract reentrancy 

24- Parity Wallet (2017): Delegatecall vulnerability 

25- Poly Network (2021): Cross-chain trust issues 

26""" 

27 

28import re 

29from typing import List, Dict, Any, Optional, Set, Tuple 

30from dataclasses import dataclass, field 

31from enum import Enum 

32import networkx as nx 

33 

34try: 

35 from .solidity_analyzer import SolidityVulnerability, VulnerabilityType 

36except ImportError: 

37 import sys 

38 from pathlib import Path 

39 sys.path.insert(0, str(Path(__file__).parent)) 

40 from solidity_analyzer import SolidityVulnerability, VulnerabilityType 

41 

42 

43class CrossContractVulnType(Enum): 

44 """Types of cross-contract vulnerabilities""" 

45 CROSS_CONTRACT_REENTRANCY = "cross_contract_reentrancy" 

46 UPGRADE_VULNERABILITY = "upgrade_vulnerability" 

47 DELEGATECALL_INJECTION = "delegatecall_injection" 

48 INTERFACE_TRUST = "interface_trust" 

49 ACCESS_CONTROL_BREACH = "access_control_breach" 

50 ATTACK_CHAIN = "attack_chain" 

51 

52 

53@dataclass 

54class ContractDependency: 

55 """Represents a dependency between contracts""" 

56 from_contract: str 

57 to_contract: str 

58 function_name: str 

59 call_type: str # "call", "delegatecall", "staticcall", "interface" 

60 line_number: int 

61 

62 

63@dataclass 

64class AttackChain: 

65 """Represents a multi-step attack sequence""" 

66 steps: List[Dict[str, Any]] 

67 total_impact: str 

68 complexity: str 

69 description: str 

70 

71 

72@dataclass 

73class CrossContractVulnerability: 

74 """Cross-contract vulnerability""" 

75 vuln_type: CrossContractVulnType 

76 severity: str 

77 title: str 

78 description: str 

79 contracts_involved: List[str] 

80 attack_chain: Optional[AttackChain] 

81 estimated_loss: Tuple[float, float] 

82 confidence: int = 90 

83 

84 

85class CrossContractAnalyzer: 

86 """ 

87 Analyze vulnerabilities across multiple contracts 

88 

89 Week 3 Day 4 Implementation: 

90 1. Build dependency graphs between contracts 

91 2. Detect cross-contract reentrancy 

92 3. Identify upgrade vulnerabilities 

93 4. Find attack chains 

94 

95 Features: 

96 - Dependency graph construction using NetworkX 

97 - Multi-contract reentrancy detection 

98 - Proxy pattern analysis 

99 - Attack path discovery 

100 """ 

101 

102 def __init__(self): 

103 self.dependency_graph = nx.DiGraph() 

104 self.contracts: Dict[str, Dict[str, Any]] = {} 

105 self.vulnerabilities: List[CrossContractVulnerability] = [] 

106 

107 def analyze_contracts( 

108 self, 

109 contracts: Dict[str, str], # {filename: source_code} 

110 file_path: str = "multi-contract" 

111 ) -> List[SolidityVulnerability]: 

112 """ 

113 Analyze multiple contracts for cross-contract vulnerabilities 

114 

115 Args: 

116 contracts: Dictionary of contract names to source code 

117 

118 Returns: 

119 List of vulnerabilities in standard format 

120 """ 

121 self.vulnerabilities = [] 

122 self.contracts = {} 

123 self.dependency_graph = nx.DiGraph() 

124 

125 # Step 1: Parse all contracts 

126 for contract_name, source_code in contracts.items(): 

127 self.contracts[contract_name] = self._parse_contract(contract_name, source_code) 

128 

129 # Step 2: Build dependency graph 

130 self._build_dependency_graph() 

131 

132 # Step 3: Detect vulnerabilities 

133 self._detect_cross_contract_reentrancy() 

134 self._detect_upgrade_vulnerabilities() 

135 self._detect_delegatecall_issues() 

136 self._detect_interface_trust_issues() 

137 self._identify_attack_chains() 

138 

139 # Convert to standard format 

140 return self._convert_to_standard_format(file_path) 

141 

142 def _parse_contract(self, name: str, source_code: str) -> Dict[str, Any]: 

143 """Parse a contract and extract relevant information""" 

144 contract_info = { 

145 'name': name, 

146 'source': source_code, 

147 'functions': [], 

148 'external_calls': [], 

149 'state_variables': [], 

150 'is_proxy': False, 

151 'is_upgradeable': False 

152 } 

153 

154 lines = source_code.split('\n') 

155 

156 # Extract contract name from code 

157 for line in lines: 

158 contract_match = re.search(r'contract\s+(\w+)', line) 

159 if contract_match: 

160 contract_info['actual_name'] = contract_match.group(1) 

161 break 

162 

163 # Detect proxy pattern 

164 if any(keyword in source_code.lower() for keyword in ['delegatecall', 'proxy', 'implementation']): 

165 contract_info['is_proxy'] = True 

166 

167 # Detect upgradeable pattern 

168 if any(keyword in source_code.lower() for keyword in ['upgrade', 'initialize', 'upgradeable']): 

169 contract_info['is_upgradeable'] = True 

170 

171 # Extract functions 

172 contract_info['functions'] = self._extract_functions(source_code) 

173 

174 # Extract external calls 

175 contract_info['external_calls'] = self._extract_external_calls(source_code) 

176 

177 return contract_info 

178 

179 def _extract_functions(self, source_code: str) -> List[Dict[str, Any]]: 

180 """Extract function definitions""" 

181 functions = [] 

182 lines = source_code.split('\n') 

183 

184 for i, line in enumerate(lines): 

185 func_match = re.match( 

186 r'\s*function\s+(\w+)\s*\([^)]*\)\s*(public|external|internal|private)?', 

187 line 

188 ) 

189 if func_match: 

190 functions.append({ 

191 'name': func_match.group(1), 

192 'visibility': func_match.group(2) or 'internal', 

193 'line': i + 1 

194 }) 

195 

196 return functions 

197 

198 def _extract_external_calls(self, source_code: str) -> List[Dict[str, Any]]: 

199 """Extract external contract calls""" 

200 external_calls = [] 

201 lines = source_code.split('\n') 

202 

203 for i, line in enumerate(lines): 

204 # Pattern: ContractName.function() or contract.function() 

205 call_match = re.search(r'(\w+)\.(\w+)\s*\(', line) 

206 if call_match: 

207 target = call_match.group(1) 

208 function = call_match.group(2) 

209 

210 # Detect call type 

211 call_type = "call" 

212 if 'delegatecall' in line: 

213 call_type = "delegatecall" 

214 elif 'staticcall' in line: 

215 call_type = "staticcall" 

216 

217 external_calls.append({ 

218 'target': target, 

219 'function': function, 

220 'call_type': call_type, 

221 'line': i + 1 

222 }) 

223 

224 return external_calls 

225 

226 def _build_dependency_graph(self): 

227 """Build dependency graph between contracts""" 

228 # Add all contracts as nodes 

229 for contract_name in self.contracts.keys(): 

230 self.dependency_graph.add_node(contract_name) 

231 

232 # Add edges for external calls 

233 for contract_name, contract_info in self.contracts.items(): 

234 for call in contract_info['external_calls']: 

235 target = call['target'] 

236 

237 # Check if target is another contract in our set 

238 if target in self.contracts or target.endswith('Interface'): 

239 self.dependency_graph.add_edge( 

240 contract_name, 

241 target, 

242 call_type=call['call_type'], 

243 function=call['function'] 

244 ) 

245 

246 def _detect_cross_contract_reentrancy(self): 

247 """ 

248 Detect reentrancy vulnerabilities across contracts 

249 

250 Pattern: 

251 1. Contract A calls Contract B (external call) 

252 2. Contract B calls back to Contract A before state update 

253 3. State inconsistency allows exploitation 

254 """ 

255 

256 for contract_name, contract_info in self.contracts.items(): 

257 source = contract_info['source'] 

258 

259 # Check for external calls before state changes 

260 for func in contract_info['functions']: 

261 if func['visibility'] not in ['public', 'external']: 

262 continue 

263 

264 # Check if function has external calls 

265 has_external_call = any( 

266 call['target'] != contract_name 

267 for call in contract_info['external_calls'] 

268 ) 

269 

270 if has_external_call: 

271 # Check if state updates happen after external calls 

272 has_state_change_after = self._has_state_change_pattern(source) 

273 

274 if has_state_change_after: 

275 # Potential cross-contract reentrancy 

276 self.vulnerabilities.append(CrossContractVulnerability( 

277 vuln_type=CrossContractVulnType.CROSS_CONTRACT_REENTRANCY, 

278 severity="critical", 

279 title=f"Cross-Contract Reentrancy in {contract_name}", 

280 description=( 

281 f"Contract {contract_name} makes external calls before updating state. " 

282 f"This allows reentrancy attacks across contract boundaries." 

283 ), 

284 contracts_involved=[contract_name], 

285 attack_chain=None, 

286 estimated_loss=(100_000, 10_000_000), 

287 confidence=85 

288 )) 

289 

290 def _detect_upgrade_vulnerabilities(self): 

291 """ 

292 Detect vulnerabilities in upgradeable contracts 

293 

294 Patterns: 

295 1. Proxy without access control on upgrade function 

296 2. Missing initialization in upgradeable contract 

297 3. Storage collision in proxy pattern 

298 4. Delegatecall to untrusted address 

299 """ 

300 

301 for contract_name, contract_info in self.contracts.items(): 

302 if not contract_info['is_proxy'] and not contract_info['is_upgradeable']: 

303 continue 

304 

305 source = contract_info['source'] 

306 

307 # Pattern 1: Upgrade without access control 

308 if 'function upgrade' in source.lower() or 'function setimplementation' in source.lower(): 

309 has_access_control = any( 

310 keyword in source.lower() 

311 for keyword in ['onlyowner', 'require(msg.sender', 'onlyadmin'] 

312 ) 

313 

314 if not has_access_control: 

315 self.vulnerabilities.append(CrossContractVulnerability( 

316 vuln_type=CrossContractVulnType.UPGRADE_VULNERABILITY, 

317 severity="critical", 

318 title=f"Unprotected Upgrade in {contract_name}", 

319 description=( 

320 f"Contract {contract_name} has upgrade functionality without access control. " 

321 f"Anyone can upgrade to a malicious implementation and steal all funds." 

322 ), 

323 contracts_involved=[contract_name], 

324 attack_chain=None, 

325 estimated_loss=(1_000_000, 100_000_000), 

326 confidence=95 

327 )) 

328 

329 # Pattern 2: Missing initialize function 

330 if contract_info['is_upgradeable']: 

331 has_initialize = 'function initialize' in source.lower() 

332 has_constructor = 'constructor(' in source 

333 

334 if has_constructor and not has_initialize: 

335 self.vulnerabilities.append(CrossContractVulnerability( 

336 vuln_type=CrossContractVulnType.UPGRADE_VULNERABILITY, 

337 severity="high", 

338 title=f"Constructor in Upgradeable Contract: {contract_name}", 

339 description=( 

340 f"Upgradeable contract {contract_name} uses constructor instead of initialize(). " 

341 f"Constructor code won't execute in proxy context, leaving contract uninitialized." 

342 ), 

343 contracts_involved=[contract_name], 

344 attack_chain=None, 

345 estimated_loss=(50_000, 5_000_000), 

346 confidence=90 

347 )) 

348 

349 def _detect_delegatecall_issues(self): 

350 """ 

351 Detect delegatecall vulnerabilities 

352 

353 Patterns: 

354 1. Delegatecall to user-controlled address 

355 2. Delegatecall without proper validation 

356 3. Storage collision via delegatecall 

357 """ 

358 

359 for contract_name, contract_info in self.contracts.items(): 

360 source = contract_info['source'] 

361 

362 if 'delegatecall' not in source.lower(): 

363 continue 

364 

365 # Check for delegatecall to variable address 

366 lines = source.split('\n') 

367 for i, line in enumerate(lines): 

368 if 'delegatecall' in line.lower(): 

369 # Check if address is validated 

370 has_validation = any( 

371 keyword in source[:source.index(line)] 

372 for keyword in ['require(', 'if (', 'whitelist', 'approved'] 

373 ) 

374 

375 if not has_validation: 

376 self.vulnerabilities.append(CrossContractVulnerability( 

377 vuln_type=CrossContractVulnType.DELEGATECALL_INJECTION, 

378 severity="critical", 

379 title=f"Unsafe Delegatecall in {contract_name}", 

380 description=( 

381 f"Contract {contract_name} uses delegatecall without proper validation. " 

382 f"Attacker can execute arbitrary code in contract context, " 

383 f"leading to complete contract takeover." 

384 ), 

385 contracts_involved=[contract_name], 

386 attack_chain=None, 

387 estimated_loss=(500_000, 50_000_000), 

388 confidence=95 

389 )) 

390 break 

391 

392 def _detect_interface_trust_issues(self): 

393 """ 

394 Detect trust issues with external contract interfaces 

395 

396 Patterns: 

397 1. Trusting external price oracles without validation 

398 2. Calling unknown contracts without checks 

399 3. Assuming external contract behavior 

400 """ 

401 

402 for contract_name, contract_info in self.contracts.items(): 

403 # Check for calls to external contracts 

404 for call in contract_info['external_calls']: 

405 # If calling an interface or external contract 

406 if call['target'].endswith('Interface') or call['target'][0].isupper(): 

407 # Check if return value is validated 

408 source = contract_info['source'] 

409 

410 # This is a simplified check 

411 if 'require(' not in source or 'revert(' not in source: 

412 self.vulnerabilities.append(CrossContractVulnerability( 

413 vuln_type=CrossContractVulnType.INTERFACE_TRUST, 

414 severity="medium", 

415 title=f"Untrusted External Call in {contract_name}", 

416 description=( 

417 f"Contract {contract_name} calls external contract {call['target']} " 

418 f"without validating return values. Malicious contract can " 

419 f"return unexpected values." 

420 ), 

421 contracts_involved=[contract_name, call['target']], 

422 attack_chain=None, 

423 estimated_loss=(10_000, 1_000_000), 

424 confidence=70 

425 )) 

426 

427 def _identify_attack_chains(self): 

428 """ 

429 Identify multi-step attack chains across contracts 

430 

431 Example: 

432 1. Exploit Contract A → manipulate state 

433 2. Call Contract B → uses manipulated state 

434 3. Profit from price difference 

435 """ 

436 

437 # Find cycles in dependency graph (potential reentrancy chains) 

438 try: 

439 cycles = list(nx.simple_cycles(self.dependency_graph)) 

440 

441 for cycle in cycles: 

442 if len(cycle) > 1: 

443 # Potential attack chain through cycle 

444 self.vulnerabilities.append(CrossContractVulnerability( 

445 vuln_type=CrossContractVulnType.ATTACK_CHAIN, 

446 severity="high", 

447 title=f"Potential Attack Chain: {' → '.join(cycle)}", 

448 description=( 

449 f"Circular dependency detected: {' → '.join(cycle)}. " 

450 f"This creates potential for complex reentrancy or state manipulation attacks." 

451 ), 

452 contracts_involved=cycle, 

453 attack_chain=AttackChain( 

454 steps=[{'contract': c, 'action': 'call'} for c in cycle], 

455 total_impact="high", 

456 complexity="medium", 

457 description=f"Attack path through: {' → '.join(cycle)}" 

458 ), 

459 estimated_loss=(50_000, 5_000_000), 

460 confidence=75 

461 )) 

462 except: 

463 pass # No cycles found 

464 

465 def _has_state_change_pattern(self, source: str) -> bool: 

466 """Check if contract has state changes after external calls""" 

467 # Simplified check for state changes 

468 state_change_patterns = [ 

469 r'=\s*\w+', # Assignment 

470 r'balances\[', 

471 r'\.transfer\(', 

472 r'\.send\(', 

473 ] 

474 

475 return any(re.search(pattern, source) for pattern in state_change_patterns) 

476 

477 def _convert_to_standard_format(self, file_path: str) -> List[SolidityVulnerability]: 

478 """Convert cross-contract vulnerabilities to standard format""" 

479 standard_vulns = [] 

480 

481 for vuln in self.vulnerabilities: 

482 # Map to standard vulnerability types 

483 vuln_type_map = { 

484 CrossContractVulnType.CROSS_CONTRACT_REENTRANCY: VulnerabilityType.REENTRANCY, 

485 CrossContractVulnType.UPGRADE_VULNERABILITY: VulnerabilityType.ACCESS_CONTROL, 

486 CrossContractVulnType.DELEGATECALL_INJECTION: VulnerabilityType.LOGIC_ERROR, 

487 CrossContractVulnType.INTERFACE_TRUST: VulnerabilityType.LOGIC_ERROR, 

488 CrossContractVulnType.ACCESS_CONTROL_BREACH: VulnerabilityType.ACCESS_CONTROL, 

489 CrossContractVulnType.ATTACK_CHAIN: VulnerabilityType.LOGIC_ERROR, 

490 } 

491 

492 vuln_type = vuln_type_map.get(vuln.vuln_type, VulnerabilityType.LOGIC_ERROR) 

493 

494 # Create code snippet with cross-contract details 

495 code_snippet = ( 

496 f"Contracts Involved: {', '.join(vuln.contracts_involved)}\n" 

497 f"Estimated Loss: ${vuln.estimated_loss[0]:,} - ${vuln.estimated_loss[1]:,}\n" 

498 ) 

499 

500 if vuln.attack_chain: 

501 code_snippet += f"Attack Chain: {vuln.attack_chain.description}\n" 

502 

503 # Remediation 

504 remediation = self._get_remediation(vuln.vuln_type) 

505 

506 standard_vuln = SolidityVulnerability( 

507 vulnerability_type=vuln_type, 

508 severity=vuln.severity, 

509 title=f"[Cross-Contract] {vuln.title}", 

510 description=vuln.description, 

511 file_path=file_path, 

512 line_number=1, 

513 function_name="multiple", 

514 contract_name=', '.join(vuln.contracts_involved), 

515 code_snippet=code_snippet, 

516 remediation=remediation, 

517 confidence=vuln.confidence 

518 ) 

519 

520 standard_vulns.append(standard_vuln) 

521 

522 return standard_vulns 

523 

524 def _get_remediation(self, vuln_type: CrossContractVulnType) -> str: 

525 """Get remediation advice""" 

526 remediation_map = { 

527 CrossContractVulnType.CROSS_CONTRACT_REENTRANCY: ( 

528 "Follow Checks-Effects-Interactions pattern: update state before external calls. " 

529 "Use ReentrancyGuard from OpenZeppelin. Consider using pull over push for payments." 

530 ), 

531 CrossContractVulnType.UPGRADE_VULNERABILITY: ( 

532 "Add access control to upgrade functions (onlyOwner/onlyAdmin). " 

533 "Use OpenZeppelin's UUPSUpgradeable or TransparentUpgradeableProxy. " 

534 "Implement timelock for upgrades. Use initialize() instead of constructor." 

535 ), 

536 CrossContractVulnType.DELEGATECALL_INJECTION: ( 

537 "Never delegatecall to user-controlled addresses. " 

538 "Maintain a whitelist of approved implementation contracts. " 

539 "Use library patterns instead of delegatecall where possible." 

540 ), 

541 CrossContractVulnType.INTERFACE_TRUST: ( 

542 "Validate all return values from external contracts. " 

543 "Use try/catch for external calls. " 

544 "Implement circuit breakers for critical operations. " 

545 "Verify contract addresses and implementations." 

546 ), 

547 CrossContractVulnType.ATTACK_CHAIN: ( 

548 "Break circular dependencies where possible. " 

549 "Add reentrancy guards across the call chain. " 

550 "Validate state consistency at each step. " 

551 "Consider using commit-reveal schemes." 

552 ) 

553 } 

554 

555 return remediation_map.get(vuln_type, "Review cross-contract interactions carefully.") 

556 

557 def visualize_dependency_graph(self) -> str: 

558 """Generate text representation of dependency graph""" 

559 if not self.dependency_graph.nodes(): 

560 return "No dependencies found" 

561 

562 output = "Contract Dependency Graph:\n" 

563 output += "=" * 50 + "\n\n" 

564 

565 for contract in self.dependency_graph.nodes(): 

566 output += f"{contract}\n" 

567 

568 # Outgoing dependencies 

569 for target in self.dependency_graph.successors(contract): 

570 edge_data = self.dependency_graph[contract][target] 

571 output += f" └─> {target} ({edge_data.get('call_type', 'call')})\n" 

572 

573 return output 

574 

575 

576# Example usage 

577if __name__ == "__main__": 

578 analyzer = CrossContractAnalyzer() 

579 

580 # Test with multiple contracts 

581 contracts = { 

582 "Vault": """ 

583 contract Vault { 

584 mapping(address => uint256) public balances; 

585 

586 function withdraw() external { 

587 uint256 amount = balances[msg.sender]; 

588 (bool success, ) = msg.sender.call{value: amount}(""); // External call! 

589 balances[msg.sender] = 0; // State change AFTER call - VULNERABLE! 

590 } 

591 } 

592 """, 

593 "Proxy": """ 

594 contract Proxy { 

595 address public implementation; 

596 

597 function upgrade(address newImpl) external { 

598 // NO ACCESS CONTROL - VULNERABLE! 

599 implementation = newImpl; 

600 } 

601 

602 fallback() external payable { 

603 address impl = implementation; 

604 assembly { 

605 delegatecall(gas(), impl, 0, calldatasize(), 0, 0) 

606 } 

607 } 

608 } 

609 """ 

610 } 

611 

612 vulns = analyzer.analyze_contracts(contracts) 

613 

614 print(f"Found {len(vulns)} cross-contract vulnerabilities:\n") 

615 for vuln in vulns: 

616 print(f"{vuln.severity.upper()}: {vuln.title}") 

617 print(f" {vuln.description}") 

618 print(f" {vuln.code_snippet}") 

619 print() 

620 

621 print(analyzer.visualize_dependency_graph())