summaryrefslogtreecommitdiff
path: root/pkgs/development/python-modules/torchtune/default.nix
blob: 1fac454501f06486e2a6ab622d3f428fe319eed5 (plain)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
{
  lib,
  stdenv,
  buildPythonPackage,
  fetchFromGitHub,

  # build-system
  setuptools,

  # dependencies
  blobfile,
  datasets,
  huggingface-hub,
  kagglehub,
  numpy,
  omegaconf,
  pillow,
  psutil,
  safetensors,
  sentencepiece,
  tiktoken,
  tokenizers,
  torch,
  torchdata,
  tqdm,
  torchao,
  torchvision,

  # tests
  comet-ml,
  mlflow,
  pytest-integration,
  pytest-mock,
  pytestCheckHook,
  writableTmpDirAsHomeHook,
}:

buildPythonPackage rec {
  pname = "torchtune";
  version = "0.6.1";
  pyproject = true;

  src = fetchFromGitHub {
    owner = "meta-pytorch";
    repo = "torchtune";
    tag = "v${version}";
    hash = "sha256-evhQBpZiUXriL0PAYkEzGypH21iRs37Ix6Nl5YAyeQ0=";
  };

  build-system = [
    setuptools
  ];

  dependencies = [
    blobfile
    datasets
    huggingface-hub
    kagglehub
    numpy
    omegaconf
    pillow
    psutil
    safetensors
    sentencepiece
    tiktoken
    tokenizers
    torch
    torchdata
    tqdm

    # Not explicitly listed as requirements, but effectively imported at runtime
    torchao
    torchvision
  ]
  ++ huggingface-hub.optional-dependencies.hf_transfer;

  pythonImportsCheck = [ "torchtune" ];

  nativeCheckInputs = [
    comet-ml
    mlflow
    pytest-integration
    pytest-mock
    pytestCheckHook
    writableTmpDirAsHomeHook
  ];

  # Exclude `regression` which depends on a specific llama model and `recipies` which are sample code
  enabledTestPaths = [ "tests/torchtune" ];

  disabledTests = [
    # AssertionError (tensors are not equal)
    "test_stop_tokens"
    "test_stop_tokens_batched"
    "test_stop_tokens_batched_uneven_stopping"
    "test_stop_tokens_batched_uneven_stopping_left_padded"

    # RuntimeError: not allowed to set torch.backends.cudnn flags after disable_global_flags;
    # please use flags() context manager instead
    "test_deterministic_false"
    "test_deterministic_true"

    # TypeError: exceptions must be derived from Warning, not <class 'NoneType'>
    "test_deprecated"

    # Flaky
    # AssertionError: actual: -83.3048095703125, expected: -83.15229797363281
    "test_forward"
    "test_forward_kv_cache"
    "test_forward_with_2d_pos_ids"
    "test_forward_with_curr_pos"
    "test_forward_with_packed_pos"
  ]
  ++ lib.optionals (stdenv.hostPlatform.isLinux && stdenv.hostPlatform.isAarch64) [
    # Fatal Python error: Segmentation fault
    "test_forward_gqa"
  ]
  ++ lib.optionals stdenv.hostPlatform.isDarwin [
    # tests/torchtune/training/test_distributed.py
    "test_init_from_env_no_dup"
    "test_init_from_env_dup"
  ];

  disabledTestPaths = lib.optionals stdenv.hostPlatform.isDarwin [
    # fail due to floating-point precision differences
    "tests/torchtune/models/flux/test_flux_autoencoder.py::TestFluxAutoencoder::test_encode"
    "tests/torchtune/modules/peft/test_dora.py::TestDoRALinear::test_qdora_parity[True-dtype1]"
    "tests/torchtune/modules/peft/test_lora.py::TestLoRALinear::test_qlora_parity[True-dtype1]"

    # hangs
    "tests/torchtune/utils"
  ];

  meta = {
    description = "PyTorch native post-training library";
    homepage = "https://github.com/meta-pytorch/torchtune";
    changelog = "https://github.com/meta-pytorch/torchtune/releases/tag/${src.tag}";
    license = lib.licenses.bsd3;
    maintainers = with lib.maintainers; [
      GaetanLepage
      sarahec
    ];
  };
}