summaryrefslogtreecommitdiff
path: root/pkgs/development/python-modules/flax/default.nix
blob: 9e9a91f0fd7070c70e126fa9a98c6439fdbeb055 (plain)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
{
  lib,
  buildPythonPackage,
  fetchFromGitHub,

  # build-system
  setuptools,
  setuptools-scm,

  # dependencies
  jax,
  msgpack,
  numpy,
  optax,
  orbax-checkpoint,
  pyyaml,
  rich,
  tensorstore,
  typing-extensions,

  # optional-dependencies
  matplotlib,

  # tests
  cloudpickle,
  keras,
  einops,
  flaxlib,
  pytestCheckHook,
  pytest-xdist,
  sphinx,
  tensorflow,
  treescope,

  writeScript,
  tomlq,
}:

buildPythonPackage rec {
  pname = "flax";
  version = "0.12.2";
  pyproject = true;

  src = fetchFromGitHub {
    owner = "google";
    repo = "flax";
    tag = "v${version}";
    hash = "sha256-Wdfc35/iah98C5WNYZWiAd2FJUJlyGLJ8xELpuYD3GU=";
  };

  build-system = [
    setuptools
    setuptools-scm
  ];

  dependencies = [
    flaxlib
    jax
    msgpack
    numpy
    optax
    orbax-checkpoint
    pyyaml
    rich
    tensorstore
    treescope
    typing-extensions
  ];

  optional-dependencies = {
    all = [ matplotlib ];
  };

  pythonImportsCheck = [ "flax" ];

  nativeCheckInputs = [
    cloudpickle
    keras
    einops
    pytestCheckHook
    pytest-xdist
    sphinx
    tensorflow
  ];

  pytestFlags = [
    # FutureWarning: In the future `np.object` will be defined as the corresponding NumPy scalar.
    "-Wignore::FutureWarning"
  ];

  disabledTestPaths = [
    # Docs test, needs extra deps + we're not interested in it.
    "docs/_ext/codediff_test.py"

    # The tests in `examples` are not designed to be executed from a single test
    # session and thus either have the modules that conflict with each other or
    # wrong import paths, depending on how they're invoked. Many tests also have
    # dependencies that are not packaged in `nixpkgs` (`clu`, `jgraph`,
    # `tensorflow_datasets`, `vocabulary`) so the benefits of trying to run them
    # would be limited anyway.
    "examples/*"
  ];

  disabledTests = [
    # AssertionError: [Chex] Function 'add' is traced > 1 times!
    "PadShardUnpadTest"

    # AssertionError: nnx_model.kernel.value.sharding = NamedSharding(...
    "test_linen_to_nnx_metadata"

    # AssertionError: 'Linear_0' not found in State({})
    "test_compact_basic"
    # KeyError: 'intermediates'
    "test_linen_submodule"
    "test_pure_nnx_submodule"
    # KeyError: 'counts
    "test_mutable_state"
    # AttributeError: 'Top' object has no attribute '_pytree__state'. Did you mean: '_pytree__flatten'?
    "test_shared_modules"
    # AttributeError: 'MLP' object has no attribute 'scope
    "test_transforms"
  ];

  passthru = {
    updateScript = writeScript "update.sh" ''
      nix-update flax # does not --build by default
      nix-build . -A flax.src # src is essentially a passthru
      nix-update flaxlib --version="$(${lib.getExe tomlq} <result/Cargo.toml .something.version)" --commit
    '';
  };

  meta = {
    description = "Neural network library for JAX";
    homepage = "https://github.com/google/flax";
    changelog = "https://github.com/google/flax/releases/tag/v${version}";
    license = lib.licenses.asl20;
    maintainers = with lib.maintainers; [ ndl ];
  };
}