Search code examples
pytorchhuggingface-transformerstransformer-modelpytorch-lightning

How to set the window size for video swin transformer?


I am confused with the window size setting for the video Swin transformer. I have a data input with the shape (200, 8, 8), the height and width are 8, and the number of single channel frames is 200. I need each patch to have a dimension of 1x1x10.

I refer to this code, and I created a dummy data dummy_x = torch.rand(1, 1, 200, 8, 8). The swin transformer that I set is the following tiny swin blocks.

model = SwinTransformer3D(pretrained=None,
                 pretrained2d=True,
                 patch_size=(10,1,1),
                 in_chans=1,
                 embed_dim=96,
                 depths=[2, 2, 6, 2],
                 num_heads=[3, 6, 12, 24],
                 window_size=(20,7,7),
                 mlp_ratio=4.,
                 qkv_bias=True,
                 qk_scale=None,
                 drop_rate=0.,
                 attn_drop_rate=0.,
                 drop_path_rate=0.2,
                 norm_layer=torch.nn.LayerNorm,
                 patch_norm=True,
                 frozen_stages=-1,
                 use_checkpoint=False)

How to set the window size for the depth (frames)?

For example, window size is set to (7,7) in 2D swin swin. What the number 7 is indicating here? And how to set the window size for the 3D Swin transformer. In my settings, I set it as (20,7,7). Is this correct? how it should be?


Solution

  • You have specified a window size of (20, 7, 7), which means each patch will attend to a 20-frame depth and a 7x7 spatial region in the height and width dimensions of your input data, matching your desired patch dimension of 1x1x10. So u can use below code :

    input_shape = (200, 8, 8)
    
    patch_size = (1, 1, 10)
    
    window_size = (input_shape[0] // patch_size[0], input_shape[1] // patch_size[1], input_shape[2] // patch_size[2])
    
     model = SwinTransformer3D(
        patch_size=patch_size,
        in_chans=1,  
        embed_dim=96, 
        depths=[2, 2, 6, 2],  
        num_heads=[3, 6, 12, 24], 
        window_size=window_size,  
        mlp_ratio=4.0, 
        qkv_bias=True,  
        qk_scale=None, 
        drop_rate=0.0, 
        attn_drop_rate=0.0, 
        drop_path_rate=0.2,  
        patch_norm=False, 
        frozen_stages=-1, 
        use_checkpoint=False 
    )
    
    print(model)