Skip to content

Commit

Permalink
Refactor & test internal deprecation APIs
Browse files Browse the repository at this point in the history
The names and APIs were previously too similar and therefore somewhat confusing; this will be more clear I think.

PiperOrigin-RevId: 635615163
  • Loading branch information
Jake VanderPlas authored and jax authors committed May 21, 2024
1 parent 2eff241 commit d33a568
Show file tree
Hide file tree
Showing 6 changed files with 50 additions and 22 deletions.
2 changes: 1 addition & 1 deletion jax/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -179,7 +179,7 @@
del _ccache

from jax._src.deprecations import register as _register_deprecation
_register_deprecation("jax.experimental", "maps-module")
_register_deprecation("jax-experimental-maps-module")
del _register_deprecation

_deprecations = {
Expand Down
40 changes: 26 additions & 14 deletions jax/_src/deprecations.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@
# See the License for the specific language governing permissions and
# limitations under the License.

from dataclasses import dataclass
from types import ModuleType
import warnings

Expand Down Expand Up @@ -56,7 +57,7 @@ def getattr(name):
return getattr


def accelerate_module_deprecation(module: ModuleType, name: str) -> None:
def accelerate_getattr_deprecation(module: ModuleType, name: str) -> None:
"""Accelerate the deprecation of a module-level attribute.
Raises an AttributeError instead of a DeprecationWarning upon attribute access.
Expand All @@ -68,24 +69,35 @@ def accelerate_module_deprecation(module: ModuleType, name: str) -> None:
# The following mechanism is a separate one, for registering and
# accelerating deprecations that are not imports (for example, deprecations
# of a function argument).
# Maps a pair of strings to a boolean specifying whether the deprecation
# is accelerated. The intent is that non-accelerated deprecations will warn,
# and accelerated deprecations will error.
_registered_deprecations: dict[tuple[str, str], bool] = {}
# Maps a globally unique string ID to a DeprecationState, which tracks whether
# the deprecation is accelerated.
# The intent is that non-accelerated deprecations will warn, and accelerated
# deprecations will error.

@dataclass
class DeprecationState:
accelerated: bool = False

def register(module: str, key: str) -> None:
_registered_deprecations[module, key] = False
_registered_deprecations: dict[str, DeprecationState] = {}


def unregister(module: str, key: str) -> None:
_registered_deprecations.pop((module, key))
def register(deprecation_id: str) -> None:
_registered_deprecations[deprecation_id] = DeprecationState()


def accelerate(module: str, key: str) -> None:
assert (module, key) in _registered_deprecations
_registered_deprecations[module, key] = True
def unregister(deprecation_id: str) -> None:
if deprecation_id not in _registered_deprecations:
raise ValueError(f"{deprecation_id=!r} not registered.")
_registered_deprecations.pop(deprecation_id)


def is_accelerated(module: str, key: str) -> bool:
return _registered_deprecations[module, key]
def accelerate(deprecation_id: str) -> None:
if deprecation_id not in _registered_deprecations:
raise ValueError(f"{deprecation_id=!r} not registered.")
_registered_deprecations[deprecation_id].accelerated = True


def is_accelerated(deprecation_id: str) -> bool:
if deprecation_id not in _registered_deprecations:
raise ValueError(f"{deprecation_id=!r} not registered.")
return _registered_deprecations[deprecation_id].accelerated
4 changes: 2 additions & 2 deletions jax/_src/numpy/lax_numpy.py
Original file line number Diff line number Diff line change
Expand Up @@ -2704,7 +2704,7 @@ def _supports_buffer_protocol(obj):
https://jax.readthedocs.io/en/latest/faq.html).
"""

deprecations.register(__name__, "array-none")
deprecations.register("jax-numpy-array-none")

@util.implements(np.array, lax_description=_ARRAY_DOC)
def array(object: Any, dtype: DTypeLike | None = None, copy: bool = True,
Expand Down Expand Up @@ -2762,7 +2762,7 @@ def array(object: Any, dtype: DTypeLike | None = None, copy: bool = True,
leaves = tree_leaves(object, is_leaf=lambda x: x is None)
if any(leaf is None for leaf in leaves):
# Added Nov 16 2023
if deprecations.is_accelerated(__name__, "array-none"):
if deprecations.is_accelerated("jax-numpy-array-none"):
raise TypeError("None is not a valid value for jnp.array")
warnings.warn(
"None encountered in jnp.array(); this is currently treated as NaN. "
Expand Down
3 changes: 0 additions & 3 deletions jax/experimental/host_callback.py
Original file line number Diff line number Diff line change
Expand Up @@ -2014,9 +2014,6 @@ def _deprecated_stop_outfeed_receiver():
stop_outfeed_receiver = _deprecated_stop_outfeed_receiver
else:
from jax._src.deprecations import deprecation_getattr as _deprecation_getattr
from jax._src.deprecations import register
for deprecated in _deprecations.keys():
register(__name__, deprecated)
__getattr__ = _deprecation_getattr(__name__, _deprecations)
del _deprecation_getattr
del typing
2 changes: 1 addition & 1 deletion jax/experimental/maps.py
Original file line number Diff line number Diff line change
Expand Up @@ -41,7 +41,7 @@
" jax.experimental.shard_map nor jax.vmap are suitable for your use case."
)

if deprecations.is_accelerated("jax.experimental", "maps-module"):
if deprecations.is_accelerated("jax-experimental-maps-module"):
raise ImportError(_msg)
else:
warnings.warn(_msg, DeprecationWarning, stacklevel=2)
Expand Down
21 changes: 20 additions & 1 deletion tests/deprecation_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,12 +15,13 @@
import warnings

from absl.testing import absltest
from jax._src import deprecations
from jax._src import test_util as jtu
from jax._src.internal_test_util import deprecation_module as m

class DeprecationTest(absltest.TestCase):

def testDeprecation(self):
def testModuleDeprecation(self):
with warnings.catch_warnings():
warnings.simplefilter("error")
self.assertEqual(m.x, 42)
Expand All @@ -35,6 +36,24 @@ def testDeprecation(self):
"module .* has no attribute 'w'"):
_ = m.w

def testNamedDeprecation(self):
some_unique_id = "some-unique-id"
try:
deprecations.register(some_unique_id)
self.assertFalse(deprecations.is_accelerated(some_unique_id))
deprecations.accelerate(some_unique_id)
self.assertTrue(deprecations.is_accelerated(some_unique_id))
finally:
deprecations.unregister(some_unique_id)

msg = f"deprecation_id={some_unique_id!r} not registered"
with self.assertRaisesRegex(ValueError, msg):
deprecations.accelerate(some_unique_id)
with self.assertRaisesRegex(ValueError, msg):
deprecations.is_accelerated(some_unique_id)
with self.assertRaisesRegex(ValueError, msg):
deprecations.unregister(some_unique_id)


if __name__ == "__main__":
absltest.main(testLoader=jtu.JaxTestLoader())

0 comments on commit d33a568

Please sign in to comment.