Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

better log10 doc #22192

Open
wants to merge 1 commit into
base: main
Choose a base branch
from
Open

better log10 doc #22192

wants to merge 1 commit into from

Conversation

pkgoogle
Copy link
Contributor

better log10 doc

Part of #21461

Examples:
>>> x1 = jnp.array([0.01, 0.1, 1, 10, 100, 1000])
>>> jnp.log10(x1)
Array([-1.9999999 , -0.99999994, 0. , 0.99999994, 1.9999999 ,
Copy link
Collaborator

@jakevdp jakevdp Jun 28, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Floating point outputs like this are problematic for doctests, because different backends may round results differently. One way to deal with this is to print the values within a jnp.printoptions context; you can check out some examples in jax/_src/numpy.lax_numpy.py

@jakevdp jakevdp assigned jakevdp and dfm and unassigned jakevdp Jun 28, 2024
Copy link
Member

@dfm dfm left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I have requested a small change. After you remove that example, can you squash into a single commit? Then we can merge. Thanks!

Comment on lines 856 to 858
>>> x2 = jnp.array([-0.01, -0.1, -1, -10, -100, -1000])
>>> jnp.log10(x2)
Array([nan, nan, nan, nan, nan, nan], dtype=float32)
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Like in #22191, I think we should skip this example with negative numbers.

removing complex example

forced precision context to work better with doctests

removing negative example
Copy link
Member

@dfm dfm left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Looks good - thanks!

@google-ml-butler google-ml-butler bot added kokoro:force-run pull ready Ready for copybara import and testing labels Jul 2, 2024
@dfm
Copy link
Member

dfm commented Jul 3, 2024

@pkgoogle — It looks like the merge of this one collided with the log2 PR. Would you mind rebasing your branch onto the current jax main branch? Sorry about that!

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
pull ready Ready for copybara import and testing
Projects
None yet
Development

Successfully merging this pull request may close these issues.

None yet

4 participants