Skip to content

Commit

Permalink
make remat reduce precision of saved values to avoid xla excess preci…
Browse files Browse the repository at this point in the history
…sion

problem: f(x) != value_and_grad(f)(x)[0] ??

Co-authored-by: Peter Hawkins <[email protected]>
  • Loading branch information
mattjj and hawkinsp committed Jul 2, 2024
1 parent fcaeea4 commit 17e31eb
Showing 1 changed file with 37 additions and 0 deletions.
37 changes: 37 additions & 0 deletions jax/_src/ad_checkpoint.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,7 @@
from jax._src import config
from jax._src import core
from jax._src import dispatch
from jax._src import dtypes
from jax._src import linear_util as lu
from jax._src import effects
from jax._src import source_info_util
Expand Down Expand Up @@ -546,6 +547,9 @@ def remat_partial_eval(trace, *tracers, jaxpr, **params):
jaxpr_known, in_used_known = pe.dce_jaxpr(jaxpr_known, out_used_known)
num_res = sum(used_res)

# insert reduce_precision calls
jaxpr_known = _insert_reduce_precision(jaxpr_known, num_res)

# compute known outputs and residuals (hoisted out of remat primitive)
_, in_consts_ = unzip2(t.pval for t in tracers if t.pval.is_known())
_, in_consts = partition_list(in_used_known, in_consts_)
Expand Down Expand Up @@ -589,6 +593,39 @@ def remat_partial_eval(trace, *tracers, jaxpr, **params):
return merge_lists(out_unknowns, out_knowns, out_jaxpr_tracers)
pe.custom_partial_eval_rules[remat_p] = remat_partial_eval

@weakref_lru_cache
def _insert_reduce_precision(jaxpr: core.Jaxpr, num_res: int) -> core.Jaxpr:
res_vars = jaxpr.outvars[len(jaxpr.outvars) - num_res:]
invars, constvars, eqns = jaxpr.invars[:], jaxpr.constvars[:], jaxpr.eqns[:]
for v in res_vars:
if (not isinstance(v.aval, core.UnshapedArray) or
not dtypes.issubdtype(v.aval.dtype, np.inexact)):
continue
newvar = core.Var(v.suffix, v.aval)
finfo = dtypes.finfo(v.aval.dtype)
params = dict(exponent_bits=finfo.nexp, mantissa_bits=finfo.nmant)
if v in constvars or v in invars:
lst = constvars if v in constvars else invars
new_eqn = core.new_jaxpr_eqn(
[newvar], [v], lax_internal.reduce_precision_p, params, set())
lst[lst.index(v)] = newvar
eqns.insert(0, new_eqn)
else:
(eqn_idx, eqn), = ((i, e) for i, e in enumerate(eqns) if v in e.outvars)
if (eqn.primitive == lax_internal.reduce_precision_p and
eqn.params == params):
continue
replace_eqn = eqn.replace(outvars=[v_ if v_ != v else newvar
for v_ in eqn.outvars])
new_eqn = core.new_jaxpr_eqn(
[newvar], [v], lax_internal.reduce_precision_p, params, set(),
eqn.source_info, eqn.ctx)
eqns[eqn_idx] = replace_eqn
eqns.insert(eqn_idx+1, new_eqn)
new_jaxpr = jaxpr.replace(invars=invars, constvars=constvars, eqns=eqns)
config.enable_checks.value and core.check_jaxpr(new_jaxpr)
return new_jaxpr

def remat_partial_eval_custom_params_updater(*args):
*_, params_known, params_staged = args
return params_known, dict(params_staged, differentiated=True)
Expand Down

0 comments on commit 17e31eb

Please sign in to comment.