summaryrefslogtreecommitdiff
path: root/pkgs/development/python-modules/jax/test-cuda.nix
blob: cd56cd518d952cb1dfb7805d17675577a2316336 (plain)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
{
  jax,
  pkgs,
}:

pkgs.writers.writePython3Bin "jax-test-cuda"
  {
    libraries = [
      jax
    ]
    ++ jax.optional-dependencies.cuda;
  }
  ''
    import jax
    import jax.numpy as jnp
    from jax import random
    from jax.experimental import sparse

    assert jax.devices()[0].platform == "gpu"  # libcuda.so

    rng = random.key(0)  # libcudart.so, libcudnn.so
    x = random.normal(rng, (100, 100))
    x @ x  # libcublas.so
    jnp.fft.fft(x)  # libcufft.so
    jnp.linalg.inv(x)  # libcusolver.so
    sparse.CSR.fromdense(x) @ x  # libcusparse.so

    print("success!")
  ''