1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
|
diff --git a/brax/training/agents/apg/train.py b/brax/training/agents/apg/train.py
index f5fcb0e..87b198f 100644
--- a/brax/training/agents/apg/train.py
+++ b/brax/training/agents/apg/train.py
@@ -310,7 +310,7 @@ def train(
specs.Array((env.observation_size,), jnp.dtype(dtype))
),
)
- training_state = jax.device_put_replicated(
+ training_state = pmap.device_put_replicated(
training_state, jax.local_devices()[:local_devices_to_use]
)
diff --git a/brax/training/agents/ppo/train.py b/brax/training/agents/ppo/train.py
index 9aec960..6624733 100644
--- a/brax/training/agents/ppo/train.py
+++ b/brax/training/agents/ppo/train.py
@@ -753,7 +753,7 @@ def train(
{},
)
- training_state = jax.device_put_replicated(
+ training_state = pmap.device_put_replicated(
training_state, jax.local_devices()[:local_devices_to_use]
)
diff --git a/brax/training/agents/sac/train.py b/brax/training/agents/sac/train.py
index be716e9..8dcf3bf 100644
--- a/brax/training/agents/sac/train.py
+++ b/brax/training/agents/sac/train.py
@@ -108,7 +108,7 @@ def _init_training_state(
alpha_params=log_alpha,
normalizer_params=normalizer_params,
)
- return jax.device_put_replicated(
+ return pmap.device_put_replicated(
training_state, jax.local_devices()[:local_devices_to_use]
)
diff --git a/brax/training/pmap.py b/brax/training/pmap.py
index 82760fc..af62ef8 100644
--- a/brax/training/pmap.py
+++ b/brax/training/pmap.py
@@ -19,12 +19,23 @@ from typing import Any
import jax
import jax.numpy as jnp
+from jax.sharding import Mesh, NamedSharding, PartitionSpec as P
+import numpy as np
+
+
+def device_put_replicated(x, devices):
+ """Drop-in replacement for jax.device_put_replicated supporting pytrees."""
+ mesh = Mesh(np.array(devices), ('x',))
+ sharding = NamedSharding(mesh, P('x'))
+ return jax.tree.map(
+ lambda y: jax.device_put(jnp.stack([y] * len(devices)), sharding), x
+ )
def bcast_local_devices(value, local_devices_to_use=1):
"""Broadcasts an object to all local devices."""
devices = jax.local_devices()[:local_devices_to_use]
- return jax.device_put_replicated(value, devices)
+ return device_put_replicated(value, devices)
def synchronize_hosts():
|