1
 2
 3
 4
 5
 6
 7
 8
 9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
use std::io;
use std::path::{Path};
use std::io::{Read};
use std::fs::{File};
use model::Matrix;
use flate2::read::GzDecoder;

#[derive(Debug, Clone)]
pub enum DatasetError {
    FileNotFound,
    InvalidGzip,
    FileReadError,
    BadMagicNumber,
    BadRead
}

fn read_u32_msb(reader: &mut Read) -> io::Result<u32> {
    let mut r: [u8; 4] = [0; 4];
    try!(reader.read_exact(&mut r));

    Ok(r.iter().fold(0 as u32, |acc, y| { acc * (256 as u32) + (*y as u32)}))
}

macro_rules! try_or_else {
    ( $ex: expr, $err: expr ) => {
        match $ex {
            Err(_) => return Err($err),
            Ok(x) => x
        }
    }
}

/// Return a matrix representation of the set of images in the dataset
/// file. Each row is an image, whose elements are the pixel-wise
/// intensities laid out in row-major order.
pub fn load_mnist_images(filename: &Path) -> Result<((u32, u32), Matrix<f32>), DatasetError> {
    let f = try_or_else!(File::open(filename), DatasetError::FileNotFound);

    let mut reader = try_or_else!(GzDecoder::new(f), DatasetError::InvalidGzip);

    // magic number
    match read_u32_msb(&mut reader) {
        Err(_) => return Err(DatasetError::FileReadError),
        Ok(2051) => {},
        Ok(x) => {
            println!("wrong magic number: {}", x);
            return Err(DatasetError::BadMagicNumber)
        }
    }

    let num_items = read_u32_msb(&mut reader).unwrap() as usize;
    let width = read_u32_msb(&mut reader).unwrap();
    let height = read_u32_msb(&mut reader).unwrap();

    let d = (width * height) as usize;

    let mut buf_data: Vec<u8> = Vec::new();
    buf_data.resize(d * num_items, 0);
    let mut buf_v: Vec<f32> = Vec::new();
    buf_v.resize(d * num_items, 0.0);

    try_or_else!(reader.read_exact(&mut buf_data), DatasetError::BadRead);

    for (x, a) in buf_data.iter().zip(buf_v.iter_mut()) {
        *a = (*x as f32) / 255.0;
    }

    Ok(((width, height), Matrix::from_shape_vec((num_items, d), buf_v).ok().unwrap()))
}

pub fn load_mnist_labels(filename: &Path) -> Result<Vec<u8>, DatasetError> {
    let f = try_or_else!(File::open(filename), DatasetError::FileNotFound);

    let mut reader = try_or_else!(GzDecoder::new(f), DatasetError::InvalidGzip);

    // check the magic number
    let magic = try_or_else!(read_u32_msb(&mut reader), DatasetError::FileReadError);
    if magic != 2049 {
        return Err(DatasetError::BadMagicNumber);
    }

    // Read the labels data.
    let num_items = try_or_else!(read_u32_msb(&mut reader), DatasetError::FileReadError);

    let mut v: Vec<u8> = Vec::new();
    v.resize(num_items as usize, 0);
    try_or_else!(reader.read_exact(&mut v), DatasetError::FileReadError);

    Ok(v)
}