Join Our 5-Week ML/AI Engineer Interview Bootcamp 🚀 led by ML Tech Leads at FAANGs

Back to Questions

51. Vision transformer patch embedding

medium
GeneralGeneral
senior

Implement Vision Transformer (ViT) patch embedding, which turns an image into a sequence of patch vectors for transformer input. You’ll split the image into non-overlapping patches, flatten each patch, then apply a learned linear projection to get embeddings.

The patch embedding can be written as:

E=XW+bE = XW + b

where (X \in \mathbb{R}^{N \times (P^2C)}) is the matrix of flattened patches, (W \in \mathbb{R}^{(P^2C) \times D}) is the projection matrix, (b \in \mathbb{R}^{D}), (N) is the number of patches, (P) is the patch size, (C) is the number of channels, and (D) is the embedding dimension.

Requirements

Implement the function

python

Rules:

  • Split image into non-overlapping patches of size patch_size x patch_size.
  • Flatten each patch in row-major order, keeping channel values last (i.e., flatten over H, then W, then C within a patch).
  • Stack all flattened patches into a matrix X with shape (N, P*P*C) in top-to-bottom, left-to-right patch order.
  • Compute embeddings using E = X @ W + b (use NumPy for matrix math).
  • Return E as a NumPy array.

Example

python

Output:

python
Input Signature
ArgumentType
Wnp.ndarray
bnp.ndarray
imagenp.ndarray
patch_sizeint
Output Signature
Return NameType
valuenp.ndarray

Constraints

  • Use NumPy for matrix multiplication.

  • Return NumPy array.

  • Flatten order: H, W, then C.

Hint 1

Get H, W, C from image.shape.

Hint 2

Loop patches in top-to-bottom, left-to-right order: for i in range(H//P) and j in range(W//P), slice image[i*P:(i+1)*P, j*P:(j+1)*P, :].

Hint 3

Flatten each patch with patch.reshape(-1) (row-major, channels last), stack into X with shape (N, P*P*C), then compute E = X @ W + b.

Roles
ML Engineer
AI Engineer
Companies
GeneralGeneral
Levels
senior
entry
Tags
vision-transformer
patch-embedding
numpy-linear-algebra
tensor-reshaping
13 people are solving this problem
Python LogoPython Editor
Ln 1, Col 1

Input Arguments

Edit values below to test with custom inputs

You need tolog in/sign upto run or submit