Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Involuntary Full Rematerialization #15207

Open
chaserileyroberts opened this issue Jul 22, 2024 · 2 comments
Open

Involuntary Full Rematerialization #15207

chaserileyroberts opened this issue Jul 22, 2024 · 2 comments

Comments

@chaserileyroberts
Copy link
Contributor

Ported this issue from google/jax#21562

This code

import jax
import numpy as np
import jax.numpy as jnp
from jax.sharding import PartitionSpec as PS, NamedSharding, Mesh


devices = np.asarray(jax.devices()).reshape((4, 2))
mesh = Mesh(devices, axis_names=('x', 'y'))

shardtype2 = NamedSharding(mesh, PS(None, ('x', 'y'), None))
shardtype1 = NamedSharding(mesh, PS('y', None, 'x'))

def f(a, b, c):
    d = a + b 
    d = jax.lax.with_sharding_constraint(d, shardtype2)
    return c + d 


fjit = jax.jit(f, in_shardings=(shardtype1, shardtype1, shardtype2), out_shardings=shardtype2)

a = jnp.arange((16 * 16 * 16)).reshape((16, 16, 16)).astype(np.float32)
b = jnp.arange((16 * 16 * 16)).reshape((16, 16, 16)).astype(np.float32)
c = jnp.arange((16 * 16 * 16)).reshape((16, 16, 16)).astype(np.float32)

print(fjit(a, b, c).block_until_ready())

Gives this warning

E0531 10:54:04.832741 2609008 spmd_partitioner.cc:569] [spmd] Involuntary full rematerialization. The compiler was not able to go from sharding {devices=[2,1,4]<=[4,2]T(1,0)} to {devices=[1,8,1]<=[8]} without doing a full rematerialization of the tensor. You probably want to enrich the sharding annotations to prevent this from happening.
E0531 10:54:04.832805 2609008 spmd_partitioner.cc:569] [spmd] Involuntary full rematerialization. The compiler was not able to go from sharding {devices=[2,1,4]<=[4,2]T(1,0)} to {devices=[1,8,1]<=[8]} without doing a full rematerialization of the tensor. You probably want to enrich the sharding annotations to prevent this from happening.

We've had to write our own resharding logic instead of solely relying on with_sharding_constraint to avoid this issue.

@ptoulme-aws
Copy link
Contributor

This means there was some conflict in sharding - potentially sharding that was propagated from an intermediate.
The XLA compiler had to rematerialize the full tensor to reshard it.

This PR improves logging to show you which Hlo Instruction has the conflict
#15402

@ptoulme-aws
Copy link
Contributor

If you want to further debug this, look at the Hlo dump after ShardingPropagation pass but before SPMD partitioning. Then using my logging PR look at the HloSharding metadata of that HloInstruction and the instructions around it. Most likely you will see there is a conflict like (4,8)->(8,4) triggered the reshard.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

2 participants