Coverage for src/srunx/cli/main.py: 58%
314 statements
« prev ^ index » next coverage.py v7.9.1, created at 2025-09-04 07:47 +0000
« prev ^ index » next coverage.py v7.9.1, created at 2025-09-04 07:47 +0000
1"""Main CLI interface for srunx."""
3import argparse
4import os
5import sys
6from pathlib import Path
8from rich.console import Console
9from rich.table import Table
11from srunx.callbacks import SlackCallback
12from srunx.client import Slurm
13from srunx.config import (
14 create_example_config,
15 get_config,
16 get_config_paths,
17)
18from srunx.logging import (
19 configure_cli_logging,
20 configure_workflow_logging,
21 get_logger,
22)
23from srunx.models import Job, JobEnvironment, JobResource
24from srunx.runner import WorkflowRunner
26logger = get_logger(__name__)
29def create_job_parser() -> argparse.ArgumentParser:
30 """Create argument parser for job submission."""
31 # Get configuration defaults
32 config = get_config()
34 parser = argparse.ArgumentParser(
35 description="Submit SLURM jobs with various configurations",
36 formatter_class=argparse.RawDescriptionHelpFormatter,
37 )
39 # Required arguments
40 parser.add_argument(
41 "command",
42 nargs="+",
43 help="Command to execute in the SLURM job",
44 )
46 # Job configuration
47 parser.add_argument(
48 "--name",
49 "--job-name",
50 type=str,
51 default="job",
52 help="Job name (default: %(default)s)",
53 )
54 parser.add_argument(
55 "--log-dir",
56 type=str,
57 default=config.log_dir,
58 help="Log directory (default: %(default)s)",
59 )
60 parser.add_argument(
61 "--work-dir",
62 "--chdir",
63 type=str,
64 default=config.work_dir,
65 help="Working directory for the job",
66 )
68 # Resource configuration
69 resource_group = parser.add_argument_group("Resource Options")
70 resource_group.add_argument(
71 "-N",
72 "--nodes",
73 type=int,
74 default=config.resources.nodes,
75 help="Number of nodes (default: %(default)s)",
76 )
77 resource_group.add_argument(
78 "--gpus-per-node",
79 type=int,
80 default=config.resources.gpus_per_node,
81 help="Number of GPUs per node (default: %(default)s)",
82 )
83 resource_group.add_argument(
84 "--ntasks-per-node",
85 type=int,
86 default=config.resources.ntasks_per_node,
87 help="Number of tasks per node (default: %(default)s)",
88 )
89 resource_group.add_argument(
90 "--cpus-per-task",
91 type=int,
92 default=config.resources.cpus_per_task,
93 help="Number of CPUs per task (default: %(default)s)",
94 )
95 resource_group.add_argument(
96 "--memory",
97 "--mem",
98 type=str,
99 default=config.resources.memory_per_node,
100 help="Memory per node (e.g., '32GB', '1TB') (default: %(default)s)",
101 )
102 resource_group.add_argument(
103 "--time",
104 "--time-limit",
105 type=str,
106 default=config.resources.time_limit,
107 help="Time limit (e.g., '1:00:00', '30:00', '1-12:00:00') (default: %(default)s)",
108 )
109 resource_group.add_argument(
110 "--nodelist",
111 type=str,
112 default=config.resources.nodelist,
113 help="Specific nodes to use (e.g., 'node001,node002') (default: %(default)s)",
114 )
115 resource_group.add_argument(
116 "--partition",
117 type=str,
118 default=config.resources.partition,
119 help="SLURM partition to use (e.g., 'gpu', 'cpu') (default: %(default)s)",
120 )
122 # Environment configuration
123 env_group = parser.add_argument_group("Environment Options")
124 env_group.add_argument(
125 "--conda",
126 type=str,
127 default=config.environment.conda,
128 help="Conda environment name (default: %(default)s)",
129 )
130 env_group.add_argument(
131 "--venv",
132 type=str,
133 default=config.environment.venv,
134 help="Virtual environment path (default: %(default)s)",
135 )
136 env_group.add_argument(
137 "--sqsh",
138 type=str,
139 default=config.environment.sqsh,
140 help="SquashFS image path (default: %(default)s)",
141 )
142 env_group.add_argument(
143 "--env",
144 action="append",
145 dest="env_vars",
146 help="Environment variable KEY=VALUE (can be used multiple times)",
147 )
149 # Execution options
150 exec_group = parser.add_argument_group("Execution Options")
151 exec_group.add_argument(
152 "--template",
153 type=str,
154 help="Path to custom SLURM template file",
155 )
156 exec_group.add_argument(
157 "--wait",
158 action="store_true",
159 help="Wait for job completion",
160 )
161 exec_group.add_argument(
162 "--poll-interval",
163 type=int,
164 default=5,
165 help="Polling interval in seconds when waiting (default: %(default)s)",
166 )
168 # Logging options
169 log_group = parser.add_argument_group("Logging Options")
170 log_group.add_argument(
171 "--log-level",
172 choices=["DEBUG", "INFO", "WARNING", "ERROR"],
173 default="INFO",
174 help="Set logging level (default: %(default)s)",
175 )
176 log_group.add_argument(
177 "--quiet",
178 "-q",
179 action="store_true",
180 help="Only show warnings and errors",
181 )
183 # Callback options
184 callback_group = parser.add_argument_group("Notification Options")
185 callback_group.add_argument(
186 "--slack",
187 action="store_true",
188 help="Send notifications to Slack",
189 )
191 # Misc options
192 misc_group = parser.add_argument_group("Misc Options")
193 misc_group.add_argument(
194 "--verbose",
195 action="store_true",
196 help="Print the rendered content",
197 )
199 return parser
202def create_status_parser() -> argparse.ArgumentParser:
203 """Create argument parser for job status."""
204 parser = argparse.ArgumentParser(
205 description="Check SLURM job status",
206 formatter_class=argparse.RawDescriptionHelpFormatter,
207 )
209 parser.add_argument(
210 "job_id",
211 type=int,
212 help="SLURM job ID to check",
213 )
215 return parser
218def create_queue_parser() -> argparse.ArgumentParser:
219 """Create argument parser for queueing jobs."""
220 parser = argparse.ArgumentParser(
221 description="Queue SLURM jobs",
222 formatter_class=argparse.RawDescriptionHelpFormatter,
223 )
225 parser.add_argument(
226 "--user",
227 "-u",
228 type=str,
229 help="Queue jobs for specific user (default: current user)",
230 )
232 return parser
235def create_cancel_parser() -> argparse.ArgumentParser:
236 """Create argument parser for job cancellation."""
237 parser = argparse.ArgumentParser(
238 description="Cancel SLURM job",
239 formatter_class=argparse.RawDescriptionHelpFormatter,
240 )
242 parser.add_argument(
243 "job_id",
244 type=int,
245 help="SLURM job ID to cancel",
246 )
248 return parser
251def create_main_parser() -> argparse.ArgumentParser:
252 """Create main argument parser with subcommands."""
253 parser = argparse.ArgumentParser(
254 description="srunx - Python library for SLURM job management",
255 formatter_class=argparse.RawDescriptionHelpFormatter,
256 )
258 # Global options
259 parser.add_argument(
260 "--log-level",
261 "-l",
262 choices=["DEBUG", "INFO", "WARNING", "ERROR"],
263 default="INFO",
264 help="Set logging level (default: %(default)s)",
265 )
266 parser.add_argument(
267 "--quiet",
268 "-q",
269 action="store_true",
270 help="Only show warnings and errors",
271 )
273 subparsers = parser.add_subparsers(dest="command", help="Available commands")
275 # Submit command (default)
276 submit_parser = subparsers.add_parser("submit", help="Submit a SLURM job")
277 submit_parser.set_defaults(func=cmd_submit)
278 _copy_parser_args(create_job_parser(), submit_parser)
280 # Status command
281 status_parser = subparsers.add_parser("status", help="Check job status")
282 status_parser.set_defaults(func=cmd_status)
283 _copy_parser_args(create_status_parser(), status_parser)
285 # Queue command
286 queue_parser = subparsers.add_parser("queue", help="Queue jobs")
287 queue_parser.set_defaults(func=cmd_queue)
288 _copy_parser_args(create_queue_parser(), queue_parser)
290 # Cancel command
291 cancel_parser = subparsers.add_parser("cancel", help="Cancel job")
292 cancel_parser.set_defaults(func=cmd_cancel)
293 _copy_parser_args(create_cancel_parser(), cancel_parser)
295 # Flow command
296 flow_parser = subparsers.add_parser("flow", help="Workflow management")
297 flow_parser.set_defaults(func=None) # Will be overridden by subcommands
299 # Flow subcommands
300 flow_subparsers = flow_parser.add_subparsers(
301 dest="flow_command", help="Flow commands"
302 )
304 # Flow run command
305 flow_run_parser = flow_subparsers.add_parser("run", help="Execute workflow")
306 flow_run_parser.set_defaults(func=cmd_flow_run)
307 flow_run_parser.add_argument(
308 "yaml_file",
309 type=str,
310 help="Path to YAML workflow definition file",
311 )
312 flow_run_parser.add_argument(
313 "--dry-run",
314 action="store_true",
315 help="Show what would be executed without running jobs",
316 )
317 flow_run_parser.add_argument(
318 "--slack",
319 action="store_true",
320 help="Send notifications to Slack",
321 )
323 # Flow validate command
324 flow_validate_parser = flow_subparsers.add_parser(
325 "validate", help="Validate workflow"
326 )
327 flow_validate_parser.set_defaults(func=cmd_flow_validate)
328 flow_validate_parser.add_argument(
329 "yaml_file",
330 type=str,
331 help="Path to YAML workflow definition file",
332 )
334 # Config command
335 config_parser = subparsers.add_parser("config", help="Configuration management")
336 config_parser.set_defaults(func=None) # Will be overridden by subcommands
338 # Config subcommands
339 config_subparsers = config_parser.add_subparsers(
340 dest="config_command", help="Configuration commands"
341 )
343 # Config show command
344 config_show_parser = config_subparsers.add_parser(
345 "show", help="Show current configuration"
346 )
347 config_show_parser.set_defaults(func=cmd_config_show)
349 # Config paths command
350 config_paths_parser = config_subparsers.add_parser(
351 "paths", help="Show configuration file paths"
352 )
353 config_paths_parser.set_defaults(func=cmd_config_paths)
355 # Config init command
356 config_init_parser = config_subparsers.add_parser(
357 "init", help="Initialize configuration file"
358 )
359 config_init_parser.set_defaults(func=cmd_config_init)
360 config_init_parser.add_argument(
361 "--global",
362 action="store_true",
363 dest="global_config",
364 help="Create global user config instead of project config",
365 )
367 return parser
370def _copy_parser_args(
371 source_parser: argparse.ArgumentParser, target_parser: argparse.ArgumentParser
372) -> None:
373 """Copy arguments from source parser to target parser."""
374 for action in source_parser._actions:
375 if action.dest == "help":
376 continue
377 target_parser._add_action(action)
380def _parse_env_vars(env_var_list: list[str] | None) -> dict[str, str]:
381 """Parse environment variables from list of KEY=VALUE strings."""
382 env_vars = {}
383 if env_var_list:
384 for env_var in env_var_list:
385 if "=" in env_var:
386 key, value = env_var.split("=", 1)
387 env_vars[key] = value
388 else:
389 logger.warning(f"Invalid environment variable format: {env_var}")
390 return env_vars
393def cmd_submit(args: argparse.Namespace) -> None:
394 """Handle job submission command."""
395 try:
396 # Parse environment variables and merge with config defaults
397 config = get_config()
398 env_vars = config.environment.env_vars.copy()
399 cli_env_vars = _parse_env_vars(getattr(args, "env_vars", None))
400 env_vars.update(cli_env_vars)
402 # Create job configuration
403 resources = JobResource(
404 nodes=args.nodes,
405 gpus_per_node=args.gpus_per_node,
406 ntasks_per_node=args.ntasks_per_node,
407 cpus_per_task=args.cpus_per_task,
408 memory_per_node=getattr(args, "memory", None),
409 time_limit=getattr(args, "time", None),
410 nodelist=getattr(args, "nodelist", None),
411 partition=getattr(args, "partition", None),
412 )
414 # Create environment with explicit handling of defaults
415 # Only pass non-None values to avoid conflicts with validation
416 env_config = {}
417 if args.conda is not None:
418 env_config["conda"] = args.conda
419 if args.venv is not None:
420 env_config["venv"] = args.venv
421 if args.sqsh is not None:
422 env_config["sqsh"] = args.sqsh
423 env_config["env_vars"] = env_vars
425 # If no environment was explicitly set, let JobEnvironment use its defaults
426 if not any([args.conda, args.venv, args.sqsh]):
427 environment = JobEnvironment(env_vars=env_vars)
428 else:
429 environment = JobEnvironment.model_validate(env_config)
431 job_data = {
432 "name": args.name,
433 "command": args.command,
434 "resources": resources,
435 "environment": environment,
436 "log_dir": args.log_dir,
437 }
439 if args.work_dir is not None:
440 job_data["work_dir"] = args.work_dir
442 job = Job.model_validate(job_data)
444 if args.slack:
445 webhook_url = os.getenv("SLACK_WEBHOOK_URL")
446 if not webhook_url:
447 raise ValueError("SLACK_WEBHOOK_URL is not set")
448 callbacks = [SlackCallback(webhook_url=webhook_url)]
449 else:
450 callbacks = []
452 # Submit job
453 client = Slurm(callbacks=callbacks)
454 submitted_job = client.submit(
455 job, getattr(args, "template", None), verbose=args.verbose
456 )
458 logger.info(f"Submitted job {submitted_job.job_id}: {submitted_job.name}")
460 # Wait for completion if requested
461 if getattr(args, "wait", False):
462 logger.info(f"Waiting for job {submitted_job.job_id} to complete...")
463 completed_job = client.monitor(
464 submitted_job, poll_interval=args.poll_interval
465 )
466 status_str = (
467 completed_job.status.value if completed_job.status else "Unknown"
468 )
469 logger.info(
470 f"Job {submitted_job.job_id} completed with status: {status_str}"
471 )
473 except Exception as e:
474 logger.error(f"Error submitting job: {e}")
475 sys.exit(1)
478def cmd_status(args: argparse.Namespace) -> None:
479 """Handle job status command."""
480 try:
481 client = Slurm()
482 job = client.retrieve(args.job_id)
484 logger.info(f"Job ID: {job.job_id}")
485 logger.info(f"Name: {job.name}")
486 if job.status:
487 logger.info(f"Status: {job.status.value}")
488 else:
489 logger.info("Status: Unknown")
491 except Exception as e:
492 logger.error(f"Error getting job status: {e}")
493 sys.exit(1)
496def cmd_queue(args: argparse.Namespace) -> None:
497 """Handle job queueing command."""
498 try:
499 client = Slurm()
500 jobs = client.queue(getattr(args, "user", None))
502 if not jobs:
503 logger.info("No jobs found")
504 return
506 logger.info(f"{'Job ID':<12} {'Name':<20} {'Status':<12}")
507 logger.info("-" * 45)
508 for job in jobs:
509 status_str = job.status.value if job.status else "Unknown"
510 logger.info(f"{job.job_id:<12} {job.name:<20} {status_str:<12}")
512 except Exception as e:
513 logger.error(f"Error queueing jobs: {e}")
514 sys.exit(1)
517def cmd_cancel(args: argparse.Namespace) -> None:
518 """Handle job cancellation command."""
519 try:
520 client = Slurm()
521 client.cancel(args.job_id)
522 logger.info(f"Cancelled job {args.job_id}")
524 except Exception as e:
525 logger.error(f"Error cancelling job: {e}")
526 sys.exit(1)
529def cmd_flow_run(args: argparse.Namespace) -> None:
530 """Handle flow run command."""
531 # Configure logging for workflow execution
532 configure_workflow_logging(level=getattr(args, "log_level", "INFO"))
534 try:
535 yaml_file = Path(args.yaml_file)
536 if not yaml_file.exists():
537 logger.error(f"Workflow file not found: {args.yaml_file}")
538 sys.exit(1)
540 # Setup callbacks if requested
541 callbacks = []
542 if getattr(args, "slack", False):
543 webhook_url = os.getenv("SLACK_WEBHOOK_URL")
544 if not webhook_url:
545 raise ValueError("SLACK_WEBHOOK_URL environment variable is not set")
546 callbacks.append(SlackCallback(webhook_url=webhook_url))
548 runner = WorkflowRunner.from_yaml(yaml_file, callbacks=callbacks)
550 # Validate dependencies
551 runner.workflow.validate()
553 if args.dry_run:
554 runner.workflow.show()
555 return
557 # Execute workflow
558 results = runner.run()
560 logger.success(f"🎉 Workflow {runner.workflow.name} completed!!")
561 table = Table(title=f"Workflow {runner.workflow.name} Summary")
562 table.add_column("Job", justify="left", style="cyan", no_wrap=True)
563 table.add_column("Status", justify="left", style="cyan", no_wrap=True)
564 table.add_column("ID", justify="left", style="cyan", no_wrap=True)
565 for job in results.values():
566 table.add_row(job.name, job.status.value, str(job.job_id))
568 console = Console()
569 console.print(table)
571 except Exception as e:
572 logger.error(f"Workflow execution failed: {e}")
573 sys.exit(1)
576def cmd_flow_validate(args: argparse.Namespace) -> None:
577 """Handle flow validate command."""
578 # Configure logging for workflow validation
579 configure_workflow_logging(level=getattr(args, "log_level", "INFO"))
581 try:
582 yaml_file = Path(args.yaml_file)
583 if not yaml_file.exists():
584 logger.error(f"Workflow file not found: {args.yaml_file}")
585 sys.exit(1)
587 runner = WorkflowRunner.from_yaml(yaml_file)
589 # Validate dependencies
590 runner.workflow.validate()
592 logger.info("Workflow validation successful")
594 except Exception as e:
595 logger.error(f"Workflow validation failed: {e}")
596 sys.exit(1)
599def cmd_config_show(args: argparse.Namespace) -> None:
600 """Handle config show command."""
601 try:
602 config = get_config()
604 console = Console()
605 console.print("[bold cyan]Current Configuration:[/bold cyan]")
607 # Display config in a nice format using Rich
608 table = Table(
609 title="srunx Configuration", show_header=True, header_style="bold magenta"
610 )
611 table.add_column("Section", style="cyan")
612 table.add_column("Setting", style="green")
613 table.add_column("Value", style="yellow")
615 # Resources
616 table.add_row("resources", "nodes", str(config.resources.nodes))
617 table.add_row("", "gpus_per_node", str(config.resources.gpus_per_node))
618 table.add_row("", "ntasks_per_node", str(config.resources.ntasks_per_node))
619 table.add_row("", "cpus_per_task", str(config.resources.cpus_per_task))
620 table.add_row("", "memory_per_node", str(config.resources.memory_per_node))
621 table.add_row("", "time_limit", str(config.resources.time_limit))
622 table.add_row("", "nodelist", str(config.resources.nodelist))
623 table.add_row("", "partition", str(config.resources.partition))
625 # Environment
626 table.add_row("environment", "conda", str(config.environment.conda))
627 table.add_row("", "venv", str(config.environment.venv))
628 table.add_row("", "sqsh", str(config.environment.sqsh))
629 if config.environment.env_vars:
630 for key, value in config.environment.env_vars.items():
631 table.add_row("", f"env_vars.{key}", value)
632 else:
633 table.add_row("", "env_vars", "(empty)")
635 # General
636 table.add_row("general", "log_dir", config.log_dir)
637 table.add_row("", "work_dir", str(config.work_dir))
639 console.print(table)
641 except Exception as e:
642 logger.error(f"Error showing configuration: {e}")
643 sys.exit(1)
646def cmd_config_paths(args: argparse.Namespace) -> None:
647 """Handle config paths command."""
648 try:
649 console = Console()
650 console.print("[bold cyan]Configuration File Paths:[/bold cyan]")
651 console.print("(Listed in order of precedence - last one wins)")
653 paths = get_config_paths()
654 for i, path in enumerate(paths, 1):
655 exists = "✓" if path.exists() else "✗"
656 console.print(f"{i}. [{exists}] {path}")
658 except Exception as e:
659 logger.error(f"Error showing configuration paths: {e}")
660 sys.exit(1)
663def cmd_config_init(args: argparse.Namespace) -> None:
664 """Handle config init command."""
665 try:
666 if getattr(args, "global_config", False):
667 # Create global user config
668 config_paths = get_config_paths()
669 config_path = config_paths[1] # User config path
670 else:
671 # Create project config
672 config_path = Path.cwd() / "srunx.json"
674 if config_path.exists():
675 logger.error(f"Configuration file already exists: {config_path}")
676 sys.exit(1)
678 # Create directory if it doesn't exist
679 config_path.parent.mkdir(parents=True, exist_ok=True)
681 # Write example config
682 example_config = create_example_config()
683 with open(config_path, "w", encoding="utf-8") as f:
684 f.write(example_config)
686 logger.info(f"Configuration file created: {config_path}")
687 logger.info("Edit this file to customize your defaults")
689 except Exception as e:
690 logger.error(f"Error creating configuration file: {e}")
691 sys.exit(1)
694def main() -> None:
695 """Main entry point for the CLI."""
696 parser = create_main_parser()
697 args = parser.parse_args()
699 # Configure logging
700 log_level = getattr(args, "log_level", "INFO")
701 quiet = getattr(args, "quiet", False)
702 configure_cli_logging(level=log_level, quiet=quiet)
704 # If no command specified, default to submit behavior for backward compatibility
705 if not hasattr(args, "func") or args.func is None:
706 # Check if this is a flow command without subcommand
707 if hasattr(args, "command") and args.command == "flow":
708 if not hasattr(args, "flow_command") or args.flow_command is None:
709 logger.error("Flow command requires a subcommand (run or validate)")
710 parser.print_help()
711 sys.exit(1)
712 # Check if this is a config command without subcommand
713 elif hasattr(args, "command") and args.command == "config":
714 if not hasattr(args, "config_command") or args.config_command is None:
715 logger.error(
716 "Config command requires a subcommand (show, paths, or init)"
717 )
718 parser.print_help()
719 sys.exit(1)
720 else:
721 # Try to parse as submit command
722 submit_parser = create_job_parser()
723 try:
724 submit_args = submit_parser.parse_args()
725 cmd_submit(submit_args)
726 except SystemExit:
727 parser.print_help()
728 sys.exit(1)
729 else:
730 args.func(args)
733if __name__ == "__main__":
734 main()