it learning

master
q3k 2021-04-02 22:38:15 +02:00
parent ec7d29ceef
commit 16663ebd02
6 changed files with 842 additions and 55 deletions

268
Cargo.lock generated
View File

@ -6,6 +6,12 @@ version = "1.0.2"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "f26201604c87b1e01bd3d98f8d5d9a8fcbb815e8cedb41ffccbeb4bf593a35fe"
[[package]]
name = "adler32"
version = "1.2.0"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "aae1277d39aeec15cb388266ecc24b11c80469deae6067e17a1a7aa9e5c1f234"
[[package]]
name = "aho-corasick"
version = "0.7.15"
@ -32,6 +38,18 @@ version = "1.0.1"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "cdb031dd78e28731d87d56cc8ffef4a8f36ca26c38fe2de700543e627f8a464a"
[[package]]
name = "bitflags"
version = "1.2.1"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "cf1de2fe8c75bc145a2f577add951f8134889b4795d47466a54a5c846d691693"
[[package]]
name = "bytemuck"
version = "1.5.1"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "bed57e2090563b83ba8f83366628ce535a7584c9afa4c9fc0612a03925c6df58"
[[package]]
name = "byteorder"
version = "1.4.3"
@ -44,6 +62,12 @@ version = "1.0.0"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "baf1de4339761588bc0619e3cbc0120ee582ebb74b53b4efbf79117bd2da40fd"
[[package]]
name = "color_quant"
version = "1.1.0"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "3d7b894f5411737b7867f4827955924d7c254fc9f4d91a6aad6b097804b1018b"
[[package]]
name = "crc32fast"
version = "1.2.1"
@ -98,6 +122,16 @@ dependencies = [
"lazy_static",
]
[[package]]
name = "deflate"
version = "0.8.6"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "73770f8e1fe7d64df17ca66ad28994a0a623ea497fa69486e14984e715c5d174"
dependencies = [
"adler32",
"byteorder",
]
[[package]]
name = "either"
version = "1.6.1"
@ -126,7 +160,7 @@ dependencies = [
"cfg-if",
"crc32fast",
"libc",
"miniz_oxide",
"miniz_oxide 0.4.4",
]
[[package]]
@ -140,6 +174,16 @@ dependencies = [
"wasi",
]
[[package]]
name = "gif"
version = "0.11.2"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "5a668f699973d0f573d15749b7002a9ac9e1f9c6b220e7b165601334c173d8de"
dependencies = [
"color_quant",
"weezl",
]
[[package]]
name = "hermit-abi"
version = "0.1.18"
@ -155,6 +199,49 @@ version = "2.1.0"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "9a3a5bfb195931eeb336b2a7b4d761daec841b97f947d34394601737a7bba5e4"
[[package]]
name = "image"
version = "0.23.14"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "24ffcb7e7244a9bf19d35bf2883b9c080c4ced3c07a9895572178cdb8f13f6a1"
dependencies = [
"bytemuck",
"byteorder",
"color_quant",
"gif",
"jpeg-decoder",
"num-iter",
"num-rational",
"num-traits",
"png",
"scoped_threadpool",
"tiff",
]
[[package]]
name = "itertools"
version = "0.10.0"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "37d572918e350e82412fe766d24b15e6682fb2ed2bbe018280caa810397cb319"
dependencies = [
"either",
]
[[package]]
name = "itoa"
version = "0.4.7"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "dd25036021b0de88a0aff6b850051563c6516d0bf53f8638938edbb9de732736"
[[package]]
name = "jpeg-decoder"
version = "0.1.22"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "229d53d58899083193af11e15917b5640cd40b29ff475a1fe4ef725deb02d0f2"
dependencies = [
"rayon",
]
[[package]]
name = "lazy_static"
version = "1.4.0"
@ -167,6 +254,12 @@ version = "0.2.91"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "8916b1f6ca17130ec6568feccee27c156ad12037880833a3b842a823236502e7"
[[package]]
name = "libm"
version = "0.2.1"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "c7d73b3f436185384286bd8098d17ec07c9a7d2388a6599f824d8502b529702a"
[[package]]
name = "log"
version = "0.4.14"
@ -191,6 +284,15 @@ dependencies = [
"autocfg",
]
[[package]]
name = "miniz_oxide"
version = "0.3.7"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "791daaae1ed6889560f8c4359194f56648355540573244a5448a83ba1ecc7435"
dependencies = [
"adler32",
]
[[package]]
name = "miniz_oxide"
version = "0.4.4"
@ -208,9 +310,56 @@ dependencies = [
"byteorder",
"env_logger",
"flate2",
"image",
"itertools",
"log",
"rand",
"rand_distr",
"rayon",
"serde",
"serde_json",
]
[[package]]
name = "num-integer"
version = "0.1.44"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "d2cc698a63b549a70bc047073d2949cce27cd1c7b0a4a862d08a8031bc2801db"
dependencies = [
"autocfg",
"num-traits",
]
[[package]]
name = "num-iter"
version = "0.1.42"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "b2021c8337a54d21aca0d59a92577a029af9431cb59b909b03252b9c164fad59"
dependencies = [
"autocfg",
"num-integer",
"num-traits",
]
[[package]]
name = "num-rational"
version = "0.3.2"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "12ac428b1cb17fce6f731001d307d351ec70a6d202fc2e60f7d4c5e42d8f4f07"
dependencies = [
"autocfg",
"num-integer",
"num-traits",
]
[[package]]
name = "num-traits"
version = "0.2.14"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "9a64b1ec5cda2586e284722486d802acf1f7dbdc623e2bfc57e65ca1cd099290"
dependencies = [
"autocfg",
"libm",
]
[[package]]
@ -223,12 +372,42 @@ dependencies = [
"libc",
]
[[package]]
name = "png"
version = "0.16.8"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "3c3287920cb847dee3de33d301c463fba14dda99db24214ddf93f83d3021f4c6"
dependencies = [
"bitflags",
"crc32fast",
"deflate",
"miniz_oxide 0.3.7",
]
[[package]]
name = "ppv-lite86"
version = "0.2.10"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "ac74c624d6b2d21f425f752262f42188365d7b8ff1aff74c82e45136510a4857"
[[package]]
name = "proc-macro2"
version = "1.0.24"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "1e0704ee1a7e00d7bb417d0770ea303c1bccbabf0ef1667dae92b5967f5f8a71"
dependencies = [
"unicode-xid",
]
[[package]]
name = "quote"
version = "1.0.9"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "c3d0b9745dc2debf507c8422de05d7226cc1f0644216dfdfead988f9b1ab32a7"
dependencies = [
"proc-macro2",
]
[[package]]
name = "rand"
version = "0.8.3"
@ -260,6 +439,16 @@ dependencies = [
"getrandom",
]
[[package]]
name = "rand_distr"
version = "0.4.0"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "da9e8f32ad24fb80d07d2323a9a2ce8b30d68a62b8cb4df88119ff49a698f038"
dependencies = [
"num-traits",
"rand",
]
[[package]]
name = "rand_hc"
version = "0.3.0"
@ -311,12 +500,66 @@ version = "0.6.23"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "24d5f089152e60f62d28b835fbff2cd2e8dc0baf1ac13343bef92ab7eed84548"
[[package]]
name = "ryu"
version = "1.0.5"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "71d301d4193d031abdd79ff7e3dd721168a9572ef3fe51a1517aba235bd8f86e"
[[package]]
name = "scoped_threadpool"
version = "0.1.9"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "1d51f5df5af43ab3f1360b429fa5e0152ac5ce8c0bd6485cae490332e96846a8"
[[package]]
name = "scopeguard"
version = "1.1.0"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "d29ab0c6d3fc0ee92fe66e2d99f700eab17a8d57d1c1d3b748380fb20baa78cd"
[[package]]
name = "serde"
version = "1.0.125"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "558dc50e1a5a5fa7112ca2ce4effcb321b0300c0d4ccf0776a9f60cd89031171"
dependencies = [
"serde_derive",
]
[[package]]
name = "serde_derive"
version = "1.0.125"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "b093b7a2bb58203b5da3056c05b4ec1fed827dcfdb37347a8841695263b3d06d"
dependencies = [
"proc-macro2",
"quote",
"syn",
]
[[package]]
name = "serde_json"
version = "1.0.64"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "799e97dc9fdae36a5c8b8f2cae9ce2ee9fdce2058c57a93e6099d919fd982f79"
dependencies = [
"itoa",
"ryu",
"serde",
]
[[package]]
name = "syn"
version = "1.0.67"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "6498a9efc342871f91cc2d0d694c674368b4ceb40f62b65a7a08c3792935e702"
dependencies = [
"proc-macro2",
"quote",
"unicode-xid",
]
[[package]]
name = "termcolor"
version = "1.1.2"
@ -326,12 +569,35 @@ dependencies = [
"winapi-util",
]
[[package]]
name = "tiff"
version = "0.6.1"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "9a53f4706d65497df0c4349241deddf35f84cee19c87ed86ea8ca590f4464437"
dependencies = [
"jpeg-decoder",
"miniz_oxide 0.4.4",
"weezl",
]
[[package]]
name = "unicode-xid"
version = "0.2.1"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "f7fe0bb3479651439c9112f72b6c505038574c9fbb575ed1bf3b797fa39dd564"
[[package]]
name = "wasi"
version = "0.10.2+wasi-snapshot-preview1"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "fd6fbd9a79829dd1ad0cc20627bf1ed606756a7f77edff7b66b7064f9cb327c6"
[[package]]
name = "weezl"
version = "0.1.4"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "4a32b378380f4e9869b22f0b5177c68a5519f03b3454fde0b291455ddbae266c"
[[package]]
name = "winapi"
version = "0.3.9"

View File

@ -12,4 +12,9 @@ flate2 = "1"
log = "0.4.0"
env_logger = "0.8.3"
rand = "0.8.3"
rand_distr = "0.4.0"
rayon = "1"
itertools = "0.10.0"
serde = { version = "1", features = ["derive"] }
serde_json = "1"
image = "0.23"

View File

@ -1,30 +1,35 @@
use std::time;
use rand::Rng;
use rayon::prelude::*;
use itertools::Itertools;
use crate::types as t;
use crate::types::{Data, Input, Output};
use rayon::prelude::*;
use serde::{Serialize, Deserialize};
pub mod types;
pub mod maths;
mod mnist;
use maths::{Vector, Matrix, Shape};
#[derive(Serialize,Deserialize)]
struct JSONBiases(Vec<Vec<Vec<f32>>>);
#[derive(Serialize,Deserialize)]
struct JSONWeights(Vec<Vec<Vec<f32>>>);
struct Network {
sizes: Vec<usize>,
biases: Vec<Vec<f32>>,
weights: Vec<Vec<Vec<f32>>>,
biases: Vec<Vector>,
weights: Vec<Matrix>,
}
impl Network {
fn new(sizes: &Vec<usize>) -> Self {
let mut rng = rand::thread_rng();
let biases = sizes[1..].iter().map(|s| {
(0..s.clone()).map(|_| rng.gen::<f32>()).collect()
}).collect();
let biases = sizes[1..].iter().map(|s| Vector::random(*s, &mut rng)).collect();
let weights = sizes.iter().zip(sizes[1..].iter()).map(|(from, to)| {
(0..to.clone()).map(|_| {
(0..from.clone()).map(|_| rng.gen::<f32>()).collect()
}).collect()
Matrix::random(Shape::new(*to, *from), &mut rng)
}).collect();
Self {
@ -33,18 +38,174 @@ impl Network {
}
}
fn dump(&self) {
for (i, (b, w)) in self.biases.iter().zip(self.weights.iter()).enumerate() {
log::info!("l {}: biases: {:?}", i, b);
log::info!("l {}: weights:", i);
for row in w.rows.iter() {
log::info!(" {:?}", row);
}
}
}
fn feedforward<D: t::Data>(&self, i: &D::Input) -> D::Output {
let a = i.data();
let o = self.biases.iter()
.zip(self.weights.iter())
.fold(a.clone(), |a, (biases, weights)| {
mat_vec_mult(weights, &a).iter()
.zip(biases.iter())
.map(|(aa, b)| {
sigmoid(aa+b)
}).collect()
});
D::Output::from_nn_output(o)
let mut a = i.data().clone();
for (b, w) in self.biases.iter().zip(self.weights.iter()) {
//log::info!("before: {:?}", a);
let mut z = w.mult_vec(&a);
z.add_mut(&b);
for e in z.iter_mut() {
*e = sigmoid(*e);
}
a = z;
//log::info!("after: {:?}", a);
}
D::Output::from_nn_output(a)
}
fn sgd<D: t::Data>(
&mut self,
data: &D,
epochs: usize,
mini_batch_size: usize,
eta: f32,
test_data: Option<&D>,
) {
for j in 0usize..epochs {
//self.dump();
let shuffled = data.shuffle();
let batches = shuffled.iter().chunks(mini_batch_size);
for batch in &batches {
self.update_batch::<D, _>(batch, eta);
}
if let Some(td) = test_data {
let n_test = td.size();
log::info!("Epoch {}: {} / {}", j, self.evaluate(td), n_test);
} else {
log::info!("Epoch {} complete", j);
}
}
}
fn update_batch<'a, D, B>(
&mut self,
batch: B,
eta: f32,
) where
D: t::Data,
D::Input: 'a + Sync,
D::Output: 'a + Sync,
B: std::iter::Iterator<Item = &'a (D::Input, D::Output)>,
{
let mut nabla_b: Vec<Vector> = self.biases.iter().map(|el| {
Vector::zeroes(el.len())
}).collect();
let mut nabla_w: Vec<Matrix> = self.weights.iter().map(|el| {
Matrix::zeroes(*el.shape())
}).collect();
let batch: Vec<(D::Input, D::Output)> = batch.cloned().collect();
let par = batch.par_iter().map(|(input, output)| {
self.backprop::<D>(input, output)
});
let mut batch_size = 0usize;
for (delta_nabla_b, delta_nabla_w) in par.collect::<Vec<_>>() {
for (nb, dnb) in nabla_b.iter_mut().zip(delta_nabla_b.into_iter()) {
nb.add_mut(&dnb)
}
for (nw, dnw) in nabla_w.iter_mut().zip(delta_nabla_w.into_iter()) {
nw.add_mut(&dnw)
}
batch_size += 1;
}
for (w, mut nw) in self.weights.iter_mut().zip(nabla_w.into_iter()) {
nw.mult_scalar_mut(eta / (batch_size as f32));
w.sub_mut(&nw)
}
for (b, mut nb) in self.biases.iter_mut().zip(nabla_b.into_iter()) {
nb.mult_scalar_mut(eta / (batch_size as f32));
b.sub_mut(&nb)
}
}
fn evaluate<D: t::Data>(
&self,
test_data: &D,
) -> usize {
let mut count: usize = 0;
for (i, (input, expected)) in test_data.iter().enumerate() {
let got = self.feedforward::<D>(input);
if got.onehot_decode() == expected.onehot_decode() {
count += 1;
}
}
count
}
fn backprop<D: t::Data>(
&self,
x: &D::Input,
y: &D::Output,
) -> (Vec<Vector>, Vec<Matrix>) {
let mut nabla_b: Vec<Vector> = vec![];
let mut nabla_w: Vec<Matrix> = vec![];
let mut activations: Vec<Vector> = vec![x.data().clone()];
let mut activation = &activations[0];
let mut zs: Vec<Vector> = vec![];
for (b, w) in self.biases.iter().zip(self.weights.iter()) {
let mut z = w.mult_vec(activation);
z.add_mut(b);
zs.push(z.clone());
for e in z.iter_mut() {
*e = sigmoid(*e);
}
activations.push(z);
activation = activations.last().unwrap();
}
activations.reverse();
zs.reverse();
let sp: Vector = zs[0].iter().map(|el| sigmoid_prime(*el)).collect();
let mut delta: Vector = self.cost_derivative::<D>(&activations[0], y);
delta.mult_mut(&sp);
nabla_b.push(delta.clone());
nabla_w.push(
Matrix::from_column(&delta).mult(&Matrix::from_row(activations[1].clone()))
);
for l in 1..(self.sizes.len()-1) {
let z = &zs[l];
let sp: Vector = z.iter().map(|el| sigmoid_prime(*el)).collect();
let weights = &self.weights[self.weights.len()-l];
let weights = weights.transpose();
delta = weights.mult_vec(&delta).mult(&sp);
nabla_b.push(delta.clone());
nabla_w.push(
Matrix::from_column(&delta).mult(&Matrix::from_row(activations[l+1].clone()))
);
}
nabla_b.reverse();
nabla_w.reverse();
(nabla_b, nabla_w)
}
fn cost_derivative<D: t::Data>(
&self,
output_activations: &Vector,
y: &D::Output,
) -> Vector {
output_activations.sub(y.data())
}
}
@ -52,13 +213,8 @@ fn sigmoid(z: f32) -> f32 {
1. / (1. + (-z).exp())
}
fn mat_vec_mult(m: &Vec<Vec<f32>>, v: &Vec<f32>) -> Vec<f32> {
m.iter().enumerate().map(|(i, row)| {
if row.len() != v.len() {
panic!("Mismatch: V {}, M {}x{}", v.len(), m.len(), row.len());
}
row.iter().zip(v.iter()).map(|(rr, vv)| rr * vv).sum()
}).collect()
fn sigmoid_prime(z: f32) -> f32 {
sigmoid(z) * (1. - sigmoid(z))
}
fn main() {
@ -66,8 +222,9 @@ fn main() {
let now = time::Instant::now();
let train = mnist::load("train").unwrap();
let test = mnist::load("t10k").unwrap();
let elapsed = now.elapsed();
log::info!("Loaded {} images in {:?}.", train.len(), elapsed);
log::info!("Loaded {} training / {} test images in {:?}.", train.len(), test.len(), elapsed);
let num_train = (train.len() as f64 * 5.0/6.0) as usize;
let num_validation = train.len() - num_train;
@ -77,16 +234,12 @@ fn main() {
let now = time::Instant::now();
let size = vec![
28*28 as usize,
15usize,
30usize,
10usize,
];
let network = Network::new(&size);
let mut net = Network::new(&size);
let elapsed = now.elapsed();
log::info!("Created random network {:?} in {:?}.", size, elapsed);
let now = time::Instant::now();
let data = train.iter().next().unwrap();
let res = network.feedforward::<mnist::Data>(&data.0);
let elapsed = now.elapsed();
log::info!("feedforward: {:?} in {:?}", res, elapsed);
net.sgd(&train, 90, 10, 1.0, Some(&test));
}

294
src/maths.rs Normal file
View File

@ -0,0 +1,294 @@
use rand_distr::StandardNormal;
#[derive(Clone,Debug)]
pub struct Vector(pub Vec<f32>);
impl Vector {
pub fn len(&self) -> usize {
self.0.len()
}
pub fn zeroes(size: usize) -> Self {
Vector(vec![0.0f32; size])
}
pub fn random<R: rand::Rng>(size: usize, rng: &mut R) -> Self {
Vector((0..size).map(|_| rng.sample(StandardNormal)).collect())
}
pub fn dot(&self, o: &Self) -> f32 {
if self.len() != o.len() {
panic!("Shape mismatch: {} and {}", self.len(), o.len())
}
self.0.iter().zip(o.0.iter()).map(|(a, b)| a*b).sum()
}
pub fn add_mut(&mut self, o: &Self) {
if self.len() != o.len() {
panic!("Shape mismatch: {} and {}", self.len(), o.len())
}
for (a, b) in self.0.iter_mut().zip(o.0.iter()) {
*a += b;
}
}
pub fn add(&self, o: &Self) -> Self {
let mut res = self.clone();
res.add_mut(o);
res
}
pub fn sub_mut(&mut self, o: &Self) {
if self.len() != o.len() {
panic!("Shape mismatch: {} and {}", self.len(), o.len())
}
for (a, b) in self.0.iter_mut().zip(o.0.iter()) {
*a -= b;
}
}
pub fn sub(&self, o: &Self) -> Self {
let mut res = self.clone();
res.sub_mut(o);
res
}
pub fn mult_mut(&mut self, o: &Self) {
if self.len() != o.len() {
panic!("Shape mismatch: {} and {}", self.len(), o.len())
}
for (a, b) in self.0.iter_mut().zip(o.0.iter()) {
*a *= b;
}
}
pub fn mult(&self, o: &Self) -> Self {
let mut res = self.clone();
res.mult_mut(o);
res
}
pub fn mult_scalar_mut(&mut self, val: f32) {
for e in self.0.iter_mut() {
*e *= val;
}
}
pub fn iter_mut(&mut self) -> std::slice::IterMut<f32> {
self.0.iter_mut()
}
pub fn iter(&self) -> std::slice::Iter<f32> {
self.0.iter()
}
}
impl std::iter::FromIterator<f32> for Vector {
fn from_iter<I: std::iter::IntoIterator<Item=f32>>(iter: I) -> Self {
Self(Vec::from_iter(iter))
}
}
impl std::iter::IntoIterator for Vector {
type Item = f32;
type IntoIter = std::vec::IntoIter<Self::Item>;
fn into_iter(self) -> Self::IntoIter {
self.0.into_iter()
}
}
impl std::ops::Index<usize> for Vector {
type Output = f32;
fn index(&self, index: usize) -> &Self::Output {
self.0.index(index)
}
}
impl std::ops::IndexMut<usize> for Vector {
fn index_mut(&mut self, index: usize) -> &mut Self::Output {
self.0.index_mut(index)
}
}
#[derive(Clone,Copy,Debug,PartialEq,Eq)]
pub struct Shape(usize, usize);
impl std::fmt::Display for Shape {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
write!(f, "{}x{}", self.rows(), self.columns())
}
}
impl Shape {
pub fn new(rows: usize, columns: usize) -> Self {
Self(rows, columns)
}
pub fn rows(&self) -> usize {
self.0
}
pub fn columns(&self) -> usize {
self.1
}
}
#[derive(Clone,Debug)]
pub struct Matrix {
shape: Shape,
pub rows: Vec<Vector>,
}
impl Matrix {
pub fn zeroes(shape: Shape) -> Self {
Self {
shape,
rows: vec![Vector::zeroes(shape.columns()); shape.rows()],
}
}
pub fn new(rows: Vec<Vector>) -> Self {
let nrows = rows.len();
if nrows < 1 {
panic!("No rows given");
}
let ncolumns = rows[0].len();
if !rows.iter().all(|r| r.len() == ncolumns) {
panic!("Given vectors are not the same length")
}
Self {
shape: Shape::new(nrows, ncolumns),
rows,
}
}
pub fn random<R: rand::Rng>(shape: Shape, rng: &mut R) -> Self {
Self {
shape,
rows: vec![Vector::random(shape.columns(), rng); shape.rows()],
}
}
pub fn from_row(row: Vector) -> Self {
Self {
shape: Shape::new(1, row.len()),
rows: vec![row],
}
}
pub fn from_column(column: &Vector) -> Self {
let rows = column.iter().map(|el| Vector(vec![*el])).collect();
Self {
shape: Shape::new(column.len(), 1),
rows,
}
}
pub fn shape(&self) -> &Shape {
&self.shape
}
pub fn mult_vec(&self, v: &Vector) -> Vector {
if v.len() != self.shape().columns() {
panic!("Shape mismatch: V {}, M {}", v.len(), self.shape())
}
self.rows.iter().map(|row| row.dot(v)).collect()
}
pub fn column(&self, n: usize) -> Vector {
if n >= self.shape.columns() {
panic!("Out of bounds: want column {}, have {}", n, self.shape.columns())
}
self.rows.iter().map(|row| row[n]).collect()
}
pub fn mult(&self, o: &Matrix) -> Matrix {
if self.shape().columns() != o.shape().rows() {
panic!("Shape mismatch: self {}, other {}", self.shape(), o.shape())
}
let rows = self.rows.iter().map(|row_a| {
(0..o.shape().columns()).map(|n| {
row_a.iter().enumerate().map(|(m, v)| {
let v2 = o.rows[m][n];
v * v2
}).sum()
}).collect()
}).collect();
Self {
shape: Shape(self.shape().rows(), o.shape.columns()),
rows,
}
}
pub fn add(&self, o: &Self) -> Self {
if self.shape() != o.shape() {
panic!("Shape mismatch: self {}, other {}", self.shape(), o.shape())
}
Self {
shape: self.shape().clone(),
rows: self.rows.iter().zip(o.rows.iter()).map(|(a, b)| a.add(b)).collect(),
}
}
pub fn add_mut(&mut self, o: &Self) {
if self.shape() != o.shape() {
panic!("Shape mismatch: self {}, other {}", self.shape(), o.shape())
}
for (ra, rb) in self.rows.iter_mut().zip(o.rows.iter()) {
ra.add_mut(rb);
}
}
pub fn sub_mut(&mut self, o: &Self) {
if self.shape() != o.shape() {
panic!("Shape mismatch: self {}, other {}", self.shape(), o.shape())
}
for (ra, rb) in self.rows.iter_mut().zip(o.rows.iter()) {
ra.sub_mut(rb);
}
}
pub fn mult_scalar_mut(&mut self, val: f32) {
for row in self.rows.iter_mut() {
row.mult_scalar_mut(val);
}
}
pub fn transpose(&self) -> Self {
let rows = (0..self.shape().columns()).map(|cnum| {
self.column(cnum)
}).collect();
Self {
shape: Shape::new(self.shape().columns(), self.shape().rows()),
rows,
}
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_column_column_mult() {
let m1 = Matrix::from_column(&Vector(vec![1.,2.,3.]));
let m2 = Matrix::from_row(Vector(vec![2.,3.,4.]));
let res = m1.mult(&m2);
assert_eq!(res.rows[0].0, vec![ 2.0, 3.0, 4.0 ]);
assert_eq!(res.rows[1].0, vec![ 4.0, 6.0, 8.0 ]);
assert_eq!(res.rows[2].0, vec![ 6.0, 9.0, 12.0 ]);
}
#[test]
fn test_matrix_matrix_mult() {
let m1 = Matrix::new(vec![
Vector(vec![ 0.0, 4.0, -2.0 ]),
Vector(vec![-4.0, -3.0, 0.0 ]),
]);
let m2 = Matrix::new(vec![
Vector(vec![ 0.0, 1.0 ]),
Vector(vec![ 1.0, -1.0 ]),
Vector(vec![ 2.0, 3.0 ]),
]);
let res = m1.mult(&m2);
assert_eq!(res.rows[0].0, vec![ 0.0, -10.0 ]);
assert_eq!(res.rows[1].0, vec![ -3.0, -1.0 ]);
}
#[test]
fn test_matrix_vector_mult() {
let m1 = Matrix::new(vec![
Vector(vec![ 1.0, -1.0, 2.0 ]),
Vector(vec![ 0.0, -3.0, 1.0 ]),
]);
let v2 = Vector(vec![2.0, 1.0, 0.0]);
let res = m1.mult_vec(&v2);
assert_eq!(res.0, vec![1.0, -3.0]);
}
}

View File

@ -5,12 +5,14 @@ use byteorder::{
};
use flate2::read::GzDecoder;
use crate::types as t;
use crate::maths::Vector;
use image::{RgbImage, Rgb};
#[derive(Debug)]
pub enum Error {
IO(io::Error),
InvalidMagic,
Parse(String),
}
impl From<io::Error> for Error {
@ -21,8 +23,9 @@ impl From<io::Error> for Error {
pub type Result<T> = std::result::Result<T, Error>;
#[derive(Debug, Clone)]
pub struct Image {
pixels: Vec<f32>,
pixels: Vector,
width: usize,
height: usize,
}
@ -35,19 +38,32 @@ impl Image {
) -> Result<Self> {
let npixels: usize = width * height;
let mut pixels: Vec<u8> = vec![0u8; npixels];
rdr.read(&mut pixels)?;
let pixels: Vec<f32> = pixels.into_iter().map(|u| (u as f32) / 255.0).collect();
rdr.read_exact(&mut pixels)?;
let pixels: Vector = pixels.into_iter().map(|u| (u as f32) / 255.0).collect();
Ok(Image {
pixels, width, height,
})
}
pub fn to_rgb(
&self,
) -> RgbImage {
let mut image = RgbImage::new(self.width as u32, self.height as u32);
for x in 0..self.width {
for y in 0..self.height {
let val = (self.pixels[y*self.height+x] * 255.0) as u8;
image.put_pixel(x as u32, y as u32, Rgb([val; 3]));
}
}
image
}
}
impl t::Input for Image {
fn size(&self) -> usize {
self.width * self.height
}
fn data(&self) -> &Vec<f32> {
fn data(&self) -> &Vector {
&self.pixels
}
}
@ -77,12 +93,12 @@ impl ImageFile {
}
}
#[derive(Debug)]
pub struct RecognizedDigit(Vec<f32>);
#[derive(Debug, Clone)]
pub struct RecognizedDigit(Vector);
impl RecognizedDigit {
pub fn parse(label: usize, max: usize) -> Self {
let mut data: Vec<f32> = vec![0f32; max];
let mut data = Vector::zeroes(max);
data[label] = 1.0f32;
Self(data)
}
@ -92,10 +108,10 @@ impl t::Output for RecognizedDigit {
fn size(&self) -> usize {
self.0.len()
}
fn data(&self) -> &Vec<f32> {
fn data(&self) -> &Vector {
&self.0
}
fn from_nn_output(data: Vec<f32>) -> Self {
fn from_nn_output(data: Vector) -> Self {
if data.len() != 10 {
panic!("invalid nn output size (got {}, wanted {})", data.len(), 10)
}
@ -115,7 +131,7 @@ impl LabelFile {
}
let num_labels = rdr.read_u32::<BigEndian>()?;
let mut labels: Vec<u8> = vec![0u8; num_labels as usize];
rdr.read(&mut labels)?;
rdr.read_exact(&mut labels)?;
Ok(LabelFile {
labels: labels.into_iter().map(|l| RecognizedDigit::parse(l as usize, 10usize)).collect()
})

View File

@ -1,12 +1,33 @@
pub trait Input {
use rand::thread_rng;
use rand::seq::SliceRandom;
use crate::maths::Vector;
pub trait Input: std::fmt::Debug + Clone + Sync {
fn size(&self) -> usize;
fn data(&self) -> &Vec<f32>;
fn data(&self) -> &Vector;
}
pub trait Output {
pub trait Output: std::fmt::Debug + Clone + Sync {
fn size(&self) -> usize;
fn data(&self) -> &Vec<f32>;
fn from_nn_output(data: Vec<f32>) -> Self;
fn data(&self) -> &Vector;
fn from_nn_output(data: Vector) -> Self;
fn onehot_decode(&self) -> usize {
let mut cur: Option<(f32, usize)> = None;
for (i, &v) in self.data().iter().enumerate() {
match cur {
None => {
cur = Some((v, i));
},
Some((mv, _)) => {
if v > mv {
cur = Some((v, i));
}
},
}
}
cur.unwrap().1
}
}
pub trait Data {
@ -14,18 +35,32 @@ pub trait Data {
type Output: Output;
fn iter<'a>(&'a self) -> Box<dyn std::iter::ExactSizeIterator<Item = &'a (Self::Input, Self::Output)> + 'a>;
fn shuffle<'a> (&'a self) -> Box<dyn Data<Input = Self::Input, Output = Self::Output> + 'a>;
fn size(&self) -> usize;
}
pub struct InMemoryData<I: Input, O: Output> {
inner: Vec<(I, O)>,
}
pub struct InMemoryDataView<'a, I: Input, O: Output> {
inner: Vec<&'a (I, O)>,
}
impl <I: Input, O: Output> Data for InMemoryData<I, O> {
type Input = I;
type Output = O;
fn iter<'a>(&'a self) -> Box<dyn std::iter::ExactSizeIterator<Item = &'a (I, O)> + 'a> {
Box::new(self.inner.iter())
}
fn shuffle<'a>(&'a self) -> Box<dyn Data<Input = Self::Input, Output = Self::Output>+'a> {
Box::new(InMemoryDataView {
inner: self.inner.iter().collect(),
})
}
fn size(&self) -> usize {
self.inner.len()
}
}
impl <I: Input, O: Output> InMemoryData<I, O> {
@ -46,3 +81,21 @@ impl <I: Input, O: Output> InMemoryData<I, O> {
self.inner.len()
}
}
impl <'a, I: Input, O: Output> Data for InMemoryDataView<'a, I, O> {
type Input = I;
type Output = O;
fn iter<'b>(&'b self) -> Box<dyn std::iter::ExactSizeIterator<Item = &'b (I, O)> + 'b> {
let mut shuffled: Vec<&'b (I, O)> = self.inner.iter().map(|el| *el).collect();
shuffled.shuffle(&mut thread_rng());
Box::new(shuffled.into_iter())
}
fn shuffle<'b>(&'b self) -> Box<dyn Data<Input = Self::Input, Output = Self::Output>+'b> {
Box::new(InMemoryDataView {
inner: self.inner.iter().cloned().collect(),
})
}
fn size(&self) -> usize {
self.inner.len()
}
}