Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

optax.multi_transform + nnx.State/nnx.Optimizer troubles #3955

Closed
cgarciae opened this issue Jun 1, 2024 Discussed in #3954 · 0 comments · Fixed by #3964
Closed

optax.multi_transform + nnx.State/nnx.Optimizer troubles #3955

cgarciae opened this issue Jun 1, 2024 Discussed in #3954 · 0 comments · Fixed by #3964
Assignees
Labels
Priority: P1 - soon Response within 5 business days. Resolution within 30 days. (Assignee required)

Comments

@cgarciae
Copy link
Collaborator

cgarciae commented Jun 1, 2024

Discussed in #3954

Originally posted by yklcs June 1, 2024
optax.multi_transform defines multiple transforms with a Mapping[Hashable, GradientTransformation] and uses a PyTree or function to map parameters to the key.

Using optax.multi_transform with nnx.Optimizer means said mapping of type nnx.State is needed.
nnx.State is typed to use StateLeaf which means we can't use string or integer keys.
While ignoring typing does work, it feels brittle and might end up broken later.
Is there any other solution for this problem?

tx = optax.multi_transform(
    {
        "weights": optax.adamw(learning_rate, momentum),
        "biases": optax.adamw(learning_rate, momentum),
        },
        # this doesn't work:
        # {
        #     "weights": "weights",
        #     "biases": "biases",
        # },
        # this does, but is it safe?:
        nnx.State({
            "weights": "weights",
            "biases": "biases"
        })
    }
)
```</div>
@cgarciae cgarciae self-assigned this Jun 3, 2024
@cgarciae cgarciae added the Priority: P1 - soon Response within 5 business days. Resolution within 30 days. (Assignee required) label Jun 3, 2024
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
Priority: P1 - soon Response within 5 business days. Resolution within 30 days. (Assignee required)
Projects
None yet
Development

Successfully merging a pull request may close this issue.

1 participant