summaryrefslogtreecommitdiff
path: root/pkgs/development/python-modules/torch/tests/mnist-example/default.nix
blob: bb8b90aae5a95e7969e70433c5e90c0760c4a7d1 (plain)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
{
  lib,
  linkFarm,
  fetchurl,
  writers,
  torch,
  torchvision,
  runCommand,
}:
let
  fashionMnistDataset = linkFarm "fashion-mnist-dataset" (
    lib.mapAttrsToList
      (name: hash: {
        inherit name;
        path = fetchurl {
          url = "http://fashion-mnist.s3-website.eu-central-1.amazonaws.com/${name}";
          inherit hash;
        };
      })
      {
        "train-images-idx3-ubyte.gz" = "sha256-Ou3jjWGGOQiteGE/ajLtJxYm3RKAC6JjZWlRI2kmioQ=";
        "train-labels-idx1-ubyte.gz" = "sha256-oE8XE0rANWCkfjdk4RuS/JfeTRv6+LoaOqKa9UzJCEU=";
        "t10k-images-idx3-ubyte.gz" = "sha256-NG5VuUjZc6l+WNI1Hd4WpIS9QV1FlSl2M7sI8D22oHM=";
        "t10k-labels-idx1-ubyte.gz" = "sha256-Z9oXx26v/KVEbDNhqqtcPNbRwmCHZNNd+xhQsIa/jdU=";
      }
  );

  mnist-script = writers.writePython3 "test_mnist" {
    libraries = [
      torch
      torchvision
    ];
    flakeIgnore = [ "E501" ];
  } (builtins.readFile ./script.py);
in
runCommand "mnist" { } ''
  mkdir -p data/FashionMNIST/raw

  for archive in `ls ${fashionMnistDataset}`; do
    gzip -d < "${fashionMnistDataset}/$archive" > data/FashionMNIST/raw/"''${archive%.*}"
  done

  ${mnist-script}

  touch $out
''