summaryrefslogtreecommitdiff
path: root/pkgs/development/python-modules/torch/tests/default.nix
blob: 90fb21d018d7d39d825e32e40a5a52aa8fabae5b (plain)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
{ callPackage }:

rec {
  # To perform the runtime check use either
  # `nix run .#python3Packages.torch.tests.tester-cudaAvailable` (outside the sandbox), or
  # `nix build .#python3Packages.torch.tests.tester-cudaAvailable.gpuCheck` (in a relaxed sandbox)
  tester-cudaAvailable = callPackage ./mk-runtime-check.nix {
    feature = "cuda";
    versionAttr = "cuda";
    libraries = ps: [ ps.torchWithCuda ];
  };
  tester-rocmAvailable = callPackage ./mk-runtime-check.nix {
    feature = "rocm";
    versionAttr = "hip";
    libraries = ps: [ ps.torchWithRocm ];
  };

  compileCpu = tester-compileCpu.gpuCheck;
  tester-compileCpu = callPackage ./mk-torch-compile-check.nix {
    feature = null;
    libraries = ps: [ ps.torch ];
  };
  tester-compileCuda = callPackage ./mk-torch-compile-check.nix {
    feature = "cuda";
    libraries = ps: [ ps.torchWithCuda ];
  };
  tester-compileRocm = callPackage ./mk-torch-compile-check.nix {
    feature = "rocm";
    libraries = ps: [ ps.torchWithRocm ];
  };

  mnist-example = callPackage ./mnist-example { };
}