1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
|
{
lib,
buildPythonPackage,
fetchFromGitHub,
# build-system
setuptools,
setuptools-scm,
# dependencies
jax,
msgpack,
numpy,
optax,
orbax-checkpoint,
pyyaml,
rich,
tensorstore,
typing-extensions,
# optional-dependencies
matplotlib,
# tests
cloudpickle,
keras,
einops,
flaxlib,
pytestCheckHook,
pytest-xdist,
sphinx,
tensorflow,
treescope,
writeScript,
tomlq,
}:
buildPythonPackage rec {
pname = "flax";
version = "0.12.2";
pyproject = true;
src = fetchFromGitHub {
owner = "google";
repo = "flax";
tag = "v${version}";
hash = "sha256-Wdfc35/iah98C5WNYZWiAd2FJUJlyGLJ8xELpuYD3GU=";
};
build-system = [
setuptools
setuptools-scm
];
dependencies = [
flaxlib
jax
msgpack
numpy
optax
orbax-checkpoint
pyyaml
rich
tensorstore
treescope
typing-extensions
];
optional-dependencies = {
all = [ matplotlib ];
};
pythonImportsCheck = [ "flax" ];
nativeCheckInputs = [
cloudpickle
keras
einops
pytestCheckHook
pytest-xdist
sphinx
tensorflow
];
pytestFlags = [
# FutureWarning: In the future `np.object` will be defined as the corresponding NumPy scalar.
"-Wignore::FutureWarning"
];
disabledTestPaths = [
# Docs test, needs extra deps + we're not interested in it.
"docs/_ext/codediff_test.py"
# The tests in `examples` are not designed to be executed from a single test
# session and thus either have the modules that conflict with each other or
# wrong import paths, depending on how they're invoked. Many tests also have
# dependencies that are not packaged in `nixpkgs` (`clu`, `jgraph`,
# `tensorflow_datasets`, `vocabulary`) so the benefits of trying to run them
# would be limited anyway.
"examples/*"
];
disabledTests = [
# AssertionError: [Chex] Function 'add' is traced > 1 times!
"PadShardUnpadTest"
# AssertionError: nnx_model.kernel.value.sharding = NamedSharding(...
"test_linen_to_nnx_metadata"
# AssertionError: 'Linear_0' not found in State({})
"test_compact_basic"
# KeyError: 'intermediates'
"test_linen_submodule"
"test_pure_nnx_submodule"
# KeyError: 'counts
"test_mutable_state"
# AttributeError: 'Top' object has no attribute '_pytree__state'. Did you mean: '_pytree__flatten'?
"test_shared_modules"
# AttributeError: 'MLP' object has no attribute 'scope
"test_transforms"
];
passthru = {
updateScript = writeScript "update.sh" ''
nix-update flax # does not --build by default
nix-build . -A flax.src # src is essentially a passthru
nix-update flaxlib --version="$(${lib.getExe tomlq} <result/Cargo.toml .something.version)" --commit
'';
};
meta = {
description = "Neural network library for JAX";
homepage = "https://github.com/google/flax";
changelog = "https://github.com/google/flax/releases/tag/v${version}";
license = lib.licenses.asl20;
maintainers = with lib.maintainers; [ ndl ];
};
}
|