Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Casted finfo attributes from np.float to float for Array API compliance #21424

Closed
wants to merge 1 commit into from

Conversation

vfdev-5
Copy link
Collaborator

@vfdev-5 vfdev-5 commented May 24, 2024

Description:

>>> import jax.numpy as jnp
>>> f = jnp.finfo("float16")
>>> f.__dict__

Output on main:

{
  'dtype': dtype('float16'), 'precision': 3, 'iexp': 5, 'maxexp': 16, 
  'minexp': -14, 'negep': -11, 'machep': -10, 'resolution': np.float16(0.001), 
  'epsneg': np.float16(0.0004883), 'smallest_subnormal': np.float16(6e-08), 'bits': 16, 'max': np.float16(65500.0), 
  'min': np.float16(-65500.0), 'eps': np.float16(0.000977), 'nexp': 5, 'nmant': 10, '_machar': <numpy._core.getlimits.MachArLike object at 0x7f4b3bdaff70>, 
  '_str_tiny': '6.10352e-05', '_str_max': '6.55040e+04', '_str_epsneg': '4.88281e-04', '_str_eps': '9.76562e-04',
  '_str_resolution': '1.00040e-03', '_str_smallest_normal': '6.10352e-05', '_str_smallest_subnormal': '5.96046e-08'
}

This PR:

{
  'dtype': dtype('float16'), 'precision': 3, 'iexp': 5, 'maxexp': 16, 
  'minexp': -14, 'negep': -11, 'machep': -10, 'resolution': 0.0010004043579101562, 
  'epsneg': 0.00048828125, 'smallest_subnormal': 5.960464477539063e-08, 'bits': 16, 'max': 65504.0, 
  'min': -65504.0, 'eps': 0.0009765625, 'nexp': 5, 'nmant': 10, '_machar': <numpy._core.getlimits.MachArLike object at 0x7fd87f65aa10>, 
  '_str_tiny': '6.10352e-05', '_str_max': '6.55040e+04', '_str_epsneg': '4.88281e-04', '_str_eps': '9.76562e-04',
  '_str_resolution': '1.00040e-03', '_str_smallest_normal': '6.10352e-05', '_str_smallest_subnormal': '5.96046e-08'
}

@jakevdp
Copy link
Collaborator

jakevdp commented May 24, 2024

I think the better fix here would be to change this at the source, in ml_dtypes. The only reason it returns scalar typed values rather than built-in floats is because it was following the precedent in NumPy; presumably NumPy is changing, so we should change ml_dtypes as well.

@vfdev-5
Copy link
Collaborator Author

vfdev-5 commented May 24, 2024

I was pointed out to this issue: data-apis/array-api#405 where essentially

numpy has support for extended precision types, so python float doesn't work for longdouble

Source: numpy/numpy#26523

It means that most probably numpy will keep eps, max, min, smallest_normal attributes as np.floating and array-API will add additional note on returned types.
@jakevdp @Micky774 so, we may want to revert changes in this PR and update jax/experimental/array_api/_data_type_functions.py to return np.floats as numpy ?

@jakevdp
Copy link
Collaborator

jakevdp commented May 24, 2024

Yeah in that case let's not make any changes.

@vfdev-5 vfdev-5 closed this May 24, 2024
@vfdev-5 vfdev-5 deleted the finfo-with-float-attrs branch May 24, 2024 16:18
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

2 participants