gaitsetpy.classification.models

 1from .random_forest import RandomForestModel
 2from .mlp import MLPModel
 3
 4# Optional PyTorch-dependent imports
 5try:
 6    from .lstm import LSTMModel
 7    from .bilstm import BiLSTMModel
 8    from .gnn import GNNModel
 9    from .cnn import CNNModel
10    PYTORCH_AVAILABLE = True
11except ImportError:
12    LSTMModel = None
13    BiLSTMModel = None
14    GNNModel = None
15    CNNModel = None
16    PYTORCH_AVAILABLE = False
17
18def get_classification_model(name: str, **kwargs):
19    """
20    Factory function to get a classification model by name.
21
22    Args:
23        name (str): Name of the model. One of: 'random_forest', 'mlp', 'lstm', 'bilstm', 'gnn', 'cnn'.
24        **kwargs: Model-specific parameters.
25
26    Returns:
27        An instance of the requested model.
28
29    Raises:
30        ValueError: If the model name is not recognized or PyTorch is not available for PyTorch models.
31
32    Example:
33        model = get_classification_model('cnn', input_channels=20, num_classes=4)
34    """
35    name = name.lower()
36    if name == 'random_forest':
37        return RandomForestModel(**kwargs)
38    elif name == 'mlp':
39        return MLPModel(**kwargs)
40    elif name == 'lstm':
41        if not PYTORCH_AVAILABLE or LSTMModel is None:
42            raise ImportError("LSTM model requires PyTorch. Please install PyTorch to use this model.")
43        return LSTMModel(**kwargs)
44    elif name == 'bilstm':
45        if not PYTORCH_AVAILABLE or BiLSTMModel is None:
46            raise ImportError("BiLSTM model requires PyTorch. Please install PyTorch to use this model.")
47        return BiLSTMModel(**kwargs)
48    elif name == 'gnn':
49        if not PYTORCH_AVAILABLE or GNNModel is None:
50            raise ImportError("GNN model requires PyTorch. Please install PyTorch to use this model.")
51        return GNNModel(**kwargs)
52    elif name == 'cnn':
53        if not PYTORCH_AVAILABLE or CNNModel is None:
54            raise ImportError("CNN model requires PyTorch. Please install PyTorch to use this model.")
55        return CNNModel(**kwargs)
56    else:
57        available_models = ['random_forest', 'mlp']
58        if PYTORCH_AVAILABLE:
59            available_models.extend(['lstm', 'bilstm', 'gnn', 'cnn'])
60        raise ValueError(f"Unknown model name: {name}. Supported: {available_models}.")
def get_classification_model(name: str, **kwargs):
19def get_classification_model(name: str, **kwargs):
20    """
21    Factory function to get a classification model by name.
22
23    Args:
24        name (str): Name of the model. One of: 'random_forest', 'mlp', 'lstm', 'bilstm', 'gnn', 'cnn'.
25        **kwargs: Model-specific parameters.
26
27    Returns:
28        An instance of the requested model.
29
30    Raises:
31        ValueError: If the model name is not recognized or PyTorch is not available for PyTorch models.
32
33    Example:
34        model = get_classification_model('cnn', input_channels=20, num_classes=4)
35    """
36    name = name.lower()
37    if name == 'random_forest':
38        return RandomForestModel(**kwargs)
39    elif name == 'mlp':
40        return MLPModel(**kwargs)
41    elif name == 'lstm':
42        if not PYTORCH_AVAILABLE or LSTMModel is None:
43            raise ImportError("LSTM model requires PyTorch. Please install PyTorch to use this model.")
44        return LSTMModel(**kwargs)
45    elif name == 'bilstm':
46        if not PYTORCH_AVAILABLE or BiLSTMModel is None:
47            raise ImportError("BiLSTM model requires PyTorch. Please install PyTorch to use this model.")
48        return BiLSTMModel(**kwargs)
49    elif name == 'gnn':
50        if not PYTORCH_AVAILABLE or GNNModel is None:
51            raise ImportError("GNN model requires PyTorch. Please install PyTorch to use this model.")
52        return GNNModel(**kwargs)
53    elif name == 'cnn':
54        if not PYTORCH_AVAILABLE or CNNModel is None:
55            raise ImportError("CNN model requires PyTorch. Please install PyTorch to use this model.")
56        return CNNModel(**kwargs)
57    else:
58        available_models = ['random_forest', 'mlp']
59        if PYTORCH_AVAILABLE:
60            available_models.extend(['lstm', 'bilstm', 'gnn', 'cnn'])
61        raise ValueError(f"Unknown model name: {name}. Supported: {available_models}.")

Factory function to get a classification model by name.

Args: name (str): Name of the model. One of: 'random_forest', 'mlp', 'lstm', 'bilstm', 'gnn', 'cnn'. **kwargs: Model-specific parameters.

Returns: An instance of the requested model.

Raises: ValueError: If the model name is not recognized or PyTorch is not available for PyTorch models.

Example: model = get_classification_model('cnn', input_channels=20, num_classes=4)