summaryrefslogtreecommitdiff
path: root/pkgs/development/python-modules/bsuite/default.nix
blob: 2d7f72d9e1b4024eae8e8b0d921e5e72c0526ab1 (plain)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
{
  lib,
  absl-py,
  buildPythonPackage,
  distrax,
  dm-env,
  dm-haiku,
  dm-sonnet,
  dm-tree,
  fetchpatch,
  fetchPypi,
  frozendict,
  gym,
  matplotlib,
  mizani,
  optax,
  pandas,
  patsy,
  plotnine,
  pytestCheckHook,
  rlax,
  scikit-image,
  scipy,
  setuptools,
  statsmodels,
  tensorflow-probability,
  termcolor,
}:

let
  bsuite = buildPythonPackage rec {
    pname = "bsuite";
    version = "0.3.5";
    pyproject = true;

    src = fetchPypi {
      inherit pname version;
      hash = "sha256-ak9McvXl7Nz5toUaPaRaJek9lurxiQiIW209GnZEjX0=";
    };

    patches = [
      # Convert np.int -> np.int32 since np.int is deprecated, https://github.com/google-deepmind/bsuite/pull/48
      (fetchpatch {
        url = "https://github.com/google-deepmind/bsuite/pull/48/commits/f8d81b2f1c27ef2c8c71ae286001ed879ea306ab.patch";
        hash = "sha256-FXtvVS+U8brulq8Z27+yWIimB+kigGiUOIv1SHb1TA8=";
      })
      # Replace imp with importlib, https://github.com/google-deepmind/bsuite/pull/50
      (fetchpatch {
        name = "replace-imp.patch";
        url = "https://github.com/google-deepmind/bsuite/commit/d08b63655c7efa5b5bb0f35e825e17549d23e812.patch";
        hash = "sha256-V5p/6edNXTpEckuSuxJ/mvfJng5yE/pfeMoYbvlNpEo=";
      })
    ];

    build-system = [ setuptools ];

    dependencies = [
      absl-py
      dm-env
      dm-tree
      frozendict
      gym
      matplotlib
      mizani
      pandas
      patsy
      plotnine
      scikit-image
      scipy
      statsmodels
      termcolor
    ];

    nativeCheckInputs = [
      distrax
      dm-haiku
      dm-sonnet
      optax
      pytestCheckHook
      rlax
      tensorflow-probability
    ];

    pythonImportsCheck = [ "bsuite" ];

    disabledTests = [
      # Tests require network connection
      "test_run9"
      "test_longer_action_sequence"
      "test_reset"
      "test_step_after_reset"
      "test_step_on_fresh_environment"
      "test_longer_action_sequence"
      "test_reset"
      "test_step_after_reset"
      "test_step_on_fresh_environment"
      "test_longer_action_sequence"
      "test_reset"
      "test_step_after_reset"
      "test_step_on_fresh_environment"
      "test_logger"
      "test_episode_truncation"
    ];

    # Escape infinite recursion with rlax
    doCheck = false;

    passthru.tests = {
      check = bsuite.overridePythonAttrs (_: {
        doCheck = true;
      });
    };

    meta = {
      description = "Collection of experiments that investigate core capabilities of a reinforcement learning (RL) agent";
      homepage = "https://github.com/deepmind/bsuite";
      changelog = "https://github.com/google-deepmind/bsuite/releases/tag/${version}";
      license = lib.licenses.asl20;
      maintainers = with lib.maintainers; [ onny ];
    };
  };
in
bsuite