Join Our 5-Week ML/AI Engineer Interview Bootcamp 🚀 led by ML Tech Leads at FAANGs

Back to Questions

90. Truncated BPTT

hard
GeneralGeneral
manager

Truncated Backpropagation Through Time (BPTT) is a practical way to train sequence models (like vanilla RNNs) on long sequences without backpropagating through the entire history. In this problem, you’ll implement one training step that computes gradients using a fixed truncation window (k), instead of the full sequence length.

Requirements

Implement the function

python

Rules:

  • Use the equations above and compute the average cross-entropy loss: [ L = -\frac{1}{T}\sum_{t=0}^{T-1}\log p_t[y_t] ]
  • Implement a numerically stable softmax (subtract max logit before exponentiating).
  • Implement truncated BPTT: only propagate gradients back at most k steps in time.
  • Update parameters with one step of vanilla SGD: param -= lr * grad.
  • Use only NumPy + Python built-ins; return parameters as NumPy arrays.

Example

python

Output:

python
Input Signature
ArgumentType
kint
bhnp.ndarray
bynp.ndarray
h0np.ndarray
lrfloat
xsnp.ndarray
ysnp.ndarray
Whhnp.ndarray
Whynp.ndarray
Wxhnp.ndarray
Output Signature
Return NameType
valuetuple

Constraints

  • Numerically stable softmax (subtract max logit).

  • Return updated parameters as NumPy arrays.

  • Use NumPy only; vanilla SGD update.

Hint 1

Store intermediates from the forward pass: h[t], logits, and softmax probs p[t]. You’ll need them for BPTT (especially h[t-1] and x[t]).

Hint 2

Use the softmax + cross-entropy shortcut: for each time step t, set dp = p[t].copy(); dp[y[t]] -= 1, then scale by 1/T because the loss is averaged.

Hint 3

Implement truncation explicitly: when backpropagating from time t, only loop backward to t_start = max(0, t-k+1). Propagate dh through tanh and then through Whh.T each step, accumulating dWxh, dWhh, and dbh.

Roles
ML Engineer
AI Engineer
Companies
GeneralGeneral
Levels
manager
staff
senior
Tags
RNN
truncated-BPTT
backpropagation
softmax-cross-entropy
28 people are solving this problem
Python LogoPython Editor
Ln 1, Col 1

Input Arguments

Edit values below to test with custom inputs

You need tolog in/sign upto run or submit