From ffb9b7a78425975f4c0c75cd39fe5c97d7d2c5b8 Mon Sep 17 00:00:00 2001 From: Matthew Johnson Date: Tue, 2 Jul 2024 23:13:09 +0000 Subject: [PATCH] make remat reduce precision of saved values to avoid xla excess precision problem: f(x) != value_and_grad(f)(x)[0] ?? Co-authored-by: Peter Hawkins --- jax/_src/ad_checkpoint.py | 39 ++++++++++++++++++++++++++++++++++++++- 1 file changed, 38 insertions(+), 1 deletion(-) diff --git a/jax/_src/ad_checkpoint.py b/jax/_src/ad_checkpoint.py index 6f752e0e2e4a..8b25a0746926 100644 --- a/jax/_src/ad_checkpoint.py +++ b/jax/_src/ad_checkpoint.py @@ -28,6 +28,7 @@ from jax._src import config from jax._src import core from jax._src import dispatch +from jax._src import dtypes from jax._src import linear_util as lu from jax._src import effects from jax._src import source_info_util @@ -546,10 +547,13 @@ def remat_partial_eval(trace, *tracers, jaxpr, **params): jaxpr_known, in_used_known = pe.dce_jaxpr(jaxpr_known, out_used_known) num_res = sum(used_res) + # insert reduce_precision calls + jaxpr_known_ = _insert_reduce_precision(jaxpr_known, num_res) + # compute known outputs and residuals (hoisted out of remat primitive) _, in_consts_ = unzip2(t.pval for t in tracers if t.pval.is_known()) _, in_consts = partition_list(in_used_known, in_consts_) - out_consts = core.eval_jaxpr(jaxpr_known, (), *in_consts) + out_consts = core.eval_jaxpr(jaxpr_known_, (), *in_consts) out_knowns, residuals = split_list(out_consts, [len(out_consts)-num_res]) # set up unknown outputs with a recipe to call remat @@ -589,6 +593,39 @@ def remat_partial_eval(trace, *tracers, jaxpr, **params): return merge_lists(out_unknowns, out_knowns, out_jaxpr_tracers) pe.custom_partial_eval_rules[remat_p] = remat_partial_eval +@weakref_lru_cache +def _insert_reduce_precision(jaxpr: core.Jaxpr, num_res: int) -> core.Jaxpr: + res_vars = jaxpr.outvars[len(jaxpr.outvars) - num_res:] + invars, constvars, eqns = jaxpr.invars[:], jaxpr.constvars[:], jaxpr.eqns[:] + for v in res_vars: + if (not isinstance(v.aval, core.UnshapedArray) or + not dtypes.issubdtype(v.aval.dtype, np.inexact)): + continue + newvar = core.Var(v.suffix, v.aval) + finfo = dtypes.finfo(v.aval.dtype) + params = dict(exponent_bits=finfo.nexp, mantissa_bits=finfo.nmant) + if v in constvars or v in invars: + lst = constvars if v in constvars else invars + new_eqn = core.new_jaxpr_eqn( + [newvar], [v], lax_internal.reduce_precision_p, params, set()) + lst[lst.index(v)] = newvar + eqns.insert(0, new_eqn) + else: + (eqn_idx, eqn), = ((i, e) for i, e in enumerate(eqns) if v in e.outvars) + if (eqn.primitive == lax_internal.reduce_precision_p and + eqn.params == params): + continue + replace_eqn = eqn.replace(outvars=[v_ if v_ != v else newvar + for v_ in eqn.outvars]) + new_eqn = core.new_jaxpr_eqn( + [newvar], [v], lax_internal.reduce_precision_p, params, set(), + eqn.source_info, eqn.ctx) + eqns[eqn_idx] = replace_eqn + eqns.insert(eqn_idx+1, new_eqn) + new_jaxpr = jaxpr.replace(invars=invars, constvars=constvars, eqns=eqns) + config.enable_checks.value and core.check_jaxpr(new_jaxpr) + return new_jaxpr + def remat_partial_eval_custom_params_updater(*args): *_, params_known, params_staged = args return params_known, dict(params_staged, differentiated=True)