Coverage for src/srunx/models.py: 82%
233 statements
« prev ^ index » next coverage.py v7.9.1, created at 2025-09-04 07:47 +0000
« prev ^ index » next coverage.py v7.9.1, created at 2025-09-04 07:47 +0000
1"""Data models for SLURM job management."""
3import os
4import subprocess
5import time
6from enum import Enum
7from pathlib import Path
8from typing import Self
10import jinja2
11from pydantic import BaseModel, Field, PrivateAttr, model_validator
13from srunx.exceptions import WorkflowValidationError
14from srunx.logging import get_logger
16logger = get_logger(__name__)
19def _get_config_defaults():
20 """Get configuration defaults, with lazy import to avoid circular dependencies."""
21 try:
22 from srunx.config import get_config
24 return get_config()
25 except ImportError:
26 # Fallback if config module is not available
27 return None
30def _default_nodes():
31 """Get default nodes from config."""
32 config = _get_config_defaults()
33 return config.resources.nodes if config else 1
36def _default_gpus_per_node():
37 """Get default GPUs per node from config."""
38 config = _get_config_defaults()
39 return config.resources.gpus_per_node if config else 0
42def _default_ntasks_per_node():
43 """Get default ntasks per node from config."""
44 config = _get_config_defaults()
45 return config.resources.ntasks_per_node if config else 1
48def _default_cpus_per_task():
49 """Get default CPUs per task from config."""
50 config = _get_config_defaults()
51 return config.resources.cpus_per_task if config else 1
54def _default_memory_per_node():
55 """Get default memory per node from config."""
56 config = _get_config_defaults()
57 return config.resources.memory_per_node if config else None
60def _default_time_limit():
61 """Get default time limit from config."""
62 config = _get_config_defaults()
63 return config.resources.time_limit if config else None
66def _default_nodelist():
67 """Get default nodelist from config."""
68 config = _get_config_defaults()
69 return config.resources.nodelist if config else None
72def _default_partition():
73 """Get default partition from config."""
74 config = _get_config_defaults()
75 return config.resources.partition if config else None
78def _default_conda():
79 """Get default conda environment from config."""
80 config = _get_config_defaults()
81 return config.environment.conda if config else None
84def _default_venv():
85 """Get default venv path from config."""
86 config = _get_config_defaults()
87 return config.environment.venv if config else None
90def _default_sqsh():
91 """Get default sqsh path from config."""
92 config = _get_config_defaults()
93 return config.environment.sqsh if config else None
96def _default_env_vars():
97 """Get default environment variables from config."""
98 config = _get_config_defaults()
99 return config.environment.env_vars if config else {}
102def _default_log_dir():
103 """Get default log directory from config."""
104 config = _get_config_defaults()
105 return config.log_dir if config else os.getenv("SLURM_LOG_DIR", "logs")
108def _default_work_dir():
109 """Get default work directory from config."""
110 config = _get_config_defaults()
111 return config.work_dir if config else None
114class JobStatus(Enum):
115 """Job status enumeration for both SLURM jobs and workflow jobs."""
117 UNKNOWN = "UNKNOWN"
118 PENDING = "PENDING"
119 RUNNING = "RUNNING"
120 COMPLETED = "COMPLETED"
121 FAILED = "FAILED"
122 CANCELLED = "CANCELLED"
123 TIMEOUT = "TIMEOUT"
126class JobResource(BaseModel):
127 """SLURM resource allocation requirements."""
129 nodes: int = Field(
130 default_factory=_default_nodes, ge=1, description="Number of compute nodes"
131 )
132 gpus_per_node: int = Field(
133 default_factory=_default_gpus_per_node,
134 ge=0,
135 description="Number of GPUs per node",
136 )
137 ntasks_per_node: int = Field(
138 default_factory=_default_ntasks_per_node,
139 ge=1,
140 description="Number of jobs per node",
141 )
142 cpus_per_task: int = Field(
143 default_factory=_default_cpus_per_task,
144 ge=1,
145 description="Number of CPUs per task",
146 )
147 memory_per_node: str | None = Field(
148 default_factory=_default_memory_per_node,
149 description="Memory per node (e.g., '32GB')",
150 )
151 time_limit: str | None = Field(
152 default_factory=_default_time_limit, description="Time limit (e.g., '1:00:00')"
153 )
154 nodelist: str | None = Field(
155 default_factory=_default_nodelist,
156 description="Specific nodes to use (e.g., 'node001,node002')",
157 )
158 partition: str | None = Field(
159 default_factory=_default_partition,
160 description="SLURM partition to use (e.g., 'gpu', 'cpu')",
161 )
164class JobEnvironment(BaseModel):
165 """Job environment configuration."""
167 conda: str | None = Field(
168 default_factory=_default_conda, description="Conda environment name"
169 )
170 venv: str | None = Field(
171 default_factory=_default_venv, description="Virtual environment path"
172 )
173 sqsh: str | None = Field(
174 default_factory=_default_sqsh, description="SquashFS image path"
175 )
176 env_vars: dict[str, str] = Field(
177 default_factory=_default_env_vars, description="Environment variables"
178 )
180 @model_validator(mode="after")
181 def validate_environment(self) -> Self:
182 envs = [self.conda, self.venv, self.sqsh]
183 non_none_count = sum(x is not None for x in envs)
184 if non_none_count == 0:
185 logger.info("No virtual environment is set.")
186 elif non_none_count > 1:
187 raise ValueError(
188 "Only one virtual environment (conda, venv, or sqsh) can be specified"
189 )
190 return self
193class BaseJob(BaseModel):
194 name: str = Field(default="job", description="Job name")
195 job_id: int | None = Field(default=None, description="SLURM job ID")
196 depends_on: list[str] = Field(
197 default_factory=list, description="Task dependencies for workflow execution"
198 )
200 _status: JobStatus = PrivateAttr(default=JobStatus.PENDING)
202 @property
203 def status(self) -> JobStatus:
204 """
205 Accessing ``job.status`` always triggers a lightweight refresh
206 (only if we have a ``job_id`` and the status isn't terminal).
207 """
208 if self.job_id is not None and self._status not in {
209 JobStatus.COMPLETED,
210 JobStatus.FAILED,
211 JobStatus.CANCELLED,
212 JobStatus.TIMEOUT,
213 }:
214 self.refresh()
215 return self._status
217 @status.setter
218 def status(self, value: JobStatus) -> None:
219 self._status = value
221 def refresh(self, retries: int = 3) -> Self:
222 """Query sacct and update ``_status`` in-place."""
223 if self.job_id is None:
224 return self
226 for retry in range(retries):
227 try:
228 result = subprocess.run(
229 [
230 "sacct",
231 "-j",
232 str(self.job_id),
233 "--format",
234 "JobID,State",
235 "--noheader",
236 "--parsable2",
237 ],
238 capture_output=True,
239 text=True,
240 check=True,
241 )
242 except subprocess.CalledProcessError as e:
243 logger.error(f"Failed to query job {self.job_id}: {e}")
244 raise
246 line = result.stdout.strip().split("\n")[0] if result.stdout.strip() else ""
247 if not line:
248 if retry < retries - 1:
249 time.sleep(1)
250 continue
251 self._status = JobStatus.UNKNOWN
252 return self
253 break
255 _, state = line.split("|", 1)
256 self._status = JobStatus(state)
257 return self
259 def dependencies_satisfied(self, completed_job_names: list[str]) -> bool:
260 """All dependencies are completed & this job is still pending."""
261 return self.status == JobStatus.PENDING and all(
262 dep in completed_job_names for dep in self.depends_on
263 )
266class Job(BaseJob):
267 """Represents a SLURM job with complete configuration."""
269 command: list[str] = Field(description="Command to execute")
270 resources: JobResource = Field(
271 default_factory=JobResource, description="Resource requirements"
272 )
273 environment: JobEnvironment = Field(
274 default_factory=JobEnvironment, description="Environment setup"
275 )
276 log_dir: str = Field(
277 default_factory=_default_log_dir,
278 description="Directory for log files",
279 )
280 work_dir: str = Field(
281 default_factory=lambda: _default_work_dir() or os.getcwd(),
282 description="Working directory",
283 )
286class ShellJob(BaseJob):
287 path: str = Field(description="Shell script path to execute")
290type JobType = BaseJob | Job | ShellJob
291type RunnableJobType = Job | ShellJob
294class Workflow:
295 """Represents a workflow containing multiple jobs with dependencies."""
297 def __init__(self, name: str, jobs: list[RunnableJobType] | None = None) -> None:
298 if jobs is None:
299 jobs = []
301 self.name = name
302 self.jobs = jobs
304 def add(self, job: RunnableJobType) -> None:
305 # Check if job already exists
306 if job.depends_on:
307 for dep in job.depends_on:
308 if dep not in self.jobs:
309 raise WorkflowValidationError(
310 f"Job '{job.name}' depends on unknown job '{dep}'"
311 )
312 self.jobs.append(job)
314 def remove(self, job: RunnableJobType) -> None:
315 self.jobs.remove(job)
317 def get(self, name: str) -> RunnableJobType | None:
318 """Get a job by name."""
319 for job in self.jobs:
320 if job.name == name:
321 return job.refresh()
322 return None
324 def get_dependencies(self, job_name: str) -> list[str]:
325 """Get dependencies for a specific job."""
326 job = self.get(job_name)
327 return job.depends_on if job else []
329 def show(self):
330 msg = f"""\
331{" PLAN ":=^80}
332Workflow: {self.name}
333Jobs: {len(self.jobs)}
334"""
336 def add_indent(indent: int, msg: str) -> str:
337 return " " * indent + msg
339 for job in self.jobs:
340 msg += add_indent(1, f"Job: {job.name}\n")
341 if isinstance(job, Job):
342 msg += add_indent(
343 2, f"{'Command:': <13} {' '.join(job.command or [])}\n"
344 )
345 msg += add_indent(
346 2,
347 f"{'Resources:': <13} {job.resources.nodes} nodes, {job.resources.gpus_per_node} GPUs/node\n",
348 )
349 if job.environment.conda:
350 msg += add_indent(
351 2, f"{'Conda env:': <13} {job.environment.conda}\n"
352 )
353 if job.environment.sqsh:
354 msg += add_indent(2, f"{'Sqsh:': <13} {job.environment.sqsh}\n")
355 if job.environment.venv:
356 msg += add_indent(2, f"{'Venv:': <13} {job.environment.venv}\n")
357 elif isinstance(job, ShellJob):
358 msg += add_indent(2, f"{'Path:': <13} {job.path}\n")
359 if job.depends_on:
360 msg += add_indent(
361 2, f"{'Dependencies:': <13} {', '.join(job.depends_on)}\n"
362 )
364 msg += f"{'=' * 80}\n"
365 print(msg)
367 def validate(self):
368 """Validate workflow job dependencies."""
369 job_names = {job.name for job in self.jobs}
371 if len(job_names) != len(self.jobs):
372 raise WorkflowValidationError("Duplicate job names found in workflow")
374 for job in self.jobs:
375 for dependency in job.depends_on:
376 if dependency not in job_names:
377 raise WorkflowValidationError(
378 f"Job '{job.name}' depends on unknown job '{dependency}'"
379 )
381 # Check for circular dependencies (simple check)
382 visited = set()
383 rec_stack = set()
385 def has_cycle(job_name: str) -> bool:
386 if job_name in rec_stack:
387 return True
388 if job_name in visited:
389 return False
391 visited.add(job_name)
392 rec_stack.add(job_name)
394 job = self.get(job_name)
395 if job:
396 for dependency in job.depends_on:
397 if has_cycle(dependency):
398 return True
400 rec_stack.remove(job_name)
401 return False
403 for job in self.jobs:
404 if has_cycle(job.name):
405 raise WorkflowValidationError(
406 f"Circular dependency detected involving job '{job.name}'"
407 )
410def render_job_script(
411 template_path: Path | str,
412 job: Job,
413 output_dir: Path | str,
414 verbose: bool = False,
415) -> str:
416 """Render a SLURM job script from a template.
418 Args:
419 template_path: Path to the Jinja template file.
420 job: Job configuration.
421 output_dir: Directory where the generated script will be saved.
422 verbose: Whether to print the rendered content.
424 Returns:
425 Path to the generated SLURM batch script.
427 Raises:
428 FileNotFoundError: If the template file does not exist.
429 jinja2.TemplateError: If template rendering fails.
430 """
431 template_file = Path(template_path)
432 if not template_file.is_file():
433 raise FileNotFoundError(f"Template file '{template_path}' not found")
435 with open(template_file, encoding="utf-8") as f:
436 template_content = f.read()
438 template = jinja2.Template(template_content, undefined=jinja2.StrictUndefined)
440 # Prepare template variables
441 template_vars = {
442 "job_name": job.name,
443 "command": " ".join(job.command or []),
444 "log_dir": job.log_dir,
445 "work_dir": job.work_dir,
446 "environment_setup": _build_environment_setup(job.environment),
447 **job.resources.model_dump(),
448 }
450 # Debug: log template variables
451 logger.debug(f"Template variables: {template_vars}")
453 rendered_content = template.render(template_vars)
455 if verbose:
456 print(rendered_content)
458 # Generate output file
459 output_path = Path(output_dir) / f"{job.name}.slurm"
460 with open(output_path, "w", encoding="utf-8") as f:
461 f.write(rendered_content)
463 return str(output_path)
466def _build_environment_setup(environment: JobEnvironment) -> str:
467 """Build environment setup script."""
468 setup_lines = []
470 # Set environment variables
471 for key, value in environment.env_vars.items():
472 setup_lines.append(f"export {key}={value}")
474 # Activate environments
475 if environment.conda:
476 home_dir = Path.home()
477 setup_lines.extend(
478 [
479 f"source {str(home_dir)}/miniconda3/bin/activate",
480 "conda deactivate",
481 f"conda activate {environment.conda}",
482 ]
483 )
484 elif environment.venv:
485 setup_lines.append(f"source {environment.venv}/bin/activate")
486 elif environment.sqsh:
487 setup_lines.extend(
488 [
489 f': "${{IMAGE:={environment.sqsh}}}"',
490 "declare -a CONTAINER_ARGS=(",
491 ' --container-image "$IMAGE"',
492 ")",
493 ]
494 )
496 return "\n".join(setup_lines)