gaitsetpy.classification.models.gnn
1import torch 2import torch.nn as nn 3import torch.optim as optim 4import numpy as np 5from typing import List, Dict, Any, Optional, Union 6from ...core.base_classes import BaseClassificationModel 7from ..utils.preprocess import preprocess_features 8from sklearn.model_selection import train_test_split 9from sklearn.metrics import accuracy_score, confusion_matrix, classification_report 10 11class SimpleGCN(nn.Module): 12 def __init__(self, input_dim, hidden_dim, output_dim): 13 super(SimpleGCN, self).__init__() 14 self.fc1 = nn.Linear(input_dim, hidden_dim) 15 self.fc2 = nn.Linear(hidden_dim, output_dim) 16 def forward(self, x, adj): 17 h = torch.relu(self.fc1(torch.matmul(adj, x))) 18 out = self.fc2(torch.matmul(adj, h)) 19 return out 20 21class GNNModel(BaseClassificationModel): 22 """ 23 Simple Graph Neural Network (GCN) classification model using PyTorch. 24 Implements the BaseClassificationModel interface. 25 Expects features as node features and adjacency matrix in kwargs. 26 """ 27 def __init__(self, input_dim=10, hidden_dim=32, output_dim=2, lr=0.001, epochs=20, device=None): 28 super().__init__( 29 name="gnn", 30 description="Graph Convolutional Network (GCN) classifier for gait data classification" 31 ) 32 self.config = { 33 'input_dim': input_dim, 34 'hidden_dim': hidden_dim, 35 'output_dim': output_dim, 36 'lr': lr, 37 'epochs': epochs 38 } 39 self.device = device or ("cuda" if torch.cuda.is_available() else "cpu") 40 self.model = SimpleGCN(input_dim, hidden_dim, output_dim).to(self.device) 41 self.epochs = epochs 42 self.trained = False 43 self.feature_names = [] 44 self.class_names = [] 45 46 def train(self, features: List[Dict], **kwargs): 47 X, y = preprocess_features(features) 48 # X: (num_nodes, num_features), y: (num_nodes,) 49 adj = kwargs.get('adjacency_matrix') 50 if adj is None: 51 raise ValueError("Adjacency matrix must be provided as 'adjacency_matrix' in kwargs for GNN training.") 52 X = torch.tensor(X, dtype=torch.float32).to(self.device) 53 y = torch.tensor(y, dtype=torch.long).to(self.device) 54 adj = torch.tensor(adj, dtype=torch.float32).to(self.device) 55 self.feature_names = [f"feature_{i}" for i in range(X.shape[1])] 56 self.class_names = list(set(y.cpu().numpy())) 57 criterion = nn.CrossEntropyLoss() 58 optimizer = optim.Adam(self.model.parameters(), lr=self.config['lr']) 59 for epoch in range(self.epochs): 60 self.model.train() 61 optimizer.zero_grad() 62 outputs = self.model(X, adj) 63 loss = criterion(outputs, y) 64 loss.backward() 65 optimizer.step() 66 if (epoch+1) % 5 == 0 or epoch == 0: 67 print(f"Epoch [{epoch+1}/{self.epochs}], Loss: {loss.item():.4f}") 68 self.trained = True 69 print("GNN model trained successfully.") 70 71 def predict(self, features: List[Dict], **kwargs) -> np.ndarray: 72 if not self.trained: 73 raise ValueError("Model must be trained before making predictions") 74 X, _ = preprocess_features(features) 75 adj = kwargs.get('adjacency_matrix') 76 if adj is None: 77 raise ValueError("Adjacency matrix must be provided as 'adjacency_matrix' in kwargs for GNN prediction.") 78 X = torch.tensor(X, dtype=torch.float32).to(self.device) 79 adj = torch.tensor(adj, dtype=torch.float32).to(self.device) 80 self.model.eval() 81 with torch.no_grad(): 82 outputs = self.model(X, adj) 83 _, predicted = torch.max(outputs.data, 1) 84 return predicted.cpu().numpy() 85 86 def evaluate(self, features: List[Dict], **kwargs) -> Dict[str, float]: 87 if not self.trained: 88 raise ValueError("Model must be trained before evaluation") 89 X, y = preprocess_features(features) 90 adj = kwargs.get('adjacency_matrix') 91 if adj is None: 92 raise ValueError("Adjacency matrix must be provided as 'adjacency_matrix' in kwargs for GNN evaluation.") 93 X = torch.tensor(X, dtype=torch.float32).to(self.device) 94 y = np.array(y) 95 adj = torch.tensor(adj, dtype=torch.float32).to(self.device) 96 self.model.eval() 97 with torch.no_grad(): 98 outputs = self.model(X, adj) 99 _, y_pred = torch.max(outputs.data, 1) 100 y_pred = y_pred.cpu().numpy() 101 accuracy = accuracy_score(y, y_pred) 102 conf_matrix = confusion_matrix(y, y_pred) 103 metrics = { 104 'accuracy': accuracy, 105 'confusion_matrix': conf_matrix.tolist() 106 } 107 detailed_report = kwargs.get('detailed_report', False) 108 if detailed_report: 109 class_report = classification_report(y, y_pred, output_dict=True) 110 metrics['classification_report'] = class_report 111 return metrics 112 113 def save_model(self, filepath: str): 114 if not self.trained: 115 raise ValueError("Model must be trained before saving") 116 torch.save({ 117 'model_state_dict': self.model.state_dict(), 118 'config': self.config, 119 'feature_names': self.feature_names, 120 'class_names': self.class_names, 121 'trained': self.trained 122 }, filepath) 123 print(f"GNN model saved to {filepath}") 124 125 def load_model(self, filepath: str): 126 checkpoint = torch.load(filepath, map_location=self.device) 127 self.model = SimpleGCN( 128 self.config['input_dim'], 129 self.config['hidden_dim'], 130 self.config['output_dim'] 131 ).to(self.device) 132 self.model.load_state_dict(checkpoint['model_state_dict']) 133 self.config = checkpoint.get('config', self.config) 134 self.feature_names = checkpoint.get('feature_names', []) 135 self.class_names = checkpoint.get('class_names', []) 136 self.trained = checkpoint.get('trained', True) 137 print(f"GNN model loaded from {filepath}")
12class SimpleGCN(nn.Module): 13 def __init__(self, input_dim, hidden_dim, output_dim): 14 super(SimpleGCN, self).__init__() 15 self.fc1 = nn.Linear(input_dim, hidden_dim) 16 self.fc2 = nn.Linear(hidden_dim, output_dim) 17 def forward(self, x, adj): 18 h = torch.relu(self.fc1(torch.matmul(adj, x))) 19 out = self.fc2(torch.matmul(adj, h)) 20 return out
Base class for all neural network modules.
Your models should also subclass this class.
Modules can also contain other Modules, allowing them to be nested in a tree structure. You can assign the submodules as regular attributes::
import torch.nn as nn
import torch.nn.functional as F
class Model(nn.Module):
def __init__(self) -> None:
super().__init__()
self.conv1 = nn.Conv2d(1, 20, 5)
self.conv2 = nn.Conv2d(20, 20, 5)
def forward(self, x):
x = F.relu(self.conv1(x))
return F.relu(self.conv2(x))
Submodules assigned in this way will be registered, and will also have their
parameters converted when you call to(), etc.
As per the example above, an __init__() call to the parent class
must be made before assignment on the child.
:ivar training: Boolean represents whether this module is in training or evaluation mode. :vartype training: bool
13 def __init__(self, input_dim, hidden_dim, output_dim): 14 super(SimpleGCN, self).__init__() 15 self.fc1 = nn.Linear(input_dim, hidden_dim) 16 self.fc2 = nn.Linear(hidden_dim, output_dim)
Initialize internal Module state, shared by both nn.Module and ScriptModule.
17 def forward(self, x, adj): 18 h = torch.relu(self.fc1(torch.matmul(adj, x))) 19 out = self.fc2(torch.matmul(adj, h)) 20 return out
Define the computation performed at every call.
Should be overridden by all subclasses.
Although the recipe for forward pass needs to be defined within
this function, one should call the Module instance afterwards
instead of this since the former takes care of running the
registered hooks while the latter silently ignores them.
22class GNNModel(BaseClassificationModel): 23 """ 24 Simple Graph Neural Network (GCN) classification model using PyTorch. 25 Implements the BaseClassificationModel interface. 26 Expects features as node features and adjacency matrix in kwargs. 27 """ 28 def __init__(self, input_dim=10, hidden_dim=32, output_dim=2, lr=0.001, epochs=20, device=None): 29 super().__init__( 30 name="gnn", 31 description="Graph Convolutional Network (GCN) classifier for gait data classification" 32 ) 33 self.config = { 34 'input_dim': input_dim, 35 'hidden_dim': hidden_dim, 36 'output_dim': output_dim, 37 'lr': lr, 38 'epochs': epochs 39 } 40 self.device = device or ("cuda" if torch.cuda.is_available() else "cpu") 41 self.model = SimpleGCN(input_dim, hidden_dim, output_dim).to(self.device) 42 self.epochs = epochs 43 self.trained = False 44 self.feature_names = [] 45 self.class_names = [] 46 47 def train(self, features: List[Dict], **kwargs): 48 X, y = preprocess_features(features) 49 # X: (num_nodes, num_features), y: (num_nodes,) 50 adj = kwargs.get('adjacency_matrix') 51 if adj is None: 52 raise ValueError("Adjacency matrix must be provided as 'adjacency_matrix' in kwargs for GNN training.") 53 X = torch.tensor(X, dtype=torch.float32).to(self.device) 54 y = torch.tensor(y, dtype=torch.long).to(self.device) 55 adj = torch.tensor(adj, dtype=torch.float32).to(self.device) 56 self.feature_names = [f"feature_{i}" for i in range(X.shape[1])] 57 self.class_names = list(set(y.cpu().numpy())) 58 criterion = nn.CrossEntropyLoss() 59 optimizer = optim.Adam(self.model.parameters(), lr=self.config['lr']) 60 for epoch in range(self.epochs): 61 self.model.train() 62 optimizer.zero_grad() 63 outputs = self.model(X, adj) 64 loss = criterion(outputs, y) 65 loss.backward() 66 optimizer.step() 67 if (epoch+1) % 5 == 0 or epoch == 0: 68 print(f"Epoch [{epoch+1}/{self.epochs}], Loss: {loss.item():.4f}") 69 self.trained = True 70 print("GNN model trained successfully.") 71 72 def predict(self, features: List[Dict], **kwargs) -> np.ndarray: 73 if not self.trained: 74 raise ValueError("Model must be trained before making predictions") 75 X, _ = preprocess_features(features) 76 adj = kwargs.get('adjacency_matrix') 77 if adj is None: 78 raise ValueError("Adjacency matrix must be provided as 'adjacency_matrix' in kwargs for GNN prediction.") 79 X = torch.tensor(X, dtype=torch.float32).to(self.device) 80 adj = torch.tensor(adj, dtype=torch.float32).to(self.device) 81 self.model.eval() 82 with torch.no_grad(): 83 outputs = self.model(X, adj) 84 _, predicted = torch.max(outputs.data, 1) 85 return predicted.cpu().numpy() 86 87 def evaluate(self, features: List[Dict], **kwargs) -> Dict[str, float]: 88 if not self.trained: 89 raise ValueError("Model must be trained before evaluation") 90 X, y = preprocess_features(features) 91 adj = kwargs.get('adjacency_matrix') 92 if adj is None: 93 raise ValueError("Adjacency matrix must be provided as 'adjacency_matrix' in kwargs for GNN evaluation.") 94 X = torch.tensor(X, dtype=torch.float32).to(self.device) 95 y = np.array(y) 96 adj = torch.tensor(adj, dtype=torch.float32).to(self.device) 97 self.model.eval() 98 with torch.no_grad(): 99 outputs = self.model(X, adj) 100 _, y_pred = torch.max(outputs.data, 1) 101 y_pred = y_pred.cpu().numpy() 102 accuracy = accuracy_score(y, y_pred) 103 conf_matrix = confusion_matrix(y, y_pred) 104 metrics = { 105 'accuracy': accuracy, 106 'confusion_matrix': conf_matrix.tolist() 107 } 108 detailed_report = kwargs.get('detailed_report', False) 109 if detailed_report: 110 class_report = classification_report(y, y_pred, output_dict=True) 111 metrics['classification_report'] = class_report 112 return metrics 113 114 def save_model(self, filepath: str): 115 if not self.trained: 116 raise ValueError("Model must be trained before saving") 117 torch.save({ 118 'model_state_dict': self.model.state_dict(), 119 'config': self.config, 120 'feature_names': self.feature_names, 121 'class_names': self.class_names, 122 'trained': self.trained 123 }, filepath) 124 print(f"GNN model saved to {filepath}") 125 126 def load_model(self, filepath: str): 127 checkpoint = torch.load(filepath, map_location=self.device) 128 self.model = SimpleGCN( 129 self.config['input_dim'], 130 self.config['hidden_dim'], 131 self.config['output_dim'] 132 ).to(self.device) 133 self.model.load_state_dict(checkpoint['model_state_dict']) 134 self.config = checkpoint.get('config', self.config) 135 self.feature_names = checkpoint.get('feature_names', []) 136 self.class_names = checkpoint.get('class_names', []) 137 self.trained = checkpoint.get('trained', True) 138 print(f"GNN model loaded from {filepath}")
Simple Graph Neural Network (GCN) classification model using PyTorch. Implements the BaseClassificationModel interface. Expects features as node features and adjacency matrix in kwargs.
28 def __init__(self, input_dim=10, hidden_dim=32, output_dim=2, lr=0.001, epochs=20, device=None): 29 super().__init__( 30 name="gnn", 31 description="Graph Convolutional Network (GCN) classifier for gait data classification" 32 ) 33 self.config = { 34 'input_dim': input_dim, 35 'hidden_dim': hidden_dim, 36 'output_dim': output_dim, 37 'lr': lr, 38 'epochs': epochs 39 } 40 self.device = device or ("cuda" if torch.cuda.is_available() else "cpu") 41 self.model = SimpleGCN(input_dim, hidden_dim, output_dim).to(self.device) 42 self.epochs = epochs 43 self.trained = False 44 self.feature_names = [] 45 self.class_names = []
Initialize the classification model.
Args: name: Name of the classification model description: Description of the classification model
47 def train(self, features: List[Dict], **kwargs): 48 X, y = preprocess_features(features) 49 # X: (num_nodes, num_features), y: (num_nodes,) 50 adj = kwargs.get('adjacency_matrix') 51 if adj is None: 52 raise ValueError("Adjacency matrix must be provided as 'adjacency_matrix' in kwargs for GNN training.") 53 X = torch.tensor(X, dtype=torch.float32).to(self.device) 54 y = torch.tensor(y, dtype=torch.long).to(self.device) 55 adj = torch.tensor(adj, dtype=torch.float32).to(self.device) 56 self.feature_names = [f"feature_{i}" for i in range(X.shape[1])] 57 self.class_names = list(set(y.cpu().numpy())) 58 criterion = nn.CrossEntropyLoss() 59 optimizer = optim.Adam(self.model.parameters(), lr=self.config['lr']) 60 for epoch in range(self.epochs): 61 self.model.train() 62 optimizer.zero_grad() 63 outputs = self.model(X, adj) 64 loss = criterion(outputs, y) 65 loss.backward() 66 optimizer.step() 67 if (epoch+1) % 5 == 0 or epoch == 0: 68 print(f"Epoch [{epoch+1}/{self.epochs}], Loss: {loss.item():.4f}") 69 self.trained = True 70 print("GNN model trained successfully.")
Train the classification model.
Args: features: List of feature dictionaries **kwargs: Additional arguments for training
72 def predict(self, features: List[Dict], **kwargs) -> np.ndarray: 73 if not self.trained: 74 raise ValueError("Model must be trained before making predictions") 75 X, _ = preprocess_features(features) 76 adj = kwargs.get('adjacency_matrix') 77 if adj is None: 78 raise ValueError("Adjacency matrix must be provided as 'adjacency_matrix' in kwargs for GNN prediction.") 79 X = torch.tensor(X, dtype=torch.float32).to(self.device) 80 adj = torch.tensor(adj, dtype=torch.float32).to(self.device) 81 self.model.eval() 82 with torch.no_grad(): 83 outputs = self.model(X, adj) 84 _, predicted = torch.max(outputs.data, 1) 85 return predicted.cpu().numpy()
Make predictions using the trained model.
Args: features: List of feature dictionaries **kwargs: Additional arguments for prediction
Returns: Array of predictions
87 def evaluate(self, features: List[Dict], **kwargs) -> Dict[str, float]: 88 if not self.trained: 89 raise ValueError("Model must be trained before evaluation") 90 X, y = preprocess_features(features) 91 adj = kwargs.get('adjacency_matrix') 92 if adj is None: 93 raise ValueError("Adjacency matrix must be provided as 'adjacency_matrix' in kwargs for GNN evaluation.") 94 X = torch.tensor(X, dtype=torch.float32).to(self.device) 95 y = np.array(y) 96 adj = torch.tensor(adj, dtype=torch.float32).to(self.device) 97 self.model.eval() 98 with torch.no_grad(): 99 outputs = self.model(X, adj) 100 _, y_pred = torch.max(outputs.data, 1) 101 y_pred = y_pred.cpu().numpy() 102 accuracy = accuracy_score(y, y_pred) 103 conf_matrix = confusion_matrix(y, y_pred) 104 metrics = { 105 'accuracy': accuracy, 106 'confusion_matrix': conf_matrix.tolist() 107 } 108 detailed_report = kwargs.get('detailed_report', False) 109 if detailed_report: 110 class_report = classification_report(y, y_pred, output_dict=True) 111 metrics['classification_report'] = class_report 112 return metrics
Evaluate the model performance.
Args: features: List of feature dictionaries **kwargs: Additional arguments for evaluation
Returns: Dictionary containing evaluation metrics
114 def save_model(self, filepath: str): 115 if not self.trained: 116 raise ValueError("Model must be trained before saving") 117 torch.save({ 118 'model_state_dict': self.model.state_dict(), 119 'config': self.config, 120 'feature_names': self.feature_names, 121 'class_names': self.class_names, 122 'trained': self.trained 123 }, filepath) 124 print(f"GNN model saved to {filepath}")
Save the trained model to a file.
Args: filepath: Path to save the model
126 def load_model(self, filepath: str): 127 checkpoint = torch.load(filepath, map_location=self.device) 128 self.model = SimpleGCN( 129 self.config['input_dim'], 130 self.config['hidden_dim'], 131 self.config['output_dim'] 132 ).to(self.device) 133 self.model.load_state_dict(checkpoint['model_state_dict']) 134 self.config = checkpoint.get('config', self.config) 135 self.feature_names = checkpoint.get('feature_names', []) 136 self.class_names = checkpoint.get('class_names', []) 137 self.trained = checkpoint.get('trained', True) 138 print(f"GNN model loaded from {filepath}")
Load a trained model from a file.
Args: filepath: Path to the saved model