1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
|
{
lib,
buildPythonPackage,
pythonAtLeast,
fetchFromGitHub,
fetchpatch,
# build-system
setuptools,
# dependencies
absl-py,
jaxlib,
jmp,
numpy,
tabulate,
# optional-dependencies
jax,
flax,
# tests
pytest-xdist,
pytestCheckHook,
bsuite,
chex,
cloudpickle,
dill,
dm-env,
dm-tree,
optax,
rlax,
tensorflow,
}:
let
dm-haiku = buildPythonPackage rec {
pname = "dm-haiku";
version = "0.0.16";
pyproject = true;
src = fetchFromGitHub {
owner = "deepmind";
repo = "dm-haiku";
tag = "v${version}";
hash = "sha256-XugzzHapnqXD8w17k6HaNeqWcxRe49r7OIb8v5LI2NM=";
};
patches = [
# https://github.com/deepmind/dm-haiku/pull/672
(fetchpatch {
name = "fix-find-namespace-packages.patch";
url = "https://github.com/deepmind/dm-haiku/commit/728031721f77d9aaa260bba0eddd9200d107ba5d.patch";
hash = "sha256-qV94TdJnphlnpbq+B0G3KTx5CFGPno+8FvHyu/aZeQE=";
})
];
build-system = [ setuptools ];
dependencies = [
absl-py
jaxlib # implicit runtime dependency
jmp
numpy
tabulate
];
optional-dependencies = {
jax = [
jax
jaxlib
];
flax = [ flax ];
};
pythonImportsCheck = [ "haiku" ];
nativeCheckInputs = [
bsuite
chex
cloudpickle
dill
dm-env
dm-haiku
dm-tree
flax
jaxlib
optax
pytest-xdist
pytestCheckHook
# rlax (broken dependency tensorflow-probability)
tensorflow
];
disabledTests = [
# See https://github.com/deepmind/dm-haiku/issues/366.
"test_jit_Recurrent"
# Assertion errors
"testShapeChecking0"
"testShapeChecking1"
# This test requires a more recent version of tensorflow. The current one (2.13) is not enough.
"test_reshape_convert"
# This test requires JAX support for double precision (64bit), but enabling this causes several
# other tests to fail.
# https://jax.readthedocs.io/en/latest/notebooks/Common_Gotchas_in_JAX.html#double-64bit-precision
"test_doctest_haiku.experimental"
# AssertionError: 1 != 0 : 1 doctests failed
"test_doctest_haiku"
# ValueError: pmap wrapped function must be passed at least one argument containing an array,
# got empty *args=() and **kwargs={}
"test_equivalent_when_passing_transformed_fn2"
# AssertionError: ValueError not raised
"test_passing_function_to_transform_pmap_transform"
"test_passing_function_to_transform_pmap_transform_with_state"
];
disabledTestPaths = [
# Require rlax which is unavailable as its dependency tensorflow-probability is broken
"examples/impala/actor_test.py"
"examples/impala/learner_test.py"
"examples/impala_lite_test.py"
];
doCheck = false;
# check in passthru.tests.pytest to escape infinite recursion with bsuite
passthru.tests.pytest = dm-haiku.overridePythonAttrs (_: {
pname = "${pname}-tests";
doCheck = true;
# We don't have to install because the only purpose
# of this passthru test is to, well, test.
# This fixes having to set `catchConflicts` to false.
dontInstall = true;
});
meta = {
description = "Haiku is a simple neural network library for JAX developed by some of the authors of Sonnet";
homepage = "https://github.com/deepmind/dm-haiku";
changelog = "https://github.com/google-deepmind/dm-haiku/releases/tag/${src.tag}";
license = lib.licenses.asl20;
maintainers = with lib.maintainers; [ ndl ];
};
};
in
dm-haiku
|