Coverage for src/srunx/runner.py: 93%

148 statements  

« prev     ^ index     » next       coverage.py v7.9.1, created at 2025-09-04 07:47 +0000

1"""Workflow runner for executing YAML-defined workflows with SLURM""" 

2 

3import time 

4from collections import defaultdict 

5from collections.abc import Sequence 

6from concurrent.futures import ThreadPoolExecutor 

7from pathlib import Path 

8from typing import Any, Self 

9 

10import jinja2 

11import yaml 

12 

13from srunx.callbacks import Callback 

14from srunx.client import Slurm 

15from srunx.exceptions import WorkflowValidationError 

16from srunx.logging import get_logger 

17from srunx.models import ( 

18 Job, 

19 JobEnvironment, 

20 JobResource, 

21 JobStatus, 

22 RunnableJobType, 

23 ShellJob, 

24 Workflow, 

25) 

26 

27logger = get_logger(__name__) 

28 

29 

30class WorkflowRunner: 

31 """Runner for executing workflows defined in YAML with dynamic job scheduling. 

32 

33 Jobs are executed as soon as their dependencies are satisfied, 

34 rather than waiting for entire dependency levels to complete. 

35 """ 

36 

37 def __init__( 

38 self, 

39 workflow: Workflow, 

40 callbacks: Sequence[Callback] | None = None, 

41 args: dict[str, Any] | None = None, 

42 ) -> None: 

43 """Initialize workflow runner. 

44 

45 Args: 

46 workflow: Workflow to execute. 

47 callbacks: List of callbacks for job notifications. 

48 args: Template variables from the YAML args section. 

49 """ 

50 self.workflow = workflow 

51 self.slurm = Slurm(callbacks=callbacks) 

52 self.callbacks = callbacks or [] 

53 self.args = args or {} 

54 

55 @classmethod 

56 def from_yaml( 

57 cls, yaml_path: str | Path, callbacks: Sequence[Callback] | None = None 

58 ) -> Self: 

59 """Load and validate a workflow from a YAML file. 

60 

61 Args: 

62 yaml_path: Path to the YAML workflow definition file. 

63 callbacks: List of callbacks for job notifications. 

64 

65 Returns: 

66 WorkflowRunner instance with loaded workflow. 

67 

68 Raises: 

69 FileNotFoundError: If the YAML file doesn't exist. 

70 yaml.YAMLError: If the YAML is malformed. 

71 ValidationError: If the workflow structure is invalid. 

72 """ 

73 yaml_file = Path(yaml_path) 

74 if not yaml_file.exists(): 

75 raise FileNotFoundError(f"Workflow file not found: {yaml_path}") 

76 

77 with open(yaml_file, encoding="utf-8") as f: 

78 data = yaml.safe_load(f) 

79 

80 name = data.get("name", "unnamed") 

81 args = data.get("args", {}) 

82 jobs_data = data.get("jobs", []) 

83 

84 # Render Jinja templates in jobs_data using args 

85 rendered_jobs_data = cls._render_jobs_with_args(jobs_data, args) 

86 

87 jobs = [] 

88 for job_data in rendered_jobs_data: 

89 job = cls.parse_job(job_data) 

90 jobs.append(job) 

91 return cls( 

92 workflow=Workflow(name=name, jobs=jobs), callbacks=callbacks, args=args 

93 ) 

94 

95 @staticmethod 

96 def _render_jobs_with_args( 

97 jobs_data: list[dict[str, Any]], args: dict[str, Any] 

98 ) -> list[dict[str, Any]]: 

99 """Render Jinja templates in job data using args. 

100 

101 Args: 

102 jobs_data: List of job configurations from YAML. 

103 args: Template variables from the YAML args section. 

104 

105 Returns: 

106 List of job configurations with rendered templates. 

107 """ 

108 if not args: 

109 return jobs_data 

110 

111 # Convert jobs_data to YAML string, render as template, then parse back 

112 jobs_yaml = yaml.dump(jobs_data, default_flow_style=False) 

113 template = jinja2.Template(jobs_yaml, undefined=jinja2.StrictUndefined) 

114 

115 for key, value in args.items(): 

116 if isinstance(value, str): 

117 if value.startswith("python:"): 

118 cmd = value.split(":")[1] 

119 if "datetime" in cmd: 

120 import datetime # noqa: F401 

121 

122 args[key] = eval(cmd) 

123 

124 try: 

125 rendered_yaml = template.render(args) 

126 return yaml.safe_load(rendered_yaml) 

127 except jinja2.TemplateError as e: 

128 logger.error(f"Jinja template rendering failed: {e}") 

129 raise WorkflowValidationError(f"Template rendering failed: {e}") from e 

130 

131 def get_independent_jobs(self) -> list[RunnableJobType]: 

132 """Get all jobs that are independent of any other job.""" 

133 independent_jobs = [] 

134 for job in self.workflow.jobs: 

135 if not job.depends_on: 

136 independent_jobs.append(job) 

137 return independent_jobs 

138 

139 def run(self) -> dict[str, RunnableJobType]: 

140 """Run a workflow with dynamic job scheduling. 

141 

142 Jobs are executed as soon as their dependencies are satisfied. 

143 

144 Returns: 

145 Dictionary mapping job names to completed Job instances. 

146 """ 

147 logger.info( 

148 f"🚀 Starting Workflow {self.workflow.name} with {len(self.workflow.jobs)} jobs" 

149 ) 

150 for callback in self.callbacks: 

151 callback.on_workflow_started(self.workflow) 

152 

153 # Track all jobs and results 

154 all_jobs = self.workflow.jobs.copy() 

155 results: dict[str, RunnableJobType] = {} 

156 running_futures: dict[str, Any] = {} 

157 

158 # Build reverse dependency map for efficient lookups 

159 dependents = defaultdict(set) 

160 for job in all_jobs: 

161 for dep in job.depends_on: 

162 dependents[dep].add(job.name) 

163 

164 def execute_job(job: RunnableJobType) -> RunnableJobType: 

165 """Execute a single job.""" 

166 logger.info(f"🌋 {'SUBMITTED':<12} Job {job.name:<12}") 

167 

168 try: 

169 result = self.slurm.run(job) 

170 return result 

171 except Exception as e: 

172 raise 

173 

174 def on_job_complete(job_name: str, result: RunnableJobType) -> list[str]: 

175 """Handle job completion and return newly ready job names.""" 

176 results[job_name] = result 

177 completed_job_names = list(set(results.keys())) 

178 

179 # Find newly ready jobs 

180 newly_ready = [] 

181 for dependent_name in dependents[job_name]: 

182 dependent_job = next(j for j in all_jobs if j.name == dependent_name) 

183 if ( 

184 dependent_job.status == JobStatus.PENDING 

185 and dependent_job.dependencies_satisfied(completed_job_names) 

186 ): 

187 newly_ready.append(dependent_name) 

188 

189 return newly_ready 

190 

191 # Execute workflow with ThreadPoolExecutor 

192 with ThreadPoolExecutor(max_workers=8) as executor: 

193 # Submit initial ready jobs 

194 initial_jobs = self.get_independent_jobs() 

195 

196 for job in initial_jobs: 

197 future = executor.submit(execute_job, job) 

198 running_futures[job.name] = future 

199 

200 # Process completed jobs and schedule new ones 

201 while running_futures: 

202 # Check for completed futures 

203 completed = [] 

204 for job_name, future in list(running_futures.items()): 

205 if future.done(): 

206 completed.append((job_name, future)) 

207 del running_futures[job_name] 

208 

209 if not completed: 

210 time.sleep(0.1) # Brief sleep to avoid busy waiting 

211 continue 

212 

213 # Handle completed jobs 

214 for job_name, future in completed: 

215 try: 

216 result = future.result() 

217 newly_ready_names = on_job_complete(job_name, result) 

218 

219 # Schedule newly ready jobs 

220 for ready_name in newly_ready_names: 

221 if ready_name not in running_futures: 

222 ready_job = next( 

223 j for j in all_jobs if j.name == ready_name 

224 ) 

225 new_future = executor.submit(execute_job, ready_job) 

226 running_futures[ready_name] = new_future 

227 

228 except Exception as e: 

229 logger.error(f"❌ Job {job_name} failed: {e}") 

230 raise 

231 

232 # Verify all jobs completed successfully 

233 failed_jobs = [j.name for j in all_jobs if j.status == JobStatus.FAILED] 

234 incomplete_jobs = [ 

235 j.name 

236 for j in all_jobs 

237 if j.status not in [JobStatus.COMPLETED, JobStatus.FAILED] 

238 ] 

239 

240 if failed_jobs: 

241 logger.error(f"❌ Jobs failed: {failed_jobs}") 

242 raise RuntimeError(f"Workflow execution failed: {failed_jobs}") 

243 

244 if incomplete_jobs: 

245 logger.error(f"❌ Jobs did not complete: {incomplete_jobs}") 

246 raise RuntimeError(f"Workflow execution incomplete: {incomplete_jobs}") 

247 

248 logger.success(f"🎉 Workflow {self.workflow.name} completed!!") 

249 

250 for callback in self.callbacks: 

251 callback.on_workflow_completed(self.workflow) 

252 

253 return results 

254 

255 def execute_from_yaml(self, yaml_path: str | Path) -> dict[str, RunnableJobType]: 

256 """Load and execute a workflow from YAML file. 

257 

258 Args: 

259 yaml_path: Path to YAML workflow file. 

260 

261 Returns: 

262 Dictionary mapping job names to completed Job instances. 

263 """ 

264 logger.info(f"Loading workflow from {yaml_path}") 

265 runner = self.from_yaml(yaml_path) 

266 return runner.run() 

267 

268 @staticmethod 

269 def parse_job(data: dict[str, Any]) -> RunnableJobType: 

270 if data.get("path") and data.get("command"): 

271 raise WorkflowValidationError("Job cannot have both 'path' and 'command'") 

272 

273 base = {"name": data["name"], "depends_on": data.get("depends_on", [])} 

274 

275 if data.get("path"): 

276 return ShellJob.model_validate({**base, "path": data["path"]}) 

277 

278 resource = JobResource.model_validate(data.get("resources", {})) 

279 environment = JobEnvironment.model_validate(data.get("environment", {})) 

280 

281 job_data = { 

282 **base, 

283 "command": data["command"], 

284 "resources": resource, 

285 "environment": environment, 

286 } 

287 if data.get("log_dir"): 

288 job_data["log_dir"] = data["log_dir"] 

289 if data.get("work_dir"): 

290 job_data["work_dir"] = data["work_dir"] 

291 

292 return Job.model_validate(job_data) 

293 

294 

295def run_workflow_from_file(yaml_path: str | Path) -> dict[str, RunnableJobType]: 

296 """Convenience function to run workflow from YAML file. 

297 

298 Args: 

299 yaml_path: Path to YAML workflow file. 

300 

301 Returns: 

302 Dictionary mapping job names to completed Job instances. 

303 """ 

304 runner = WorkflowRunner.from_yaml(yaml_path) 

305 return runner.run()