blob: cd56cd518d952cb1dfb7805d17675577a2316336 (
plain)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
|
{
jax,
pkgs,
}:
pkgs.writers.writePython3Bin "jax-test-cuda"
{
libraries = [
jax
]
++ jax.optional-dependencies.cuda;
}
''
import jax
import jax.numpy as jnp
from jax import random
from jax.experimental import sparse
assert jax.devices()[0].platform == "gpu" # libcuda.so
rng = random.key(0) # libcudart.so, libcudnn.so
x = random.normal(rng, (100, 100))
x @ x # libcublas.so
jnp.fft.fft(x) # libcufft.so
jnp.linalg.inv(x) # libcusolver.so
sparse.CSR.fromdense(x) @ x # libcusparse.so
print("success!")
''
|