diff --git a/jax/_src/ad_checkpoint.py b/jax/_src/ad_checkpoint.py index 6f752e0e2e4a..109f32d8f1ae 100644 --- a/jax/_src/ad_checkpoint.py +++ b/jax/_src/ad_checkpoint.py @@ -28,6 +28,7 @@ from jax._src import config from jax._src import core from jax._src import dispatch +from jax._src import dtypes from jax._src import linear_util as lu from jax._src import effects from jax._src import source_info_util @@ -546,6 +547,9 @@ def remat_partial_eval(trace, *tracers, jaxpr, **params): jaxpr_known, in_used_known = pe.dce_jaxpr(jaxpr_known, out_used_known) num_res = sum(used_res) + # insert reduce_precision calls + jaxpr_known = _insert_reduce_precision(jaxpr_known, num_res) + # compute known outputs and residuals (hoisted out of remat primitive) _, in_consts_ = unzip2(t.pval for t in tracers if t.pval.is_known()) _, in_consts = partition_list(in_used_known, in_consts_) @@ -589,6 +593,39 @@ def remat_partial_eval(trace, *tracers, jaxpr, **params): return merge_lists(out_unknowns, out_knowns, out_jaxpr_tracers) pe.custom_partial_eval_rules[remat_p] = remat_partial_eval +@weakref_lru_cache +def _insert_reduce_precision(jaxpr: core.Jaxpr, num_res: int) -> core.Jaxpr: + res_vars = jaxpr.outvars[len(jaxpr.outvars) - num_res:] + invars, constvars, eqns = jaxpr.invars[:], jaxpr.constvars[:], jaxpr.eqns[:] + for v in res_vars: + if (not isinstance(v.aval, core.UnshapedArray) or + not dtypes.issubdtype(v.aval.dtype, np.inexact)): + continue + newvar = core.Var(v.suffix, v.aval) + finfo = dtypes.finfo(v.aval.dtype) + params = dict(exponent_bits=finfo.nexp, mantissa_bits=finfo.nmant) + if v in constvars or v in invars: + lst = constvars if v in constvars else invars + new_eqn = core.new_jaxpr_eqn( + [newvar], [v], lax_internal.reduce_precision_p, params, set()) + lst[lst.index(v)] = newvar + eqns.insert(0, new_eqn) + else: + (eqn_idx, eqn), = ((i, e) for i, e in enumerate(eqns) if v in e.outvars) + if (eqn.primitive == lax_internal.reduce_precision_p and + eqn.params == params): + continue + replace_eqn = eqn.replace(outvars=[v_ if v_ != v else newvar + for v_ in eqn.outvars]) + new_eqn = core.new_jaxpr_eqn( + [newvar], [v], lax_internal.reduce_precision_p, params, set(), + eqn.source_info, eqn.ctx) + eqns[eqn_idx] = replace_eqn + eqns.insert(eqn_idx+1, new_eqn) + new_jaxpr = jaxpr.replace(invars=invars, constvars=constvars, eqns=eqns) + config.enable_checks.value and core.check_jaxpr(new_jaxpr) + return new_jaxpr + def remat_partial_eval_custom_params_updater(*args): *_, params_known, params_staged = args return params_known, dict(params_staged, differentiated=True)