Skip to content

Commit

Permalink
Merge pull request #3936 from google:nnx-stabilize
Browse files Browse the repository at this point in the history
PiperOrigin-RevId: 636951986
  • Loading branch information
Flax Authors committed May 24, 2024
2 parents b1cb952 + 67fa051 commit c84510d
Show file tree
Hide file tree
Showing 129 changed files with 451 additions and 385 deletions.
2 changes: 1 addition & 1 deletion CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -82,7 +82,7 @@ vNext

0.8.0
-----
- Added [NNX](https://github.com/google/flax/tree/main/flax/experimental/nnx#nnx), a neural network library for JAX that provides a simple yet powerful module system that adheres to standard Python semantics. Its aim is to combine the robustness of Linen with a simplified, Pythonic API akin to that of PyTorch.
- Added [NNX](https://github.com/google/flax/tree/main/flax/nnx#nnx), a neural network library for JAX that provides a simple yet powerful module system that adheres to standard Python semantics. Its aim is to combine the robustness of Linen with a simplified, Pythonic API akin to that of PyTorch.
- Added `nn.compact_name_scope` decorator that enables methods to act as compact name scopes as with regular Haiku methods. This makes porting Haiku code easier.
- Add copy() method to Module. This is a user-friendly version of the internal clone() method with better
defaults for common use cases.
Expand Down
2 changes: 2 additions & 0 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,8 @@
| [**What does Flax look like?**](#what-does-flax-look-like)
| [**Documentation**](https://flax.readthedocs.io/)

**📣 NEW**: Check out the [**NNX**](https://flax.readthedocs.io/en/latest/nnx/index.html) API!

This README is a very short intro. **To learn everything you need to know about Flax, refer to our [full documentation](https://flax.readthedocs.io/).**

Flax was originally started by engineers and researchers within the Brain Team in Google Research (in close collaboration with the JAX team), and is now developed jointly with the open source community.
Expand Down
8 changes: 0 additions & 8 deletions docs/api_reference/flax.experimental.nnx/nn/stochastic.rst

This file was deleted.

This file was deleted.

7 changes: 0 additions & 7 deletions docs/api_reference/flax.experimental.nnx/visualization.rst

This file was deleted.

Original file line number Diff line number Diff line change
@@ -1,8 +1,8 @@
graph
------------------------

.. automodule:: flax.experimental.nnx
.. currentmodule:: flax.experimental.nnx
.. automodule:: flax.nnx
.. currentmodule:: flax.nnx


.. autofunction:: split
Expand Down
Original file line number Diff line number Diff line change
@@ -1,8 +1,8 @@
helpers
------------------------

.. automodule:: flax.experimental.nnx
.. currentmodule:: flax.experimental.nnx
.. automodule:: flax.nnx
.. currentmodule:: flax.nnx

.. autoclass:: Dict
:members:
Expand Down
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
flax.experimental.nnx
flax.nnx
------------------------

Experimental API. See the `NNX page <https://flax.readthedocs.io/en/latest/experimental/nnx/index.html>`__ for more details.
Experimental API. See the `NNX page <https://flax.readthedocs.io/en/latest/nnx/index.html>`__ for more details.

.. toctree::
:maxdepth: 3
Expand Down
Original file line number Diff line number Diff line change
@@ -1,8 +1,8 @@
module
------------------------

.. automodule:: flax.experimental.nnx
.. currentmodule:: flax.experimental.nnx
.. automodule:: flax.nnx
.. currentmodule:: flax.nnx

.. autoclass:: Module
:members:
Expand Down
Original file line number Diff line number Diff line change
@@ -1,8 +1,8 @@
Activation functions
------------------------

.. automodule:: flax.experimental.nnx
.. currentmodule:: flax.experimental.nnx
.. automodule:: flax.nnx
.. currentmodule:: flax.nnx

.. autofunction:: celu
.. autofunction:: elu
Expand Down
Original file line number Diff line number Diff line change
@@ -1,8 +1,8 @@
Attention
------------------------

.. automodule:: flax.experimental.nnx
.. currentmodule:: flax.experimental.nnx
.. automodule:: flax.nnx
.. currentmodule:: flax.nnx

.. autoclass:: MultiHeadAttention
:members:
Expand Down
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
nn
----------------------------

Experimental API. See the `NNX page <https://flax.readthedocs.io/en/latest/experimental/nnx/index.html>`__ for more details.
Experimental API. See the `NNX page <https://flax.readthedocs.io/en/latest/nnx/index.html>`__ for more details.

.. toctree::
:maxdepth: 3
Expand Down
Original file line number Diff line number Diff line change
@@ -1,8 +1,8 @@
Initializers
------------------------

.. automodule:: flax.experimental.nnx.initializers
.. currentmodule:: flax.experimental.nnx.initializers
.. automodule:: flax.nnx.initializers
.. currentmodule:: flax.nnx.initializers

.. autofunction:: constant
.. autofunction:: delta_orthogonal
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -3,8 +3,8 @@ Linear

NNX linear layer classes.

.. automodule:: flax.experimental.nnx
.. currentmodule:: flax.experimental.nnx
.. automodule:: flax.nnx
.. currentmodule:: flax.nnx

.. autoclass:: Conv
:members:
Expand Down
Original file line number Diff line number Diff line change
@@ -1,8 +1,8 @@
Normalization
------------------------

.. automodule:: flax.experimental.nnx
.. currentmodule:: flax.experimental.nnx
.. automodule:: flax.nnx
.. currentmodule:: flax.nnx

.. autoclass:: BatchNorm
:members:
Expand Down
8 changes: 8 additions & 0 deletions docs/api_reference/flax.nnx/nn/stochastic.rst
Original file line number Diff line number Diff line change
@@ -0,0 +1,8 @@
Stochastic
------------------------

.. automodule:: flax.nnx
.. currentmodule:: flax.nnx

.. autoclass:: Dropout
:members:
Original file line number Diff line number Diff line change
@@ -1,8 +1,8 @@
rnglib
------------------------

.. automodule:: flax.experimental.nnx
.. currentmodule:: flax.experimental.nnx
.. automodule:: flax.nnx
.. currentmodule:: flax.nnx

.. autoclass:: Rngs
:members:
Expand Down
Original file line number Diff line number Diff line change
@@ -1,8 +1,8 @@
spmd
------------------------

.. automodule:: flax.experimental.nnx
.. currentmodule:: flax.experimental.nnx
.. automodule:: flax.nnx
.. currentmodule:: flax.nnx

.. autofunction:: get_partition_spec
.. autofunction:: get_named_sharding
Expand Down
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
training
----------------------------

Experimental API. See the `NNX page <https://flax.readthedocs.io/en/latest/experimental/nnx/index.html>`__ for more details.
Experimental API. See the `NNX page <https://flax.readthedocs.io/en/latest/nnx/index.html>`__ for more details.

.. toctree::
:maxdepth: 3
Expand Down
Original file line number Diff line number Diff line change
@@ -1,8 +1,8 @@
Metrics
------------------------

.. automodule:: flax.experimental.nnx.metrics
.. currentmodule:: flax.experimental.nnx.metrics
.. automodule:: flax.nnx.metrics
.. currentmodule:: flax.nnx.metrics

.. autoclass:: Metric
:members:
Expand Down
8 changes: 8 additions & 0 deletions docs/api_reference/flax.nnx/training/optimizer.rst
Original file line number Diff line number Diff line change
@@ -0,0 +1,8 @@
Optimizer
------------------------

.. automodule:: flax.nnx.optimizer
.. currentmodule:: flax.nnx.optimizer

.. autoclass:: Optimizer
:members:
Original file line number Diff line number Diff line change
@@ -1,8 +1,8 @@
transforms
------------------------

.. automodule:: flax.experimental.nnx
.. currentmodule:: flax.experimental.nnx
.. automodule:: flax.nnx
.. currentmodule:: flax.nnx

.. autoclass:: JIT
:members:
Expand Down
Original file line number Diff line number Diff line change
@@ -1,8 +1,8 @@
variables
------------------------

.. automodule:: flax.experimental.nnx
.. currentmodule:: flax.experimental.nnx
.. automodule:: flax.nnx
.. currentmodule:: flax.nnx

.. autoclass:: BatchStat
:members:
Expand Down
7 changes: 7 additions & 0 deletions docs/api_reference/flax.nnx/visualization.rst
Original file line number Diff line number Diff line change
@@ -0,0 +1,7 @@
visualization
------------------------

.. automodule:: flax.nnx
.. currentmodule:: flax.nnx

.. autofunction:: display
2 changes: 1 addition & 1 deletion docs/api_reference/index.rst
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@ API Reference
flax.core.frozen_dict
flax.cursor
flax.errors
flax.experimental.nnx/index
flax.nnx/index
flax.jax_utils
flax.linen/index
flax.serialization
Expand Down
15 changes: 13 additions & 2 deletions docs/conf.py
Original file line number Diff line number Diff line change
Expand Up @@ -110,6 +110,16 @@

html_extra_path = ['robots.txt']

# href with no underline and white bold text color
announcement = """
<a
href="https://flax.readthedocs.io/en/latest/nnx/index.html"
style="text-decoration: none; color: white;"
>
📣 Check out the new <b>NNX</b> API!
</a>
"""

html_theme_options = {
'repository_url': 'https://github.com/google/flax',
'use_repository_button': True, # add a 'link to repository' button
Expand All @@ -122,6 +132,7 @@
},
'prev_next_buttons_location': None,
'show_navbar_depth': 1,
'announcement': announcement,
}

# -- Options for myst ----------------------------------------------
Expand All @@ -135,7 +146,7 @@
nb_execution_excludepatterns = [
'quick_start.ipynb', # <-- times out
'transfer_learning.ipynb', # <-- transformers requires flax<=0.7.0
'flax/experimental/nnx', # exclude nnx
'flax/nnx', # exclude nnx
]
# raise exceptions on execution so CI can catch errors
nb_execution_allow_errors = False
Expand All @@ -151,7 +162,7 @@
doctest_global_setup = """
import jax
import jax.numpy as jnp
from flax.experimental import nnx
from flax import nnx
import logging as slog
from absl import logging as alog
Expand Down
7 changes: 0 additions & 7 deletions docs/experimental/index.rst

This file was deleted.

2 changes: 1 addition & 1 deletion docs/guides/flax_fundamentals/flax_basics.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -951,7 +951,7 @@
"source": [
"### Exporting to Tensorflow's SavedModel with jax2tf\n",
"\n",
"JAX released an experimental converter called [jax2tf](https://github.com/google/jax/tree/main/jax/experimental/jax2tf), which allows converting trained Flax models into Tensorflow's SavedModel format (so it can be used for [TF Hub](https://www.tensorflow.org/hub), [TF.lite](https://www.tensorflow.org/lite), [TF.js](https://www.tensorflow.org/js), or other downstream applications). The repository contains more documentation and has various examples for Flax."
"JAX released an experimental converter called [jax2tf](https://github.com/google/jax/tree/main/jax/jax2tf), which allows converting trained Flax models into Tensorflow's SavedModel format (so it can be used for [TF Hub](https://www.tensorflow.org/hub), [TF.lite](https://www.tensorflow.org/lite), [TF.js](https://www.tensorflow.org/js), or other downstream applications). The repository contains more documentation and has various examples for Flax."
]
}
],
Expand Down
2 changes: 1 addition & 1 deletion docs/guides/flax_fundamentals/flax_basics.md
Original file line number Diff line number Diff line change
Expand Up @@ -469,4 +469,4 @@ Flax provides a handy wrapper - `TrainState` - that simplifies the above code. C

### Exporting to Tensorflow's SavedModel with jax2tf

JAX released an experimental converter called [jax2tf](https://github.com/google/jax/tree/main/jax/experimental/jax2tf), which allows converting trained Flax models into Tensorflow's SavedModel format (so it can be used for [TF Hub](https://www.tensorflow.org/hub), [TF.lite](https://www.tensorflow.org/lite), [TF.js](https://www.tensorflow.org/js), or other downstream applications). The repository contains more documentation and has various examples for Flax.
JAX released an experimental converter called [jax2tf](https://github.com/google/jax/tree/main/jax/jax2tf), which allows converting trained Flax models into Tensorflow's SavedModel format (so it can be used for [TF Hub](https://www.tensorflow.org/hub), [TF.lite](https://www.tensorflow.org/lite), [TF.js](https://www.tensorflow.org/js), or other downstream applications). The repository contains more documentation and has various examples for Flax.
6 changes: 4 additions & 2 deletions docs/index.rst
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,7 @@ both in the open source community
(like `Hugging Face <https://huggingface.co/flax-community>`__)
and at Google
(like
`PaLM <https://ai.googleblog.com/2022/04/pathways-language-model-palm-scaling-to.html>`__,
`Gemini <https://deepmind.google/technologies/gemini>`__,
`Imagen <https://imagen.research.google>`__,
`Scenic <https://github.com/google-research/scenic/>`__,
and `Big Vision <https://github.com/google-research/big_vision>`__).
Expand Down Expand Up @@ -309,6 +309,8 @@ Notable examples in Flax include:



.. role:: bold
:class: bold

.. toctree::
:hidden:
Expand All @@ -325,4 +327,4 @@ Notable examples in Flax include:
contributing
experimental
api_reference/index
experimental/index
NNX <nnx/index>
Loading

0 comments on commit c84510d

Please sign in to comment.