summaryrefslogtreecommitdiff
path: root/pkgs/development/python-modules/rangehttpserver
diff options
context:
space:
mode:
authorMartin Weinelt <hexa@darmstadt.ccc.de>2022-04-11 23:48:50 +0200
committerMartin Weinelt <hexa@darmstadt.ccc.de>2022-04-15 01:39:54 +0200
commit836e3af5447ee51c81566b6717f445213c50f47b (patch)
tree57ca8b18312f59109c33efac732d091c642ba7e2 /pkgs/development/python-modules/rangehttpserver
parent84cc0b7449edec98b4d94857e549a9430b257b3d (diff)
python3Packages.jax: disable test_custom_linear_solve_aux
``` ______________ CustomLinearSolveTest.test_custom_linear_solve_aux ______________ [gw3] linux -- Python 3.9.11 /nix/store/k1physzalj5vffsvl7ag6h6b6vaqip5x-python3-3.9.11/bin/python3.9 self = <custom_linear_solve_test.CustomLinearSolveTest testMethod=test_custom_linear_solve_aux> @jtu.skip_on_flag("jax_skip_slow_tests", True) def test_custom_linear_solve_aux(self): def explicit_jacobian_solve_aux(matvec, b): x = lax.stop_gradient(jnp.linalg.solve(jax.jacobian(matvec)(b), b)) return x, array_aux def matrix_free_solve_aux(matvec, b): return lax.custom_linear_solve( matvec, b, explicit_jacobian_solve_aux, explicit_jacobian_solve_aux, symmetric=True, has_aux=True) def linear_solve_aux(a, b): return matrix_free_solve_aux(partial(high_precision_dot, a), b) # array aux values, to be able to use jtu.check_grads array_aux = {"converged": np.array(1.), "nfev": np.array(12345.)} rng = self.rng() a = rng.randn(3, 3) a = a + a.T b = rng.randn(3) expected = jnp.linalg.solve(a, b) actual_nojit, nojit_aux = linear_solve_aux(a, b) actual_jit, jit_aux = jax.jit(linear_solve_aux)(a, b) self.assertAllClose(expected, actual_nojit) self.assertAllClose(expected, actual_jit) # scalar dict equality check self.assertDictEqual(nojit_aux, array_aux) self.assertDictEqual(jit_aux, array_aux) # jvp / vjp test > jtu.check_grads(linear_solve_aux, (a, b), order=2, rtol=4e-3) tests/custom_linear_solve_test.py:157: _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ jax/_src/test_util.py:372: in check_grads _check_grads(f, args, order) jax/_src/test_util.py:361: in _check_grads _check_grads(partial(api.jvp, f), (args, args), order - 1, fwd_msg) jax/_src/test_util.py:365: in _check_grads _check_vjp(f, partial(api.vjp, f), args, err_msg=rev_msg) jax/_src/test_util.py:325: in check_vjp check_close(ip, ip_expected, atol=atol, rtol=rtol, jax/_src/test_util.py:227: in check_close tree_all(tree_multimap(assert_close, xs, ys)) jax/_src/tree_util.py:180: in tree_map return treedef.unflatten(f(*xs) for xs in zip(*all_leaves)) jax/_src/tree_util.py:180: in <genexpr> return treedef.unflatten(f(*xs) for xs in zip(*all_leaves)) jax/_src/test_util.py:217: in _assert_numpy_close _assert_numpy_allclose(a, b, atol=atol * a.size, rtol=rtol * b.size, _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ a = array(1.89683694), b = array(1.88698006), atol = 0.002, rtol = 0.004 err_msg = 'VJP of JVP cotangent projection' def _assert_numpy_allclose(a, b, atol=None, rtol=None, err_msg=''): if a.dtype == b.dtype == _dtypes.float0: np.testing.assert_array_equal(a, b, err_msg=err_msg) return a = a.astype(np.float32) if a.dtype == _dtypes.bfloat16 else a b = b.astype(np.float32) if b.dtype == _dtypes.bfloat16 else b kw = {} if atol: kw["atol"] = atol if rtol: kw["rtol"] = rtol with np.errstate(invalid='ignore'): # TODO(phawkins): surprisingly, assert_allclose sometimes reports invalid # value errors. It should not do that. > np.testing.assert_allclose(a, b, **kw, err_msg=err_msg) E AssertionError: E Not equal to tolerance rtol=0.004, atol=0.002 E VJP of JVP cotangent projection E Mismatched elements: 1 / 1 (100%) E Max absolute difference: 0.00985688 E Max relative difference: 0.00522363 E x: array(1.896837) E y: array(1.88698) jax/_src/test_util.py:187: AssertionError ```
Diffstat (limited to 'pkgs/development/python-modules/rangehttpserver')
0 files changed, 0 insertions, 0 deletions