gaitsetpy.classification.models.gnn

  1import torch
  2import torch.nn as nn
  3import torch.optim as optim
  4import numpy as np
  5from typing import List, Dict, Any, Optional, Union
  6from ...core.base_classes import BaseClassificationModel
  7from ..utils.preprocess import preprocess_features
  8from sklearn.model_selection import train_test_split
  9from sklearn.metrics import accuracy_score, confusion_matrix, classification_report
 10
 11class SimpleGCN(nn.Module):
 12    def __init__(self, input_dim, hidden_dim, output_dim):
 13        super(SimpleGCN, self).__init__()
 14        self.fc1 = nn.Linear(input_dim, hidden_dim)
 15        self.fc2 = nn.Linear(hidden_dim, output_dim)
 16    def forward(self, x, adj):
 17        h = torch.relu(self.fc1(torch.matmul(adj, x)))
 18        out = self.fc2(torch.matmul(adj, h))
 19        return out
 20
 21class GNNModel(BaseClassificationModel):
 22    """
 23    Simple Graph Neural Network (GCN) classification model using PyTorch.
 24    Implements the BaseClassificationModel interface.
 25    Expects features as node features and adjacency matrix in kwargs.
 26    """
 27    def __init__(self, input_dim=10, hidden_dim=32, output_dim=2, lr=0.001, epochs=20, device=None):
 28        super().__init__(
 29            name="gnn",
 30            description="Graph Convolutional Network (GCN) classifier for gait data classification"
 31        )
 32        self.config = {
 33            'input_dim': input_dim,
 34            'hidden_dim': hidden_dim,
 35            'output_dim': output_dim,
 36            'lr': lr,
 37            'epochs': epochs
 38        }
 39        self.device = device or ("cuda" if torch.cuda.is_available() else "cpu")
 40        self.model = SimpleGCN(input_dim, hidden_dim, output_dim).to(self.device)
 41        self.epochs = epochs
 42        self.trained = False
 43        self.feature_names = []
 44        self.class_names = []
 45
 46    def train(self, features: List[Dict], **kwargs):
 47        X, y = preprocess_features(features)
 48        # X: (num_nodes, num_features), y: (num_nodes,)
 49        adj = kwargs.get('adjacency_matrix')
 50        if adj is None:
 51            raise ValueError("Adjacency matrix must be provided as 'adjacency_matrix' in kwargs for GNN training.")
 52        X = torch.tensor(X, dtype=torch.float32).to(self.device)
 53        y = torch.tensor(y, dtype=torch.long).to(self.device)
 54        adj = torch.tensor(adj, dtype=torch.float32).to(self.device)
 55        self.feature_names = [f"feature_{i}" for i in range(X.shape[1])]
 56        self.class_names = list(set(y.cpu().numpy()))
 57        criterion = nn.CrossEntropyLoss()
 58        optimizer = optim.Adam(self.model.parameters(), lr=self.config['lr'])
 59        for epoch in range(self.epochs):
 60            self.model.train()
 61            optimizer.zero_grad()
 62            outputs = self.model(X, adj)
 63            loss = criterion(outputs, y)
 64            loss.backward()
 65            optimizer.step()
 66            if (epoch+1) % 5 == 0 or epoch == 0:
 67                print(f"Epoch [{epoch+1}/{self.epochs}], Loss: {loss.item():.4f}")
 68        self.trained = True
 69        print("GNN model trained successfully.")
 70
 71    def predict(self, features: List[Dict], **kwargs) -> np.ndarray:
 72        if not self.trained:
 73            raise ValueError("Model must be trained before making predictions")
 74        X, _ = preprocess_features(features)
 75        adj = kwargs.get('adjacency_matrix')
 76        if adj is None:
 77            raise ValueError("Adjacency matrix must be provided as 'adjacency_matrix' in kwargs for GNN prediction.")
 78        X = torch.tensor(X, dtype=torch.float32).to(self.device)
 79        adj = torch.tensor(adj, dtype=torch.float32).to(self.device)
 80        self.model.eval()
 81        with torch.no_grad():
 82            outputs = self.model(X, adj)
 83            _, predicted = torch.max(outputs.data, 1)
 84        return predicted.cpu().numpy()
 85
 86    def evaluate(self, features: List[Dict], **kwargs) -> Dict[str, float]:
 87        if not self.trained:
 88            raise ValueError("Model must be trained before evaluation")
 89        X, y = preprocess_features(features)
 90        adj = kwargs.get('adjacency_matrix')
 91        if adj is None:
 92            raise ValueError("Adjacency matrix must be provided as 'adjacency_matrix' in kwargs for GNN evaluation.")
 93        X = torch.tensor(X, dtype=torch.float32).to(self.device)
 94        y = np.array(y)
 95        adj = torch.tensor(adj, dtype=torch.float32).to(self.device)
 96        self.model.eval()
 97        with torch.no_grad():
 98            outputs = self.model(X, adj)
 99            _, y_pred = torch.max(outputs.data, 1)
100        y_pred = y_pred.cpu().numpy()
101        accuracy = accuracy_score(y, y_pred)
102        conf_matrix = confusion_matrix(y, y_pred)
103        metrics = {
104            'accuracy': accuracy,
105            'confusion_matrix': conf_matrix.tolist()
106        }
107        detailed_report = kwargs.get('detailed_report', False)
108        if detailed_report:
109            class_report = classification_report(y, y_pred, output_dict=True)
110            metrics['classification_report'] = class_report
111        return metrics
112
113    def save_model(self, filepath: str):
114        if not self.trained:
115            raise ValueError("Model must be trained before saving")
116        torch.save({
117            'model_state_dict': self.model.state_dict(),
118            'config': self.config,
119            'feature_names': self.feature_names,
120            'class_names': self.class_names,
121            'trained': self.trained
122        }, filepath)
123        print(f"GNN model saved to {filepath}")
124
125    def load_model(self, filepath: str):
126        checkpoint = torch.load(filepath, map_location=self.device)
127        self.model = SimpleGCN(
128            self.config['input_dim'],
129            self.config['hidden_dim'],
130            self.config['output_dim']
131        ).to(self.device)
132        self.model.load_state_dict(checkpoint['model_state_dict'])
133        self.config = checkpoint.get('config', self.config)
134        self.feature_names = checkpoint.get('feature_names', [])
135        self.class_names = checkpoint.get('class_names', [])
136        self.trained = checkpoint.get('trained', True)
137        print(f"GNN model loaded from {filepath}")
class SimpleGCN(torch.nn.modules.module.Module):
12class SimpleGCN(nn.Module):
13    def __init__(self, input_dim, hidden_dim, output_dim):
14        super(SimpleGCN, self).__init__()
15        self.fc1 = nn.Linear(input_dim, hidden_dim)
16        self.fc2 = nn.Linear(hidden_dim, output_dim)
17    def forward(self, x, adj):
18        h = torch.relu(self.fc1(torch.matmul(adj, x)))
19        out = self.fc2(torch.matmul(adj, h))
20        return out

Base class for all neural network modules.

Your models should also subclass this class.

Modules can also contain other Modules, allowing them to be nested in a tree structure. You can assign the submodules as regular attributes::

import torch.nn as nn
import torch.nn.functional as F


class Model(nn.Module):
    def __init__(self) -> None:
        super().__init__()
        self.conv1 = nn.Conv2d(1, 20, 5)
        self.conv2 = nn.Conv2d(20, 20, 5)

    def forward(self, x):
        x = F.relu(self.conv1(x))
        return F.relu(self.conv2(x))

Submodules assigned in this way will be registered, and will also have their parameters converted when you call to(), etc.

As per the example above, an __init__() call to the parent class must be made before assignment on the child.

:ivar training: Boolean represents whether this module is in training or evaluation mode. :vartype training: bool

SimpleGCN(input_dim, hidden_dim, output_dim)
13    def __init__(self, input_dim, hidden_dim, output_dim):
14        super(SimpleGCN, self).__init__()
15        self.fc1 = nn.Linear(input_dim, hidden_dim)
16        self.fc2 = nn.Linear(hidden_dim, output_dim)

Initialize internal Module state, shared by both nn.Module and ScriptModule.

fc1
fc2
def forward(self, x, adj):
17    def forward(self, x, adj):
18        h = torch.relu(self.fc1(torch.matmul(adj, x)))
19        out = self.fc2(torch.matmul(adj, h))
20        return out

Define the computation performed at every call.

Should be overridden by all subclasses.

Although the recipe for forward pass needs to be defined within this function, one should call the Module instance afterwards instead of this since the former takes care of running the registered hooks while the latter silently ignores them.

 22class GNNModel(BaseClassificationModel):
 23    """
 24    Simple Graph Neural Network (GCN) classification model using PyTorch.
 25    Implements the BaseClassificationModel interface.
 26    Expects features as node features and adjacency matrix in kwargs.
 27    """
 28    def __init__(self, input_dim=10, hidden_dim=32, output_dim=2, lr=0.001, epochs=20, device=None):
 29        super().__init__(
 30            name="gnn",
 31            description="Graph Convolutional Network (GCN) classifier for gait data classification"
 32        )
 33        self.config = {
 34            'input_dim': input_dim,
 35            'hidden_dim': hidden_dim,
 36            'output_dim': output_dim,
 37            'lr': lr,
 38            'epochs': epochs
 39        }
 40        self.device = device or ("cuda" if torch.cuda.is_available() else "cpu")
 41        self.model = SimpleGCN(input_dim, hidden_dim, output_dim).to(self.device)
 42        self.epochs = epochs
 43        self.trained = False
 44        self.feature_names = []
 45        self.class_names = []
 46
 47    def train(self, features: List[Dict], **kwargs):
 48        X, y = preprocess_features(features)
 49        # X: (num_nodes, num_features), y: (num_nodes,)
 50        adj = kwargs.get('adjacency_matrix')
 51        if adj is None:
 52            raise ValueError("Adjacency matrix must be provided as 'adjacency_matrix' in kwargs for GNN training.")
 53        X = torch.tensor(X, dtype=torch.float32).to(self.device)
 54        y = torch.tensor(y, dtype=torch.long).to(self.device)
 55        adj = torch.tensor(adj, dtype=torch.float32).to(self.device)
 56        self.feature_names = [f"feature_{i}" for i in range(X.shape[1])]
 57        self.class_names = list(set(y.cpu().numpy()))
 58        criterion = nn.CrossEntropyLoss()
 59        optimizer = optim.Adam(self.model.parameters(), lr=self.config['lr'])
 60        for epoch in range(self.epochs):
 61            self.model.train()
 62            optimizer.zero_grad()
 63            outputs = self.model(X, adj)
 64            loss = criterion(outputs, y)
 65            loss.backward()
 66            optimizer.step()
 67            if (epoch+1) % 5 == 0 or epoch == 0:
 68                print(f"Epoch [{epoch+1}/{self.epochs}], Loss: {loss.item():.4f}")
 69        self.trained = True
 70        print("GNN model trained successfully.")
 71
 72    def predict(self, features: List[Dict], **kwargs) -> np.ndarray:
 73        if not self.trained:
 74            raise ValueError("Model must be trained before making predictions")
 75        X, _ = preprocess_features(features)
 76        adj = kwargs.get('adjacency_matrix')
 77        if adj is None:
 78            raise ValueError("Adjacency matrix must be provided as 'adjacency_matrix' in kwargs for GNN prediction.")
 79        X = torch.tensor(X, dtype=torch.float32).to(self.device)
 80        adj = torch.tensor(adj, dtype=torch.float32).to(self.device)
 81        self.model.eval()
 82        with torch.no_grad():
 83            outputs = self.model(X, adj)
 84            _, predicted = torch.max(outputs.data, 1)
 85        return predicted.cpu().numpy()
 86
 87    def evaluate(self, features: List[Dict], **kwargs) -> Dict[str, float]:
 88        if not self.trained:
 89            raise ValueError("Model must be trained before evaluation")
 90        X, y = preprocess_features(features)
 91        adj = kwargs.get('adjacency_matrix')
 92        if adj is None:
 93            raise ValueError("Adjacency matrix must be provided as 'adjacency_matrix' in kwargs for GNN evaluation.")
 94        X = torch.tensor(X, dtype=torch.float32).to(self.device)
 95        y = np.array(y)
 96        adj = torch.tensor(adj, dtype=torch.float32).to(self.device)
 97        self.model.eval()
 98        with torch.no_grad():
 99            outputs = self.model(X, adj)
100            _, y_pred = torch.max(outputs.data, 1)
101        y_pred = y_pred.cpu().numpy()
102        accuracy = accuracy_score(y, y_pred)
103        conf_matrix = confusion_matrix(y, y_pred)
104        metrics = {
105            'accuracy': accuracy,
106            'confusion_matrix': conf_matrix.tolist()
107        }
108        detailed_report = kwargs.get('detailed_report', False)
109        if detailed_report:
110            class_report = classification_report(y, y_pred, output_dict=True)
111            metrics['classification_report'] = class_report
112        return metrics
113
114    def save_model(self, filepath: str):
115        if not self.trained:
116            raise ValueError("Model must be trained before saving")
117        torch.save({
118            'model_state_dict': self.model.state_dict(),
119            'config': self.config,
120            'feature_names': self.feature_names,
121            'class_names': self.class_names,
122            'trained': self.trained
123        }, filepath)
124        print(f"GNN model saved to {filepath}")
125
126    def load_model(self, filepath: str):
127        checkpoint = torch.load(filepath, map_location=self.device)
128        self.model = SimpleGCN(
129            self.config['input_dim'],
130            self.config['hidden_dim'],
131            self.config['output_dim']
132        ).to(self.device)
133        self.model.load_state_dict(checkpoint['model_state_dict'])
134        self.config = checkpoint.get('config', self.config)
135        self.feature_names = checkpoint.get('feature_names', [])
136        self.class_names = checkpoint.get('class_names', [])
137        self.trained = checkpoint.get('trained', True)
138        print(f"GNN model loaded from {filepath}")

Simple Graph Neural Network (GCN) classification model using PyTorch. Implements the BaseClassificationModel interface. Expects features as node features and adjacency matrix in kwargs.

GNNModel( input_dim=10, hidden_dim=32, output_dim=2, lr=0.001, epochs=20, device=None)
28    def __init__(self, input_dim=10, hidden_dim=32, output_dim=2, lr=0.001, epochs=20, device=None):
29        super().__init__(
30            name="gnn",
31            description="Graph Convolutional Network (GCN) classifier for gait data classification"
32        )
33        self.config = {
34            'input_dim': input_dim,
35            'hidden_dim': hidden_dim,
36            'output_dim': output_dim,
37            'lr': lr,
38            'epochs': epochs
39        }
40        self.device = device or ("cuda" if torch.cuda.is_available() else "cpu")
41        self.model = SimpleGCN(input_dim, hidden_dim, output_dim).to(self.device)
42        self.epochs = epochs
43        self.trained = False
44        self.feature_names = []
45        self.class_names = []

Initialize the classification model.

Args: name: Name of the classification model description: Description of the classification model

config
device
model
epochs
trained
feature_names
class_names
def train(self, features: List[Dict], **kwargs):
47    def train(self, features: List[Dict], **kwargs):
48        X, y = preprocess_features(features)
49        # X: (num_nodes, num_features), y: (num_nodes,)
50        adj = kwargs.get('adjacency_matrix')
51        if adj is None:
52            raise ValueError("Adjacency matrix must be provided as 'adjacency_matrix' in kwargs for GNN training.")
53        X = torch.tensor(X, dtype=torch.float32).to(self.device)
54        y = torch.tensor(y, dtype=torch.long).to(self.device)
55        adj = torch.tensor(adj, dtype=torch.float32).to(self.device)
56        self.feature_names = [f"feature_{i}" for i in range(X.shape[1])]
57        self.class_names = list(set(y.cpu().numpy()))
58        criterion = nn.CrossEntropyLoss()
59        optimizer = optim.Adam(self.model.parameters(), lr=self.config['lr'])
60        for epoch in range(self.epochs):
61            self.model.train()
62            optimizer.zero_grad()
63            outputs = self.model(X, adj)
64            loss = criterion(outputs, y)
65            loss.backward()
66            optimizer.step()
67            if (epoch+1) % 5 == 0 or epoch == 0:
68                print(f"Epoch [{epoch+1}/{self.epochs}], Loss: {loss.item():.4f}")
69        self.trained = True
70        print("GNN model trained successfully.")

Train the classification model.

Args: features: List of feature dictionaries **kwargs: Additional arguments for training

def predict(self, features: List[Dict], **kwargs) -> numpy.ndarray:
72    def predict(self, features: List[Dict], **kwargs) -> np.ndarray:
73        if not self.trained:
74            raise ValueError("Model must be trained before making predictions")
75        X, _ = preprocess_features(features)
76        adj = kwargs.get('adjacency_matrix')
77        if adj is None:
78            raise ValueError("Adjacency matrix must be provided as 'adjacency_matrix' in kwargs for GNN prediction.")
79        X = torch.tensor(X, dtype=torch.float32).to(self.device)
80        adj = torch.tensor(adj, dtype=torch.float32).to(self.device)
81        self.model.eval()
82        with torch.no_grad():
83            outputs = self.model(X, adj)
84            _, predicted = torch.max(outputs.data, 1)
85        return predicted.cpu().numpy()

Make predictions using the trained model.

Args: features: List of feature dictionaries **kwargs: Additional arguments for prediction

Returns: Array of predictions

def evaluate(self, features: List[Dict], **kwargs) -> Dict[str, float]:
 87    def evaluate(self, features: List[Dict], **kwargs) -> Dict[str, float]:
 88        if not self.trained:
 89            raise ValueError("Model must be trained before evaluation")
 90        X, y = preprocess_features(features)
 91        adj = kwargs.get('adjacency_matrix')
 92        if adj is None:
 93            raise ValueError("Adjacency matrix must be provided as 'adjacency_matrix' in kwargs for GNN evaluation.")
 94        X = torch.tensor(X, dtype=torch.float32).to(self.device)
 95        y = np.array(y)
 96        adj = torch.tensor(adj, dtype=torch.float32).to(self.device)
 97        self.model.eval()
 98        with torch.no_grad():
 99            outputs = self.model(X, adj)
100            _, y_pred = torch.max(outputs.data, 1)
101        y_pred = y_pred.cpu().numpy()
102        accuracy = accuracy_score(y, y_pred)
103        conf_matrix = confusion_matrix(y, y_pred)
104        metrics = {
105            'accuracy': accuracy,
106            'confusion_matrix': conf_matrix.tolist()
107        }
108        detailed_report = kwargs.get('detailed_report', False)
109        if detailed_report:
110            class_report = classification_report(y, y_pred, output_dict=True)
111            metrics['classification_report'] = class_report
112        return metrics

Evaluate the model performance.

Args: features: List of feature dictionaries **kwargs: Additional arguments for evaluation

Returns: Dictionary containing evaluation metrics

def save_model(self, filepath: str):
114    def save_model(self, filepath: str):
115        if not self.trained:
116            raise ValueError("Model must be trained before saving")
117        torch.save({
118            'model_state_dict': self.model.state_dict(),
119            'config': self.config,
120            'feature_names': self.feature_names,
121            'class_names': self.class_names,
122            'trained': self.trained
123        }, filepath)
124        print(f"GNN model saved to {filepath}")

Save the trained model to a file.

Args: filepath: Path to save the model

def load_model(self, filepath: str):
126    def load_model(self, filepath: str):
127        checkpoint = torch.load(filepath, map_location=self.device)
128        self.model = SimpleGCN(
129            self.config['input_dim'],
130            self.config['hidden_dim'],
131            self.config['output_dim']
132        ).to(self.device)
133        self.model.load_state_dict(checkpoint['model_state_dict'])
134        self.config = checkpoint.get('config', self.config)
135        self.feature_names = checkpoint.get('feature_names', [])
136        self.class_names = checkpoint.get('class_names', [])
137        self.trained = checkpoint.get('trained', True)
138        print(f"GNN model loaded from {filepath}")

Load a trained model from a file.

Args: filepath: Path to the saved model