Join Our 5-Week ML/AI Engineer Interview Bootcamp 🚀 led by ML Tech Leads at FAANGs
Truncated Backpropagation Through Time (BPTT) is a practical way to train sequence models (like vanilla RNNs) on long sequences without backpropagating through the entire history. In this problem, you’ll implement one training step that computes gradients using a fixed truncation window (k), instead of the full sequence length.
Implement the function
Rules:
k steps in time.param -= lr * grad.Output:
| Argument | Type |
|---|---|
| k | int |
| bh | np.ndarray |
| by | np.ndarray |
| h0 | np.ndarray |
| lr | float |
| xs | np.ndarray |
| ys | np.ndarray |
| Whh | np.ndarray |
| Why | np.ndarray |
| Wxh | np.ndarray |
| Return Name | Type |
|---|---|
| value | tuple |
Numerically stable softmax (subtract max logit).
Return updated parameters as NumPy arrays.
Use NumPy only; vanilla SGD update.
Store intermediates from the forward pass: h[t], logits, and softmax probs p[t]. You’ll need them for BPTT (especially h[t-1] and x[t]).
Use the softmax + cross-entropy shortcut: for each time step t, set dp = p[t].copy(); dp[y[t]] -= 1, then scale by 1/T because the loss is averaged.
Implement truncation explicitly: when backpropagating from time t, only loop backward to t_start = max(0, t-k+1). Propagate dh through tanh and then through Whh.T each step, accumulating dWxh, dWhh, and dbh.