diff --git a/brax/training/agents/apg/train.py b/brax/training/agents/apg/train.py index f5fcb0e..87b198f 100644 --- a/brax/training/agents/apg/train.py +++ b/brax/training/agents/apg/train.py @@ -310,7 +310,7 @@ def train( specs.Array((env.observation_size,), jnp.dtype(dtype)) ), ) - training_state = jax.device_put_replicated( + training_state = pmap.device_put_replicated( training_state, jax.local_devices()[:local_devices_to_use] ) diff --git a/brax/training/agents/ppo/train.py b/brax/training/agents/ppo/train.py index 9aec960..6624733 100644 --- a/brax/training/agents/ppo/train.py +++ b/brax/training/agents/ppo/train.py @@ -753,7 +753,7 @@ def train( {}, ) - training_state = jax.device_put_replicated( + training_state = pmap.device_put_replicated( training_state, jax.local_devices()[:local_devices_to_use] ) diff --git a/brax/training/agents/sac/train.py b/brax/training/agents/sac/train.py index be716e9..8dcf3bf 100644 --- a/brax/training/agents/sac/train.py +++ b/brax/training/agents/sac/train.py @@ -108,7 +108,7 @@ def _init_training_state( alpha_params=log_alpha, normalizer_params=normalizer_params, ) - return jax.device_put_replicated( + return pmap.device_put_replicated( training_state, jax.local_devices()[:local_devices_to_use] ) diff --git a/brax/training/pmap.py b/brax/training/pmap.py index 82760fc..af62ef8 100644 --- a/brax/training/pmap.py +++ b/brax/training/pmap.py @@ -19,12 +19,23 @@ from typing import Any import jax import jax.numpy as jnp +from jax.sharding import Mesh, NamedSharding, PartitionSpec as P +import numpy as np + + +def device_put_replicated(x, devices): + """Drop-in replacement for jax.device_put_replicated supporting pytrees.""" + mesh = Mesh(np.array(devices), ('x',)) + sharding = NamedSharding(mesh, P('x')) + return jax.tree.map( + lambda y: jax.device_put(jnp.stack([y] * len(devices)), sharding), x + ) def bcast_local_devices(value, local_devices_to_use=1): """Broadcasts an object to all local devices.""" devices = jax.local_devices()[:local_devices_to_use] - return jax.device_put_replicated(value, devices) + return device_put_replicated(value, devices) def synchronize_hosts():