gaitsetpy.classification.models.cnn

  1import torch
  2import torch.nn as nn
  3import torch.optim as optim
  4import numpy as np
  5from typing import List, Dict, Any, Optional, Union
  6from ...core.base_classes import BaseClassificationModel
  7from ..utils.preprocess import preprocess_features
  8from sklearn.model_selection import train_test_split
  9from sklearn.metrics import accuracy_score, confusion_matrix, classification_report
 10
 11class SimpleCNN(nn.Module):
 12    def __init__(self, input_channels, num_classes, seq_len=1):
 13        super(SimpleCNN, self).__init__()
 14        self.conv1 = nn.Conv1d(input_channels, 32, kernel_size=3, padding=1)
 15        self.relu = nn.ReLU()
 16        self.pool = nn.AdaptiveMaxPool1d(1)
 17        self.fc = nn.Linear(32, num_classes)
 18    def forward(self, x):
 19        # x: (batch, channels, seq_len)
 20        x = self.conv1(x)
 21        x = self.relu(x)
 22        x = self.pool(x)
 23        x = x.view(x.size(0), -1)
 24        x = self.fc(x)
 25        return x
 26
 27class CNNModel(BaseClassificationModel):
 28    """
 29    Simple 1D CNN classification model using PyTorch.
 30    Implements the BaseClassificationModel interface.
 31    """
 32    def __init__(self, input_channels=10, num_classes=2, lr=0.001, epochs=20, batch_size=32, device=None):
 33        super().__init__(
 34            name="cnn",
 35            description="1D CNN classifier for gait data classification"
 36        )
 37        self.config = {
 38            'input_channels': input_channels,
 39            'num_classes': num_classes,
 40            'lr': lr,
 41            'epochs': epochs,
 42            'batch_size': batch_size
 43        }
 44        self.device = device or ("cuda" if torch.cuda.is_available() else "cpu")
 45        self.model = SimpleCNN(input_channels, num_classes).to(self.device)
 46        self.epochs = epochs
 47        self.batch_size = batch_size
 48        self.trained = False
 49        self.feature_names = []
 50        self.class_names = []
 51
 52    def train(self, features: List[Dict], **kwargs):
 53        X, y = preprocess_features(features)
 54        # Reshape X for CNN: (samples, channels, seq_len)
 55        # Here, treat each feature vector as a channel with seq_len=1
 56        X = X.reshape((X.shape[0], X.shape[1], 1))
 57        self.feature_names = [f"feature_{i}" for i in range(X.shape[1])]
 58        self.class_names = list(set(y))
 59        test_size = kwargs.get('test_size', 0.2)
 60        validation_split = kwargs.get('validation_split', True)
 61        if validation_split:
 62            X_train, X_test, y_train, y_test = train_test_split(
 63                X, y, test_size=test_size, random_state=42
 64            )
 65            self.X_test = X_test
 66            self.y_test = y_test
 67        else:
 68            X_train, y_train = X, y
 69        X_train = torch.tensor(X_train, dtype=torch.float32).to(self.device)
 70        y_train = torch.tensor(y_train, dtype=torch.long).to(self.device)
 71        criterion = nn.CrossEntropyLoss()
 72        optimizer = optim.Adam(self.model.parameters(), lr=self.config['lr'])
 73        for epoch in range(self.epochs):
 74            self.model.train()
 75            optimizer.zero_grad()
 76            outputs = self.model(X_train)
 77            loss = criterion(outputs, y_train)
 78            loss.backward()
 79            optimizer.step()
 80            if (epoch+1) % 5 == 0 or epoch == 0:
 81                print(f"Epoch [{epoch+1}/{self.epochs}], Loss: {loss.item():.4f}")
 82        self.trained = True
 83        print("CNN model trained successfully.")
 84
 85    def predict(self, features: List[Dict], **kwargs) -> np.ndarray:
 86        if not self.trained:
 87            raise ValueError("Model must be trained before making predictions")
 88        X, _ = preprocess_features(features)
 89        X = X.reshape((X.shape[0], X.shape[1], 1))
 90        X = torch.tensor(X, dtype=torch.float32).to(self.device)
 91        self.model.eval()
 92        with torch.no_grad():
 93            outputs = self.model(X)
 94            _, predicted = torch.max(outputs.data, 1)
 95        return predicted.cpu().numpy()
 96
 97    def evaluate(self, features: List[Dict], **kwargs) -> Dict[str, float]:
 98        if not self.trained:
 99            raise ValueError("Model must be trained before evaluation")
100        if hasattr(self, 'X_test') and hasattr(self, 'y_test'):
101            X_test, y_test = self.X_test, self.y_test
102        else:
103            X_test, y_test = preprocess_features(features)
104            X_test = X_test.reshape((X_test.shape[0], X_test.shape[1], 1))
105        X_test = torch.tensor(X_test, dtype=torch.float32).to(self.device)
106        y_test = np.array(y_test)
107        self.model.eval()
108        with torch.no_grad():
109            outputs = self.model(X_test)
110            _, y_pred = torch.max(outputs.data, 1)
111        y_pred = y_pred.cpu().numpy()
112        accuracy = accuracy_score(y_test, y_pred)
113        conf_matrix = confusion_matrix(y_test, y_pred)
114        metrics = {
115            'accuracy': accuracy,
116            'confusion_matrix': conf_matrix.tolist()
117        }
118        detailed_report = kwargs.get('detailed_report', False)
119        if detailed_report:
120            class_report = classification_report(y_test, y_pred, output_dict=True)
121            metrics['classification_report'] = class_report
122        return metrics
123
124    def save_model(self, filepath: str):
125        if not self.trained:
126            raise ValueError("Model must be trained before saving")
127        torch.save({
128            'model_state_dict': self.model.state_dict(),
129            'config': self.config,
130            'feature_names': self.feature_names,
131            'class_names': self.class_names,
132            'trained': self.trained
133        }, filepath)
134        print(f"CNN model saved to {filepath}")
135
136    def load_model(self, filepath: str):
137        checkpoint = torch.load(filepath, map_location=self.device)
138        self.model = SimpleCNN(
139            self.config['input_channels'],
140            self.config['num_classes']
141        ).to(self.device)
142        self.model.load_state_dict(checkpoint['model_state_dict'])
143        self.config = checkpoint.get('config', self.config)
144        self.feature_names = checkpoint.get('feature_names', [])
145        self.class_names = checkpoint.get('class_names', [])
146        self.trained = checkpoint.get('trained', True)
147        print(f"CNN model loaded from {filepath}") 
class SimpleCNN(torch.nn.modules.module.Module):
12class SimpleCNN(nn.Module):
13    def __init__(self, input_channels, num_classes, seq_len=1):
14        super(SimpleCNN, self).__init__()
15        self.conv1 = nn.Conv1d(input_channels, 32, kernel_size=3, padding=1)
16        self.relu = nn.ReLU()
17        self.pool = nn.AdaptiveMaxPool1d(1)
18        self.fc = nn.Linear(32, num_classes)
19    def forward(self, x):
20        # x: (batch, channels, seq_len)
21        x = self.conv1(x)
22        x = self.relu(x)
23        x = self.pool(x)
24        x = x.view(x.size(0), -1)
25        x = self.fc(x)
26        return x

Base class for all neural network modules.

Your models should also subclass this class.

Modules can also contain other Modules, allowing them to be nested in a tree structure. You can assign the submodules as regular attributes::

import torch.nn as nn
import torch.nn.functional as F


class Model(nn.Module):
    def __init__(self) -> None:
        super().__init__()
        self.conv1 = nn.Conv2d(1, 20, 5)
        self.conv2 = nn.Conv2d(20, 20, 5)

    def forward(self, x):
        x = F.relu(self.conv1(x))
        return F.relu(self.conv2(x))

Submodules assigned in this way will be registered, and will also have their parameters converted when you call to(), etc.

As per the example above, an __init__() call to the parent class must be made before assignment on the child.

:ivar training: Boolean represents whether this module is in training or evaluation mode. :vartype training: bool

SimpleCNN(input_channels, num_classes, seq_len=1)
13    def __init__(self, input_channels, num_classes, seq_len=1):
14        super(SimpleCNN, self).__init__()
15        self.conv1 = nn.Conv1d(input_channels, 32, kernel_size=3, padding=1)
16        self.relu = nn.ReLU()
17        self.pool = nn.AdaptiveMaxPool1d(1)
18        self.fc = nn.Linear(32, num_classes)

Initialize internal Module state, shared by both nn.Module and ScriptModule.

conv1
relu
pool
fc
def forward(self, x):
19    def forward(self, x):
20        # x: (batch, channels, seq_len)
21        x = self.conv1(x)
22        x = self.relu(x)
23        x = self.pool(x)
24        x = x.view(x.size(0), -1)
25        x = self.fc(x)
26        return x

Define the computation performed at every call.

Should be overridden by all subclasses.

Although the recipe for forward pass needs to be defined within this function, one should call the Module instance afterwards instead of this since the former takes care of running the registered hooks while the latter silently ignores them.

 28class CNNModel(BaseClassificationModel):
 29    """
 30    Simple 1D CNN classification model using PyTorch.
 31    Implements the BaseClassificationModel interface.
 32    """
 33    def __init__(self, input_channels=10, num_classes=2, lr=0.001, epochs=20, batch_size=32, device=None):
 34        super().__init__(
 35            name="cnn",
 36            description="1D CNN classifier for gait data classification"
 37        )
 38        self.config = {
 39            'input_channels': input_channels,
 40            'num_classes': num_classes,
 41            'lr': lr,
 42            'epochs': epochs,
 43            'batch_size': batch_size
 44        }
 45        self.device = device or ("cuda" if torch.cuda.is_available() else "cpu")
 46        self.model = SimpleCNN(input_channels, num_classes).to(self.device)
 47        self.epochs = epochs
 48        self.batch_size = batch_size
 49        self.trained = False
 50        self.feature_names = []
 51        self.class_names = []
 52
 53    def train(self, features: List[Dict], **kwargs):
 54        X, y = preprocess_features(features)
 55        # Reshape X for CNN: (samples, channels, seq_len)
 56        # Here, treat each feature vector as a channel with seq_len=1
 57        X = X.reshape((X.shape[0], X.shape[1], 1))
 58        self.feature_names = [f"feature_{i}" for i in range(X.shape[1])]
 59        self.class_names = list(set(y))
 60        test_size = kwargs.get('test_size', 0.2)
 61        validation_split = kwargs.get('validation_split', True)
 62        if validation_split:
 63            X_train, X_test, y_train, y_test = train_test_split(
 64                X, y, test_size=test_size, random_state=42
 65            )
 66            self.X_test = X_test
 67            self.y_test = y_test
 68        else:
 69            X_train, y_train = X, y
 70        X_train = torch.tensor(X_train, dtype=torch.float32).to(self.device)
 71        y_train = torch.tensor(y_train, dtype=torch.long).to(self.device)
 72        criterion = nn.CrossEntropyLoss()
 73        optimizer = optim.Adam(self.model.parameters(), lr=self.config['lr'])
 74        for epoch in range(self.epochs):
 75            self.model.train()
 76            optimizer.zero_grad()
 77            outputs = self.model(X_train)
 78            loss = criterion(outputs, y_train)
 79            loss.backward()
 80            optimizer.step()
 81            if (epoch+1) % 5 == 0 or epoch == 0:
 82                print(f"Epoch [{epoch+1}/{self.epochs}], Loss: {loss.item():.4f}")
 83        self.trained = True
 84        print("CNN model trained successfully.")
 85
 86    def predict(self, features: List[Dict], **kwargs) -> np.ndarray:
 87        if not self.trained:
 88            raise ValueError("Model must be trained before making predictions")
 89        X, _ = preprocess_features(features)
 90        X = X.reshape((X.shape[0], X.shape[1], 1))
 91        X = torch.tensor(X, dtype=torch.float32).to(self.device)
 92        self.model.eval()
 93        with torch.no_grad():
 94            outputs = self.model(X)
 95            _, predicted = torch.max(outputs.data, 1)
 96        return predicted.cpu().numpy()
 97
 98    def evaluate(self, features: List[Dict], **kwargs) -> Dict[str, float]:
 99        if not self.trained:
100            raise ValueError("Model must be trained before evaluation")
101        if hasattr(self, 'X_test') and hasattr(self, 'y_test'):
102            X_test, y_test = self.X_test, self.y_test
103        else:
104            X_test, y_test = preprocess_features(features)
105            X_test = X_test.reshape((X_test.shape[0], X_test.shape[1], 1))
106        X_test = torch.tensor(X_test, dtype=torch.float32).to(self.device)
107        y_test = np.array(y_test)
108        self.model.eval()
109        with torch.no_grad():
110            outputs = self.model(X_test)
111            _, y_pred = torch.max(outputs.data, 1)
112        y_pred = y_pred.cpu().numpy()
113        accuracy = accuracy_score(y_test, y_pred)
114        conf_matrix = confusion_matrix(y_test, y_pred)
115        metrics = {
116            'accuracy': accuracy,
117            'confusion_matrix': conf_matrix.tolist()
118        }
119        detailed_report = kwargs.get('detailed_report', False)
120        if detailed_report:
121            class_report = classification_report(y_test, y_pred, output_dict=True)
122            metrics['classification_report'] = class_report
123        return metrics
124
125    def save_model(self, filepath: str):
126        if not self.trained:
127            raise ValueError("Model must be trained before saving")
128        torch.save({
129            'model_state_dict': self.model.state_dict(),
130            'config': self.config,
131            'feature_names': self.feature_names,
132            'class_names': self.class_names,
133            'trained': self.trained
134        }, filepath)
135        print(f"CNN model saved to {filepath}")
136
137    def load_model(self, filepath: str):
138        checkpoint = torch.load(filepath, map_location=self.device)
139        self.model = SimpleCNN(
140            self.config['input_channels'],
141            self.config['num_classes']
142        ).to(self.device)
143        self.model.load_state_dict(checkpoint['model_state_dict'])
144        self.config = checkpoint.get('config', self.config)
145        self.feature_names = checkpoint.get('feature_names', [])
146        self.class_names = checkpoint.get('class_names', [])
147        self.trained = checkpoint.get('trained', True)
148        print(f"CNN model loaded from {filepath}") 

Simple 1D CNN classification model using PyTorch. Implements the BaseClassificationModel interface.

CNNModel( input_channels=10, num_classes=2, lr=0.001, epochs=20, batch_size=32, device=None)
33    def __init__(self, input_channels=10, num_classes=2, lr=0.001, epochs=20, batch_size=32, device=None):
34        super().__init__(
35            name="cnn",
36            description="1D CNN classifier for gait data classification"
37        )
38        self.config = {
39            'input_channels': input_channels,
40            'num_classes': num_classes,
41            'lr': lr,
42            'epochs': epochs,
43            'batch_size': batch_size
44        }
45        self.device = device or ("cuda" if torch.cuda.is_available() else "cpu")
46        self.model = SimpleCNN(input_channels, num_classes).to(self.device)
47        self.epochs = epochs
48        self.batch_size = batch_size
49        self.trained = False
50        self.feature_names = []
51        self.class_names = []

Initialize the classification model.

Args: name: Name of the classification model description: Description of the classification model

config
device
model
epochs
batch_size
trained
feature_names
class_names
def train(self, features: List[Dict], **kwargs):
53    def train(self, features: List[Dict], **kwargs):
54        X, y = preprocess_features(features)
55        # Reshape X for CNN: (samples, channels, seq_len)
56        # Here, treat each feature vector as a channel with seq_len=1
57        X = X.reshape((X.shape[0], X.shape[1], 1))
58        self.feature_names = [f"feature_{i}" for i in range(X.shape[1])]
59        self.class_names = list(set(y))
60        test_size = kwargs.get('test_size', 0.2)
61        validation_split = kwargs.get('validation_split', True)
62        if validation_split:
63            X_train, X_test, y_train, y_test = train_test_split(
64                X, y, test_size=test_size, random_state=42
65            )
66            self.X_test = X_test
67            self.y_test = y_test
68        else:
69            X_train, y_train = X, y
70        X_train = torch.tensor(X_train, dtype=torch.float32).to(self.device)
71        y_train = torch.tensor(y_train, dtype=torch.long).to(self.device)
72        criterion = nn.CrossEntropyLoss()
73        optimizer = optim.Adam(self.model.parameters(), lr=self.config['lr'])
74        for epoch in range(self.epochs):
75            self.model.train()
76            optimizer.zero_grad()
77            outputs = self.model(X_train)
78            loss = criterion(outputs, y_train)
79            loss.backward()
80            optimizer.step()
81            if (epoch+1) % 5 == 0 or epoch == 0:
82                print(f"Epoch [{epoch+1}/{self.epochs}], Loss: {loss.item():.4f}")
83        self.trained = True
84        print("CNN model trained successfully.")

Train the classification model.

Args: features: List of feature dictionaries **kwargs: Additional arguments for training

def predict(self, features: List[Dict], **kwargs) -> numpy.ndarray:
86    def predict(self, features: List[Dict], **kwargs) -> np.ndarray:
87        if not self.trained:
88            raise ValueError("Model must be trained before making predictions")
89        X, _ = preprocess_features(features)
90        X = X.reshape((X.shape[0], X.shape[1], 1))
91        X = torch.tensor(X, dtype=torch.float32).to(self.device)
92        self.model.eval()
93        with torch.no_grad():
94            outputs = self.model(X)
95            _, predicted = torch.max(outputs.data, 1)
96        return predicted.cpu().numpy()

Make predictions using the trained model.

Args: features: List of feature dictionaries **kwargs: Additional arguments for prediction

Returns: Array of predictions

def evaluate(self, features: List[Dict], **kwargs) -> Dict[str, float]:
 98    def evaluate(self, features: List[Dict], **kwargs) -> Dict[str, float]:
 99        if not self.trained:
100            raise ValueError("Model must be trained before evaluation")
101        if hasattr(self, 'X_test') and hasattr(self, 'y_test'):
102            X_test, y_test = self.X_test, self.y_test
103        else:
104            X_test, y_test = preprocess_features(features)
105            X_test = X_test.reshape((X_test.shape[0], X_test.shape[1], 1))
106        X_test = torch.tensor(X_test, dtype=torch.float32).to(self.device)
107        y_test = np.array(y_test)
108        self.model.eval()
109        with torch.no_grad():
110            outputs = self.model(X_test)
111            _, y_pred = torch.max(outputs.data, 1)
112        y_pred = y_pred.cpu().numpy()
113        accuracy = accuracy_score(y_test, y_pred)
114        conf_matrix = confusion_matrix(y_test, y_pred)
115        metrics = {
116            'accuracy': accuracy,
117            'confusion_matrix': conf_matrix.tolist()
118        }
119        detailed_report = kwargs.get('detailed_report', False)
120        if detailed_report:
121            class_report = classification_report(y_test, y_pred, output_dict=True)
122            metrics['classification_report'] = class_report
123        return metrics

Evaluate the model performance.

Args: features: List of feature dictionaries **kwargs: Additional arguments for evaluation

Returns: Dictionary containing evaluation metrics

def save_model(self, filepath: str):
125    def save_model(self, filepath: str):
126        if not self.trained:
127            raise ValueError("Model must be trained before saving")
128        torch.save({
129            'model_state_dict': self.model.state_dict(),
130            'config': self.config,
131            'feature_names': self.feature_names,
132            'class_names': self.class_names,
133            'trained': self.trained
134        }, filepath)
135        print(f"CNN model saved to {filepath}")

Save the trained model to a file.

Args: filepath: Path to save the model

def load_model(self, filepath: str):
137    def load_model(self, filepath: str):
138        checkpoint = torch.load(filepath, map_location=self.device)
139        self.model = SimpleCNN(
140            self.config['input_channels'],
141            self.config['num_classes']
142        ).to(self.device)
143        self.model.load_state_dict(checkpoint['model_state_dict'])
144        self.config = checkpoint.get('config', self.config)
145        self.feature_names = checkpoint.get('feature_names', [])
146        self.class_names = checkpoint.get('class_names', [])
147        self.trained = checkpoint.get('trained', True)
148        print(f"CNN model loaded from {filepath}") 

Load a trained model from a file.

Args: filepath: Path to the saved model