I saw the code for a ResNet CNN in Python3 and PyTorch here as follows:
def resnet_block(input_channels, num_channels, num_residuals,
first_block=False):
blk = []
for i in range(num_residuals):
if i == 0 and not first_block:
blk.append(Residual(input_channels, num_channels,
use_1x1conv=True, strides=2))
else:
blk.append(Residual(num_channels, num_channels))
return blk
To add the modules, the following code is used-
b2 = nn.Sequential(*resnet_block(64, 64, 2, first_block=True))
b3 = nn.Sequential(*resnet_block(64, 128, 2))
b4 = nn.Sequential(*resnet_block(128, 256, 2))
b5 = nn.Sequential(*resnet_block(256, 512, 2))
What does "*resnet_block()" mean/do?
Basically *iterable
is used to unpack the items of an iterable object as positional arguments. In your question resnet_block
returns a list, and the items of that list are passed to nn.Sequential
rather than the list itself.