Search code examples
pythontensorflowkerashuggingface-transformers

How to modify base ViT architecture from Huggingface in Tensorflow


I am new to hugging face and want to adopt the same Transformer architecture as done in ViT for image classification to my domain. I thus need to change the input shape and the augmentations done.

From the snippet from huggingface:

from transformers import ViTFeatureExtractor, TFViTForImageClassification
import tensorflow as tf
from PIL import Image
import requests

url = "http://images.cocodataset.org/val2017/000000039769.jpg"
image = Image.open(requests.get(url, stream=True).raw)

feature_extractor = ViTFeatureExtractor.from_pretrained("google/vit-base-patch16-224")
model = TFViTForImageClassification.from_pretrained("google/vit-base-patch16-224")

inputs = feature_extractor(images=image, return_tensors="tf")
outputs = model(**inputs)
logits = outputs.logits
# model predicts one of the 1000 ImageNet classes
predicted_class_idx = tf.math.argmax(logits, axis=-1)[0]
print("Predicted class:", model.config.id2label[int(predicted_class_idx)])

When I do mode.summary()

I get the following results:

Model: "tf_vi_t_for_image_classification_1"
_________________________________________________________________
 Layer (type)                Output Shape              Param #   
=================================================================
 vit (TFViTMainLayer)        multiple                  85798656  
                                                                 
 classifier (Dense)          multiple                  769000    
                                                                 
=================================================================
Total params: 86,567,656
Trainable params: 86,567,656
Non-trainable params: 0

As shown, the layers of the ViT base is encapsulated, is there a method to unwrap the layers to allow me to modify specific layers?


Solution

  • In your case, I would recommend looking at the source code here and tracing the called classes. For example to get the layers of the Embeddings class, you can run:

    print(model.layers[0].embeddings.patch_embeddings.projection)
    print(model.layers[0].embeddings.dropout)
    
    <keras.layers.convolutional.Conv2D object at 0x7fea6264c6d0>
    <keras.layers.core.dropout.Dropout object at 0x7fea62d65110>
    

    Or if you want to get the layers of the first Attention block, try:

    print(model.layers[0].encoder.layer[0].attention.self_attention.query)
    print(model.layers[0].encoder.layer[0].attention.self_attention.key)
    print(model.layers[0].encoder.layer[0].attention.self_attention.value)
    print(model.layers[0].encoder.layer[0].attention.self_attention.dropout)
    print(model.layers[0].encoder.layer[0].attention.dense_output.dense)
    print(model.layers[0].encoder.layer[0].attention.dense_output.dropout)
    
    <keras.layers.convolutional.Conv2D object at 0x7fea6264c6d0>
    <keras.layers.core.dropout.Dropout object at 0x7fea62d65110>
    <keras.layers.core.dense.Dense object at 0x7fea62ec7f90>
    <keras.layers.core.dense.Dense object at 0x7fea62ec7b50>
    <keras.layers.core.dense.Dense object at 0x7fea62ec79d0>
    <keras.layers.core.dropout.Dropout object at 0x7fea62cf5c90>
    <keras.layers.core.dense.Dense object at 0x7fea62cf5250>
    <keras.layers.core.dropout.Dropout object at 0x7fea62cf5410>
    

    and so on.