Coverage for src/srunx/config.py: 85%

133 statements  

« prev     ^ index     » next       coverage.py v7.9.1, created at 2025-09-04 07:47 +0000

1"""Configuration management for srunx.""" 

2 

3import json 

4import os 

5from pathlib import Path 

6from typing import Any 

7 

8from pydantic import BaseModel, Field 

9 

10from srunx.logging import get_logger 

11 

12logger = get_logger(__name__) 

13 

14 

15class ResourceDefaults(BaseModel): 

16 """Default resource configuration.""" 

17 

18 nodes: int = Field(default=1, ge=1, description="Default number of compute nodes") 

19 gpus_per_node: int = Field( 

20 default=0, ge=0, description="Default number of GPUs per node" 

21 ) 

22 ntasks_per_node: int = Field( 

23 default=1, ge=1, description="Default number of tasks per node" 

24 ) 

25 cpus_per_task: int = Field( 

26 default=1, ge=1, description="Default number of CPUs per task" 

27 ) 

28 memory_per_node: str | None = Field( 

29 default=None, description="Default memory per node" 

30 ) 

31 time_limit: str | None = Field(default=None, description="Default time limit") 

32 nodelist: str | None = Field(default=None, description="Default nodelist") 

33 partition: str | None = Field(default=None, description="Default partition") 

34 

35 

36class EnvironmentDefaults(BaseModel): 

37 """Default environment configuration.""" 

38 

39 conda: str | None = Field(default=None, description="Default conda environment") 

40 venv: str | None = Field( 

41 default=None, description="Default virtual environment path" 

42 ) 

43 sqsh: str | None = Field(default=None, description="Default SquashFS image path") 

44 env_vars: dict[str, str] = Field( 

45 default_factory=dict, description="Default environment variables" 

46 ) 

47 

48 

49class SrunxConfig(BaseModel): 

50 """Main srunx configuration.""" 

51 

52 resources: ResourceDefaults = Field(default_factory=ResourceDefaults) 

53 environment: EnvironmentDefaults = Field(default_factory=EnvironmentDefaults) 

54 log_dir: str = Field(default="logs", description="Default log directory") 

55 work_dir: str | None = Field(default=None, description="Default working directory") 

56 

57 

58def get_config_paths() -> list[Path]: 

59 """Get configuration file paths in order of precedence (lowest to highest).""" 

60 paths = [] 

61 

62 # System-wide config (for pip installations) 

63 # On Unix: /etc/srunx/config.json 

64 # On Windows: C:\ProgramData\srunx\config.json 

65 if os.name == "posix": 

66 paths.append(Path("/etc/srunx/config.json")) 

67 else: 

68 paths.append(Path("C:/ProgramData/srunx/config.json")) 

69 

70 # User-wide config 

71 # On Unix: ~/.config/srunx/config.json 

72 # On Windows: ~/AppData/Roaming/srunx/config.json 

73 if os.name == "posix": 

74 user_config_dir = Path.home() / ".config" / "srunx" 

75 else: 

76 user_config_dir = Path.home() / "AppData" / "Roaming" / "srunx" 

77 paths.append(user_config_dir / "config.json") 

78 

79 # Project-wide config (current working directory) 

80 paths.append(Path.cwd() / ".srunx.json") 

81 paths.append(Path.cwd() / "srunx.json") 

82 

83 return paths 

84 

85 

86def load_config_from_file(config_path: Path) -> dict[str, Any]: 

87 """Load configuration from a JSON file.""" 

88 try: 

89 if config_path.exists(): 

90 logger.debug(f"Loading config from {config_path}") 

91 with open(config_path, encoding="utf-8") as f: 

92 return json.load(f) 

93 except (OSError, json.JSONDecodeError) as e: 

94 logger.warning(f"Failed to load config from {config_path}: {e}") 

95 return {} 

96 

97 

98def merge_config(base: dict[str, Any], override: dict[str, Any]) -> dict[str, Any]: 

99 """Recursively merge configuration dictionaries.""" 

100 result = base.copy() 

101 

102 for key, value in override.items(): 

103 if key in result and isinstance(result[key], dict) and isinstance(value, dict): 

104 result[key] = merge_config(result[key], value) 

105 else: 

106 result[key] = value 

107 

108 return result 

109 

110 

111def load_config_from_env() -> dict[str, Any]: 

112 """Load configuration from environment variables.""" 

113 config: dict[str, Any] = {} 

114 

115 # Resource defaults from environment 

116 resources: dict[str, Any] = {} 

117 if nodes := os.getenv("SRUNX_DEFAULT_NODES"): 

118 try: 

119 resources["nodes"] = int(nodes) 

120 except ValueError: 

121 logger.warning(f"Invalid SRUNX_DEFAULT_NODES value: {nodes}") 

122 

123 if gpus := os.getenv("SRUNX_DEFAULT_GPUS_PER_NODE"): 

124 try: 

125 resources["gpus_per_node"] = int(gpus) 

126 except ValueError: 

127 logger.warning(f"Invalid SRUNX_DEFAULT_GPUS_PER_NODE value: {gpus}") 

128 

129 if ntasks := os.getenv("SRUNX_DEFAULT_NTASKS_PER_NODE"): 

130 try: 

131 resources["ntasks_per_node"] = int(ntasks) 

132 except ValueError: 

133 logger.warning(f"Invalid SRUNX_DEFAULT_NTASKS_PER_NODE value: {ntasks}") 

134 

135 if cpus := os.getenv("SRUNX_DEFAULT_CPUS_PER_TASK"): 

136 try: 

137 resources["cpus_per_task"] = int(cpus) 

138 except ValueError: 

139 logger.warning(f"Invalid SRUNX_DEFAULT_CPUS_PER_TASK value: {cpus}") 

140 

141 if memory := os.getenv("SRUNX_DEFAULT_MEMORY_PER_NODE"): 

142 resources["memory_per_node"] = memory 

143 

144 if time_limit := os.getenv("SRUNX_DEFAULT_TIME_LIMIT"): 

145 resources["time_limit"] = time_limit 

146 

147 if nodelist := os.getenv("SRUNX_DEFAULT_NODELIST"): 

148 resources["nodelist"] = nodelist 

149 

150 if partition := os.getenv("SRUNX_DEFAULT_PARTITION"): 

151 resources["partition"] = partition 

152 

153 if resources: 

154 config["resources"] = resources 

155 

156 # Environment defaults from environment 

157 environment: dict[str, Any] = {} 

158 if conda := os.getenv("SRUNX_DEFAULT_CONDA"): 

159 environment["conda"] = conda 

160 

161 if venv := os.getenv("SRUNX_DEFAULT_VENV"): 

162 environment["venv"] = venv 

163 

164 if sqsh := os.getenv("SRUNX_DEFAULT_SQSH"): 

165 environment["sqsh"] = sqsh 

166 

167 if environment: 

168 config["environment"] = environment 

169 

170 # General defaults from environment 

171 if log_dir := os.getenv("SRUNX_DEFAULT_LOG_DIR"): 

172 config["log_dir"] = log_dir 

173 

174 if work_dir := os.getenv("SRUNX_DEFAULT_WORK_DIR"): 

175 config["work_dir"] = work_dir 

176 

177 return config 

178 

179 

180def load_config() -> SrunxConfig: 

181 """Load configuration from all sources in order of precedence.""" 

182 # Start with empty config 

183 config_data: dict[str, Any] = {} 

184 

185 # Load from config files (lowest to highest precedence) 

186 for config_path in get_config_paths(): 

187 file_config = load_config_from_file(config_path) 

188 if file_config: 

189 config_data = merge_config(config_data, file_config) 

190 

191 # Override with environment variables (highest precedence) 

192 env_config = load_config_from_env() 

193 if env_config: 

194 config_data = merge_config(config_data, env_config) 

195 

196 # Create and validate config 

197 try: 

198 return SrunxConfig.model_validate(config_data) 

199 except Exception as e: 

200 logger.warning(f"Failed to validate config: {e}. Using defaults.") 

201 return SrunxConfig() 

202 

203 

204def save_user_config(config: SrunxConfig) -> None: 

205 """Save configuration to user config file.""" 

206 config_paths = get_config_paths() 

207 # Use the user-wide config path (second in the list) 

208 user_config_path = config_paths[1] 

209 

210 # Create directory if it doesn't exist 

211 user_config_path.parent.mkdir(parents=True, exist_ok=True) 

212 

213 # Save config 

214 try: 

215 with open(user_config_path, "w", encoding="utf-8") as f: 

216 json.dump(config.model_dump(exclude_unset=True), f, indent=2) 

217 logger.info(f"Configuration saved to {user_config_path}") 

218 except OSError as e: 

219 logger.error(f"Failed to save config to {user_config_path}: {e}") 

220 

221 

222def create_example_config() -> str: 

223 """Create an example configuration file content.""" 

224 example_config = { 

225 "resources": { 

226 "nodes": 1, 

227 "gpus_per_node": 1, 

228 "ntasks_per_node": 1, 

229 "cpus_per_task": 8, 

230 "memory_per_node": "32GB", 

231 "time_limit": "2:00:00", 

232 "partition": "gpu", 

233 }, 

234 "environment": { 

235 "conda": "ml_env", 

236 "env_vars": {"CUDA_VISIBLE_DEVICES": "0", "OMP_NUM_THREADS": "8"}, 

237 }, 

238 "log_dir": "slurm_logs", 

239 "work_dir": "/scratch/username", 

240 } 

241 return json.dumps(example_config, indent=2) 

242 

243 

244# Global config instance 

245_config: SrunxConfig | None = None 

246 

247 

248def get_config(reload: bool = False) -> SrunxConfig: 

249 """Get the global configuration instance.""" 

250 global _config 

251 if _config is None or reload: 

252 _config = load_config() 

253 return _config