Coverage for src/srunx/config.py: 85%
133 statements
« prev ^ index » next coverage.py v7.9.1, created at 2025-09-04 07:47 +0000
« prev ^ index » next coverage.py v7.9.1, created at 2025-09-04 07:47 +0000
1"""Configuration management for srunx."""
3import json
4import os
5from pathlib import Path
6from typing import Any
8from pydantic import BaseModel, Field
10from srunx.logging import get_logger
12logger = get_logger(__name__)
15class ResourceDefaults(BaseModel):
16 """Default resource configuration."""
18 nodes: int = Field(default=1, ge=1, description="Default number of compute nodes")
19 gpus_per_node: int = Field(
20 default=0, ge=0, description="Default number of GPUs per node"
21 )
22 ntasks_per_node: int = Field(
23 default=1, ge=1, description="Default number of tasks per node"
24 )
25 cpus_per_task: int = Field(
26 default=1, ge=1, description="Default number of CPUs per task"
27 )
28 memory_per_node: str | None = Field(
29 default=None, description="Default memory per node"
30 )
31 time_limit: str | None = Field(default=None, description="Default time limit")
32 nodelist: str | None = Field(default=None, description="Default nodelist")
33 partition: str | None = Field(default=None, description="Default partition")
36class EnvironmentDefaults(BaseModel):
37 """Default environment configuration."""
39 conda: str | None = Field(default=None, description="Default conda environment")
40 venv: str | None = Field(
41 default=None, description="Default virtual environment path"
42 )
43 sqsh: str | None = Field(default=None, description="Default SquashFS image path")
44 env_vars: dict[str, str] = Field(
45 default_factory=dict, description="Default environment variables"
46 )
49class SrunxConfig(BaseModel):
50 """Main srunx configuration."""
52 resources: ResourceDefaults = Field(default_factory=ResourceDefaults)
53 environment: EnvironmentDefaults = Field(default_factory=EnvironmentDefaults)
54 log_dir: str = Field(default="logs", description="Default log directory")
55 work_dir: str | None = Field(default=None, description="Default working directory")
58def get_config_paths() -> list[Path]:
59 """Get configuration file paths in order of precedence (lowest to highest)."""
60 paths = []
62 # System-wide config (for pip installations)
63 # On Unix: /etc/srunx/config.json
64 # On Windows: C:\ProgramData\srunx\config.json
65 if os.name == "posix":
66 paths.append(Path("/etc/srunx/config.json"))
67 else:
68 paths.append(Path("C:/ProgramData/srunx/config.json"))
70 # User-wide config
71 # On Unix: ~/.config/srunx/config.json
72 # On Windows: ~/AppData/Roaming/srunx/config.json
73 if os.name == "posix":
74 user_config_dir = Path.home() / ".config" / "srunx"
75 else:
76 user_config_dir = Path.home() / "AppData" / "Roaming" / "srunx"
77 paths.append(user_config_dir / "config.json")
79 # Project-wide config (current working directory)
80 paths.append(Path.cwd() / ".srunx.json")
81 paths.append(Path.cwd() / "srunx.json")
83 return paths
86def load_config_from_file(config_path: Path) -> dict[str, Any]:
87 """Load configuration from a JSON file."""
88 try:
89 if config_path.exists():
90 logger.debug(f"Loading config from {config_path}")
91 with open(config_path, encoding="utf-8") as f:
92 return json.load(f)
93 except (OSError, json.JSONDecodeError) as e:
94 logger.warning(f"Failed to load config from {config_path}: {e}")
95 return {}
98def merge_config(base: dict[str, Any], override: dict[str, Any]) -> dict[str, Any]:
99 """Recursively merge configuration dictionaries."""
100 result = base.copy()
102 for key, value in override.items():
103 if key in result and isinstance(result[key], dict) and isinstance(value, dict):
104 result[key] = merge_config(result[key], value)
105 else:
106 result[key] = value
108 return result
111def load_config_from_env() -> dict[str, Any]:
112 """Load configuration from environment variables."""
113 config: dict[str, Any] = {}
115 # Resource defaults from environment
116 resources: dict[str, Any] = {}
117 if nodes := os.getenv("SRUNX_DEFAULT_NODES"):
118 try:
119 resources["nodes"] = int(nodes)
120 except ValueError:
121 logger.warning(f"Invalid SRUNX_DEFAULT_NODES value: {nodes}")
123 if gpus := os.getenv("SRUNX_DEFAULT_GPUS_PER_NODE"):
124 try:
125 resources["gpus_per_node"] = int(gpus)
126 except ValueError:
127 logger.warning(f"Invalid SRUNX_DEFAULT_GPUS_PER_NODE value: {gpus}")
129 if ntasks := os.getenv("SRUNX_DEFAULT_NTASKS_PER_NODE"):
130 try:
131 resources["ntasks_per_node"] = int(ntasks)
132 except ValueError:
133 logger.warning(f"Invalid SRUNX_DEFAULT_NTASKS_PER_NODE value: {ntasks}")
135 if cpus := os.getenv("SRUNX_DEFAULT_CPUS_PER_TASK"):
136 try:
137 resources["cpus_per_task"] = int(cpus)
138 except ValueError:
139 logger.warning(f"Invalid SRUNX_DEFAULT_CPUS_PER_TASK value: {cpus}")
141 if memory := os.getenv("SRUNX_DEFAULT_MEMORY_PER_NODE"):
142 resources["memory_per_node"] = memory
144 if time_limit := os.getenv("SRUNX_DEFAULT_TIME_LIMIT"):
145 resources["time_limit"] = time_limit
147 if nodelist := os.getenv("SRUNX_DEFAULT_NODELIST"):
148 resources["nodelist"] = nodelist
150 if partition := os.getenv("SRUNX_DEFAULT_PARTITION"):
151 resources["partition"] = partition
153 if resources:
154 config["resources"] = resources
156 # Environment defaults from environment
157 environment: dict[str, Any] = {}
158 if conda := os.getenv("SRUNX_DEFAULT_CONDA"):
159 environment["conda"] = conda
161 if venv := os.getenv("SRUNX_DEFAULT_VENV"):
162 environment["venv"] = venv
164 if sqsh := os.getenv("SRUNX_DEFAULT_SQSH"):
165 environment["sqsh"] = sqsh
167 if environment:
168 config["environment"] = environment
170 # General defaults from environment
171 if log_dir := os.getenv("SRUNX_DEFAULT_LOG_DIR"):
172 config["log_dir"] = log_dir
174 if work_dir := os.getenv("SRUNX_DEFAULT_WORK_DIR"):
175 config["work_dir"] = work_dir
177 return config
180def load_config() -> SrunxConfig:
181 """Load configuration from all sources in order of precedence."""
182 # Start with empty config
183 config_data: dict[str, Any] = {}
185 # Load from config files (lowest to highest precedence)
186 for config_path in get_config_paths():
187 file_config = load_config_from_file(config_path)
188 if file_config:
189 config_data = merge_config(config_data, file_config)
191 # Override with environment variables (highest precedence)
192 env_config = load_config_from_env()
193 if env_config:
194 config_data = merge_config(config_data, env_config)
196 # Create and validate config
197 try:
198 return SrunxConfig.model_validate(config_data)
199 except Exception as e:
200 logger.warning(f"Failed to validate config: {e}. Using defaults.")
201 return SrunxConfig()
204def save_user_config(config: SrunxConfig) -> None:
205 """Save configuration to user config file."""
206 config_paths = get_config_paths()
207 # Use the user-wide config path (second in the list)
208 user_config_path = config_paths[1]
210 # Create directory if it doesn't exist
211 user_config_path.parent.mkdir(parents=True, exist_ok=True)
213 # Save config
214 try:
215 with open(user_config_path, "w", encoding="utf-8") as f:
216 json.dump(config.model_dump(exclude_unset=True), f, indent=2)
217 logger.info(f"Configuration saved to {user_config_path}")
218 except OSError as e:
219 logger.error(f"Failed to save config to {user_config_path}: {e}")
222def create_example_config() -> str:
223 """Create an example configuration file content."""
224 example_config = {
225 "resources": {
226 "nodes": 1,
227 "gpus_per_node": 1,
228 "ntasks_per_node": 1,
229 "cpus_per_task": 8,
230 "memory_per_node": "32GB",
231 "time_limit": "2:00:00",
232 "partition": "gpu",
233 },
234 "environment": {
235 "conda": "ml_env",
236 "env_vars": {"CUDA_VISIBLE_DEVICES": "0", "OMP_NUM_THREADS": "8"},
237 },
238 "log_dir": "slurm_logs",
239 "work_dir": "/scratch/username",
240 }
241 return json.dumps(example_config, indent=2)
244# Global config instance
245_config: SrunxConfig | None = None
248def get_config(reload: bool = False) -> SrunxConfig:
249 """Get the global configuration instance."""
250 global _config
251 if _config is None or reload:
252 _config = load_config()
253 return _config