gaitsetpy.classification.models
1from .random_forest import RandomForestModel 2from .mlp import MLPModel 3 4# Optional PyTorch-dependent imports 5try: 6 from .lstm import LSTMModel 7 from .bilstm import BiLSTMModel 8 from .gnn import GNNModel 9 from .cnn import CNNModel 10 PYTORCH_AVAILABLE = True 11except ImportError: 12 LSTMModel = None 13 BiLSTMModel = None 14 GNNModel = None 15 CNNModel = None 16 PYTORCH_AVAILABLE = False 17 18def get_classification_model(name: str, **kwargs): 19 """ 20 Factory function to get a classification model by name. 21 22 Args: 23 name (str): Name of the model. One of: 'random_forest', 'mlp', 'lstm', 'bilstm', 'gnn', 'cnn'. 24 **kwargs: Model-specific parameters. 25 26 Returns: 27 An instance of the requested model. 28 29 Raises: 30 ValueError: If the model name is not recognized or PyTorch is not available for PyTorch models. 31 32 Example: 33 model = get_classification_model('cnn', input_channels=20, num_classes=4) 34 """ 35 name = name.lower() 36 if name == 'random_forest': 37 return RandomForestModel(**kwargs) 38 elif name == 'mlp': 39 return MLPModel(**kwargs) 40 elif name == 'lstm': 41 if not PYTORCH_AVAILABLE or LSTMModel is None: 42 raise ImportError("LSTM model requires PyTorch. Please install PyTorch to use this model.") 43 return LSTMModel(**kwargs) 44 elif name == 'bilstm': 45 if not PYTORCH_AVAILABLE or BiLSTMModel is None: 46 raise ImportError("BiLSTM model requires PyTorch. Please install PyTorch to use this model.") 47 return BiLSTMModel(**kwargs) 48 elif name == 'gnn': 49 if not PYTORCH_AVAILABLE or GNNModel is None: 50 raise ImportError("GNN model requires PyTorch. Please install PyTorch to use this model.") 51 return GNNModel(**kwargs) 52 elif name == 'cnn': 53 if not PYTORCH_AVAILABLE or CNNModel is None: 54 raise ImportError("CNN model requires PyTorch. Please install PyTorch to use this model.") 55 return CNNModel(**kwargs) 56 else: 57 available_models = ['random_forest', 'mlp'] 58 if PYTORCH_AVAILABLE: 59 available_models.extend(['lstm', 'bilstm', 'gnn', 'cnn']) 60 raise ValueError(f"Unknown model name: {name}. Supported: {available_models}.")
def
get_classification_model(name: str, **kwargs):
19def get_classification_model(name: str, **kwargs): 20 """ 21 Factory function to get a classification model by name. 22 23 Args: 24 name (str): Name of the model. One of: 'random_forest', 'mlp', 'lstm', 'bilstm', 'gnn', 'cnn'. 25 **kwargs: Model-specific parameters. 26 27 Returns: 28 An instance of the requested model. 29 30 Raises: 31 ValueError: If the model name is not recognized or PyTorch is not available for PyTorch models. 32 33 Example: 34 model = get_classification_model('cnn', input_channels=20, num_classes=4) 35 """ 36 name = name.lower() 37 if name == 'random_forest': 38 return RandomForestModel(**kwargs) 39 elif name == 'mlp': 40 return MLPModel(**kwargs) 41 elif name == 'lstm': 42 if not PYTORCH_AVAILABLE or LSTMModel is None: 43 raise ImportError("LSTM model requires PyTorch. Please install PyTorch to use this model.") 44 return LSTMModel(**kwargs) 45 elif name == 'bilstm': 46 if not PYTORCH_AVAILABLE or BiLSTMModel is None: 47 raise ImportError("BiLSTM model requires PyTorch. Please install PyTorch to use this model.") 48 return BiLSTMModel(**kwargs) 49 elif name == 'gnn': 50 if not PYTORCH_AVAILABLE or GNNModel is None: 51 raise ImportError("GNN model requires PyTorch. Please install PyTorch to use this model.") 52 return GNNModel(**kwargs) 53 elif name == 'cnn': 54 if not PYTORCH_AVAILABLE or CNNModel is None: 55 raise ImportError("CNN model requires PyTorch. Please install PyTorch to use this model.") 56 return CNNModel(**kwargs) 57 else: 58 available_models = ['random_forest', 'mlp'] 59 if PYTORCH_AVAILABLE: 60 available_models.extend(['lstm', 'bilstm', 'gnn', 'cnn']) 61 raise ValueError(f"Unknown model name: {name}. Supported: {available_models}.")
Factory function to get a classification model by name.
Args: name (str): Name of the model. One of: 'random_forest', 'mlp', 'lstm', 'bilstm', 'gnn', 'cnn'. **kwargs: Model-specific parameters.
Returns: An instance of the requested model.
Raises: ValueError: If the model name is not recognized or PyTorch is not available for PyTorch models.
Example: model = get_classification_model('cnn', input_channels=20, num_classes=4)