简介: 在网页上实现了一个可视化的全连接神经网络,用 Rust 编写,编译成 WebAssembly 部署。
工具链部分: Rust 编译至 WebAssembly,这部分与众不同的有两点
NN 部分: 由于只是一个多分类问题。网络非常简单,只有一层全连接层也就是只有两个矩阵(FC层加起来很简单,但是多几层CPU根本跑不动(超小声))
这事还得从一只蝙蝠说起......(省略一万字)
我挺想试试 Rust 和 WebAssembly 结合。最近天天写 Rust,而 WebAssembly 在我上上个学期中写 Electron App 的时候差点就用了,但是后来因为 Node 有 C++ Addon 就去写了C++ Addon。但我一直对 WebAssembly 心心念念。
上学期在学校的实验室跟学长一起学了相当一段时间的机器学习巫术,写这个正好可以巩固一下我学过的但是过去、现在、将来都不会用到的相对底层的手写NN的手艺。
从需求出发,既然需要实现一个可视化的神经网络的训练,而且要利用 WASM 的力量,那么就自然而然的把所有性能敏感的部分用 WASM 实现,可视化部分用一个 Canvas 就可以显示测试数据点以及当前的测试结果了。接下来选择编译到 WASM 的语言,用 C++ 写就显得很俗,使用 Rust 写就有点 Hack 的感觉了,所以使用了 Rust 以及配套的工具链。
Canvas渲染部分:
Neural Network 部分:采用和一般 Python 写机器学习的结构相似的结构,大约就是一个 Data Provider,一个 Network。由于是手写网络,所以 Network 部分还得手写一个 forward propagation、backward propagation 和一个 loss calculation。有 rust-ndarray 的助力,把我写过七八遍的 Python 代码 port 到 Rust 上是相对来说是很轻松的事情(相对,后面有吐槽)。BP算法这里就不推导了(写了一点,但是矩阵太多了我比较懒就放弃了),仅用对应的代码解释:
Weight 和 Bias:
前向传播:
let act1 = &points.dot(&self.w1) + &self.b1;
...
let act2 = &fc_layer.dot(&self.w2) + &self.b2
反向传播:
let dw2 = fc_layer.t().dot(&dact2) + regular_rate * &self.w2;
let db2 = dact2.sum_axis(Axis(0)).insert_axis(Axis(0));
let dw1 = points.t().dot(&dact1) + regular_rate * &self.w1;
let db1 = dact1.sum_axis(Axis(0)).insert_axis(Axis(0));
Cross Entropy:
前向传播:
let exp_scores = scores.mapv(f32::exp);
let softmax = &exp_scores / &exp_scores.sum_axis(Axis(1)).insert_axis(Axis(1));
反向传播:
let mut dscores = softmax.clone();
for (i, mut dscore) in dscores.axis_iter_mut(Axis(0)).enumerate() {
dscore[[labels[[i]] as usize]] -= 1f32;
}
dscores /= num_data as f32;
let dact2 = dscores;
ReLU:
前向传播
let fc_layer = act1.mapv(|x| x.max(0f32));
反向传播
let dfc_layer = dact2.dot(&self.w2.t());
let mut dact1 = dfc_layer.clone();
Zip::from(&mut dact1)
.and(fc_layer)
.apply(|act1, &fc| {
if fc == 0f32 {
*act1 = 0f32;
}
});
由于 data generation 需要使用随机数,如果是 C++ 的话手写一个 LCG 简单高效,但是 Rust 中 rust-ndarray 里面 random_using
函数需要一个实现 rand_core 中的 SeedableRng
trait 的参数,所以手写没有那么容易,就使用了一些随机数的库。使用 ndarray-rand 这个库的时候 random 总会在运行的时候 panic,这个让我 debug 了很久,后来找到原因居然是 random 使用了 thread_rng
然而 WASM 不支持多线程(其实是支持的[2],但是使用不够方便)。
rust-ndarray 的文档虽然有那么一点用,但是要用到的东西永远查不到,很多时候还得翻源码找没有文档的函数。如果有大佬能一小时把 rust-ndarray 的 API 们用会,请 Email 我接受膜拜。之后如果有时间我也会给这个库发几个 PR。
由于是 WASM,全局变量需要加 Mutex,cargo test
的时候的默认多线程。对于这个库如果按照正常流程运行是不会出现死锁的,但是 test 的时候对于一个函数会调用多次,这些调用之间会死锁。解决方案有两种:
cargo test -- --test-threads 1
或者 set RUST_TEST_THREADS=1
我使用的方法是把所有的全局变量全部放到一个 tuple struct 中然后用 Mutex 包裹,这在我的需求下不影响性能。
struct CriticalSection(MetaData, Data, Network);
lazy_static! {
static ref DATA: Mutex<CriticalSection> = Mutex::default();
}
上面的解决引出了另一个问题,我严重怀疑是 Rust 编译器的 bug。我写了这么一个解构赋值(Destructuring assignment)的语句,略微有点复杂(?)
let ref mut CriticalSection(metadata, data, network) = *DATA.lock().unwrap();
这在当前最新的 rustc (1.43.0-nightly)中会报错,然后就我改写成下面的形式就可以过了。
let ref mut tmp = *DATA.lock().unwrap();
let CriticalSection(metadata, data, network) = tmp;
Rust - WebAssembly 结合的教程(包括官方的教程 [3])我看到基本上都是使用 wasm-pack,且用了 npm 和 nodejs 引入了一些胶水代码。我很气愤,一个入门教程应当是简洁清晰无依赖的。我仅仅找到了几个 Zero Dependency 的教程 [4][5],但是都有各种各样的问题,最后自己一点点摸索出来了 glue code,如果读者需要的话可以看一下 GitHub 的源码。我是通过 Google rust wasm without
这样的关键词组合找到这些教程的。
[1] Stanford CS231n class http://cs231n.github.io/neural-networks-case-study/
[2] WASM thread support https://developers.google.com/web/updates/2018/10/wasm-threads
[3] Rust And Assembly 官方教程 https://rustwasm.github.io/book/game-of-life/setup.html
[4] 一个远古的 Rust 部署 WebAssembly 的教程(记得它的js代码是有坑的所以我没使用,可我忘记坑是什么了,但这篇思路是正确的) https://www.reddit.com/r/rust/comments/9t95fd/howto_setting_up_webassembly_on_stable_rust/
[5] 一个现代的 Rust 部署 WebAssembly 教程(我也忘记它是不是正确的了,但是我的部署方法一定是对的 XD) https://dev.to/dandyvica/wasm-in-rust-without-nodejs-2e0c