Bits and Bytes

Shreyas Srivastava

6 February 2023

Use einops to patchify image

by Shreyas Srivastava

Recently I came across einops while reading some DL pytorch repos. While it has a learning curve, the code it produces tends to be more readable once we get used to the notation.

To learn more please refer to Official einops docs

Example code: Vision transformer preprocessing Implementation of Vision transformer requires creating patches from original images to embedded the image as a token sequence. Hence, input: (b, c, h, w) output: (b, c, nh, nw, ph, pw) where b = batch size, c=channel, h=height, w=width, ph=patch height, patch width

Vanilla implementation in pytorch We reshape the tensor to add the patch dimensions and permute to get the desired ordering of dimensions

def patchify(image, patch_size):
    b, c, h, w = image.shape
    ph, pw = patch_size
    nh, nw = h // ph, w // pw
    image_patches = torch.reshape(image, (b, c, nh, ph, nw, pw))
    image_patches = torch.permute(image_patches, (0, 1, 2, 4, 3, 5))
    return image_patches

Einops implementation in pytorch

pip install einops
def patchify_with_einops(image, patch_size):
    return einops.rearrange(image, 'b c (nh ph) (nw pw) -> b c nh nw ph pw', ph=patch_size, pw=patch_size)

In the above einops implementation we can represent the reshaping and permutation in a single line with a more compact representation. Beside the einops notation, we also need to provide the patch size dimensions.
