summaryrefslogtreecommitdiff
path: root/pkgs/development/python-modules/tensorflow-probability/default.nix
blob: 92460ee2a3ae1686a4f7fc585c7b8cf5d69e1370 (plain)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
{
  lib,
  stdenv,
  fetchpatch2,

  # bazel wheel
  buildBazelPackage,
  fetchFromGitHub,

  # nativeBuildInputs
  python,
  setuptools,
  wheel,
  absl-py,

  #bazel_6,
  bazel,
  cctools,

  # python package
  buildPythonPackage,

  # dependencies
  cloudpickle,
  decorator,
  dm-tree,
  gast,
  keras,
  numpy,
  six,
  tensorflow,

  # tests
  hypothesis,
  matplotlib,
  mock,
  mpmath,
  pandas,
  pytest,
  scipy,
}:

let
  version = "0.25.0";
  pname = "tensorflow-probability";

  # first build all binaries and generate setup.py using bazel
  bazel-wheel = buildBazelPackage {
    name = "tensorflow_probability-${version}-py2.py3-none-any.whl";
    src = fetchFromGitHub {
      owner = "tensorflow";
      repo = "probability";
      tag = "v${version}";
      hash = "sha256-LXQfGFgnM7WYUQjJ2Y3jskdeJ/dEKz+Afg+UOQjv5kc=";
    };

    patches = [
      # AttributeError: jax.interpreters.xla.pytype_aval_mappings was deprecated in JAX v0.5.0 and
      # removed in JAX v0.7.0. jax.core.pytype_aval_mappings can be used as a replacement in most cases.
      # TODO: remove when updating to the next release
      (fetchpatch2 {
        name = "future-proof-reference-to-deprecated-pytype_aval_mappings";
        url = "https://github.com/tensorflow/probability/commit/135080b6b1ac5724fc1731b0a9ca6f2010b1aea5.patch";
        hash = "sha256-27yWIw5pI86KcUz0TsYwRFyLDoeiqmxgsRMBXaauzVw=";
      })
    ];

    nativeBuildInputs = [
      absl-py
      # needed to create the output wheel in installPhase
      python
      setuptools
      tensorflow
      wheel
    ];

    #bazel = bazel_6;
    bazel = bazel;

    bazelTargets = [ ":pip_pkg" ];
    bazelFlags = [ "--noenable_bzlmod" ];
    removeRulesCC = false;
    LIBTOOL = lib.optionalString stdenv.hostPlatform.isDarwin "${cctools}/bin/libtool";

    fetchAttrs = {
      sha256 = "sha256-7sPdIHWNFn13eaUanFgN988hFAwGnlU6cxmHOJUDpiQ=";
    };

    buildAttrs = {
      preBuild = ''
        patchShebangs .
      '';

      installPhase = ''
        # work around timestamp issues
        # https://github.com/NixOS/nixpkgs/issues/270#issuecomment-467583872
        export SOURCE_DATE_EPOCH=315532800

        # First build, then move. Otherwise pip_pkg would create the dir $out
        # and then put the wheel in that directory. However we want $out to
        # point directly to the wheel file.
        ./bazel-bin/pip_pkg . --release
        mv *.whl "$out"
      '';
    };
  };
in
buildPythonPackage {
  inherit version pname;
  format = "wheel";

  src = bazel-wheel;

  dependencies = [
    cloudpickle
    decorator
    dm-tree
    gast
    keras
    numpy
    six
    tensorflow
  ];

  # Listed here:
  # https://github.com/tensorflow/probability/blob/f3777158691787d3658b5e80883fe1a933d48989/testing/dependency_install_lib.sh#L83
  nativeCheckInputs = [
    hypothesis
    matplotlib
    mock
    mpmath
    pandas
    pytest
    scipy
  ];

  # Ideally, we run unit tests with pytest, but in checkPhase, only the Bazel-build wheel is available.
  # But it seems not guaranteed that running the tests with pytest will even work, see
  # https://github.com/tensorflow/probability/blob/c2a10877feb2c4c06a4dc58281e69c37a11315b9/CONTRIBUTING.md?plain=1#L69
  # Ideally, tests would be run using Bazel. For now, lets's do a...

  # sanity check
  pythonImportsCheck = [ "tensorflow_probability" ];

  meta = {
    description = "Library for probabilistic reasoning and statistical analysis";
    homepage = "https://www.tensorflow.org/probability/";
    changelog = "https://github.com/tensorflow/probability/releases/tag/v${version}";
    license = lib.licenses.asl20;
    maintainers = with lib.maintainers; [ GaetanLepage ];
  };
}