summaryrefslogtreecommitdiff
path: root/pkgs/development/python-modules/torch/tests/mk-runtime-check.nix
blob: 3ead20d3eeb08841a30b9c8dd1001d9700b0dff9 (plain)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
{
  cudaPackages,
  feature,
  libraries,
  versionAttr,
}:

cudaPackages.writeGpuTestPython
  {
    inherit feature;
    inherit libraries;
    name = "${feature}Available";
  }
  ''
    import torch
    message = f"{torch.cuda.is_available()=} and {torch.version.${versionAttr}=}"
    assert torch.cuda.is_available() and torch.version.${versionAttr}, message
    print(message)
  ''