summaryrefslogtreecommitdiff
path: root/pkgs/development/python-modules/xformers/default.nix
blob: 8054520520d97da234fbee28f08ce00f314c15fc (plain)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
{
  lib,
  stdenv,
  buildPythonPackage,
  pythonOlder,
  fetchFromGitHub,
  which,
  setuptools,
  # runtime dependencies
  numpy,
  torch,
  # check dependencies
  pytestCheckHook,
  pytest-cov-stub,
  # , pytest-mpi
  pytest-timeout,
  # , pytorch-image-models
  hydra-core,
  fairscale,
  scipy,
  cmake,
  ninja,
  triton,
  networkx,
  #, apex
  einops,
  transformers,
  timm,
  #, flash-attn
  openmp,
}:
let
  inherit (torch) cudaCapabilities cudaPackages cudaSupport;

  # version 0.0.32.post2 was confirmed to break CUDA.
  # Remove this note once the latest published revision "just works".
  version = "0.0.30";
  effectiveStdenv = if cudaSupport then cudaPackages.backendStdenv else stdenv;
in
buildPythonPackage.override { stdenv = effectiveStdenv; } {
  pname = "xformers";
  inherit version;
  pyproject = true;

  src = fetchFromGitHub {
    owner = "facebookresearch";
    repo = "xformers";
    tag = "v${version}";
    fetchSubmodules = true;
    hash = "sha256-ozaw9z8qnGpZ28LQNtwmKeVnrn7KDWNeJKtT6g6Q/W0=";
  };

  patches = [ ./0001-fix-allow-building-without-git.patch ];

  build-system = [ setuptools ];

  preBuild = ''
    cat << EOF > ./xformers/version.py
    # noqa: C801
    __version__ = "${version}"
    EOF

    export MAX_JOBS=$NIX_BUILD_CORES
  '';

  env = lib.attrsets.optionalAttrs cudaSupport {
    TORCH_CUDA_ARCH_LIST = "${lib.concatStringsSep ";" torch.cudaCapabilities}";
  };

  buildInputs =
    lib.optional stdenv.hostPlatform.isDarwin openmp
    ++ lib.optionals cudaSupport (
      with cudaPackages;
      [
        # flash-attn build
        cuda_cudart # cuda_runtime_api.h
        libcusparse # cusparse.h
        cuda_cccl # nv/target
        libcublas # cublas_v2.h
        libcusolver # cusolverDn.h
        libcurand # curand_kernel.h
      ]
    );

  nativeBuildInputs = [
    ninja
    which
  ]
  ++ lib.optionals cudaSupport (with cudaPackages; [ cuda_nvcc ])
  ++ lib.optional stdenv.hostPlatform.isDarwin openmp.dev;

  dependencies = [
    numpy
    torch
  ];

  pythonImportsCheck = [ "xformers" ];

  # Has broken 0.03 version:
  # https://github.com/NixOS/nixpkgs/pull/285495#issuecomment-1920730720
  passthru.skipBulkUpdate = true;

  dontUseCmakeConfigure = true;

  # see commented out missing packages
  doCheck = false;

  nativeCheckInputs = [
    pytestCheckHook
    pytest-cov-stub
    pytest-timeout
    hydra-core
    fairscale
    scipy
    cmake
    networkx
    triton
    # apex
    einops
    transformers
    timm
    # flash-attn
  ];

  meta = {
    description = "Collection of composable Transformer building blocks";
    homepage = "https://github.com/facebookresearch/xformers";
    changelog = "https://github.com/facebookresearch/xformers/blob/${version}/CHANGELOG.md";
    license = lib.licenses.bsd3;
    maintainers = with lib.maintainers; [ happysalada ];
  };
}