1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
use std::io;
use std::path::{Path};
use std::io::{Read};
use std::fs::{File};
use model::Matrix;
use flate2::read::GzDecoder;
#[derive(Debug, Clone)]
pub enum DatasetError {
FileNotFound,
InvalidGzip,
FileReadError,
BadMagicNumber,
BadRead
}
fn read_u32_msb(reader: &mut Read) -> io::Result<u32> {
let mut r: [u8; 4] = [0; 4];
try!(reader.read_exact(&mut r));
Ok(r.iter().fold(0 as u32, |acc, y| { acc * (256 as u32) + (*y as u32)}))
}
macro_rules! try_or_else {
( $ex: expr, $err: expr ) => {
match $ex {
Err(_) => return Err($err),
Ok(x) => x
}
}
}
pub fn load_mnist_images(filename: &Path) -> Result<((u32, u32), Matrix<f32>), DatasetError> {
let f = try_or_else!(File::open(filename), DatasetError::FileNotFound);
let mut reader = try_or_else!(GzDecoder::new(f), DatasetError::InvalidGzip);
match read_u32_msb(&mut reader) {
Err(_) => return Err(DatasetError::FileReadError),
Ok(2051) => {},
Ok(x) => {
println!("wrong magic number: {}", x);
return Err(DatasetError::BadMagicNumber)
}
}
let num_items = read_u32_msb(&mut reader).unwrap() as usize;
let width = read_u32_msb(&mut reader).unwrap();
let height = read_u32_msb(&mut reader).unwrap();
let d = (width * height) as usize;
let mut buf_data: Vec<u8> = Vec::new();
buf_data.resize(d * num_items, 0);
let mut buf_v: Vec<f32> = Vec::new();
buf_v.resize(d * num_items, 0.0);
try_or_else!(reader.read_exact(&mut buf_data), DatasetError::BadRead);
for (x, a) in buf_data.iter().zip(buf_v.iter_mut()) {
*a = (*x as f32) / 255.0;
}
Ok(((width, height), Matrix::from_shape_vec((num_items, d), buf_v).ok().unwrap()))
}
pub fn load_mnist_labels(filename: &Path) -> Result<Vec<u8>, DatasetError> {
let f = try_or_else!(File::open(filename), DatasetError::FileNotFound);
let mut reader = try_or_else!(GzDecoder::new(f), DatasetError::InvalidGzip);
let magic = try_or_else!(read_u32_msb(&mut reader), DatasetError::FileReadError);
if magic != 2049 {
return Err(DatasetError::BadMagicNumber);
}
let num_items = try_or_else!(read_u32_msb(&mut reader), DatasetError::FileReadError);
let mut v: Vec<u8> = Vec::new();
v.resize(num_items as usize, 0);
try_or_else!(reader.read_exact(&mut v), DatasetError::FileReadError);
Ok(v)
}