← 返回目录


哪些 Rust 库让你相见恨晚?

学校≠教育≠技能;文凭溢价=80%信号传递+20%人力资本

108 👍 / 17 💬

问题描述

相似问题:哪些 Python 库让你相见恨晚?


burn-rs,一个深度学习框架,未来 FSRS 能不能在手机端跑,就看它了。

对我这个算法工程师来说,这个库比较能解决的痛点包括:

  1. 不依赖 torch/tensorflow,编译出来体积很小,可以在客户端进行参数训练,纯本地化
  2. 能自己写 forward,而不是几个网络层搭积木,比较自由,比如下面就是我重写的 FSRS:
#[derive(Module, Debug)]
pub struct Model<B: Backend> {
pub w: Param<Tensor<B, 1>>,
}

impl<B: Backend<FloatElem = f32>> Model<B> {
pub fn new() -> Self {
Self {
w: Param::from(Tensor::from_floats([
0.4, 0.6, 2.4, 5.8, // initial stability
4.93, 0.94, 0.86, 0.01, // difficulty
1.49, 0.14, 0.94, // success
2.18, 0.05, 0.34, 1.26, // failure
0.29, 2.61, // hard penalty, easy bonus
])),
}
}

fn w(&self) -> Tensor<B, 1> {
self.w.val()
}

pub fn power_forgetting_curve(&self, t: Tensor<B, 1>, s: Tensor<B, 1>) -> Tensor<B, 1> {
let retrievability = (t / (s * 9) + 1).powf(-1.0);
retrievability
}

fn stability_after_success(
&self,
last_s: Tensor<B, 1>,
new_d: Tensor<B, 1>,
r: Tensor<B, 1>,
rating: Tensor<B, 1>,
) -> Tensor<B, 1> {
let batch_size = rating.dims()[0];
let hard_penalty = Tensor::ones([batch_size])
.mask_where(rating.clone().equal_elem(2), self.w().slice([15..16]));
let easy_bonus = Tensor::ones([batch_size])
.mask_where(rating.clone().equal_elem(4), self.w().slice([16..17]));
let new_s = last_s.clone()
* (self.w().slice([8..9]).exp()
* (-new_d + 11)
* (-self.w().slice([9..10]) * last_s.log()).exp()
* (((-r + 1) * self.w().slice([10..11])).exp() - 1)
* hard_penalty
* easy_bonus
+ 1);
new_s
}

fn stability_after_failure(
&self,
last_s: Tensor<B, 1>,
new_d: Tensor<B, 1>,
r: Tensor<B, 1>,
) -> Tensor<B, 1> {
let new_s = self.w().slice([11..12])
* (-self.w().slice([12..13]) * new_d.log()).exp()
* ((self.w().slice([13..14]) * (last_s + 1).log()).exp() - 1)
* ((-r + 1) * self.w().slice([14..15])).exp();
new_s
}

fn step(
&self,
i: usize,
delta_t: Tensor<B, 1>,
rating: Tensor<B, 1>,
stability: Tensor<B, 1>,
difficulty: Tensor<B, 1>,
) -> (Tensor<B, 1>, Tensor<B, 1>) {
if i == 0 {
let new_s = self.w().select(0, rating.clone().int() - 1);
let new_d = self.w().slice([4..5]) - self.w().slice([5..6]) * (rating - 3);
(new_s.clamp(0.1, 36500.0), new_d.clamp(1.0, 10.0))
} else {
let r = self.power_forgetting_curve(delta_t, stability.clone());
// dbg!(&r);
let new_d = difficulty.clone() - self.w().slice([6..7]) * (rating.clone() - 3);
let new_d = new_d.clamp(1.0, 10.0);
// dbg!(&new_d);
let s_recall = self.stability_after_success(
stability.clone(),
new_d.clone(),
r.clone(),
rating.clone(),
);
let s_forget = self.stability_after_failure(stability, new_d.clone(), r);
let new_s = s_recall.mask_where(rating.equal_elem(1), s_forget);
(new_s.clamp(0.1, 36500.0), new_d)
}
}

pub fn forward(
&self,
delta_ts: Tensor<B, 2>,
ratings: Tensor<B, 2, Float>,
) -> (Tensor<B, 1>, Tensor<B, 1>) {
let [seq_len, batch_size] = delta_ts.dims();
let mut stability = Tensor::zeros([batch_size]);
let mut difficulty = Tensor::zeros([batch_size]);
for i in 0..seq_len {
let delta_t = delta_ts.clone().slice([i..i + 1]).squeeze(0);
let rating = ratings.clone().slice([i..i + 1]).squeeze(0);
// dbg!(&delta_t);
// dbg!(&rating);
(stability, difficulty) = self.step(i, delta_t, rating, stability, difficulty);
// dbg!(&stability);
// dbg!(&difficulty);
// dbg!()
}
(stability, difficulty)
}
}

不过这个库还是有几个问题:

  1. 训练过程有点过度封装了,想要自己写个迭代的 loop 很难
  2. 不支持训练过程中修改参数,我想做个参数剪裁,目前还没找到方法

反正边写边踩坑,边给社区提 issue,希望今年能把这个项目写完吧。

open-spaced-repetition/fsrs-optimizer-burn: Rust-based Optimizer for FSRS (github.com)

也欢迎 Rust 大佬前来助阵!

更多有关 FSRS 的介绍,请见:

叶峻峣:KDD'22 | 墨墨背单词:基于时序模型与最优控制的记忆算法 [AI+教育]叶峻峣:解释 FSRS(上篇):算法描述与运作原理叶峻峣:解释 FSRS(下篇):准确度


← 返回目录