Use einops to patchify image
by Shreyas Srivastava
Recently I came across einops while reading some DL pytorch repos. While it has a learning curve, the code it produces tends to be more readable once we get used to the notation.
To learn more please refer to Official einops docs
Example code: Vision transformer preprocessing
Implementation of Vision transformer requires creating patches from original images to embedded the image as a token sequence.
Hence,
input: (b, c, h, w)
output: (b, c, nh, nw, ph, pw)
where b = batch size, c=channel, h=height, w=width, ph=patch height, patch width
Vanilla implementation in pytorch We reshape the tensor to add the patch dimensions and permute to get the desired ordering of dimensions
def patchify(image, patch_size):
b, c, h, w = image.shape
ph, pw = patch_size
nh, nw = h // ph, w // pw
image_patches = torch.reshape(image, (b, c, nh, ph, nw, pw))
image_patches = torch.permute(image_patches, (0, 1, 2, 4, 3, 5))
return image_patches
Einops implementation in pytorch
pip install einops
def patchify_with_einops(image, patch_size):
return einops.rearrange(image, 'b c (nh ph) (nw pw) -> b c nh nw ph pw', ph=patch_size, pw=patch_size)
In the above einops implementation we can represent the reshaping and permutation in a single line with a more compact representation. Beside the einops notation, we also need to provide the patch size dimensions.
tags: