it learning
parent
ec7d29ceef
commit
16663ebd02
|
@ -6,6 +6,12 @@ version = "1.0.2"
|
|||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "f26201604c87b1e01bd3d98f8d5d9a8fcbb815e8cedb41ffccbeb4bf593a35fe"
|
||||
|
||||
[[package]]
|
||||
name = "adler32"
|
||||
version = "1.2.0"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "aae1277d39aeec15cb388266ecc24b11c80469deae6067e17a1a7aa9e5c1f234"
|
||||
|
||||
[[package]]
|
||||
name = "aho-corasick"
|
||||
version = "0.7.15"
|
||||
|
@ -32,6 +38,18 @@ version = "1.0.1"
|
|||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "cdb031dd78e28731d87d56cc8ffef4a8f36ca26c38fe2de700543e627f8a464a"
|
||||
|
||||
[[package]]
|
||||
name = "bitflags"
|
||||
version = "1.2.1"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "cf1de2fe8c75bc145a2f577add951f8134889b4795d47466a54a5c846d691693"
|
||||
|
||||
[[package]]
|
||||
name = "bytemuck"
|
||||
version = "1.5.1"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "bed57e2090563b83ba8f83366628ce535a7584c9afa4c9fc0612a03925c6df58"
|
||||
|
||||
[[package]]
|
||||
name = "byteorder"
|
||||
version = "1.4.3"
|
||||
|
@ -44,6 +62,12 @@ version = "1.0.0"
|
|||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "baf1de4339761588bc0619e3cbc0120ee582ebb74b53b4efbf79117bd2da40fd"
|
||||
|
||||
[[package]]
|
||||
name = "color_quant"
|
||||
version = "1.1.0"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "3d7b894f5411737b7867f4827955924d7c254fc9f4d91a6aad6b097804b1018b"
|
||||
|
||||
[[package]]
|
||||
name = "crc32fast"
|
||||
version = "1.2.1"
|
||||
|
@ -98,6 +122,16 @@ dependencies = [
|
|||
"lazy_static",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "deflate"
|
||||
version = "0.8.6"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "73770f8e1fe7d64df17ca66ad28994a0a623ea497fa69486e14984e715c5d174"
|
||||
dependencies = [
|
||||
"adler32",
|
||||
"byteorder",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "either"
|
||||
version = "1.6.1"
|
||||
|
@ -126,7 +160,7 @@ dependencies = [
|
|||
"cfg-if",
|
||||
"crc32fast",
|
||||
"libc",
|
||||
"miniz_oxide",
|
||||
"miniz_oxide 0.4.4",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
|
@ -140,6 +174,16 @@ dependencies = [
|
|||
"wasi",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "gif"
|
||||
version = "0.11.2"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "5a668f699973d0f573d15749b7002a9ac9e1f9c6b220e7b165601334c173d8de"
|
||||
dependencies = [
|
||||
"color_quant",
|
||||
"weezl",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "hermit-abi"
|
||||
version = "0.1.18"
|
||||
|
@ -155,6 +199,49 @@ version = "2.1.0"
|
|||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "9a3a5bfb195931eeb336b2a7b4d761daec841b97f947d34394601737a7bba5e4"
|
||||
|
||||
[[package]]
|
||||
name = "image"
|
||||
version = "0.23.14"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "24ffcb7e7244a9bf19d35bf2883b9c080c4ced3c07a9895572178cdb8f13f6a1"
|
||||
dependencies = [
|
||||
"bytemuck",
|
||||
"byteorder",
|
||||
"color_quant",
|
||||
"gif",
|
||||
"jpeg-decoder",
|
||||
"num-iter",
|
||||
"num-rational",
|
||||
"num-traits",
|
||||
"png",
|
||||
"scoped_threadpool",
|
||||
"tiff",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "itertools"
|
||||
version = "0.10.0"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "37d572918e350e82412fe766d24b15e6682fb2ed2bbe018280caa810397cb319"
|
||||
dependencies = [
|
||||
"either",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "itoa"
|
||||
version = "0.4.7"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "dd25036021b0de88a0aff6b850051563c6516d0bf53f8638938edbb9de732736"
|
||||
|
||||
[[package]]
|
||||
name = "jpeg-decoder"
|
||||
version = "0.1.22"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "229d53d58899083193af11e15917b5640cd40b29ff475a1fe4ef725deb02d0f2"
|
||||
dependencies = [
|
||||
"rayon",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "lazy_static"
|
||||
version = "1.4.0"
|
||||
|
@ -167,6 +254,12 @@ version = "0.2.91"
|
|||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "8916b1f6ca17130ec6568feccee27c156ad12037880833a3b842a823236502e7"
|
||||
|
||||
[[package]]
|
||||
name = "libm"
|
||||
version = "0.2.1"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "c7d73b3f436185384286bd8098d17ec07c9a7d2388a6599f824d8502b529702a"
|
||||
|
||||
[[package]]
|
||||
name = "log"
|
||||
version = "0.4.14"
|
||||
|
@ -191,6 +284,15 @@ dependencies = [
|
|||
"autocfg",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "miniz_oxide"
|
||||
version = "0.3.7"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "791daaae1ed6889560f8c4359194f56648355540573244a5448a83ba1ecc7435"
|
||||
dependencies = [
|
||||
"adler32",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "miniz_oxide"
|
||||
version = "0.4.4"
|
||||
|
@ -208,9 +310,56 @@ dependencies = [
|
|||
"byteorder",
|
||||
"env_logger",
|
||||
"flate2",
|
||||
"image",
|
||||
"itertools",
|
||||
"log",
|
||||
"rand",
|
||||
"rand_distr",
|
||||
"rayon",
|
||||
"serde",
|
||||
"serde_json",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "num-integer"
|
||||
version = "0.1.44"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "d2cc698a63b549a70bc047073d2949cce27cd1c7b0a4a862d08a8031bc2801db"
|
||||
dependencies = [
|
||||
"autocfg",
|
||||
"num-traits",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "num-iter"
|
||||
version = "0.1.42"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "b2021c8337a54d21aca0d59a92577a029af9431cb59b909b03252b9c164fad59"
|
||||
dependencies = [
|
||||
"autocfg",
|
||||
"num-integer",
|
||||
"num-traits",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "num-rational"
|
||||
version = "0.3.2"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "12ac428b1cb17fce6f731001d307d351ec70a6d202fc2e60f7d4c5e42d8f4f07"
|
||||
dependencies = [
|
||||
"autocfg",
|
||||
"num-integer",
|
||||
"num-traits",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "num-traits"
|
||||
version = "0.2.14"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "9a64b1ec5cda2586e284722486d802acf1f7dbdc623e2bfc57e65ca1cd099290"
|
||||
dependencies = [
|
||||
"autocfg",
|
||||
"libm",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
|
@ -223,12 +372,42 @@ dependencies = [
|
|||
"libc",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "png"
|
||||
version = "0.16.8"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "3c3287920cb847dee3de33d301c463fba14dda99db24214ddf93f83d3021f4c6"
|
||||
dependencies = [
|
||||
"bitflags",
|
||||
"crc32fast",
|
||||
"deflate",
|
||||
"miniz_oxide 0.3.7",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "ppv-lite86"
|
||||
version = "0.2.10"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "ac74c624d6b2d21f425f752262f42188365d7b8ff1aff74c82e45136510a4857"
|
||||
|
||||
[[package]]
|
||||
name = "proc-macro2"
|
||||
version = "1.0.24"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "1e0704ee1a7e00d7bb417d0770ea303c1bccbabf0ef1667dae92b5967f5f8a71"
|
||||
dependencies = [
|
||||
"unicode-xid",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "quote"
|
||||
version = "1.0.9"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "c3d0b9745dc2debf507c8422de05d7226cc1f0644216dfdfead988f9b1ab32a7"
|
||||
dependencies = [
|
||||
"proc-macro2",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "rand"
|
||||
version = "0.8.3"
|
||||
|
@ -260,6 +439,16 @@ dependencies = [
|
|||
"getrandom",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "rand_distr"
|
||||
version = "0.4.0"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "da9e8f32ad24fb80d07d2323a9a2ce8b30d68a62b8cb4df88119ff49a698f038"
|
||||
dependencies = [
|
||||
"num-traits",
|
||||
"rand",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "rand_hc"
|
||||
version = "0.3.0"
|
||||
|
@ -311,12 +500,66 @@ version = "0.6.23"
|
|||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "24d5f089152e60f62d28b835fbff2cd2e8dc0baf1ac13343bef92ab7eed84548"
|
||||
|
||||
[[package]]
|
||||
name = "ryu"
|
||||
version = "1.0.5"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "71d301d4193d031abdd79ff7e3dd721168a9572ef3fe51a1517aba235bd8f86e"
|
||||
|
||||
[[package]]
|
||||
name = "scoped_threadpool"
|
||||
version = "0.1.9"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "1d51f5df5af43ab3f1360b429fa5e0152ac5ce8c0bd6485cae490332e96846a8"
|
||||
|
||||
[[package]]
|
||||
name = "scopeguard"
|
||||
version = "1.1.0"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "d29ab0c6d3fc0ee92fe66e2d99f700eab17a8d57d1c1d3b748380fb20baa78cd"
|
||||
|
||||
[[package]]
|
||||
name = "serde"
|
||||
version = "1.0.125"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "558dc50e1a5a5fa7112ca2ce4effcb321b0300c0d4ccf0776a9f60cd89031171"
|
||||
dependencies = [
|
||||
"serde_derive",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "serde_derive"
|
||||
version = "1.0.125"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "b093b7a2bb58203b5da3056c05b4ec1fed827dcfdb37347a8841695263b3d06d"
|
||||
dependencies = [
|
||||
"proc-macro2",
|
||||
"quote",
|
||||
"syn",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "serde_json"
|
||||
version = "1.0.64"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "799e97dc9fdae36a5c8b8f2cae9ce2ee9fdce2058c57a93e6099d919fd982f79"
|
||||
dependencies = [
|
||||
"itoa",
|
||||
"ryu",
|
||||
"serde",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "syn"
|
||||
version = "1.0.67"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "6498a9efc342871f91cc2d0d694c674368b4ceb40f62b65a7a08c3792935e702"
|
||||
dependencies = [
|
||||
"proc-macro2",
|
||||
"quote",
|
||||
"unicode-xid",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "termcolor"
|
||||
version = "1.1.2"
|
||||
|
@ -326,12 +569,35 @@ dependencies = [
|
|||
"winapi-util",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "tiff"
|
||||
version = "0.6.1"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "9a53f4706d65497df0c4349241deddf35f84cee19c87ed86ea8ca590f4464437"
|
||||
dependencies = [
|
||||
"jpeg-decoder",
|
||||
"miniz_oxide 0.4.4",
|
||||
"weezl",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "unicode-xid"
|
||||
version = "0.2.1"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "f7fe0bb3479651439c9112f72b6c505038574c9fbb575ed1bf3b797fa39dd564"
|
||||
|
||||
[[package]]
|
||||
name = "wasi"
|
||||
version = "0.10.2+wasi-snapshot-preview1"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "fd6fbd9a79829dd1ad0cc20627bf1ed606756a7f77edff7b66b7064f9cb327c6"
|
||||
|
||||
[[package]]
|
||||
name = "weezl"
|
||||
version = "0.1.4"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "4a32b378380f4e9869b22f0b5177c68a5519f03b3454fde0b291455ddbae266c"
|
||||
|
||||
[[package]]
|
||||
name = "winapi"
|
||||
version = "0.3.9"
|
||||
|
|
|
@ -12,4 +12,9 @@ flate2 = "1"
|
|||
log = "0.4.0"
|
||||
env_logger = "0.8.3"
|
||||
rand = "0.8.3"
|
||||
rand_distr = "0.4.0"
|
||||
rayon = "1"
|
||||
itertools = "0.10.0"
|
||||
serde = { version = "1", features = ["derive"] }
|
||||
serde_json = "1"
|
||||
image = "0.23"
|
||||
|
|
229
src/main.rs
229
src/main.rs
|
@ -1,30 +1,35 @@
|
|||
use std::time;
|
||||
use rand::Rng;
|
||||
use rayon::prelude::*;
|
||||
use itertools::Itertools;
|
||||
use crate::types as t;
|
||||
use crate::types::{Data, Input, Output};
|
||||
use rayon::prelude::*;
|
||||
use serde::{Serialize, Deserialize};
|
||||
|
||||
pub mod types;
|
||||
pub mod maths;
|
||||
mod mnist;
|
||||
|
||||
use maths::{Vector, Matrix, Shape};
|
||||
|
||||
#[derive(Serialize,Deserialize)]
|
||||
struct JSONBiases(Vec<Vec<Vec<f32>>>);
|
||||
|
||||
#[derive(Serialize,Deserialize)]
|
||||
struct JSONWeights(Vec<Vec<Vec<f32>>>);
|
||||
|
||||
struct Network {
|
||||
sizes: Vec<usize>,
|
||||
biases: Vec<Vec<f32>>,
|
||||
weights: Vec<Vec<Vec<f32>>>,
|
||||
biases: Vec<Vector>,
|
||||
weights: Vec<Matrix>,
|
||||
}
|
||||
|
||||
impl Network {
|
||||
fn new(sizes: &Vec<usize>) -> Self {
|
||||
let mut rng = rand::thread_rng();
|
||||
|
||||
let biases = sizes[1..].iter().map(|s| {
|
||||
(0..s.clone()).map(|_| rng.gen::<f32>()).collect()
|
||||
}).collect();
|
||||
|
||||
let biases = sizes[1..].iter().map(|s| Vector::random(*s, &mut rng)).collect();
|
||||
let weights = sizes.iter().zip(sizes[1..].iter()).map(|(from, to)| {
|
||||
(0..to.clone()).map(|_| {
|
||||
(0..from.clone()).map(|_| rng.gen::<f32>()).collect()
|
||||
}).collect()
|
||||
Matrix::random(Shape::new(*to, *from), &mut rng)
|
||||
}).collect();
|
||||
|
||||
Self {
|
||||
|
@ -33,18 +38,174 @@ impl Network {
|
|||
}
|
||||
}
|
||||
|
||||
fn dump(&self) {
|
||||
for (i, (b, w)) in self.biases.iter().zip(self.weights.iter()).enumerate() {
|
||||
log::info!("l {}: biases: {:?}", i, b);
|
||||
log::info!("l {}: weights:", i);
|
||||
for row in w.rows.iter() {
|
||||
log::info!(" {:?}", row);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
fn feedforward<D: t::Data>(&self, i: &D::Input) -> D::Output {
|
||||
let a = i.data();
|
||||
let o = self.biases.iter()
|
||||
.zip(self.weights.iter())
|
||||
.fold(a.clone(), |a, (biases, weights)| {
|
||||
mat_vec_mult(weights, &a).iter()
|
||||
.zip(biases.iter())
|
||||
.map(|(aa, b)| {
|
||||
sigmoid(aa+b)
|
||||
}).collect()
|
||||
});
|
||||
D::Output::from_nn_output(o)
|
||||
let mut a = i.data().clone();
|
||||
for (b, w) in self.biases.iter().zip(self.weights.iter()) {
|
||||
//log::info!("before: {:?}", a);
|
||||
let mut z = w.mult_vec(&a);
|
||||
z.add_mut(&b);
|
||||
for e in z.iter_mut() {
|
||||
*e = sigmoid(*e);
|
||||
}
|
||||
a = z;
|
||||
//log::info!("after: {:?}", a);
|
||||
}
|
||||
D::Output::from_nn_output(a)
|
||||
}
|
||||
|
||||
fn sgd<D: t::Data>(
|
||||
&mut self,
|
||||
data: &D,
|
||||
epochs: usize,
|
||||
mini_batch_size: usize,
|
||||
eta: f32,
|
||||
test_data: Option<&D>,
|
||||
) {
|
||||
for j in 0usize..epochs {
|
||||
//self.dump();
|
||||
let shuffled = data.shuffle();
|
||||
let batches = shuffled.iter().chunks(mini_batch_size);
|
||||
for batch in &batches {
|
||||
self.update_batch::<D, _>(batch, eta);
|
||||
}
|
||||
if let Some(td) = test_data {
|
||||
let n_test = td.size();
|
||||
log::info!("Epoch {}: {} / {}", j, self.evaluate(td), n_test);
|
||||
} else {
|
||||
log::info!("Epoch {} complete", j);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
fn update_batch<'a, D, B>(
|
||||
&mut self,
|
||||
batch: B,
|
||||
eta: f32,
|
||||
) where
|
||||
D: t::Data,
|
||||
D::Input: 'a + Sync,
|
||||
D::Output: 'a + Sync,
|
||||
B: std::iter::Iterator<Item = &'a (D::Input, D::Output)>,
|
||||
{
|
||||
let mut nabla_b: Vec<Vector> = self.biases.iter().map(|el| {
|
||||
Vector::zeroes(el.len())
|
||||
}).collect();
|
||||
let mut nabla_w: Vec<Matrix> = self.weights.iter().map(|el| {
|
||||
Matrix::zeroes(*el.shape())
|
||||
}).collect();
|
||||
|
||||
let batch: Vec<(D::Input, D::Output)> = batch.cloned().collect();
|
||||
let par = batch.par_iter().map(|(input, output)| {
|
||||
self.backprop::<D>(input, output)
|
||||
});
|
||||
|
||||
let mut batch_size = 0usize;
|
||||
for (delta_nabla_b, delta_nabla_w) in par.collect::<Vec<_>>() {
|
||||
for (nb, dnb) in nabla_b.iter_mut().zip(delta_nabla_b.into_iter()) {
|
||||
nb.add_mut(&dnb)
|
||||
}
|
||||
for (nw, dnw) in nabla_w.iter_mut().zip(delta_nabla_w.into_iter()) {
|
||||
nw.add_mut(&dnw)
|
||||
}
|
||||
batch_size += 1;
|
||||
}
|
||||
|
||||
for (w, mut nw) in self.weights.iter_mut().zip(nabla_w.into_iter()) {
|
||||
nw.mult_scalar_mut(eta / (batch_size as f32));
|
||||
w.sub_mut(&nw)
|
||||
}
|
||||
|
||||
for (b, mut nb) in self.biases.iter_mut().zip(nabla_b.into_iter()) {
|
||||
nb.mult_scalar_mut(eta / (batch_size as f32));
|
||||
b.sub_mut(&nb)
|
||||
}
|
||||
}
|
||||
|
||||
fn evaluate<D: t::Data>(
|
||||
&self,
|
||||
test_data: &D,
|
||||
) -> usize {
|
||||
let mut count: usize = 0;
|
||||
for (i, (input, expected)) in test_data.iter().enumerate() {
|
||||
let got = self.feedforward::<D>(input);
|
||||
if got.onehot_decode() == expected.onehot_decode() {
|
||||
count += 1;
|
||||
}
|
||||
}
|
||||
count
|
||||
}
|
||||
|
||||
fn backprop<D: t::Data>(
|
||||
&self,
|
||||
x: &D::Input,
|
||||
y: &D::Output,
|
||||
) -> (Vec<Vector>, Vec<Matrix>) {
|
||||
let mut nabla_b: Vec<Vector> = vec![];
|
||||
let mut nabla_w: Vec<Matrix> = vec![];
|
||||
|
||||
|
||||
let mut activations: Vec<Vector> = vec![x.data().clone()];
|
||||
let mut activation = &activations[0];
|
||||
let mut zs: Vec<Vector> = vec![];
|
||||
|
||||
for (b, w) in self.biases.iter().zip(self.weights.iter()) {
|
||||
let mut z = w.mult_vec(activation);
|
||||
z.add_mut(b);
|
||||
|
||||
zs.push(z.clone());
|
||||
for e in z.iter_mut() {
|
||||
*e = sigmoid(*e);
|
||||
}
|
||||
activations.push(z);
|
||||
activation = activations.last().unwrap();
|
||||
}
|
||||
|
||||
activations.reverse();
|
||||
zs.reverse();
|
||||
|
||||
let sp: Vector = zs[0].iter().map(|el| sigmoid_prime(*el)).collect();
|
||||
let mut delta: Vector = self.cost_derivative::<D>(&activations[0], y);
|
||||
delta.mult_mut(&sp);
|
||||
|
||||
nabla_b.push(delta.clone());
|
||||
nabla_w.push(
|
||||
Matrix::from_column(&delta).mult(&Matrix::from_row(activations[1].clone()))
|
||||
);
|
||||
|
||||
for l in 1..(self.sizes.len()-1) {
|
||||
let z = &zs[l];
|
||||
let sp: Vector = z.iter().map(|el| sigmoid_prime(*el)).collect();
|
||||
|
||||
let weights = &self.weights[self.weights.len()-l];
|
||||
let weights = weights.transpose();
|
||||
delta = weights.mult_vec(&delta).mult(&sp);
|
||||
nabla_b.push(delta.clone());
|
||||
nabla_w.push(
|
||||
Matrix::from_column(&delta).mult(&Matrix::from_row(activations[l+1].clone()))
|
||||
);
|
||||
}
|
||||
|
||||
nabla_b.reverse();
|
||||
nabla_w.reverse();
|
||||
(nabla_b, nabla_w)
|
||||
}
|
||||
|
||||
fn cost_derivative<D: t::Data>(
|
||||
&self,
|
||||
output_activations: &Vector,
|
||||
y: &D::Output,
|
||||
) -> Vector {
|
||||
output_activations.sub(y.data())
|
||||
}
|
||||
}
|
||||
|
||||
|
@ -52,13 +213,8 @@ fn sigmoid(z: f32) -> f32 {
|
|||
1. / (1. + (-z).exp())
|
||||
}
|
||||
|
||||
fn mat_vec_mult(m: &Vec<Vec<f32>>, v: &Vec<f32>) -> Vec<f32> {
|
||||
m.iter().enumerate().map(|(i, row)| {
|
||||
if row.len() != v.len() {
|
||||
panic!("Mismatch: V {}, M {}x{}", v.len(), m.len(), row.len());
|
||||
}
|
||||
row.iter().zip(v.iter()).map(|(rr, vv)| rr * vv).sum()
|
||||
}).collect()
|
||||
fn sigmoid_prime(z: f32) -> f32 {
|
||||
sigmoid(z) * (1. - sigmoid(z))
|
||||
}
|
||||
|
||||
fn main() {
|
||||
|
@ -66,8 +222,9 @@ fn main() {
|
|||
|
||||
let now = time::Instant::now();
|
||||
let train = mnist::load("train").unwrap();
|
||||
let test = mnist::load("t10k").unwrap();
|
||||
let elapsed = now.elapsed();
|
||||
log::info!("Loaded {} images in {:?}.", train.len(), elapsed);
|
||||
log::info!("Loaded {} training / {} test images in {:?}.", train.len(), test.len(), elapsed);
|
||||
|
||||
let num_train = (train.len() as f64 * 5.0/6.0) as usize;
|
||||
let num_validation = train.len() - num_train;
|
||||
|
@ -77,16 +234,12 @@ fn main() {
|
|||
let now = time::Instant::now();
|
||||
let size = vec![
|
||||
28*28 as usize,
|
||||
15usize,
|
||||
30usize,
|
||||
10usize,
|
||||
];
|
||||
let network = Network::new(&size);
|
||||
let mut net = Network::new(&size);
|
||||
let elapsed = now.elapsed();
|
||||
log::info!("Created random network {:?} in {:?}.", size, elapsed);
|
||||
|
||||
let now = time::Instant::now();
|
||||
let data = train.iter().next().unwrap();
|
||||
let res = network.feedforward::<mnist::Data>(&data.0);
|
||||
let elapsed = now.elapsed();
|
||||
log::info!("feedforward: {:?} in {:?}", res, elapsed);
|
||||
|
||||
net.sgd(&train, 90, 10, 1.0, Some(&test));
|
||||
}
|
||||
|
|
|
@ -0,0 +1,294 @@
|
|||
use rand_distr::StandardNormal;
|
||||
|
||||
#[derive(Clone,Debug)]
|
||||
pub struct Vector(pub Vec<f32>);
|
||||
|
||||
impl Vector {
|
||||
pub fn len(&self) -> usize {
|
||||
self.0.len()
|
||||
}
|
||||
pub fn zeroes(size: usize) -> Self {
|
||||
Vector(vec![0.0f32; size])
|
||||
}
|
||||
pub fn random<R: rand::Rng>(size: usize, rng: &mut R) -> Self {
|
||||
Vector((0..size).map(|_| rng.sample(StandardNormal)).collect())
|
||||
}
|
||||
pub fn dot(&self, o: &Self) -> f32 {
|
||||
if self.len() != o.len() {
|
||||
panic!("Shape mismatch: {} and {}", self.len(), o.len())
|
||||
}
|
||||
self.0.iter().zip(o.0.iter()).map(|(a, b)| a*b).sum()
|
||||
}
|
||||
pub fn add_mut(&mut self, o: &Self) {
|
||||
if self.len() != o.len() {
|
||||
panic!("Shape mismatch: {} and {}", self.len(), o.len())
|
||||
}
|
||||
for (a, b) in self.0.iter_mut().zip(o.0.iter()) {
|
||||
*a += b;
|
||||
}
|
||||
}
|
||||
pub fn add(&self, o: &Self) -> Self {
|
||||
let mut res = self.clone();
|
||||
res.add_mut(o);
|
||||
res
|
||||
}
|
||||
pub fn sub_mut(&mut self, o: &Self) {
|
||||
if self.len() != o.len() {
|
||||
panic!("Shape mismatch: {} and {}", self.len(), o.len())
|
||||
}
|
||||
for (a, b) in self.0.iter_mut().zip(o.0.iter()) {
|
||||
*a -= b;
|
||||
}
|
||||
}
|
||||
pub fn sub(&self, o: &Self) -> Self {
|
||||
let mut res = self.clone();
|
||||
res.sub_mut(o);
|
||||
res
|
||||
}
|
||||
pub fn mult_mut(&mut self, o: &Self) {
|
||||
if self.len() != o.len() {
|
||||
panic!("Shape mismatch: {} and {}", self.len(), o.len())
|
||||
}
|
||||
for (a, b) in self.0.iter_mut().zip(o.0.iter()) {
|
||||
*a *= b;
|
||||
}
|
||||
}
|
||||
pub fn mult(&self, o: &Self) -> Self {
|
||||
let mut res = self.clone();
|
||||
res.mult_mut(o);
|
||||
res
|
||||
}
|
||||
pub fn mult_scalar_mut(&mut self, val: f32) {
|
||||
for e in self.0.iter_mut() {
|
||||
*e *= val;
|
||||
}
|
||||
}
|
||||
pub fn iter_mut(&mut self) -> std::slice::IterMut<f32> {
|
||||
self.0.iter_mut()
|
||||
}
|
||||
pub fn iter(&self) -> std::slice::Iter<f32> {
|
||||
self.0.iter()
|
||||
}
|
||||
}
|
||||
|
||||
impl std::iter::FromIterator<f32> for Vector {
|
||||
fn from_iter<I: std::iter::IntoIterator<Item=f32>>(iter: I) -> Self {
|
||||
Self(Vec::from_iter(iter))
|
||||
}
|
||||
}
|
||||
|
||||
impl std::iter::IntoIterator for Vector {
|
||||
type Item = f32;
|
||||
type IntoIter = std::vec::IntoIter<Self::Item>;
|
||||
fn into_iter(self) -> Self::IntoIter {
|
||||
self.0.into_iter()
|
||||
}
|
||||
}
|
||||
|
||||
impl std::ops::Index<usize> for Vector {
|
||||
type Output = f32;
|
||||
fn index(&self, index: usize) -> &Self::Output {
|
||||
self.0.index(index)
|
||||
}
|
||||
}
|
||||
|
||||
impl std::ops::IndexMut<usize> for Vector {
|
||||
fn index_mut(&mut self, index: usize) -> &mut Self::Output {
|
||||
self.0.index_mut(index)
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Clone,Copy,Debug,PartialEq,Eq)]
|
||||
pub struct Shape(usize, usize);
|
||||
|
||||
impl std::fmt::Display for Shape {
|
||||
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
|
||||
write!(f, "{}x{}", self.rows(), self.columns())
|
||||
}
|
||||
}
|
||||
|
||||
impl Shape {
|
||||
pub fn new(rows: usize, columns: usize) -> Self {
|
||||
Self(rows, columns)
|
||||
}
|
||||
pub fn rows(&self) -> usize {
|
||||
self.0
|
||||
}
|
||||
pub fn columns(&self) -> usize {
|
||||
self.1
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Clone,Debug)]
|
||||
pub struct Matrix {
|
||||
shape: Shape,
|
||||
pub rows: Vec<Vector>,
|
||||
}
|
||||
|
||||
impl Matrix {
|
||||
pub fn zeroes(shape: Shape) -> Self {
|
||||
Self {
|
||||
shape,
|
||||
rows: vec![Vector::zeroes(shape.columns()); shape.rows()],
|
||||
}
|
||||
}
|
||||
|
||||
pub fn new(rows: Vec<Vector>) -> Self {
|
||||
let nrows = rows.len();
|
||||
if nrows < 1 {
|
||||
panic!("No rows given");
|
||||
}
|
||||
let ncolumns = rows[0].len();
|
||||
if !rows.iter().all(|r| r.len() == ncolumns) {
|
||||
panic!("Given vectors are not the same length")
|
||||
}
|
||||
Self {
|
||||
shape: Shape::new(nrows, ncolumns),
|
||||
rows,
|
||||
}
|
||||
}
|
||||
|
||||
pub fn random<R: rand::Rng>(shape: Shape, rng: &mut R) -> Self {
|
||||
Self {
|
||||
shape,
|
||||
rows: vec![Vector::random(shape.columns(), rng); shape.rows()],
|
||||
}
|
||||
}
|
||||
|
||||
pub fn from_row(row: Vector) -> Self {
|
||||
Self {
|
||||
shape: Shape::new(1, row.len()),
|
||||
rows: vec![row],
|
||||
}
|
||||
}
|
||||
|
||||
pub fn from_column(column: &Vector) -> Self {
|
||||
let rows = column.iter().map(|el| Vector(vec![*el])).collect();
|
||||
Self {
|
||||
shape: Shape::new(column.len(), 1),
|
||||
rows,
|
||||
}
|
||||
}
|
||||
|
||||
pub fn shape(&self) -> &Shape {
|
||||
&self.shape
|
||||
}
|
||||
|
||||
pub fn mult_vec(&self, v: &Vector) -> Vector {
|
||||
if v.len() != self.shape().columns() {
|
||||
panic!("Shape mismatch: V {}, M {}", v.len(), self.shape())
|
||||
}
|
||||
self.rows.iter().map(|row| row.dot(v)).collect()
|
||||
}
|
||||
|
||||
pub fn column(&self, n: usize) -> Vector {
|
||||
if n >= self.shape.columns() {
|
||||
panic!("Out of bounds: want column {}, have {}", n, self.shape.columns())
|
||||
}
|
||||
|
||||
self.rows.iter().map(|row| row[n]).collect()
|
||||
}
|
||||
|
||||
pub fn mult(&self, o: &Matrix) -> Matrix {
|
||||
if self.shape().columns() != o.shape().rows() {
|
||||
panic!("Shape mismatch: self {}, other {}", self.shape(), o.shape())
|
||||
}
|
||||
|
||||
let rows = self.rows.iter().map(|row_a| {
|
||||
(0..o.shape().columns()).map(|n| {
|
||||
row_a.iter().enumerate().map(|(m, v)| {
|
||||
let v2 = o.rows[m][n];
|
||||
v * v2
|
||||
}).sum()
|
||||
}).collect()
|
||||
}).collect();
|
||||
Self {
|
||||
shape: Shape(self.shape().rows(), o.shape.columns()),
|
||||
rows,
|
||||
}
|
||||
}
|
||||
|
||||
pub fn add(&self, o: &Self) -> Self {
|
||||
if self.shape() != o.shape() {
|
||||
panic!("Shape mismatch: self {}, other {}", self.shape(), o.shape())
|
||||
}
|
||||
Self {
|
||||
shape: self.shape().clone(),
|
||||
rows: self.rows.iter().zip(o.rows.iter()).map(|(a, b)| a.add(b)).collect(),
|
||||
}
|
||||
}
|
||||
|
||||
pub fn add_mut(&mut self, o: &Self) {
|
||||
if self.shape() != o.shape() {
|
||||
panic!("Shape mismatch: self {}, other {}", self.shape(), o.shape())
|
||||
}
|
||||
for (ra, rb) in self.rows.iter_mut().zip(o.rows.iter()) {
|
||||
ra.add_mut(rb);
|
||||
}
|
||||
}
|
||||
pub fn sub_mut(&mut self, o: &Self) {
|
||||
if self.shape() != o.shape() {
|
||||
panic!("Shape mismatch: self {}, other {}", self.shape(), o.shape())
|
||||
}
|
||||
for (ra, rb) in self.rows.iter_mut().zip(o.rows.iter()) {
|
||||
ra.sub_mut(rb);
|
||||
}
|
||||
}
|
||||
pub fn mult_scalar_mut(&mut self, val: f32) {
|
||||
for row in self.rows.iter_mut() {
|
||||
row.mult_scalar_mut(val);
|
||||
}
|
||||
}
|
||||
|
||||
pub fn transpose(&self) -> Self {
|
||||
let rows = (0..self.shape().columns()).map(|cnum| {
|
||||
self.column(cnum)
|
||||
}).collect();
|
||||
Self {
|
||||
shape: Shape::new(self.shape().columns(), self.shape().rows()),
|
||||
rows,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use super::*;
|
||||
#[test]
|
||||
fn test_column_column_mult() {
|
||||
let m1 = Matrix::from_column(&Vector(vec![1.,2.,3.]));
|
||||
let m2 = Matrix::from_row(Vector(vec![2.,3.,4.]));
|
||||
let res = m1.mult(&m2);
|
||||
assert_eq!(res.rows[0].0, vec![ 2.0, 3.0, 4.0 ]);
|
||||
assert_eq!(res.rows[1].0, vec![ 4.0, 6.0, 8.0 ]);
|
||||
assert_eq!(res.rows[2].0, vec![ 6.0, 9.0, 12.0 ]);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_matrix_matrix_mult() {
|
||||
let m1 = Matrix::new(vec![
|
||||
Vector(vec![ 0.0, 4.0, -2.0 ]),
|
||||
Vector(vec![-4.0, -3.0, 0.0 ]),
|
||||
]);
|
||||
let m2 = Matrix::new(vec![
|
||||
Vector(vec![ 0.0, 1.0 ]),
|
||||
Vector(vec![ 1.0, -1.0 ]),
|
||||
Vector(vec![ 2.0, 3.0 ]),
|
||||
]);
|
||||
|
||||
let res = m1.mult(&m2);
|
||||
assert_eq!(res.rows[0].0, vec![ 0.0, -10.0 ]);
|
||||
assert_eq!(res.rows[1].0, vec![ -3.0, -1.0 ]);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_matrix_vector_mult() {
|
||||
let m1 = Matrix::new(vec![
|
||||
Vector(vec![ 1.0, -1.0, 2.0 ]),
|
||||
Vector(vec![ 0.0, -3.0, 1.0 ]),
|
||||
]);
|
||||
let v2 = Vector(vec![2.0, 1.0, 0.0]);
|
||||
let res = m1.mult_vec(&v2);
|
||||
assert_eq!(res.0, vec![1.0, -3.0]);
|
||||
}
|
||||
}
|
38
src/mnist.rs
38
src/mnist.rs
|
@ -5,12 +5,14 @@ use byteorder::{
|
|||
};
|
||||
use flate2::read::GzDecoder;
|
||||
use crate::types as t;
|
||||
use crate::maths::Vector;
|
||||
|
||||
use image::{RgbImage, Rgb};
|
||||
|
||||
#[derive(Debug)]
|
||||
pub enum Error {
|
||||
IO(io::Error),
|
||||
InvalidMagic,
|
||||
Parse(String),
|
||||
}
|
||||
|
||||
impl From<io::Error> for Error {
|
||||
|
@ -21,8 +23,9 @@ impl From<io::Error> for Error {
|
|||
|
||||
pub type Result<T> = std::result::Result<T, Error>;
|
||||
|
||||
#[derive(Debug, Clone)]
|
||||
pub struct Image {
|
||||
pixels: Vec<f32>,
|
||||
pixels: Vector,
|
||||
width: usize,
|
||||
height: usize,
|
||||
}
|
||||
|
@ -35,19 +38,32 @@ impl Image {
|
|||
) -> Result<Self> {
|
||||
let npixels: usize = width * height;
|
||||
let mut pixels: Vec<u8> = vec![0u8; npixels];
|
||||
rdr.read(&mut pixels)?;
|
||||
let pixels: Vec<f32> = pixels.into_iter().map(|u| (u as f32) / 255.0).collect();
|
||||
rdr.read_exact(&mut pixels)?;
|
||||
let pixels: Vector = pixels.into_iter().map(|u| (u as f32) / 255.0).collect();
|
||||
Ok(Image {
|
||||
pixels, width, height,
|
||||
})
|
||||
}
|
||||
|
||||
pub fn to_rgb(
|
||||
&self,
|
||||
) -> RgbImage {
|
||||
let mut image = RgbImage::new(self.width as u32, self.height as u32);
|
||||
for x in 0..self.width {
|
||||
for y in 0..self.height {
|
||||
let val = (self.pixels[y*self.height+x] * 255.0) as u8;
|
||||
image.put_pixel(x as u32, y as u32, Rgb([val; 3]));
|
||||
}
|
||||
}
|
||||
image
|
||||
}
|
||||
}
|
||||
|
||||
impl t::Input for Image {
|
||||
fn size(&self) -> usize {
|
||||
self.width * self.height
|
||||
}
|
||||
fn data(&self) -> &Vec<f32> {
|
||||
fn data(&self) -> &Vector {
|
||||
&self.pixels
|
||||
}
|
||||
}
|
||||
|
@ -77,12 +93,12 @@ impl ImageFile {
|
|||
}
|
||||
}
|
||||
|
||||
#[derive(Debug)]
|
||||
pub struct RecognizedDigit(Vec<f32>);
|
||||
#[derive(Debug, Clone)]
|
||||
pub struct RecognizedDigit(Vector);
|
||||
|
||||
impl RecognizedDigit {
|
||||
pub fn parse(label: usize, max: usize) -> Self {
|
||||
let mut data: Vec<f32> = vec![0f32; max];
|
||||
let mut data = Vector::zeroes(max);
|
||||
data[label] = 1.0f32;
|
||||
Self(data)
|
||||
}
|
||||
|
@ -92,10 +108,10 @@ impl t::Output for RecognizedDigit {
|
|||
fn size(&self) -> usize {
|
||||
self.0.len()
|
||||
}
|
||||
fn data(&self) -> &Vec<f32> {
|
||||
fn data(&self) -> &Vector {
|
||||
&self.0
|
||||
}
|
||||
fn from_nn_output(data: Vec<f32>) -> Self {
|
||||
fn from_nn_output(data: Vector) -> Self {
|
||||
if data.len() != 10 {
|
||||
panic!("invalid nn output size (got {}, wanted {})", data.len(), 10)
|
||||
}
|
||||
|
@ -115,7 +131,7 @@ impl LabelFile {
|
|||
}
|
||||
let num_labels = rdr.read_u32::<BigEndian>()?;
|
||||
let mut labels: Vec<u8> = vec![0u8; num_labels as usize];
|
||||
rdr.read(&mut labels)?;
|
||||
rdr.read_exact(&mut labels)?;
|
||||
Ok(LabelFile {
|
||||
labels: labels.into_iter().map(|l| RecognizedDigit::parse(l as usize, 10usize)).collect()
|
||||
})
|
||||
|
|
63
src/types.rs
63
src/types.rs
|
@ -1,12 +1,33 @@
|
|||
pub trait Input {
|
||||
use rand::thread_rng;
|
||||
use rand::seq::SliceRandom;
|
||||
|
||||
use crate::maths::Vector;
|
||||
|
||||
pub trait Input: std::fmt::Debug + Clone + Sync {
|
||||
fn size(&self) -> usize;
|
||||
fn data(&self) -> &Vec<f32>;
|
||||
fn data(&self) -> &Vector;
|
||||
}
|
||||
|
||||
pub trait Output {
|
||||
pub trait Output: std::fmt::Debug + Clone + Sync {
|
||||
fn size(&self) -> usize;
|
||||
fn data(&self) -> &Vec<f32>;
|
||||
fn from_nn_output(data: Vec<f32>) -> Self;
|
||||
fn data(&self) -> &Vector;
|
||||
fn from_nn_output(data: Vector) -> Self;
|
||||
fn onehot_decode(&self) -> usize {
|
||||
let mut cur: Option<(f32, usize)> = None;
|
||||
for (i, &v) in self.data().iter().enumerate() {
|
||||
match cur {
|
||||
None => {
|
||||
cur = Some((v, i));
|
||||
},
|
||||
Some((mv, _)) => {
|
||||
if v > mv {
|
||||
cur = Some((v, i));
|
||||
}
|
||||
},
|
||||
}
|
||||
}
|
||||
cur.unwrap().1
|
||||
}
|
||||
}
|
||||
|
||||
pub trait Data {
|
||||
|
@ -14,18 +35,32 @@ pub trait Data {
|
|||
type Output: Output;
|
||||
|
||||
fn iter<'a>(&'a self) -> Box<dyn std::iter::ExactSizeIterator<Item = &'a (Self::Input, Self::Output)> + 'a>;
|
||||
fn shuffle<'a> (&'a self) -> Box<dyn Data<Input = Self::Input, Output = Self::Output> + 'a>;
|
||||
fn size(&self) -> usize;
|
||||
}
|
||||
|
||||
pub struct InMemoryData<I: Input, O: Output> {
|
||||
inner: Vec<(I, O)>,
|
||||
}
|
||||
|
||||
pub struct InMemoryDataView<'a, I: Input, O: Output> {
|
||||
inner: Vec<&'a (I, O)>,
|
||||
}
|
||||
|
||||
impl <I: Input, O: Output> Data for InMemoryData<I, O> {
|
||||
type Input = I;
|
||||
type Output = O;
|
||||
fn iter<'a>(&'a self) -> Box<dyn std::iter::ExactSizeIterator<Item = &'a (I, O)> + 'a> {
|
||||
Box::new(self.inner.iter())
|
||||
}
|
||||
fn shuffle<'a>(&'a self) -> Box<dyn Data<Input = Self::Input, Output = Self::Output>+'a> {
|
||||
Box::new(InMemoryDataView {
|
||||
inner: self.inner.iter().collect(),
|
||||
})
|
||||
}
|
||||
fn size(&self) -> usize {
|
||||
self.inner.len()
|
||||
}
|
||||
}
|
||||
|
||||
impl <I: Input, O: Output> InMemoryData<I, O> {
|
||||
|
@ -46,3 +81,21 @@ impl <I: Input, O: Output> InMemoryData<I, O> {
|
|||
self.inner.len()
|
||||
}
|
||||
}
|
||||
|
||||
impl <'a, I: Input, O: Output> Data for InMemoryDataView<'a, I, O> {
|
||||
type Input = I;
|
||||
type Output = O;
|
||||
fn iter<'b>(&'b self) -> Box<dyn std::iter::ExactSizeIterator<Item = &'b (I, O)> + 'b> {
|
||||
let mut shuffled: Vec<&'b (I, O)> = self.inner.iter().map(|el| *el).collect();
|
||||
shuffled.shuffle(&mut thread_rng());
|
||||
Box::new(shuffled.into_iter())
|
||||
}
|
||||
fn shuffle<'b>(&'b self) -> Box<dyn Data<Input = Self::Input, Output = Self::Output>+'b> {
|
||||
Box::new(InMemoryDataView {
|
||||
inner: self.inner.iter().cloned().collect(),
|
||||
})
|
||||
}
|
||||
fn size(&self) -> usize {
|
||||
self.inner.len()
|
||||
}
|
||||
}
|
||||
|
|
Loading…
Reference in New Issue