summaryrefslogtreecommitdiff
path: root/pkgs/development/python-modules/oryx/default.nix
blob: 07f0302104df9460546956b39a918cc2e023004b (plain)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
{
  lib,
  buildPythonPackage,
  fetchPypi,

  # build-system
  poetry-core,

  # dependencies
  jax,
  jaxlib,
  tensorflow-probability,

  # tests
  inference-gym,
  pytestCheckHook,
}:

buildPythonPackage rec {
  pname = "oryx";
  version = "0.2.9";
  pyproject = true;

  # No more tags on GitHub. See https://github.com/jax-ml/oryx/issues/95
  src = fetchPypi {
    inherit pname version;
    hash = "sha256-HlKUnguTNfs7gSqIJ0n2EjjLXPUgtI2JsQM70wKMeXs=";
  };

  build-system = [ poetry-core ];

  dependencies = [
    jax
    jaxlib
    tensorflow-probability
  ];

  pythonImportsCheck = [ "oryx" ];

  nativeCheckInputs = [
    inference-gym
    pytestCheckHook
  ];

  disabledTests = [
    # ValueError: Number of devices 1 must equal the product of mesh_shape (1, 2)
    "test_plant"
    "test_plant_before_shmap"
    "test_plant_inside_shmap_fails"
    "test_reap"
    "test_reap_before_shmap"
    "test_reap_inside_shmap_fails"

    # ValueError: Variable has already been reaped
    "test_call_list"
    "test_call_tuple"
    "test_dense_combinator"
    "test_dense_function"
    "test_dense_imperative"
    "test_function_in_combinator_in_function"
    "test_grad_of_function_with_literal"
    "test_grad_of_shared_layer"
    "test_grad_of_stateful_function"
    "test_kwargs_rng"
    "test_kwargs_training"
    "test_kwargs_training_rng"
    "test_reshape_call"
    "test_scale_by_adam_should_scale_by_adam"
    "test_scale_by_schedule_should_update_scale"
    "test_scale_by_stddev_should_scale_by_stddev"
    "test_trace_should_keep_track_of_momentum_with_nesterov"

    # NotImplementedError: No registered inverse for `split`
    "test_inverse_of_split"

    # jax.errors.UnexpectedTracerError: Encountered an unexpected tracer
    "test_can_plant_into_jvp_of_custom_jvp_function_unimplemented"
    "test_forward_Scale"

    # ValueError: No variable declared for assign: update_1
    "test_optimizer_adam"
    "test_optimizer_noisy_sgd"
    "test_optimizer_rmsprop"
    "test_optimizer_sgd"
    "test_optimizer_sgd_with_momentum"
    "test_optimizer_sgd_with_nesterov_momentum"

    # AssertionError
    # ACTUAL: array(-2.337877, dtype=float32)
    # DESIRED: array(0., dtype=float32)
    "test_can_map_over_batches_with_vmap_and_reduce_to_scalar_log_prob"
    "test_vmapping_distribution_reduces_to_scalar_log_prob"

    # TypeError: _dot_general_shape_rule() missing 1 required keyword-only argument: 'out_sharding'
    "test_can_rewrite_dot_to_einsu"

    # AttributeError: 'float' object has no attribute 'shape'
    "test_add_noise_should_add_noise"
    "test_apply_every_should_delay_updates"

    # TypeError: Error interpreting argument to functools.partial(...) as an abstract array
    "test_can_rewrite_nested_expression_into_single_einsum"
  ];

  disabledTestPaths = [
    # ValueError: Variable has already been reaped
    "oryx/experimental/nn/normalization_test.py"
    "oryx/experimental/nn/pooling_test.py"
  ];

  meta = {
    description = "Library for probabilistic programming and deep learning built on top of Jax";
    homepage = "https://github.com/jax-ml/oryx";
    changelog = "https://github.com/jax-ml/oryx/releases/tag/v${version}";
    license = lib.licenses.asl20;
    maintainers = with lib.maintainers; [ GaetanLepage ];
    # oryx seems to be incompatible with jax 0.5.1
    # 237 additional test failures are resulting from the jax bump.
    broken = true;
  };
}