Coverage for src/srunx/runner.py: 93%
148 statements
« prev ^ index » next coverage.py v7.9.1, created at 2025-09-04 07:47 +0000
« prev ^ index » next coverage.py v7.9.1, created at 2025-09-04 07:47 +0000
1"""Workflow runner for executing YAML-defined workflows with SLURM"""
3import time
4from collections import defaultdict
5from collections.abc import Sequence
6from concurrent.futures import ThreadPoolExecutor
7from pathlib import Path
8from typing import Any, Self
10import jinja2
11import yaml
13from srunx.callbacks import Callback
14from srunx.client import Slurm
15from srunx.exceptions import WorkflowValidationError
16from srunx.logging import get_logger
17from srunx.models import (
18 Job,
19 JobEnvironment,
20 JobResource,
21 JobStatus,
22 RunnableJobType,
23 ShellJob,
24 Workflow,
25)
27logger = get_logger(__name__)
30class WorkflowRunner:
31 """Runner for executing workflows defined in YAML with dynamic job scheduling.
33 Jobs are executed as soon as their dependencies are satisfied,
34 rather than waiting for entire dependency levels to complete.
35 """
37 def __init__(
38 self,
39 workflow: Workflow,
40 callbacks: Sequence[Callback] | None = None,
41 args: dict[str, Any] | None = None,
42 ) -> None:
43 """Initialize workflow runner.
45 Args:
46 workflow: Workflow to execute.
47 callbacks: List of callbacks for job notifications.
48 args: Template variables from the YAML args section.
49 """
50 self.workflow = workflow
51 self.slurm = Slurm(callbacks=callbacks)
52 self.callbacks = callbacks or []
53 self.args = args or {}
55 @classmethod
56 def from_yaml(
57 cls, yaml_path: str | Path, callbacks: Sequence[Callback] | None = None
58 ) -> Self:
59 """Load and validate a workflow from a YAML file.
61 Args:
62 yaml_path: Path to the YAML workflow definition file.
63 callbacks: List of callbacks for job notifications.
65 Returns:
66 WorkflowRunner instance with loaded workflow.
68 Raises:
69 FileNotFoundError: If the YAML file doesn't exist.
70 yaml.YAMLError: If the YAML is malformed.
71 ValidationError: If the workflow structure is invalid.
72 """
73 yaml_file = Path(yaml_path)
74 if not yaml_file.exists():
75 raise FileNotFoundError(f"Workflow file not found: {yaml_path}")
77 with open(yaml_file, encoding="utf-8") as f:
78 data = yaml.safe_load(f)
80 name = data.get("name", "unnamed")
81 args = data.get("args", {})
82 jobs_data = data.get("jobs", [])
84 # Render Jinja templates in jobs_data using args
85 rendered_jobs_data = cls._render_jobs_with_args(jobs_data, args)
87 jobs = []
88 for job_data in rendered_jobs_data:
89 job = cls.parse_job(job_data)
90 jobs.append(job)
91 return cls(
92 workflow=Workflow(name=name, jobs=jobs), callbacks=callbacks, args=args
93 )
95 @staticmethod
96 def _render_jobs_with_args(
97 jobs_data: list[dict[str, Any]], args: dict[str, Any]
98 ) -> list[dict[str, Any]]:
99 """Render Jinja templates in job data using args.
101 Args:
102 jobs_data: List of job configurations from YAML.
103 args: Template variables from the YAML args section.
105 Returns:
106 List of job configurations with rendered templates.
107 """
108 if not args:
109 return jobs_data
111 # Convert jobs_data to YAML string, render as template, then parse back
112 jobs_yaml = yaml.dump(jobs_data, default_flow_style=False)
113 template = jinja2.Template(jobs_yaml, undefined=jinja2.StrictUndefined)
115 for key, value in args.items():
116 if isinstance(value, str):
117 if value.startswith("python:"):
118 cmd = value.split(":")[1]
119 if "datetime" in cmd:
120 import datetime # noqa: F401
122 args[key] = eval(cmd)
124 try:
125 rendered_yaml = template.render(args)
126 return yaml.safe_load(rendered_yaml)
127 except jinja2.TemplateError as e:
128 logger.error(f"Jinja template rendering failed: {e}")
129 raise WorkflowValidationError(f"Template rendering failed: {e}") from e
131 def get_independent_jobs(self) -> list[RunnableJobType]:
132 """Get all jobs that are independent of any other job."""
133 independent_jobs = []
134 for job in self.workflow.jobs:
135 if not job.depends_on:
136 independent_jobs.append(job)
137 return independent_jobs
139 def run(self) -> dict[str, RunnableJobType]:
140 """Run a workflow with dynamic job scheduling.
142 Jobs are executed as soon as their dependencies are satisfied.
144 Returns:
145 Dictionary mapping job names to completed Job instances.
146 """
147 logger.info(
148 f"🚀 Starting Workflow {self.workflow.name} with {len(self.workflow.jobs)} jobs"
149 )
150 for callback in self.callbacks:
151 callback.on_workflow_started(self.workflow)
153 # Track all jobs and results
154 all_jobs = self.workflow.jobs.copy()
155 results: dict[str, RunnableJobType] = {}
156 running_futures: dict[str, Any] = {}
158 # Build reverse dependency map for efficient lookups
159 dependents = defaultdict(set)
160 for job in all_jobs:
161 for dep in job.depends_on:
162 dependents[dep].add(job.name)
164 def execute_job(job: RunnableJobType) -> RunnableJobType:
165 """Execute a single job."""
166 logger.info(f"🌋 {'SUBMITTED':<12} Job {job.name:<12}")
168 try:
169 result = self.slurm.run(job)
170 return result
171 except Exception as e:
172 raise
174 def on_job_complete(job_name: str, result: RunnableJobType) -> list[str]:
175 """Handle job completion and return newly ready job names."""
176 results[job_name] = result
177 completed_job_names = list(set(results.keys()))
179 # Find newly ready jobs
180 newly_ready = []
181 for dependent_name in dependents[job_name]:
182 dependent_job = next(j for j in all_jobs if j.name == dependent_name)
183 if (
184 dependent_job.status == JobStatus.PENDING
185 and dependent_job.dependencies_satisfied(completed_job_names)
186 ):
187 newly_ready.append(dependent_name)
189 return newly_ready
191 # Execute workflow with ThreadPoolExecutor
192 with ThreadPoolExecutor(max_workers=8) as executor:
193 # Submit initial ready jobs
194 initial_jobs = self.get_independent_jobs()
196 for job in initial_jobs:
197 future = executor.submit(execute_job, job)
198 running_futures[job.name] = future
200 # Process completed jobs and schedule new ones
201 while running_futures:
202 # Check for completed futures
203 completed = []
204 for job_name, future in list(running_futures.items()):
205 if future.done():
206 completed.append((job_name, future))
207 del running_futures[job_name]
209 if not completed:
210 time.sleep(0.1) # Brief sleep to avoid busy waiting
211 continue
213 # Handle completed jobs
214 for job_name, future in completed:
215 try:
216 result = future.result()
217 newly_ready_names = on_job_complete(job_name, result)
219 # Schedule newly ready jobs
220 for ready_name in newly_ready_names:
221 if ready_name not in running_futures:
222 ready_job = next(
223 j for j in all_jobs if j.name == ready_name
224 )
225 new_future = executor.submit(execute_job, ready_job)
226 running_futures[ready_name] = new_future
228 except Exception as e:
229 logger.error(f"❌ Job {job_name} failed: {e}")
230 raise
232 # Verify all jobs completed successfully
233 failed_jobs = [j.name for j in all_jobs if j.status == JobStatus.FAILED]
234 incomplete_jobs = [
235 j.name
236 for j in all_jobs
237 if j.status not in [JobStatus.COMPLETED, JobStatus.FAILED]
238 ]
240 if failed_jobs:
241 logger.error(f"❌ Jobs failed: {failed_jobs}")
242 raise RuntimeError(f"Workflow execution failed: {failed_jobs}")
244 if incomplete_jobs:
245 logger.error(f"❌ Jobs did not complete: {incomplete_jobs}")
246 raise RuntimeError(f"Workflow execution incomplete: {incomplete_jobs}")
248 logger.success(f"🎉 Workflow {self.workflow.name} completed!!")
250 for callback in self.callbacks:
251 callback.on_workflow_completed(self.workflow)
253 return results
255 def execute_from_yaml(self, yaml_path: str | Path) -> dict[str, RunnableJobType]:
256 """Load and execute a workflow from YAML file.
258 Args:
259 yaml_path: Path to YAML workflow file.
261 Returns:
262 Dictionary mapping job names to completed Job instances.
263 """
264 logger.info(f"Loading workflow from {yaml_path}")
265 runner = self.from_yaml(yaml_path)
266 return runner.run()
268 @staticmethod
269 def parse_job(data: dict[str, Any]) -> RunnableJobType:
270 if data.get("path") and data.get("command"):
271 raise WorkflowValidationError("Job cannot have both 'path' and 'command'")
273 base = {"name": data["name"], "depends_on": data.get("depends_on", [])}
275 if data.get("path"):
276 return ShellJob.model_validate({**base, "path": data["path"]})
278 resource = JobResource.model_validate(data.get("resources", {}))
279 environment = JobEnvironment.model_validate(data.get("environment", {}))
281 job_data = {
282 **base,
283 "command": data["command"],
284 "resources": resource,
285 "environment": environment,
286 }
287 if data.get("log_dir"):
288 job_data["log_dir"] = data["log_dir"]
289 if data.get("work_dir"):
290 job_data["work_dir"] = data["work_dir"]
292 return Job.model_validate(job_data)
295def run_workflow_from_file(yaml_path: str | Path) -> dict[str, RunnableJobType]:
296 """Convenience function to run workflow from YAML file.
298 Args:
299 yaml_path: Path to YAML workflow file.
301 Returns:
302 Dictionary mapping job names to completed Job instances.
303 """
304 runner = WorkflowRunner.from_yaml(yaml_path)
305 return runner.run()