Skip to content

Commit

Permalink
Merge pull request #3905 from chiamp:lm1b
Browse files Browse the repository at this point in the history
PiperOrigin-RevId: 631512606
  • Loading branch information
Flax Authors committed May 7, 2024
2 parents 40c7eaa + fb5515c commit 3d98eda
Show file tree
Hide file tree
Showing 2 changed files with 4 additions and 4 deletions.
4 changes: 2 additions & 2 deletions examples/lm1b/train.py
Original file line number Diff line number Diff line change
Expand Up @@ -582,7 +582,7 @@ def encode_strings(strs, max_len):
) # pylint: disable=cell-var-from-loop
summary["learning_rate"] = lr
summary["perplexity"] = jnp.clip(
jnp.exp(summary["loss"]), a_max=1.0e4
jnp.exp(summary["loss"]), max=1.0e4
)
summary = {"train_" + k: v for k, v in summary.items()}
writer.write_scalars(step, summary)
Expand All @@ -598,7 +598,7 @@ def encode_strings(strs, max_len):
)
# (clipped) perplexity after averaging log-perplexitie
eval_results["perplexity"] = jnp.clip(
jnp.exp(eval_results["loss"]), a_max=1.0e4
jnp.exp(eval_results["loss"]), max=1.0e4
)
writer.write_scalars(
step, {"eval_" + k: v for k, v in eval_results.items()}
Expand Down
4 changes: 2 additions & 2 deletions flax/experimental/nnx/examples/lm1b/train.py
Original file line number Diff line number Diff line change
Expand Up @@ -606,7 +606,7 @@ def constructor(config: models.TransformerConfig, key: jax.Array):
) # pylint: disable=cell-var-from-loop
summary['learning_rate'] = lr
summary['perplexity'] = jnp.clip(
jnp.exp(summary['loss']), a_max=1.0e4
jnp.exp(summary['loss']), max=1.0e4
)
summary = {'train_' + k: v for k, v in summary.items()}
writer.write_scalars(step, summary)
Expand All @@ -621,7 +621,7 @@ def constructor(config: models.TransformerConfig, key: jax.Array):
)
# (clipped) perplexity after averaging log-perplexitie
eval_results['perplexity'] = jnp.clip(
jnp.exp(eval_results['loss']), a_max=1.0e4
jnp.exp(eval_results['loss']), max=1.0e4
)
writer.write_scalars(
step, {'eval_' + k: v for k, v in eval_results.items()}
Expand Down

0 comments on commit 3d98eda

Please sign in to comment.