summaryrefslogtreecommitdiff
path: root/pkgs/development/python-modules/dm-haiku/default.nix
blob: 831516ec78cc3ba48fed1d5c0c6f0f1cb28075c5 (plain)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
{
  lib,
  buildPythonPackage,
  pythonAtLeast,
  fetchFromGitHub,
  fetchpatch,

  # build-system
  setuptools,

  # dependencies
  absl-py,
  jaxlib,
  jmp,
  numpy,
  tabulate,

  # optional-dependencies
  jax,
  flax,

  # tests
  pytest-xdist,
  pytestCheckHook,
  bsuite,
  chex,
  cloudpickle,
  dill,
  dm-env,
  dm-tree,
  optax,
  rlax,
  tensorflow,
}:

let
  dm-haiku = buildPythonPackage rec {
    pname = "dm-haiku";
    version = "0.0.16";
    pyproject = true;

    src = fetchFromGitHub {
      owner = "deepmind";
      repo = "dm-haiku";
      tag = "v${version}";
      hash = "sha256-XugzzHapnqXD8w17k6HaNeqWcxRe49r7OIb8v5LI2NM=";
    };

    patches = [
      # https://github.com/deepmind/dm-haiku/pull/672
      (fetchpatch {
        name = "fix-find-namespace-packages.patch";
        url = "https://github.com/deepmind/dm-haiku/commit/728031721f77d9aaa260bba0eddd9200d107ba5d.patch";
        hash = "sha256-qV94TdJnphlnpbq+B0G3KTx5CFGPno+8FvHyu/aZeQE=";
      })
    ];

    build-system = [ setuptools ];

    dependencies = [
      absl-py
      jaxlib # implicit runtime dependency
      jmp
      numpy
      tabulate
    ];

    optional-dependencies = {
      jax = [
        jax
        jaxlib
      ];
      flax = [ flax ];
    };

    pythonImportsCheck = [ "haiku" ];

    nativeCheckInputs = [
      bsuite
      chex
      cloudpickle
      dill
      dm-env
      dm-haiku
      dm-tree
      flax
      jaxlib
      optax
      pytest-xdist
      pytestCheckHook
      # rlax (broken dependency tensorflow-probability)
      tensorflow
    ];

    disabledTests = [
      # See https://github.com/deepmind/dm-haiku/issues/366.
      "test_jit_Recurrent"

      # Assertion errors
      "testShapeChecking0"
      "testShapeChecking1"

      # This test requires a more recent version of tensorflow. The current one (2.13) is not enough.
      "test_reshape_convert"

      # This test requires JAX support for double precision (64bit), but enabling this causes several
      # other tests to fail.
      # https://jax.readthedocs.io/en/latest/notebooks/Common_Gotchas_in_JAX.html#double-64bit-precision
      "test_doctest_haiku.experimental"

      # AssertionError: 1 != 0 : 1 doctests failed
      "test_doctest_haiku"

      # ValueError: pmap wrapped function must be passed at least one argument containing an array,
      # got empty *args=() and **kwargs={}
      "test_equivalent_when_passing_transformed_fn2"

      # AssertionError: ValueError not raised
      "test_passing_function_to_transform_pmap_transform"
      "test_passing_function_to_transform_pmap_transform_with_state"
    ];

    disabledTestPaths = [
      # Require rlax which is unavailable as its dependency tensorflow-probability is broken
      "examples/impala/actor_test.py"
      "examples/impala/learner_test.py"
      "examples/impala_lite_test.py"
    ];

    doCheck = false;

    # check in passthru.tests.pytest to escape infinite recursion with bsuite
    passthru.tests.pytest = dm-haiku.overridePythonAttrs (_: {
      pname = "${pname}-tests";
      doCheck = true;

      # We don't have to install because the only purpose
      # of this passthru test is to, well, test.
      # This fixes having to set `catchConflicts` to false.
      dontInstall = true;
    });

    meta = {
      description = "Haiku is a simple neural network library for JAX developed by some of the authors of Sonnet";
      homepage = "https://github.com/deepmind/dm-haiku";
      changelog = "https://github.com/google-deepmind/dm-haiku/releases/tag/${src.tag}";
      license = lib.licenses.asl20;
      maintainers = with lib.maintainers; [ ndl ];
    };
  };
in
dm-haiku