homepage
cluster
fun
网页上的机器学习
2020/2/22
用 Rust 写了一个 Neural Network,编译成 WebAssembly,在网页上运行。

Live Demo
Source Code

这是什么

简介: 在网页上实现了一个可视化的全连接神经网络,用 Rust 编写,编译成 WebAssembly 部署。

  1. 工具链部分: Rust 编译至 WebAssembly,这部分与众不同的有两点

    • 矩阵计算部分使用了rust-ndarray,是 Python numpy库中的 ndarray 部分函数在 Rust 中的port,非常赞。
    • 写编译到 WebAssembly 的代码没有使用 wasm-bindgen,而是直接手写的 WebAssembly 接口,我觉得很干净。
  2. NN 部分: 由于只是一个多分类问题。网络非常简单,只有一层全连接层也就是只有两个矩阵(FC层加起来很简单,但是多几层CPU根本跑不动(超小声))

    • 网络是一层FC层 + ReLU + Softmax
    • Loss 是 Cross Entropy 加上简单的 Regular Loss
    • 测试数据排布方法是来自于 CS231n 课程中的一个多分类问题的 Python 代码[1],旋臂真的很帅啊!有种星系的感觉。

为什么写这个

  1. 这事还得从一只蝙蝠说起......(省略一万字)

  2. 我挺想试试 Rust 和 WebAssembly 结合。最近天天写 Rust,而 WebAssembly 在我上上个学期中写 Electron App 的时候差点就用了,但是后来因为 Node 有 C++ Addon 就去写了C++ Addon。但我一直对 WebAssembly 心心念念。

  3. 上学期在学校的实验室跟学长一起学了相当一段时间的机器学习巫术,写这个正好可以巩固一下我学过的但是过去、现在、将来都不会用到的相对底层的手写NN的手艺。

怎么写的

从需求出发,既然需要实现一个可视化的神经网络的训练,而且要利用 WASM 的力量,那么就自然而然的把所有性能敏感的部分用 WASM 实现,可视化部分用一个 Canvas 就可以显示测试数据点以及当前的测试结果了。接下来选择编译到 WASM 的语言,用 C++ 写就显得很俗,使用 Rust 写就有点 Hack 的感觉了,所以使用了 Rust 以及配套的工具链。

  1. Canvas渲染部分:

    1. 对于画预测结果,canvas 有 64x64 的大小,如果在 WASM 中算出来然后交给 JavaScript 中一个一个画 pixel,性能可能会成为较大的问题,我不想 WASM 省下的可怜算力被几千个 draw call 直接埋没了,而且 WASM 传 array 出来是非常麻烦的,所以采用 Canvas serialize 成 pixel 数组然后传指针进 WASM 中交给 WASM 涂颜色比较合适。
    2. 对于画数据点,之前写图形库的经验告诉我手写pixel画圆挺麻烦,一方面要十几行代码,令一方面对于边界情况处理没有js中处理的那么好,再一方面data points不会太多,一般大约只有几百个,所以采用传递画点的 JavaScript 的 callback 函数到 WASM 模块的方法比较合适。
  2. Neural Network 部分:采用和一般 Python 写机器学习的结构相似的结构,大约就是一个 Data Provider,一个 Network。由于是手写网络,所以 Network 部分还得手写一个 forward propagation、backward propagation 和一个 loss calculation。有 rust-ndarray 的助力,把我写过七八遍的 Python 代码 port 到 Rust 上是相对来说是很轻松的事情(相对,后面有吐槽)。BP算法这里就不推导了(写了一点,但是矩阵太多了我比较懒就放弃了),仅用对应的代码解释:

    1. Weight 和 Bias:

      • 前向传播:

        let act1 = &points.dot(&self.w1) + &self.b1;
        ...
        let act2 = &fc_layer.dot(&self.w2) + &self.b2
        
      • 反向传播:

        let dw2 = fc_layer.t().dot(&dact2) + regular_rate * &self.w2;
        let db2 = dact2.sum_axis(Axis(0)).insert_axis(Axis(0));
        let dw1 = points.t().dot(&dact1) + regular_rate * &self.w1;
        let db1 = dact1.sum_axis(Axis(0)).insert_axis(Axis(0));
        
    2. Cross Entropy:

      • 前向传播:

        let exp_scores = scores.mapv(f32::exp);
        let softmax = &exp_scores / &exp_scores.sum_axis(Axis(1)).insert_axis(Axis(1));
        
      • 反向传播:

        let mut dscores = softmax.clone();
        for (i, mut dscore) in dscores.axis_iter_mut(Axis(0)).enumerate() {
            dscore[[labels[[i]] as usize]] -= 1f32;
        }
        dscores /= num_data as f32;
        let dact2 = dscores;
        
    3. ReLU:

      • 前向传播

        let fc_layer = act1.mapv(|x| x.max(0f32));
        
      • 反向传播

        let dfc_layer = dact2.dot(&self.w2.t());
        let mut dact1 = dfc_layer.clone();
        Zip::from(&mut dact1)
            .and(fc_layer)
            .apply(|act1, &fc| {
                if fc == 0f32 {
                    *act1 = 0f32;
                }
            });
        

吐槽

引用

[1] Stanford CS231n class http://cs231n.github.io/neural-networks-case-study/

[2] WASM thread support https://developers.google.com/web/updates/2018/10/wasm-threads

[3] Rust And Assembly 官方教程 https://rustwasm.github.io/book/game-of-life/setup.html

[4] 一个远古的 Rust 部署 WebAssembly 的教程(记得它的js代码是有坑的所以我没使用,可我忘记坑是什么了,但这篇思路是正确的) https://www.reddit.com/r/rust/comments/9t95fd/howto_setting_up_webassembly_on_stable_rust/

[5] 一个现代的 Rust 部署 WebAssembly 教程(我也忘记它是不是正确的了,但是我的部署方法一定是对的 XD) https://dev.to/dandyvica/wasm-in-rust-without-nodejs-2e0c

Rust
WebAssembly
Machine Learning
Math