summaryrefslogtreecommitdiff
path: root/pkgs/development/python-modules/warp-lang/default.nix
blob: d9550d87bd0fff29f11de72e595ea54fa5badbbb (plain)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
{
  autoAddDriverRunpath,
  buildPythonPackage,
  config,
  cudaPackages,
  callPackage,
  fetchFromGitHub,
  jax,
  lib,
  llvmPackages, # TODO: use llvm 21 in 1.10, see python-packages.nix
  numpy,
  pkgsBuildHost,
  python,
  replaceVars,
  runCommand,
  setuptools,
  stdenv,
  torch,
  warp-lang, # Self-reference to this package for passthru.tests
  writableTmpDirAsHomeHook,
  writeShellApplication,

  # Use standalone LLVM-based JIT compiler and CPU device support
  standaloneSupport ? true,

  # Use CUDA toolchain and GPU device support
  cudaSupport ? config.cudaSupport,

  # Build Warp with MathDx support (requires CUDA support)
  # Most linear-algebra tile operations like tile_cholesky(), tile_fft(),
  # and tile_matmul() require Warp to be built with the MathDx library.
  # libmathdxSupport ? cudaSupport && stdenv.hostPlatform.isLinux,
  libmathdxSupport ? cudaSupport,
}@args:

assert libmathdxSupport -> cudaSupport;

let
  effectiveStdenv = if cudaSupport then cudaPackages.backendStdenv else args.stdenv;
  stdenv = throw "Use effectiveStdenv instead of stdenv directly, as it may be replaced by cudaPackages.backendStdenv";

  version = "1.10.0";

  libmathdx = callPackage ./libmathdx.nix { };
in
buildPythonPackage.override { stdenv = effectiveStdenv; } {
  pname = "warp-lang";
  inherit version;
  pyproject = true;

  # TODO(@connorbaker): Some CUDA setup hook is failing when __structuredAttrs is false,
  # causing a bunch of missing math symbols (like expf) when linking against the static library
  # provided by NVCC.
  __structuredAttrs = true;

  src = fetchFromGitHub {
    owner = "NVIDIA";
    repo = "warp";
    tag = "v${version}";
    hash = "sha256-9OEyYdVq+/SzxHfNT+sa/YeBKklaUfpKUiJZuiuzxhQ=";
  };

  patches = lib.optionals standaloneSupport [
    (replaceVars ./dynamic-link.patch {
      LLVM_LIB = llvmPackages.llvm.lib;
      LIBCLANG_LIB = llvmPackages.libclang.lib;
    })
  ];

  postPatch = ''
    nixLog "patching $PWD/build_llvm.py to remove pre-C++11 ABI flag"
    substituteInPlace "$PWD/build_llvm.py" \
      --replace-fail \
        '"-D", f"CMAKE_CXX_FLAGS=-D_GLIBCXX_USE_CXX11_ABI=0 {abi_version}",  # The pre-C++11 ABI is still the default on the CentOS 7 toolchain' \
        ""

    substituteInPlace "$PWD/warp/_src/build_dll.py" \
      --replace-fail " -D_GLIBCXX_USE_CXX11_ABI=0" ""
  ''
  + lib.optionalString effectiveStdenv.hostPlatform.isDarwin (
    ''
      nixLog "patching $PWD/warp/_src/build_dll.py to remove macOS target flag and link against libc++"
      substituteInPlace "$PWD/warp/_src/build_dll.py" \
        --replace-fail "--target={arch}-apple-macos11" "" \
        --replace-fail 'ld_inputs = []' "ld_inputs = ['-L\"${llvmPackages.libcxx}/lib\" -lc++']"
    ''
    # AssertionError: 0.4082476496696472 != 0.40824246406555176 within 5 places
    + ''
      nixLog "patching $PWD/warp/tests/test_fem.py to disable broken tests on darwin"
      substituteInPlace "$PWD/warp/tests/test_codegen.py" \
        --replace-fail 'places=5' 'places=4'
    ''
  )
  + lib.optionalString effectiveStdenv.cc.isClang ''
    substituteInPlace "$PWD/warp/_src/build_dll.py" \
      --replace-fail "clang++" "${effectiveStdenv.cc}/bin/cc"
  ''
  + lib.optionalString standaloneSupport ''
    substituteInPlace "$PWD/warp/_src/build_dll.py" \
      --replace-fail \
        '-I"{warp_home_path.parent}/external/llvm-project/out/install/{mode}-{arch}/include"' \
        '-I"${llvmPackages.llvm.dev}/include"' \
      --replace-fail \
        '-I"{warp_home_path.parent}/_build/host-deps/llvm-project/release-{arch}/include"' \
        '-I"${llvmPackages.libclang.dev}/include"' \

  ''
  # Patch build_dll.py to use our gencode flags rather than NVIDIA's very broad defaults.
  + lib.optionalString cudaSupport (
    let
      gencodeOpts = lib.concatMapStringsSep ", " (
        gencodeString: ''"${gencodeString}"''
      ) cudaPackages.flags.gencode;

      clangArchFlags = lib.concatMapStringsSep ", " (
        realArch: ''"--cuda-gpu-arch=${realArch}"''
      ) cudaPackages.flags.realArches;
    in
    ''
      nixLog "patching $PWD/warp/_src/build_dll.py to use our gencode flags"
      substituteInPlace "$PWD/warp/_src/build_dll.py" \
        --replace-fail '*gencode_opts,' '${gencodeOpts},' \
        --replace-fail '*clang_arch_flags,' '${clangArchFlags},'
    ''
    # Patch build_dll.py to use dynamic libraries rather than static ones.
    # NOTE: We do not patch the `nvptxcompiler_static` path because it is not available as a dynamic library.
    + ''
      nixLog "patching $PWD/warp/_srsc/build_dll.py to use dynamic libraries"
      substituteInPlace "$PWD/warp/_src/build_dll.py" \
        --replace-fail '-lcudart_static' '-lcudart' \
        --replace-fail '-lnvrtc_static' '-lnvrtc' \
        --replace-fail '-lnvrtc-builtins_static' '-lnvrtc-builtins' \
        --replace-fail '-lnvJitLink_static' '-lnvJitLink' \
        --replace-fail '-lmathdx_static' '-lmathdx'
    ''
  )
  # These tests fail on CPU and CUDA.
  + ''
    nixLog "patching $PWD/warp/tests/test_reload.py to disable broken tests"
    substituteInPlace "$PWD/warp/tests/test_reload.py" \
      --replace-fail \
        'add_function_test(TestReload, "test_reload", test_reload, devices=devices)' \
        "" \
      --replace-fail \
        'add_function_test(TestReload, "test_reload_references", test_reload_references, devices=get_test_devices("basic"))' \
        ""
  '';

  build-system = [
    setuptools
  ];

  dependencies = [
    numpy
  ];

  nativeBuildInputs = lib.optionals cudaSupport [
    # NOTE: While normally we wouldn't include autoAddDriverRunpath for packages built from source, since Warp
    # will be loading GPU drivers at runtime, we need to inject the path to our video drivers.
    autoAddDriverRunpath
  ];

  buildInputs =
    lib.optionals standaloneSupport [
      llvmPackages.llvm
      llvmPackages.clang
      llvmPackages.libcxx
    ]
    ++ lib.optionals cudaSupport [
      (lib.getStatic cudaPackages.cuda_nvcc) # dependency on nvptxcompiler_static; no dynamic version available
      cudaPackages.cuda_cccl
      cudaPackages.cuda_cudart
      cudaPackages.cuda_nvcc
      cudaPackages.cuda_nvrtc
    ]
    ++ lib.optionals libmathdxSupport [
      libmathdx
      cudaPackages.libcublas
      cudaPackages.libcufft
      cudaPackages.libcusolver
      cudaPackages.libnvjitlink
    ];

  preBuild =
    let
      buildOptions =
        lib.optionals effectiveStdenv.cc.isClang [
          "--clang_build_toolchain"
        ]
        ++ lib.optionals (!standaloneSupport) [
          "--no_standalone"
        ]
        ++ lib.optionals cudaSupport [
          # NOTE: The `cuda_path` argument is the directory which contains `bin/nvcc` (i.e., the bin output).
          "--cuda_path=${lib.getBin pkgsBuildHost.cudaPackages.cuda_nvcc}"
        ]
        ++ lib.optionals libmathdxSupport [
          "--libmathdx"
          "--libmathdx_path=${libmathdx}"
        ]
        ++ lib.optionals (!libmathdxSupport) [
          "--no_libmathdx"
        ];

      buildOptionString = lib.concatStringsSep " " buildOptions;
    in
    ''
      nixLog "running $PWD/build_lib.py to create components necessary to build the wheel"
      "${python.pythonOnBuildForHost.interpreter}" "$PWD/build_lib.py" ${buildOptionString}
    '';

  pythonImportsCheck = [
    "warp"
  ];

  # See passthru.tests.
  doCheck = false;

  passthru = {
    # Make libmathdx available for introspection.
    inherit libmathdx;

    # Scripts which provide test packages and implement test logic.
    testers.unit-tests =
      let
        # Use the references from args
        python' = python.withPackages (_: [
          warp-lang
          jax
          torch
        ]);
        # Disable paddlepaddle interop tests: malloc(): unaligned tcache chunk detected
        #  (paddlepaddle.override { inherit cudaSupport; })
      in
      writeShellApplication {
        name = "warp-lang-unit-tests";
        runtimeInputs = [ python' ];
        text = ''
          ${python'}/bin/python3 -m warp.tests
        '';
      };

    # Tests run within the Nix sandbox.
    tests =
      let
        mkUnitTests =
          {
            cudaSupport,
            libmathdxSupport,
          }:
          let
            name =
              "warp-lang-unit-tests-cpu" # CPU is baseline
              + lib.optionalString cudaSupport "-cuda"
              + lib.optionalString libmathdxSupport "-libmathdx";

            warp-lang' = warp-lang.override {
              inherit cudaSupport libmathdxSupport;
              # Make sure the warp-lang provided through callPackage is replaced with the override we're making.
              warp-lang = warp-lang';
            };
          in
          runCommand name
            {
              nativeBuildInputs = [
                warp-lang'.passthru.testers.unit-tests
                writableTmpDirAsHomeHook
              ];
              requiredSystemFeatures = lib.optionals cudaSupport [ "cuda" ];
            }
            ''
              nixLog "running ${name}"

              if warp-lang-unit-tests; then
                nixLog "${name} passed"
                touch "$out"
              else
                nixErrorLog "${name} failed"
                exit 1
              fi
            '';
      in
      {
        cpu = mkUnitTests {
          cudaSupport = false;
          libmathdxSupport = false;
        };
        cuda = {
          cudaOnly = mkUnitTests {
            cudaSupport = true;
            libmathdxSupport = false;
          };
          cudaWithLibmathDx = mkUnitTests {
            cudaSupport = true;
            libmathdxSupport = true;
          };
        };
      };
  };

  meta = {
    description = "Python framework for high performance GPU simulation and graphics";
    longDescription = ''
      Warp is a Python framework for writing high-performance simulation
      and graphics code. Warp takes regular Python functions and JIT
      compiles them to efficient kernel code that can run on the CPU or
      GPU.

      Warp is designed for spatial computing and comes with a rich set
      of primitives that make it easy to write programs for physics
      simulation, perception, robotics, and geometry processing. In
      addition, Warp kernels are differentiable and can be used as part
      of machine-learning pipelines with frameworks such as PyTorch,
      JAX and Paddle.
    '';
    homepage = "https://github.com/NVIDIA/warp";
    changelog = "https://github.com/NVIDIA/warp/blob/v${version}/CHANGELOG.md";
    license = lib.licenses.asl20;
    platforms = lib.platforms.linux ++ [ "aarch64-darwin" ];
    maintainers = with lib.maintainers; [ yzx9 ];
  };
}