feedforward

master
q3k 2021-03-28 15:06:51 +02:00
commit b3fc3c59f1
6 changed files with 583 additions and 0 deletions

3
.gitignore vendored Normal file
View File

@ -0,0 +1,3 @@
/target
*swp
*gz

364
Cargo.lock generated Normal file
View File

@ -0,0 +1,364 @@
# This file is automatically @generated by Cargo.
# It is not intended for manual editing.
[[package]]
name = "adler"
version = "1.0.2"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "f26201604c87b1e01bd3d98f8d5d9a8fcbb815e8cedb41ffccbeb4bf593a35fe"
[[package]]
name = "aho-corasick"
version = "0.7.15"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "7404febffaa47dac81aa44dba71523c9d069b1bdc50a77db41195149e17f68e5"
dependencies = [
"memchr",
]
[[package]]
name = "atty"
version = "0.2.14"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "d9b39be18770d11421cdb1b9947a45dd3f37e93092cbf377614828a319d5fee8"
dependencies = [
"hermit-abi",
"libc",
"winapi",
]
[[package]]
name = "autocfg"
version = "1.0.1"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "cdb031dd78e28731d87d56cc8ffef4a8f36ca26c38fe2de700543e627f8a464a"
[[package]]
name = "byteorder"
version = "1.4.3"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "14c189c53d098945499cdfa7ecc63567cf3886b3332b312a5b4585d8d3a6a610"
[[package]]
name = "cfg-if"
version = "1.0.0"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "baf1de4339761588bc0619e3cbc0120ee582ebb74b53b4efbf79117bd2da40fd"
[[package]]
name = "crc32fast"
version = "1.2.1"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "81156fece84ab6a9f2afdb109ce3ae577e42b1228441eded99bd77f627953b1a"
dependencies = [
"cfg-if",
]
[[package]]
name = "crossbeam-channel"
version = "0.5.0"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "dca26ee1f8d361640700bde38b2c37d8c22b3ce2d360e1fc1c74ea4b0aa7d775"
dependencies = [
"cfg-if",
"crossbeam-utils",
]
[[package]]
name = "crossbeam-deque"
version = "0.8.0"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "94af6efb46fef72616855b036a624cf27ba656ffc9be1b9a3c931cfc7749a9a9"
dependencies = [
"cfg-if",
"crossbeam-epoch",
"crossbeam-utils",
]
[[package]]
name = "crossbeam-epoch"
version = "0.9.3"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "2584f639eb95fea8c798496315b297cf81b9b58b6d30ab066a75455333cf4b12"
dependencies = [
"cfg-if",
"crossbeam-utils",
"lazy_static",
"memoffset",
"scopeguard",
]
[[package]]
name = "crossbeam-utils"
version = "0.8.3"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "e7e9d99fa91428effe99c5c6d4634cdeba32b8cf784fc428a2a687f61a952c49"
dependencies = [
"autocfg",
"cfg-if",
"lazy_static",
]
[[package]]
name = "either"
version = "1.6.1"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "e78d4f1cc4ae33bbfc157ed5d5a5ef3bc29227303d595861deb238fcec4e9457"
[[package]]
name = "env_logger"
version = "0.8.3"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "17392a012ea30ef05a610aa97dfb49496e71c9f676b27879922ea5bdf60d9d3f"
dependencies = [
"atty",
"humantime",
"log",
"regex",
"termcolor",
]
[[package]]
name = "flate2"
version = "1.0.20"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "cd3aec53de10fe96d7d8c565eb17f2c687bb5518a2ec453b5b1252964526abe0"
dependencies = [
"cfg-if",
"crc32fast",
"libc",
"miniz_oxide",
]
[[package]]
name = "getrandom"
version = "0.2.2"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "c9495705279e7140bf035dde1f6e750c162df8b625267cd52cc44e0b156732c8"
dependencies = [
"cfg-if",
"libc",
"wasi",
]
[[package]]
name = "hermit-abi"
version = "0.1.18"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "322f4de77956e22ed0e5032c359a0f1273f1f7f0d79bfa3b8ffbc730d7fbcc5c"
dependencies = [
"libc",
]
[[package]]
name = "humantime"
version = "2.1.0"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "9a3a5bfb195931eeb336b2a7b4d761daec841b97f947d34394601737a7bba5e4"
[[package]]
name = "lazy_static"
version = "1.4.0"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "e2abad23fbc42b3700f2f279844dc832adb2b2eb069b2df918f455c4e18cc646"
[[package]]
name = "libc"
version = "0.2.91"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "8916b1f6ca17130ec6568feccee27c156ad12037880833a3b842a823236502e7"
[[package]]
name = "log"
version = "0.4.14"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "51b9bbe6c47d51fc3e1a9b945965946b4c44142ab8792c50835a980d362c2710"
dependencies = [
"cfg-if",
]
[[package]]
name = "memchr"
version = "2.3.4"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "0ee1c47aaa256ecabcaea351eae4a9b01ef39ed810004e298d2511ed284b1525"
[[package]]
name = "memoffset"
version = "0.6.1"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "157b4208e3059a8f9e78d559edc658e13df41410cb3ae03979c83130067fdd87"
dependencies = [
"autocfg",
]
[[package]]
name = "miniz_oxide"
version = "0.4.4"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "a92518e98c078586bc6c934028adcca4c92a53d6a958196de835170a01d84e4b"
dependencies = [
"adler",
"autocfg",
]
[[package]]
name = "neural"
version = "0.1.0"
dependencies = [
"byteorder",
"env_logger",
"flate2",
"log",
"rand",
"rayon",
]
[[package]]
name = "num_cpus"
version = "1.13.0"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "05499f3756671c15885fee9034446956fff3f243d6077b91e5767df161f766b3"
dependencies = [
"hermit-abi",
"libc",
]
[[package]]
name = "ppv-lite86"
version = "0.2.10"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "ac74c624d6b2d21f425f752262f42188365d7b8ff1aff74c82e45136510a4857"
[[package]]
name = "rand"
version = "0.8.3"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "0ef9e7e66b4468674bfcb0c81af8b7fa0bb154fa9f28eb840da5c447baeb8d7e"
dependencies = [
"libc",
"rand_chacha",
"rand_core",
"rand_hc",
]
[[package]]
name = "rand_chacha"
version = "0.3.0"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "e12735cf05c9e10bf21534da50a147b924d555dc7a547c42e6bb2d5b6017ae0d"
dependencies = [
"ppv-lite86",
"rand_core",
]
[[package]]
name = "rand_core"
version = "0.6.2"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "34cf66eb183df1c5876e2dcf6b13d57340741e8dc255b48e40a26de954d06ae7"
dependencies = [
"getrandom",
]
[[package]]
name = "rand_hc"
version = "0.3.0"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "3190ef7066a446f2e7f42e239d161e905420ccab01eb967c9eb27d21b2322a73"
dependencies = [
"rand_core",
]
[[package]]
name = "rayon"
version = "1.5.0"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "8b0d8e0819fadc20c74ea8373106ead0600e3a67ef1fe8da56e39b9ae7275674"
dependencies = [
"autocfg",
"crossbeam-deque",
"either",
"rayon-core",
]
[[package]]
name = "rayon-core"
version = "1.9.0"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "9ab346ac5921dc62ffa9f89b7a773907511cdfa5490c572ae9be1be33e8afa4a"
dependencies = [
"crossbeam-channel",
"crossbeam-deque",
"crossbeam-utils",
"lazy_static",
"num_cpus",
]
[[package]]
name = "regex"
version = "1.4.5"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "957056ecddbeba1b26965114e191d2e8589ce74db242b6ea25fc4062427a5c19"
dependencies = [
"aho-corasick",
"memchr",
"regex-syntax",
]
[[package]]
name = "regex-syntax"
version = "0.6.23"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "24d5f089152e60f62d28b835fbff2cd2e8dc0baf1ac13343bef92ab7eed84548"
[[package]]
name = "scopeguard"
version = "1.1.0"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "d29ab0c6d3fc0ee92fe66e2d99f700eab17a8d57d1c1d3b748380fb20baa78cd"
[[package]]
name = "termcolor"
version = "1.1.2"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "2dfed899f0eb03f32ee8c6a0aabdb8a7949659e3466561fc0adf54e26d88c5f4"
dependencies = [
"winapi-util",
]
[[package]]
name = "wasi"
version = "0.10.2+wasi-snapshot-preview1"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "fd6fbd9a79829dd1ad0cc20627bf1ed606756a7f77edff7b66b7064f9cb327c6"
[[package]]
name = "winapi"
version = "0.3.9"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "5c839a674fcd7a98952e593242ea400abe93992746761e38641405d28b00f419"
dependencies = [
"winapi-i686-pc-windows-gnu",
"winapi-x86_64-pc-windows-gnu",
]
[[package]]
name = "winapi-i686-pc-windows-gnu"
version = "0.4.0"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "ac3b87c63620426dd9b991e5ce0329eff545bccbbb34f3be09ff6fb6ab51b7b6"
[[package]]
name = "winapi-util"
version = "0.1.5"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "70ec6ce85bb158151cae5e5c87f95a8e97d2c0c4b001223f33a334e3ce5de178"
dependencies = [
"winapi",
]
[[package]]
name = "winapi-x86_64-pc-windows-gnu"
version = "0.4.0"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "712e227841d057c1ee1cd2fb22fa7e5a5461ae8e48fa2ca79ec42cfc1931183f"

15
Cargo.toml Normal file
View File

@ -0,0 +1,15 @@
[package]
name = "neural"
version = "0.1.0"
authors = ["Serge Bazanski <q3k@q3k.org>"]
edition = "2018"
# See more keys and their definitions at https://doc.rust-lang.org/cargo/reference/manifest.html
[dependencies]
byteorder = "1"
flate2 = "1"
log = "0.4.0"
env_logger = "0.8.3"
rand = "0.8.3"
rayon = "1"

5
shell.nix Normal file
View File

@ -0,0 +1,5 @@
with import <nixpkgs> {};
pkgs.mkShell {
buildInputs = with pkgs; [ rustup ];
}

83
src/main.rs Normal file
View File

@ -0,0 +1,83 @@
use std::time;
use rand::Rng;
use rayon::prelude::*;
mod mnist;
struct Network {
sizes: Vec<usize>,
biases: Vec<Vec<f32>>,
weights: Vec<Vec<Vec<f32>>>,
}
impl Network {
fn new(sizes: &Vec<usize>) -> Self {
let mut rng = rand::thread_rng();
let biases = sizes[1..].iter().map(|s| {
(0..s.clone()).map(|_| rng.gen::<f32>()).collect()
}).collect();
let weights = sizes.iter().zip(sizes[1..].iter()).map(|(from, to)| {
(0..to.clone()).map(|_| {
(0..from.clone()).map(|_| rng.gen::<f32>()).collect()
}).collect()
}).collect();
Self {
sizes: sizes.clone(),
biases, weights,
}
}
fn feedforward(&self, a: &Vec<f32>) -> Vec<f32> {
let res: Option<Vec<f32>> = None;
self.biases.iter()
.zip(self.weights.iter())
.fold(a.clone(), |a, (biases, weights)| {
mat_vec_mult(weights, &a).iter()
.zip(biases.iter())
.map(|(aa, b)| {
sigmoid(aa+b)
}).collect()
})
}
}
fn sigmoid(z: f32) -> f32 {
1. / (1. + (-z).exp())
}
fn mat_vec_mult(m: &Vec<Vec<f32>>, v: &Vec<f32>) -> Vec<f32> {
m.iter().enumerate().map(|(i, row)| {
if row.len() != v.len() {
panic!("Mismatch: V {}, M {}x{}", v.len(), m.len(), row.len());
}
row.iter().zip(v.iter()).map(|(rr, vv)| rr * vv).sum()
}).collect()
}
fn main() {
env_logger::init();
let now = time::Instant::now();
let set = mnist::Set::load("train").unwrap();
let elapsed = now.elapsed();
log::info!("Loaded {} images in {:?}.", set.images.images.len(), elapsed);
let now = time::Instant::now();
let size = vec![
28*28 as usize,
15usize,
10usize,
];
let network = Network::new(&size);
let elapsed = now.elapsed();
log::info!("Created random network {:?} in {:?}.", size, elapsed);
let now = time::Instant::now();
let image: Vec<f32> = (&set.images.images[0]).into();
let res = network.feedforward(&image);
let elapsed = now.elapsed();
log::info!("feedforward: {:?} in {:?}", res, elapsed);
}

113
src/mnist.rs Normal file
View File

@ -0,0 +1,113 @@
use std::{io, fs};
use byteorder::{
BigEndian,
ReadBytesExt,
};
use flate2::read::GzDecoder;
#[derive(Debug)]
pub enum Error {
IO(io::Error),
InvalidMagic,
Parse(String),
}
impl From<io::Error> for Error {
fn from(err: io::Error) -> Self {
Error::IO(err)
}
}
pub type Result<T> = std::result::Result<T, Error>;
pub struct Image {
pixels: Vec<u8>,
width: usize,
height: usize,
}
impl Image {
pub fn parse(
rdr: &mut impl io::Read,
width: usize,
height: usize,
) -> Result<Self> {
let npixels: usize = width * height;
let mut pixels: Vec<u8> = vec![0u8; npixels];
rdr.read(&mut pixels)?;
Ok(Image {
pixels, width, height,
})
}
}
impl Into<Vec<f32>> for &Image {
fn into(self) -> Vec<f32> {
self.pixels.iter().map(|px| {
(*px as f32) / 255.0
}).collect()
}
}
pub struct ImageFile {
pub images: Vec<Image>,
}
impl ImageFile {
pub fn parse(mut rdr: impl io::Read) -> Result<Self> {
let magic = rdr.read_u32::<BigEndian>()?;
if magic != 2051 {
return Err(Error::InvalidMagic);
}
let num_images = rdr.read_u32::<BigEndian>()?;
let num_rows = rdr.read_u32::<BigEndian>()?;
let num_columns = rdr.read_u32::<BigEndian>()?;
let mut images: Vec<Image> = Vec::with_capacity(num_images as usize);
for _ in 0..num_images {
let image = Image::parse(&mut rdr, num_rows as usize, num_columns as usize)?;
images.push(image);
}
Ok(ImageFile {
images,
})
}
}
pub struct LabelFile {
magic: u32,
num_labels: u32,
labels: Vec<u8>,
}
impl LabelFile {
pub fn parse(mut rdr: impl io::Read) -> Result<Self> {
let magic = rdr.read_u32::<BigEndian>()?;
if magic != 2049 {
return Err(Error::InvalidMagic);
}
let num_labels = rdr.read_u32::<BigEndian>()?;
let mut labels: Vec<u8> = vec![0u8; num_labels as usize];
rdr.read(&mut labels)?;
Ok(LabelFile {
magic, num_labels, labels,
})
}
}
pub struct Set {
pub images: ImageFile,
labels: LabelFile,
}
impl Set {
pub fn load(name: &str) -> Result<Self> {
let labels_gz = fs::File::open(format!("{}-labels-idx1-ubyte.gz", name))?;
let labels = LabelFile::parse(GzDecoder::new(labels_gz))?;
let images_gz = fs::File::open(format!("{}-images-idx3-ubyte.gz", name))?;
let images = ImageFile::parse(GzDecoder::new(images_gz))?;
Ok(Set {
images, labels,
})
}
}