Coverage for src/srunx/models.py: 82%

233 statements  

« prev     ^ index     » next       coverage.py v7.9.1, created at 2025-09-04 07:47 +0000

1"""Data models for SLURM job management.""" 

2 

3import os 

4import subprocess 

5import time 

6from enum import Enum 

7from pathlib import Path 

8from typing import Self 

9 

10import jinja2 

11from pydantic import BaseModel, Field, PrivateAttr, model_validator 

12 

13from srunx.exceptions import WorkflowValidationError 

14from srunx.logging import get_logger 

15 

16logger = get_logger(__name__) 

17 

18 

19def _get_config_defaults(): 

20 """Get configuration defaults, with lazy import to avoid circular dependencies.""" 

21 try: 

22 from srunx.config import get_config 

23 

24 return get_config() 

25 except ImportError: 

26 # Fallback if config module is not available 

27 return None 

28 

29 

30def _default_nodes(): 

31 """Get default nodes from config.""" 

32 config = _get_config_defaults() 

33 return config.resources.nodes if config else 1 

34 

35 

36def _default_gpus_per_node(): 

37 """Get default GPUs per node from config.""" 

38 config = _get_config_defaults() 

39 return config.resources.gpus_per_node if config else 0 

40 

41 

42def _default_ntasks_per_node(): 

43 """Get default ntasks per node from config.""" 

44 config = _get_config_defaults() 

45 return config.resources.ntasks_per_node if config else 1 

46 

47 

48def _default_cpus_per_task(): 

49 """Get default CPUs per task from config.""" 

50 config = _get_config_defaults() 

51 return config.resources.cpus_per_task if config else 1 

52 

53 

54def _default_memory_per_node(): 

55 """Get default memory per node from config.""" 

56 config = _get_config_defaults() 

57 return config.resources.memory_per_node if config else None 

58 

59 

60def _default_time_limit(): 

61 """Get default time limit from config.""" 

62 config = _get_config_defaults() 

63 return config.resources.time_limit if config else None 

64 

65 

66def _default_nodelist(): 

67 """Get default nodelist from config.""" 

68 config = _get_config_defaults() 

69 return config.resources.nodelist if config else None 

70 

71 

72def _default_partition(): 

73 """Get default partition from config.""" 

74 config = _get_config_defaults() 

75 return config.resources.partition if config else None 

76 

77 

78def _default_conda(): 

79 """Get default conda environment from config.""" 

80 config = _get_config_defaults() 

81 return config.environment.conda if config else None 

82 

83 

84def _default_venv(): 

85 """Get default venv path from config.""" 

86 config = _get_config_defaults() 

87 return config.environment.venv if config else None 

88 

89 

90def _default_sqsh(): 

91 """Get default sqsh path from config.""" 

92 config = _get_config_defaults() 

93 return config.environment.sqsh if config else None 

94 

95 

96def _default_env_vars(): 

97 """Get default environment variables from config.""" 

98 config = _get_config_defaults() 

99 return config.environment.env_vars if config else {} 

100 

101 

102def _default_log_dir(): 

103 """Get default log directory from config.""" 

104 config = _get_config_defaults() 

105 return config.log_dir if config else os.getenv("SLURM_LOG_DIR", "logs") 

106 

107 

108def _default_work_dir(): 

109 """Get default work directory from config.""" 

110 config = _get_config_defaults() 

111 return config.work_dir if config else None 

112 

113 

114class JobStatus(Enum): 

115 """Job status enumeration for both SLURM jobs and workflow jobs.""" 

116 

117 UNKNOWN = "UNKNOWN" 

118 PENDING = "PENDING" 

119 RUNNING = "RUNNING" 

120 COMPLETED = "COMPLETED" 

121 FAILED = "FAILED" 

122 CANCELLED = "CANCELLED" 

123 TIMEOUT = "TIMEOUT" 

124 

125 

126class JobResource(BaseModel): 

127 """SLURM resource allocation requirements.""" 

128 

129 nodes: int = Field( 

130 default_factory=_default_nodes, ge=1, description="Number of compute nodes" 

131 ) 

132 gpus_per_node: int = Field( 

133 default_factory=_default_gpus_per_node, 

134 ge=0, 

135 description="Number of GPUs per node", 

136 ) 

137 ntasks_per_node: int = Field( 

138 default_factory=_default_ntasks_per_node, 

139 ge=1, 

140 description="Number of jobs per node", 

141 ) 

142 cpus_per_task: int = Field( 

143 default_factory=_default_cpus_per_task, 

144 ge=1, 

145 description="Number of CPUs per task", 

146 ) 

147 memory_per_node: str | None = Field( 

148 default_factory=_default_memory_per_node, 

149 description="Memory per node (e.g., '32GB')", 

150 ) 

151 time_limit: str | None = Field( 

152 default_factory=_default_time_limit, description="Time limit (e.g., '1:00:00')" 

153 ) 

154 nodelist: str | None = Field( 

155 default_factory=_default_nodelist, 

156 description="Specific nodes to use (e.g., 'node001,node002')", 

157 ) 

158 partition: str | None = Field( 

159 default_factory=_default_partition, 

160 description="SLURM partition to use (e.g., 'gpu', 'cpu')", 

161 ) 

162 

163 

164class JobEnvironment(BaseModel): 

165 """Job environment configuration.""" 

166 

167 conda: str | None = Field( 

168 default_factory=_default_conda, description="Conda environment name" 

169 ) 

170 venv: str | None = Field( 

171 default_factory=_default_venv, description="Virtual environment path" 

172 ) 

173 sqsh: str | None = Field( 

174 default_factory=_default_sqsh, description="SquashFS image path" 

175 ) 

176 env_vars: dict[str, str] = Field( 

177 default_factory=_default_env_vars, description="Environment variables" 

178 ) 

179 

180 @model_validator(mode="after") 

181 def validate_environment(self) -> Self: 

182 envs = [self.conda, self.venv, self.sqsh] 

183 non_none_count = sum(x is not None for x in envs) 

184 if non_none_count == 0: 

185 logger.info("No virtual environment is set.") 

186 elif non_none_count > 1: 

187 raise ValueError( 

188 "Only one virtual environment (conda, venv, or sqsh) can be specified" 

189 ) 

190 return self 

191 

192 

193class BaseJob(BaseModel): 

194 name: str = Field(default="job", description="Job name") 

195 job_id: int | None = Field(default=None, description="SLURM job ID") 

196 depends_on: list[str] = Field( 

197 default_factory=list, description="Task dependencies for workflow execution" 

198 ) 

199 

200 _status: JobStatus = PrivateAttr(default=JobStatus.PENDING) 

201 

202 @property 

203 def status(self) -> JobStatus: 

204 """ 

205 Accessing ``job.status`` always triggers a lightweight refresh 

206 (only if we have a ``job_id`` and the status isn't terminal). 

207 """ 

208 if self.job_id is not None and self._status not in { 

209 JobStatus.COMPLETED, 

210 JobStatus.FAILED, 

211 JobStatus.CANCELLED, 

212 JobStatus.TIMEOUT, 

213 }: 

214 self.refresh() 

215 return self._status 

216 

217 @status.setter 

218 def status(self, value: JobStatus) -> None: 

219 self._status = value 

220 

221 def refresh(self, retries: int = 3) -> Self: 

222 """Query sacct and update ``_status`` in-place.""" 

223 if self.job_id is None: 

224 return self 

225 

226 for retry in range(retries): 

227 try: 

228 result = subprocess.run( 

229 [ 

230 "sacct", 

231 "-j", 

232 str(self.job_id), 

233 "--format", 

234 "JobID,State", 

235 "--noheader", 

236 "--parsable2", 

237 ], 

238 capture_output=True, 

239 text=True, 

240 check=True, 

241 ) 

242 except subprocess.CalledProcessError as e: 

243 logger.error(f"Failed to query job {self.job_id}: {e}") 

244 raise 

245 

246 line = result.stdout.strip().split("\n")[0] if result.stdout.strip() else "" 

247 if not line: 

248 if retry < retries - 1: 

249 time.sleep(1) 

250 continue 

251 self._status = JobStatus.UNKNOWN 

252 return self 

253 break 

254 

255 _, state = line.split("|", 1) 

256 self._status = JobStatus(state) 

257 return self 

258 

259 def dependencies_satisfied(self, completed_job_names: list[str]) -> bool: 

260 """All dependencies are completed & this job is still pending.""" 

261 return self.status == JobStatus.PENDING and all( 

262 dep in completed_job_names for dep in self.depends_on 

263 ) 

264 

265 

266class Job(BaseJob): 

267 """Represents a SLURM job with complete configuration.""" 

268 

269 command: list[str] = Field(description="Command to execute") 

270 resources: JobResource = Field( 

271 default_factory=JobResource, description="Resource requirements" 

272 ) 

273 environment: JobEnvironment = Field( 

274 default_factory=JobEnvironment, description="Environment setup" 

275 ) 

276 log_dir: str = Field( 

277 default_factory=_default_log_dir, 

278 description="Directory for log files", 

279 ) 

280 work_dir: str = Field( 

281 default_factory=lambda: _default_work_dir() or os.getcwd(), 

282 description="Working directory", 

283 ) 

284 

285 

286class ShellJob(BaseJob): 

287 path: str = Field(description="Shell script path to execute") 

288 

289 

290type JobType = BaseJob | Job | ShellJob 

291type RunnableJobType = Job | ShellJob 

292 

293 

294class Workflow: 

295 """Represents a workflow containing multiple jobs with dependencies.""" 

296 

297 def __init__(self, name: str, jobs: list[RunnableJobType] | None = None) -> None: 

298 if jobs is None: 

299 jobs = [] 

300 

301 self.name = name 

302 self.jobs = jobs 

303 

304 def add(self, job: RunnableJobType) -> None: 

305 # Check if job already exists 

306 if job.depends_on: 

307 for dep in job.depends_on: 

308 if dep not in self.jobs: 

309 raise WorkflowValidationError( 

310 f"Job '{job.name}' depends on unknown job '{dep}'" 

311 ) 

312 self.jobs.append(job) 

313 

314 def remove(self, job: RunnableJobType) -> None: 

315 self.jobs.remove(job) 

316 

317 def get(self, name: str) -> RunnableJobType | None: 

318 """Get a job by name.""" 

319 for job in self.jobs: 

320 if job.name == name: 

321 return job.refresh() 

322 return None 

323 

324 def get_dependencies(self, job_name: str) -> list[str]: 

325 """Get dependencies for a specific job.""" 

326 job = self.get(job_name) 

327 return job.depends_on if job else [] 

328 

329 def show(self): 

330 msg = f"""\ 

331{" PLAN ":=^80} 

332Workflow: {self.name} 

333Jobs: {len(self.jobs)} 

334""" 

335 

336 def add_indent(indent: int, msg: str) -> str: 

337 return " " * indent + msg 

338 

339 for job in self.jobs: 

340 msg += add_indent(1, f"Job: {job.name}\n") 

341 if isinstance(job, Job): 

342 msg += add_indent( 

343 2, f"{'Command:': <13} {' '.join(job.command or [])}\n" 

344 ) 

345 msg += add_indent( 

346 2, 

347 f"{'Resources:': <13} {job.resources.nodes} nodes, {job.resources.gpus_per_node} GPUs/node\n", 

348 ) 

349 if job.environment.conda: 

350 msg += add_indent( 

351 2, f"{'Conda env:': <13} {job.environment.conda}\n" 

352 ) 

353 if job.environment.sqsh: 

354 msg += add_indent(2, f"{'Sqsh:': <13} {job.environment.sqsh}\n") 

355 if job.environment.venv: 

356 msg += add_indent(2, f"{'Venv:': <13} {job.environment.venv}\n") 

357 elif isinstance(job, ShellJob): 

358 msg += add_indent(2, f"{'Path:': <13} {job.path}\n") 

359 if job.depends_on: 

360 msg += add_indent( 

361 2, f"{'Dependencies:': <13} {', '.join(job.depends_on)}\n" 

362 ) 

363 

364 msg += f"{'=' * 80}\n" 

365 print(msg) 

366 

367 def validate(self): 

368 """Validate workflow job dependencies.""" 

369 job_names = {job.name for job in self.jobs} 

370 

371 if len(job_names) != len(self.jobs): 

372 raise WorkflowValidationError("Duplicate job names found in workflow") 

373 

374 for job in self.jobs: 

375 for dependency in job.depends_on: 

376 if dependency not in job_names: 

377 raise WorkflowValidationError( 

378 f"Job '{job.name}' depends on unknown job '{dependency}'" 

379 ) 

380 

381 # Check for circular dependencies (simple check) 

382 visited = set() 

383 rec_stack = set() 

384 

385 def has_cycle(job_name: str) -> bool: 

386 if job_name in rec_stack: 

387 return True 

388 if job_name in visited: 

389 return False 

390 

391 visited.add(job_name) 

392 rec_stack.add(job_name) 

393 

394 job = self.get(job_name) 

395 if job: 

396 for dependency in job.depends_on: 

397 if has_cycle(dependency): 

398 return True 

399 

400 rec_stack.remove(job_name) 

401 return False 

402 

403 for job in self.jobs: 

404 if has_cycle(job.name): 

405 raise WorkflowValidationError( 

406 f"Circular dependency detected involving job '{job.name}'" 

407 ) 

408 

409 

410def render_job_script( 

411 template_path: Path | str, 

412 job: Job, 

413 output_dir: Path | str, 

414 verbose: bool = False, 

415) -> str: 

416 """Render a SLURM job script from a template. 

417 

418 Args: 

419 template_path: Path to the Jinja template file. 

420 job: Job configuration. 

421 output_dir: Directory where the generated script will be saved. 

422 verbose: Whether to print the rendered content. 

423 

424 Returns: 

425 Path to the generated SLURM batch script. 

426 

427 Raises: 

428 FileNotFoundError: If the template file does not exist. 

429 jinja2.TemplateError: If template rendering fails. 

430 """ 

431 template_file = Path(template_path) 

432 if not template_file.is_file(): 

433 raise FileNotFoundError(f"Template file '{template_path}' not found") 

434 

435 with open(template_file, encoding="utf-8") as f: 

436 template_content = f.read() 

437 

438 template = jinja2.Template(template_content, undefined=jinja2.StrictUndefined) 

439 

440 # Prepare template variables 

441 template_vars = { 

442 "job_name": job.name, 

443 "command": " ".join(job.command or []), 

444 "log_dir": job.log_dir, 

445 "work_dir": job.work_dir, 

446 "environment_setup": _build_environment_setup(job.environment), 

447 **job.resources.model_dump(), 

448 } 

449 

450 # Debug: log template variables 

451 logger.debug(f"Template variables: {template_vars}") 

452 

453 rendered_content = template.render(template_vars) 

454 

455 if verbose: 

456 print(rendered_content) 

457 

458 # Generate output file 

459 output_path = Path(output_dir) / f"{job.name}.slurm" 

460 with open(output_path, "w", encoding="utf-8") as f: 

461 f.write(rendered_content) 

462 

463 return str(output_path) 

464 

465 

466def _build_environment_setup(environment: JobEnvironment) -> str: 

467 """Build environment setup script.""" 

468 setup_lines = [] 

469 

470 # Set environment variables 

471 for key, value in environment.env_vars.items(): 

472 setup_lines.append(f"export {key}={value}") 

473 

474 # Activate environments 

475 if environment.conda: 

476 home_dir = Path.home() 

477 setup_lines.extend( 

478 [ 

479 f"source {str(home_dir)}/miniconda3/bin/activate", 

480 "conda deactivate", 

481 f"conda activate {environment.conda}", 

482 ] 

483 ) 

484 elif environment.venv: 

485 setup_lines.append(f"source {environment.venv}/bin/activate") 

486 elif environment.sqsh: 

487 setup_lines.extend( 

488 [ 

489 f': "${{IMAGE:={environment.sqsh}}}"', 

490 "declare -a CONTAINER_ARGS=(", 

491 ' --container-image "$IMAGE"', 

492 ")", 

493 ] 

494 ) 

495 

496 return "\n".join(setup_lines)