Skip to content

Commit

Permalink
buggy mpsolver
Browse files Browse the repository at this point in the history
  • Loading branch information
Dario Coscia authored and Dario Coscia committed Nov 22, 2023
1 parent 57b0f2c commit cb36d1e
Showing 1 changed file with 67 additions and 2 deletions.
69 changes: 67 additions & 2 deletions pina/solvers/mp_solver.py
Original file line number Diff line number Diff line change
Expand Up @@ -52,14 +52,79 @@ def __init__(

# inverse problem handling
if isinstance(self.problem, InverseProblem):
self._params = self.problem.unknown_parameters
raise ValueError('Message Passing only for forward problems.')
#self._params = self.problem.unknown_parameters
else:
self._params = None


def forward(self, x):
return self.neural_net(x)

def training_step(self, batch, batch_idx):

#############################
dataloader = self.trainer.train_dataloader
condition_idx = batch['condition']

for condition_id in range(condition_idx.min(), condition_idx.max()+1):

condition_name = dataloader.condition_names[condition_id]
condition = self.problem.conditions[condition_name]
pts = batch['pts']
out = batch['output']

if condition_name not in self.problem.conditions:
raise RuntimeError('Something wrong happened.')

# for data driven mode
if not hasattr(condition, 'output_points'):
raise NotImplementedError('Supervised solver works only in data-driven mode.')
#############################
u, x, variables = pts.shape.....

# Randomly choose number of unrollings
unrolling = 2
import random
unrolled_graphs = random.choice(unrolling)
steps = [t for t in range(graph_creator.tw,
graph_creator.t_res - graph_creator.tw - (graph_creator.tw * unrolled_graphs) + 1)]
# Randomly choose starting (time) point at the PDE solution manifold
random_steps = random.choices(steps, k=u.shape[0])
data, labels = graph_creator.create_data(u, random_steps)
if f'{model}' == 'GNN':
graph = graph_creator.create_graph(data, labels, x, variables, random_steps).to(device)
else:
data, labels = data.to(device), labels.to(device)

# Unrolling of the equation which serves as input at the current step
# This is the pushforward trick!!!
with torch.no_grad():
for _ in range(unrolled_graphs):
random_steps = [rs + graph_creator.tw for rs in random_steps]
_, labels = graph_creator.create_data(u_super, random_steps)
if f'{model}' == 'GNN':
pred = model(graph)
graph = graph_creator.create_next_graph(graph, pred, labels, random_steps).to(device)
else:
data = model(data)
labels = labels.to(device)

if f'{model}' == 'GNN':
pred = model(graph)
loss = criterion(pred, graph.y)
else:
pred = model(data)
loss = criterion(pred, labels)

return loss

def configure_optimizers(self):
"""Optimizer configuration for the solver.
:return: The optimizers and the schedulers
:rtype: tuple(list, list)
"""
return self.optimizers, [self.scheduler]

@property
def scheduler(self):
Expand Down

0 comments on commit cb36d1e

Please sign in to comment.