summaryrefslogtreecommitdiff
path: root/pkgs/development/python-modules/brax/dont-use-device_put_replicated-compat.patch
blob: 261aa5907a605a063811b277c5590db436aac95d (plain)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
diff --git a/brax/training/agents/apg/train.py b/brax/training/agents/apg/train.py
index f5fcb0e..87b198f 100644
--- a/brax/training/agents/apg/train.py
+++ b/brax/training/agents/apg/train.py
@@ -310,7 +310,7 @@ def train(
           specs.Array((env.observation_size,), jnp.dtype(dtype))
       ),
   )
-  training_state = jax.device_put_replicated(
+  training_state = pmap.device_put_replicated(
       training_state, jax.local_devices()[:local_devices_to_use]
   )
 
diff --git a/brax/training/agents/ppo/train.py b/brax/training/agents/ppo/train.py
index 9aec960..6624733 100644
--- a/brax/training/agents/ppo/train.py
+++ b/brax/training/agents/ppo/train.py
@@ -753,7 +753,7 @@ def train(
         {},
     )
 
-  training_state = jax.device_put_replicated(
+  training_state = pmap.device_put_replicated(
       training_state, jax.local_devices()[:local_devices_to_use]
   )
 
diff --git a/brax/training/agents/sac/train.py b/brax/training/agents/sac/train.py
index be716e9..8dcf3bf 100644
--- a/brax/training/agents/sac/train.py
+++ b/brax/training/agents/sac/train.py
@@ -108,7 +108,7 @@ def _init_training_state(
       alpha_params=log_alpha,
       normalizer_params=normalizer_params,
   )
-  return jax.device_put_replicated(
+  return pmap.device_put_replicated(
       training_state, jax.local_devices()[:local_devices_to_use]
   )
 
diff --git a/brax/training/pmap.py b/brax/training/pmap.py
index 82760fc..af62ef8 100644
--- a/brax/training/pmap.py
+++ b/brax/training/pmap.py
@@ -19,12 +19,23 @@ from typing import Any
 
 import jax
 import jax.numpy as jnp
+from jax.sharding import Mesh, NamedSharding, PartitionSpec as P
+import numpy as np
+
+
+def device_put_replicated(x, devices):
+  """Drop-in replacement for jax.device_put_replicated supporting pytrees."""
+  mesh = Mesh(np.array(devices), ('x',))
+  sharding = NamedSharding(mesh, P('x'))
+  return jax.tree.map(
+      lambda y: jax.device_put(jnp.stack([y] * len(devices)), sharding), x
+  )
 
 
 def bcast_local_devices(value, local_devices_to_use=1):
   """Broadcasts an object to all local devices."""
   devices = jax.local_devices()[:local_devices_to_use]
-  return jax.device_put_replicated(value, devices)
+  return device_put_replicated(value, devices)
 
 
 def synchronize_hosts():