summaryrefslogtreecommitdiff
path: root/pkgs/development/libraries/science/math/libtorch/bin.nix
blob: cc6d362619beb25a135f1bb51306a1eda3e7fbe0 (plain)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
{
  callPackage,
  stdenv,
  fetchzip,
  lib,
  llvmPackages,
  config,

  autoAddDriverRunpath,
  autoPatchelfHook,
  patchelf,
  fixDarwinDylibNames,

  cudaSupport ? config.cudaSupport,
  cudaPackages_13,
  libz,
}:

let
  # The binary libtorch distribution statically links the CUDA
  # toolkit. This means that we do not need to provide CUDA to
  # this derivation. However, we should ensure on version bumps
  # that the CUDA toolkit for `passthru.tests` is still
  # up-to-date.
  version = "2.9.0";
  device = if cudaSupport then "cuda" else "cpu";
  srcs = import ./binary-hashes.nix version;
  unavailable = throw "libtorch is not available for this platform";
in
stdenv.mkDerivation {
  inherit version;
  pname = "libtorch";

  src = fetchzip srcs."${stdenv.hostPlatform.system}-${device}" or unavailable;

  nativeBuildInputs =
    if stdenv.hostPlatform.isDarwin then
      [ fixDarwinDylibNames ]
    else
      [
        patchelf
        autoPatchelfHook
      ]
      ++ lib.optionals cudaSupport [ autoAddDriverRunpath ];

  dontBuild = true;
  dontConfigure = true;
  dontStrip = true;

  installPhase = ''
    # Copy headers and CMake files.
    mkdir -p $dev
    cp -r include $dev
    cp -r share $dev

    install -Dm755 -t $out/lib lib/*${stdenv.hostPlatform.extensions.sharedLibrary}*

    # We do not care about Java support...
    rm -f $out/lib/lib*jni* 2> /dev/null || true

    # Fix up library paths for split outputs
    substituteInPlace $dev/share/cmake/Torch/TorchConfig.cmake \
      --replace \''${TORCH_INSTALL_PREFIX}/lib "$out/lib" \

    substituteInPlace \
      $dev/share/cmake/Caffe2/Caffe2Targets-release.cmake \
      --replace \''${_IMPORT_PREFIX}/lib "$out/lib"
  '';

  postFixup =
    let
      rpath = lib.makeLibraryPath (
        [ stdenv.cc.cc ]
        ++ lib.optionals cudaSupport [
          cudaPackages_13.cuda_cudart # libcuda.so
          cudaPackages_13.libcufft
          cudaPackages_13.libcurand
          cudaPackages_13.libcusolver
          cudaPackages_13.libcusparse
          libz
        ]
      );
    in
    lib.optionalString stdenv.hostPlatform.isLinux ''
      find $out/lib -type f \( -name '*.so' -or -name '*.so.*' \) | while read lib; do
        echo "setting rpath for $lib..."
        patchelf --set-rpath "${rpath}:$out/lib" "$lib"
        ${lib.optionalString cudaSupport ''
          addDriverRunpath "$lib"
        ''}
      done
    ''
    + lib.optionalString stdenv.hostPlatform.isDarwin ''
      for f in $out/lib/*.dylib; do
          otool -L $f
      done
      for f in $out/lib/*.dylib; do
        if otool -L $f | grep "@rpath/libomp.dylib" >& /dev/null; then
          install_name_tool -change "@rpath/libomp.dylib" ${llvmPackages.openmp}/lib/libomp.dylib $f
        fi
        install_name_tool -id $out/lib/$(basename $f) $f || true
        for rpath in $(otool -L $f | grep rpath | awk '{print $1}');do
          install_name_tool -change $rpath $out/lib/$(basename $rpath) $f
        done
      done
      for f in $out/lib/*.dylib; do
          otool -L $f
      done
    '';

  outputs = [
    "out"
    "dev"
  ];

  passthru.tests.cmake = callPackage ./test {
    inherit cudaSupport;
  };

  meta = {
    description = "C++ API of the PyTorch machine learning framework";
    homepage = "https://pytorch.org/";
    sourceProvenance = with lib.sourceTypes; [ binaryNativeCode ];
    # Includes CUDA and Intel MKL, but redistributions of the binary are not limited.
    # https://docs.nvidia.com/cuda/eula/index.html
    # https://www.intel.com/content/www/us/en/developer/articles/license/onemkl-license-faq.html
    license = lib.licenses.bsd3;
    maintainers = with lib.maintainers; [ junjihashimoto ];
    platforms = [
      "aarch64-darwin"
      "x86_64-linux"
    ];
  };
}