Skip to content

Commit

Permalink
bug fixing
Browse files Browse the repository at this point in the history
  • Loading branch information
gc031298 committed Feb 1, 2024
1 parent f208bab commit 02bf7c3
Show file tree
Hide file tree
Showing 4 changed files with 128 additions and 257 deletions.
77 changes: 20 additions & 57 deletions pina/model/gnn_model.py
Original file line number Diff line number Diff line change
@@ -1,79 +1,42 @@
import torch
from pina.model.layers.gnn_layer import GNN_Layer
from gnn_layer import GNN_Layer
from pina import LabelTensor

class GNN(torch.nn.Module):
"""
Message passing model class.
"""
def __init__(self,
time_window: int,
n_variables: int,
t_max: float,
embedding_dimension: int = 128,
processing_layers: int = 6):
"""
Initialize Message Passing model class.
Args:
time_window: temporal bundling parameter
n_variables: number of paramaters of the PDE
output_dimension: dimension of the output
embedding_dimension: dimension of node features
processing_layers: number of message passing layers
"""


def __init__(self, time_window, t_max, n_variables, embedding_dimension=128, processing_layers=6):
super().__init__()
self.output_dimension = time_window
self.embedding_dimension = embedding_dimension
self.processing_layers = processing_layers
self.time_window = time_window
self.n_variables = n_variables
self.t_max = t_max

self.encoder = torch.nn.Sequential(torch.nn.Linear(self.time_window + self.n_variables + 1, self.embedding_dimension),
torch.nn.SiLU(),
torch.nn.Linear(self.embedding_dimension, self.embedding_dimension),
torch.nn.SiLU())

# Encoder
# TODO: the user should be able to define as many layers as wanted
self.encoder = torch.nn.Sequential(torch.nn.Linear(self.time_window + self.n_variables +1, self.embedding_dimension), torch.nn.SiLU(),
torch.nn.Linear(self.embedding_dimension, self.embedding_dimension), torch.nn.SiLU())

# GNN layers
self.gnn_layers = torch.nn.ModuleList(modules=(GNN_Layer(in_features=self.embedding_dimension,
self.gnn_layers = torch.nn.ModuleList(modules=(GNN_Layer(in_features=self.embedding_dimension,
hidden_features=self.embedding_dimension,
out_features=self.embedding_dimension,
time_window=self.time_window,
n_variables=self.n_variables) for _ in range(self.processing_layers)))

# Decoder
# TODO: use a linear layer after convolutions to allow easier management of strides and kernel sizes.
# However, it is not clean nor always correct: for self.embedding_dimension < 55, it is meaningless.
# self.decoder = torch.nn.Sequential(torch.nn.Conv1d(in_channels=1, out_channels=8, kernel_size=15, stride=4),
#torch.nn.SiLU(),
#torch.nn.Conv1d(in_channels=8, out_channels=1, kernel_size=10, stride=1),
#torch.nn.SiLU(),
#torch.nn.Linear(in_features= (int((self.embedding_dimension-15)/4)-8), out_features=self.output_dimension))

# At the moment we use a fixed time_window = 25
self.decoder = torch.nn.Sequential(torch.nn.Conv1d(in_channels=1, out_channels=8, kernel_size=16, stride=3),
torch.nn.SiLU(),
torch.nn.SiLU(),
torch.nn.Conv1d(in_channels=8, out_channels=1, kernel_size=14, stride=1))

def forward(self, graph):

#Normalization
vars = graph.variables.extract(['alpha', 'beta', 'gamma'])
time = graph.variables.extract(['t'])/self.t_max
graph.pos = graph.pos/graph.pos.max()

# Encoder
input = torch.cat((graph.x, graph.pos, time, vars), dim = -1)
h = self.encoder(input)

# Processor
def forward(self, data, pos, time, variables, batch, edge_index, dt):
pos = pos/pos.max()
time = time/self.t_max
var = torch.cat((time, variables), dim=-1)
node_input = torch.cat((data, pos, var), dim=-1)
h = self.encoder(node_input)
for i in range(self.processing_layers):
h = self.gnn_layers[i](h, graph.x, graph.pos, graph.variables, graph.edge_index, graph.batch)

# Decoder -- controllare che funzioni dt
dt = (torch.ones(1, self.time_window)*graph.dt).to(h.device)
dt = torch.cumsum(dt, dim=1)
h = self.gnn_layers[i](edge_index, h, data, pos, var, batch)
dt = torch.cumsum((torch.ones(1, self.time_window)*dt).to(device=h.device), dim=1)
diff = self.decoder(h[:, None]).squeeze(1)
out = graph.x[:, -1].repeat(1, self.time_window) + dt*diff

out = data[:,-1].repeat(1, self.time_window) + dt*diff
return out
50 changes: 32 additions & 18 deletions pina/model/graph_handler.py
Original file line number Diff line number Diff line change
@@ -1,30 +1,44 @@
from torch_geometric.data import Data
from torch_cluster import radius_graph

from torch_geometric.nn import radius_graph
from pina import LabelTensor
import torch

class GraphHandler():
"""
Creates and manages a graph with following attrubutes:
- graph.u: values of u(x,t) at point x and time t in considered window
- graph.pos: spatial coordinates
- graph.variables: variables of the equation (time and parameters)
- graph.x: node features
"""
def __init__(self, dt, num_neighs=10):

def __init__(self, dt, neighbors = 2):
super().__init__()
self.num_neighs = num_neighs
self.neighbors = neighbors
self.dt = dt
self.graph = None


def create_ball_graph(self, coordinates, data, variables, batch):
def create_graph(self, pts, labels, steps):
time_window = labels.shape[1]
x = torch.cat([pts[i,:,st-time_window:st].extract(['u']) for i,st in enumerate(steps)]).squeeze(-1)
coordinates = torch.Tensor(torch.cat([pts[i,:,st].extract(['x']) for i,st in enumerate(steps)])).squeeze(-1)
variables = torch.cat([pts[i,:,st].extract(['alpha', 'beta', 'gamma']) for i,st in enumerate(steps)])
variables = LabelTensor(variables, labels=['alpha', 'beta', 'gamma'])
time = torch.cat([pts[i,:,st].extract(['t']) for i,st in enumerate(steps)])
time = LabelTensor(time, labels=['t'])
num_x = torch.unique(coordinates).shape[0]
batch = torch.cat([torch.ones(num_x)*i for i in range(pts.shape[0])]).to(device=pts.device)
dx = coordinates[1] - coordinates[0]
radius = self.num_neighs * dx + 0.000001
radius = self.neighbors * dx + 0.0001
edge_index = radius_graph(coordinates, r=radius, loop=False, batch=batch)

graph = Data(x = data, edge_index = edge_index, edge_attr = None)
graph = Data(x=x, edge_index=edge_index, edge_attr=None)
graph.y = labels
graph.pos = coordinates.unsqueeze(-1)
graph.variables = variables
graph.dt = self.dt
graph.batch = batch
return graph
graph.batch = batch.long()
graph.time = time
return graph


def update_graph(self, graph, pred, labels, steps, batch_size):
time_window = labels.shape[1]
graph.x = torch.cat((graph.x, pred), dim=1)[:,time_window:]
graph.y = labels
num_x = labels.shape[0] // batch_size
time = [torch.ones(num_x)*steps[i]*self.dt for i in range(len(steps))]
graph.time = torch.cat(time).unsqueeze(-1).to(device=pred.device)
return graph
50 changes: 12 additions & 38 deletions pina/model/layers/gnn_layer.py
Original file line number Diff line number Diff line change
@@ -1,27 +1,10 @@
from typing import Any
import torch
from torch_geometric.nn import MessagePassing, InstanceNorm

class GNN_Layer(MessagePassing):
"""
Message passing layer
"""
def __init__(self,
in_features: int,
out_features: int,
hidden_features: int,
time_window: int,
n_variables: int,
n_spatial: int = 1):
"""
Initialize message passing layers
Args:
in_features (int): number of node input features
out_features (int): number of node output features
hidden_features (int): number of hidden features
time_window (int): number of input/output timesteps (temporal bundling)
n_variables (int): number of equation specific parameters used in the solver
n_spatial (int): number of spatial variables (ex: x --> 1, [x,y] --> 2)
"""

def __init__(self, in_features, out_features, hidden_features, time_window, n_variables, n_spatial=1):
super(GNN_Layer, self).__init__(node_dim=-2, aggr='mean')
self.in_features = in_features
self.out_features = out_features
Expand All @@ -30,39 +13,30 @@ def __init__(self,
self.n_variables = n_variables
self.n_spatial = n_spatial

# Message network -- equation 8
self.message_net_1 = torch.nn.Sequential(torch.nn.Linear(2*self.in_features + self.time_window + self.n_spatial + self.n_variables, self.hidden_features), torch.nn.SiLU())
self.message_net_1 = torch.nn.Sequential(torch.nn.Linear(2*self.in_features + self.time_window + self.n_spatial + self.n_variables, self.hidden_features),
torch.nn.SiLU())
self.message_net_2 = torch.nn.Sequential(torch.nn.Linear(self.hidden_features, self.hidden_features), torch.nn.SiLU())

# Update network -- equation 9
self.update_net_1 = torch.nn.Sequential(torch.nn.Linear(self.in_features + self.hidden_features + self.n_variables, self.hidden_features), torch.nn.SiLU())
self.update_net_1 = torch.nn.Sequential(torch.nn.Linear(self.in_features + self.hidden_features + self.n_variables, self.hidden_features),
torch.nn.SiLU())
self.update_net_2 = torch.nn.Sequential(torch.nn.Linear(self.hidden_features, self.out_features), torch.nn.SiLU())

self.norm = InstanceNorm(self.hidden_features)


def forward(self, x, u, pos, variables, edge_index, batch):
"""
Propagate messages along edges
"""
def forward(self, edge_index, x, u, pos, variables, batch):
f = self.propagate(edge_index=edge_index, x=x, u=u, pos=pos, variables=variables)
f = self.norm(f, batch)
f= self.norm(f, batch)
return f


def message(self, x_i, x_j, u_i, u_j, pos_i, pos_j, variables_i):
"""
Message update following formula 8 of the paper
"""
message = self.message_net_1(torch.cat((x_i, x_j, u_i - u_j, pos_i - pos_j, variables_i), dim=-1))
message = self.message_net_1(torch.cat((x_i, x_j, u_i-u_j, pos_i-pos_j, variables_i), dim=-1))
message = self.message_net_2(message)
return message


def update(self, message, x, variables):
"""
Node update following formula 9 of the paper
"""
update = self.update_net_1(torch.cat((x, message, variables), dim=-1))
update = self.update_net_2(update)
if self.in_features == self.out_features:
Expand Down
Loading

0 comments on commit 02bf7c3

Please sign in to comment.