Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Issue with problem.make_net Shape Compatibility #110

Open
Duxo opened this issue Jul 28, 2024 · 1 comment
Open

Issue with problem.make_net Shape Compatibility #110

Duxo opened this issue Jul 28, 2024 · 1 comment

Comments

@Duxo
Copy link

Duxo commented Jul 28, 2024

Description:

I am experiencing an issue with the problem.make_net method. When I try to generate a trained network using the center parameter from the searcher.status, the method fails unless I manually call squeeze() on the center parameter.

Details:

  • Library Version: evotorch==0.5.1
  • PyTorch Version: torch==2.3.0
  • PyTorch Geometric: torch_geometric==2.5.3
  • Python Version: Python 3.10.12
  • Operating System: Ubuntu 22.04.4 LTS

Code to Reproduce:

import torch
import torch.nn.functional as F
from torch.nn import Linear
from torch_geometric.nn import (
    GCNConv,
    global_mean_pool,
)
from evotorch.neuroevolution import NEProblem
from evotorch.algorithms import CMAES

class GCN_xs(torch.nn.Module):
    def __init__(self):
        super().__init__()

        hidden_dim = 64
        node_features = 15

        self.conv1 = GCNConv(node_features, hidden_dim)
        self.conv2 = GCNConv(hidden_dim, hidden_dim)
        self.fc1 = Linear(hidden_dim, hidden_dim)

    def forward(self, x, edge_index, batch):
        x = self.conv1(x, edge_index)
        x = F.relu(x)
        x = self.conv2(x, edge_index)
        x = F.relu(x)
        x = global_mean_pool(x, batch)
        x = self.fc1(x)
        return x

def fitness(network):
    return 1

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
network = GCN_xs()

problem = NEProblem(
    objective_sense="max",
    network=network,
    network_eval_func=fitness,
    device=device,
)
searcher = CMAES(problem, stdev_init=0.01)
searcher.run(0)
trained = problem.make_net(searcher.status["center"])

Issue:

The code above raises an error when trying to generate a trained network using problem.make_net(searcher.status["center"]). However, if I modify the last line to:

trained = problem.make_net(searcher.status["center"].squeeze())

it works without any issues.

Expected Behavior:

problem.make_net should handle the shape of the center parameter directly without needing to call squeeze().

engintoklu added a commit that referenced this issue Jul 30, 2024
Addressed issue:
#110

During the initialization of CMAES, the center
point of the search was generated not as a vector
of length `n`, but as a tensor of shape `(1, n)`.
Because of this, the reported "center" solution
in the status dictionary also ended up with an
unexpected leftmost dimension with size 1.

This fix introduces a `squeeze()` operation on
the initial center tensor, and also a shape
verification, ensuring that the center tensor is
1-dimensional and has a correct length.
With these changes, the reported "center" in the
status dictionary becomes 1-dimensional.
@engintoklu
Copy link
Collaborator

Hello @Duxo,

Thank you very much for this very helpful feedback!

Pull request addressing this issue is here: #111

Feel free to let us know whether or not this pull request correctly addresses the issue you encountered.

Thanks again!

flukeskywalker pushed a commit that referenced this issue Aug 2, 2024
Addressed issue:
#110

During the initialization of CMAES, the center
point of the search was generated not as a vector
of length `n`, but as a tensor of shape `(1, n)`.
Because of this, the reported "center" solution
in the status dictionary also ended up with an
unexpected leftmost dimension with size 1.

This fix introduces a `squeeze()` operation on
the initial center tensor, and also a shape
verification, ensuring that the center tensor is
1-dimensional and has a correct length.
With these changes, the reported "center" in the
status dictionary becomes 1-dimensional.
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

2 participants