πŸŽ‰ 75% of content is free forever β€” Unlock Premium from $10/mo β†’
CW
Search courses…
πŸ’Ό Servicesℹ️ Aboutβœ‰οΈ ContactView Pricing Plansfrom $10

Graph Neural Networks

⭐ Premium

Advertisement

Graph Neural Networks

Data isn't always tabular. Social networks, molecules, knowledge graphs, and citation networks are inherently graph-structured. Graph Neural Networks (GNNs) learn on this structure by aggregating information from neighbors Β– the message passing paradigm.

Graph Message Passing

GNN Message Passing ParadigmStep 1: Input Graphv₁vβ‚‚v₃vβ‚žStep 2: Aggregate Messagesvβ‚žmsg(v₁)msg(vβ‚‚)msg(v₃)AGGREGATE: Ξ£, MEAN, MAXStep 3: Updatehβ‚žβ½Λ‘βΊΒΉβΎ =UPDATE( hβ‚žβ½Λ‘βΎ, AGG({mα΅’}))

Graph Fundamentals

A graph G = (V, E) has nodes V and edges E. Each node has features, and the task might be node classification, link prediction, or graph classification.

import torch
import torch.nn as nn
import torch.nn.functional as F
from torch_geometric.nn import GCNConv, GATConv, SAGEConv, global_mean_pool
from torch_geometric.data import Data
from torch_geometric.loader import DataLoader
import numpy as np
import warnings
warnings.filterwarnings('ignore')

Creating Graph Data

# Create a simple graph
num_nodes = 100
num_features = 16
num_classes = 3

# Node features
x = torch.randn(num_nodes, num_features)

# Edges (source, target) Β– random graph
edge_index = torch.randint(0, num_nodes, (2, 300))

# Labels for node classification
y = torch.randint(0, num_classes, (num_nodes,))

# Train/test mask
train_mask = torch.zeros(num_nodes, dtype=torch.bool)
test_mask = torch.zeros(num_nodes, dtype=torch.bool)
train_mask[:70] = True
test_mask[70:] = True

data = Data(x=x, edge_index=edge_index, y=y, 
            train_mask=train_mask, test_mask=test_mask)
print(f"Graph: {data.num_nodes} nodes, {data.num_edges} edges")
print(f"Node features: {data.num_node_features}")
print(f"Classes: {data.num_classes}")

Graph Convolutional Network (GCN)

class GCN(nn.Module):
    def __init__(self, in_channels, hidden_channels, out_channels):
        super().__init__()
        self.conv1 = GCNConv(in_channels, hidden_channels)
        self.conv2 = GCNConv(hidden_channels, hidden_channels)
        self.conv3 = GCNConv(hidden_channels, out_channels)
    
    def forward(self, x, edge_index):
        x = F.relu(self.conv1(x, edge_index))
        x = F.dropout(x, p=0.5, training=self.training)
        x = F.relu(self.conv2(x, edge_index))
        x = self.conv3(x, edge_index)
        return x

model = GCN(num_features, 32, num_classes)
optimizer = torch.optim.Adam(model.parameters(), lr=0.01, weight_decay=5e-4)

# Training
def train(model, data):
    model.train()
    optimizer.zero_grad()
    out = model(data.x, data.edge_index)
    loss = F.cross_entropy(out[data.train_mask], data.y[data.train_mask])
    loss.backward()
    optimizer.step()
    return loss.item()

def test(model, data):
    model.eval()
    out = model(data.x, data.edge_index)
    pred = out.argmax(dim=1)
    accs = []
    for mask in [data.train_mask, data.test_mask]:
        correct = pred[mask] == data.y[mask]
        accs.append(correct.float().mean().item())
    return accs

for epoch in range(200):
    loss = train(model, data)
    train_acc, test_acc = test(model, data)
    if (epoch + 1) % 50 == 0:
        print(f"Epoch {epoch+1}: Loss={loss:.4f}, Train Acc={train_acc:.4f}, Test Acc={test_acc:.4f}")

Graph Attention Network (GAT)

class GAT(nn.Module):
    def __init__(self, in_channels, hidden_channels, out_channels, heads=4):
        super().__init__()
        self.conv1 = GATConv(in_channels, hidden_channels, heads=heads, concat=True)
        self.conv2 = GATConv(hidden_channels * heads, hidden_channels, heads=heads, concat=True)
        self.conv3 = GATConv(hidden_channels * heads, out_channels, heads=1, concat=False)
    
    def forward(self, x, edge_index):
        x = F.elu(self.conv1(x, edge_index))
        x = F.dropout(x, p=0.5, training=self.training)
        x = F.elu(self.conv2(x, edge_index))
        x = self.conv3(x, edge_index)
        return x

gat_model = GAT(num_features, 8, num_classes, heads=4)
print(f"GAT parameters: {sum(p.numel() for p in gat_model.parameters()):,}")

GraphSAGE

class GraphSAGE(nn.Module):
    def __init__(self, in_channels, hidden_channels, out_channels):
        super().__init__()
        self.conv1 = SAGEConv(in_channels, hidden_channels)
        self.conv2 = SAGEConv(hidden_channels, hidden_channels)
        self.conv3 = SAGEConv(hidden_channels, out_channels)
    
    def forward(self, x, edge_index):
        x = F.relu(self.conv1(x, edge_index))
        x = F.relu(self.conv2(x, edge_index))
        x = self.conv3(x, edge_index)
        return x

sage_model = GraphSAGE(num_features, 32, num_classes)
print("GraphSAGE: learns from sampled neighbors, scalable to large graphs")

Link Prediction

class LinkPredictor(nn.Module):
    def __init__(self, in_channels, hidden_channels):
        super().__init__()
        self.encoder = GCNConv(in_channels, hidden_channels)
        self.decoder = nn.Sequential(
            nn.Linear(hidden_channels * 2, hidden_channels),
            nn.ReLU(),
            nn.Linear(hidden_channels, 1)
        )
    
    def encode(self, x, edge_index):
        return F.relu(self.encoder(x, edge_index))
    
    def decode(self, z, edge_index):
        src, dst = edge_index
        h = torch.cat([z[src], z[dst]], dim=1)
        return torch.sigmoid(self.decoder(h))

# Create positive and negative edges
pos_edge_index = edge_index
neg_edge_index = torch.randint(0, num_nodes, pos_edge_index.shape)

# Split into train/test
n_pos = pos_edge_index.size(1)
perm = torch.randperm(n_pos)
train_pos = pos_edge_index[:, perm[:int(0.8 * n_pos)]]
test_pos = pos_edge_index[:, perm[int(0.8 * n_pos):]]

link_model = LinkPredictor(num_features, 32)
print(f"Link predictor ready: {link_model}")

Message Passing Framework

from torch_geometric.nn import MessagePassing

class CustomGNN(MessagePassing):
    def __init__(self, in_channels, out_channels):
        super().__init__(aggr='add')  # aggregation: add, mean, max
        self.mlp = nn.Sequential(
            nn.Linear(in_channels, out_channels),
            nn.ReLU(),
            nn.Linear(out_channels, out_channels)
        )
    
    def forward(self, x, edge_index):
        return self.propagate(edge_index, x=x)
    
    def message(self, x_j):
        # x_j: features of neighbor nodes
        return self.mlp(x_j)
    
    def update(self, aggr_out):
        return aggr_out

custom_gnn = CustomGNN(num_features, 32)
out = custom_gnn(data.x, data.edge_index)
print(f"Custom GNN output: {out.shape}")

Graph Classification

# Create multiple graphs
graphs = []
for _ in range(100):
    n = np.random.randint(10, 50)
    edge_index = torch.randint(0, n, (2, n * 2))
    x = torch.randn(n, num_features)
    y = torch.randint(0, 2, (1,))
    
    graph = Data(x=x, edge_index=edge_index, y=y,
                 batch=torch.zeros(n, dtype=torch.long))
    graphs.append(graph)

class GraphClassifier(nn.Module):
    def __init__(self, in_channels, hidden_channels, num_classes):
        super().__init__()
        self.conv1 = GCNConv(in_channels, hidden_channels)
        self.conv2 = GCNConv(hidden_channels, hidden_channels)
        self.classifier = nn.Linear(hidden_channels, num_classes)
    
    def forward(self, x, edge_index, batch):
        x = F.relu(self.conv1(x, edge_index))
        x = F.relu(self.conv2(x, edge_index))
        # Global pooling
        x = global_mean_pool(x, batch)
        return self.classifier(x)

graph_model = GraphClassifier(num_features, 32, 2)
print("Graph classifier with global mean pooling")

Best Practices

  1. Normalize adjacency matrix Β– symmetric normalization for GCN
  2. Use skip connections Β– for deep GNNs (>3 layers)
  3. Watch for oversmoothing Β– deep GCNs converge all node representations
  4. Mini-batch with neighbor sampling Β– for large graphs (GraphSAGE)
  5. Evaluate node/link/graph tasks separately Β– different metrics apply
  6. Handle heterogeneity Β– different node/edge types need specialized architectures

Summary

GNNs learn on graph-structured data through message passing. GCN, GAT, and GraphSAGE are foundational architectures for node classification, link prediction, and graph classification. Master message passing and neighborhood aggregation to work with social networks, molecules, and knowledge graphs.

Advertisement