Skip to content

Commit

Permalink
added message passing model and gnn layers
Browse files Browse the repository at this point in the history
  • Loading branch information
gc031298 committed Jan 8, 2024
1 parent fa20bc7 commit 68256f7
Show file tree
Hide file tree
Showing 2 changed files with 183 additions and 0 deletions.
110 changes: 110 additions & 0 deletions pina/model/layers/gnn_layer.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,110 @@
import torch
import torch.nn as nn
from torch.nn import SiLU
from torch_geometric.data import Data
from torch_geometric.nn import MessagePassing, InstanceNorm
from torch_cluster import radius_graph

class GNN_Layer(MessagePassing):
"""
Message passing layer
"""
def __init__(self,
in_features: int,
out_features: int,
hidden_features: int,
time_window: int,
n_variables: int):
"""
Initialize message passing layers
Args:
in_features (int): number of node input features
out_features (int): number of node output features
hidden_features (int): number of hidden features
time_window (int): number of input/output timesteps (temporal bundling)
n_variables (int): number of equation specific parameters used in the solver
"""
super(GNN_Layer, self).__init__(node_dim=-2, aggr='mean')
self.in_features = in_features
self.out_features = out_features
self.hidden_features = hidden_features
self.time_window = time_window
self.n_variables = n_variables

# Message network -- equation 8
# Notice: spacial dimension is hard-coded and set to 1. For now, we deal with 1d-spatial PDEs.
# TODO: manage spacial dimension - to be passed as a parameter?
self.message_net_1 = nn.Sequential(nn.Linear(2*self.in_features + self.time_window + 1 + self.n_variables, self.hidden_features), SiLU())
self.message_net_2 = nn.Sequential(nn.Linear(self.hidden_features, self.hidden_features), SiLU())

# Update network -- equation 9
self.update_net_1 = nn.Sequential(nn.Linear(self.in_features + self.hidden_features + self.n_variables, self.hidden_features), SiLU())
self.update_net_2 = nn.Sequential(nn.Linear(self.hidden_features, self.out_features), SiLU())

self.norm = InstanceNorm(self.hidden_features)


def forward(self, graph):
"""
Propagate messages along edges
"""
f = self.propagate(edge_index=graph.edge_index,
x=graph.x,
u=graph.u,
pos=graph.pos,
variables=graph.variables)
return f


def message(self, x_i, x_j, u_i, u_j, pos_i, pos_j, variables_i):
"""
Message update following formula 8 of the paper
"""
message = self.message_net_1(torch.cat((x_i, x_j, u_i - u_j, pos_i - pos_j, variables_i), dim=-1))
message = self.message_net_2(message)
return message


def update(self, message, x, variables):
"""
Node update following formula 9 of the paper
"""
update = self.update_net_1(torch.cat((x, message, variables), dim=-1))
update = self.update_net_2(update)
if self.in_features == self.out_features:
return x + update
else:
return update


class GraphHandler():
def __init__(self, coordinates, variables, num_neighs=10):
super().__init__()
self.n = num_neighs
self.graph = self.create_ball_graph(coordinates, variables)

def create_ball_graph(self, coordinates, variables):
# Get the smallest distance between the coordinates
if len(coordinates.shape) == 1:
dx = coordinates[1]-coordinates[0]
else:
dx = torch.pdist(coordinates).min()

# Set the radius so as to include the nearest neighbours
radius = self.n * dx + 0.000001

edge_index = radius_graph(coordinates, r=radius, loop=False)

# Features x are computed by the encoder preceeding the gnn_layer
graph = Data(x = None, edge_index = edge_index, edge_attr = None)
graph.pos = coordinates
graph.u = None
graph.variables = variables
return graph

def data_to_graph(self, data):
self.graph.u = data

# Probably not useful
# def update_graph(self, features):
# self.graph.x = features
73 changes: 73 additions & 0 deletions pina/model/mp_model.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,73 @@
import torch
import torch.nn as nn
from layers.gnn_layer import GNN_Layer

class Model(torch.nn.Module):
"""
Message passing model class.
"""
def __init__(self,
handler,
time_window: int,
n_variables: int,
input_dimension: int,
output_dimension: int,
embedding_dimension: int = 128,
processing_layers: int = 6):
"""
Initialize Message Passing model class.
Args:
handler: GraphHandler object to manage the graph
time_window: temporal bundling parameter
n_variables: number of paramaters of the PDE
input_dimension: dimension of the input
output_dimension: dimension of the output
embedding_dimension: dimension of node features
processing_layers: number of message passing layers
"""

super().__init__()
self.input_dimension = input_dimension
self.output_dimension = output_dimension
self.embedding_dimension = embedding_dimension
self.processing_layers = processing_layers
self.handler = handler
self.time_window = time_window
self.n_variables = n_variables

# Encoder
self.encoder = nn.Sequential(nn.Linear(self.input_dimension, self.embedding_dimension), nn.SiLU(),
nn.Linear(self.embedding_dimension, self.embedding_dimension), nn.SiLU())

# GNN layers
self.gnn_layers = torch.nn.ModuleList(modules=(GNN_Layer(in_features=self.embedding_dimension,
hidden_features=self.embedding_dimension,
out_features=self.embedding_dimension,
time_window=self.time_window,
n_variables=self.n_variables) for _ in range(self.processing_layers)))

# Decoder
# TODO: to be transformed in a 1d-CNN
self.decoder = torch.nn.Linear(self.embedding_dimension, self.output_dimension)
# parameters to be set in a correct way
# self.decoder = nn.Sequential(nn.Conv1d(in_channels=1, out_channels=8, kernel_size=15, stride=4),
# nn.Swish(),
# nn.Conv1d(in_channels=8, out_channels=1, kernel_size=10, stride=1))

def forward(self, x):
# Insert graph.u data for message passing
self.handler.data_to_graph(x.extract(['k']))
graph = self.handler.graph

# Encoder
input = torch.cat((graph.u, graph.pos, graph.variables), dim = -1)
graph.x = self.encoder(input)

# Processor
for i in range(self.processing_layers):
h = self.gnn_layers[i](graph)
graph.x = h

# Decoder
out = self.decoder(graph.x)
return out

0 comments on commit 68256f7

Please sign in to comment.