I was trying to run the code of diffusion probabilistic models given here. While saving the images of the reverse diffusion model using this code block:
from torchvision.utils import save_image
epochs = 100
for epoch in range(epochs):
for step, batch in enumerate(dataloader):
batch_size = batch["pixel_values"].shape[0]
batch = batch["pixel_values"].to(device)
# Algorithm 1 line 3: sample t uniformally for every example in the batch
t = torch.randint(0, timesteps, (batch_size,), device=device).long()
loss = p_losses(model, batch, t, loss_type="huber")
if step % 100 == 0:
print("Loss:", loss.item())
# save generated images
if step != 0 and step % save_and_sample_every == 0:
milestone = step // save_and_sample_every
batches = num_to_groups(4, batch_size)
all_images_list = list(map(lambda n: sample(model, batch_size=n, channels=channels), batches))
all_images = torch.cat(all_images_list, dim=0)
all_images = (all_images + 1) * 0.5
save_image(all_images, str(results_folder / f'sample-{milestone}.png'), nrow = 6)
I get an error saying: TypeError: expected Tensor as element 0 in argument 0, but got list.
This comes on the line all_images = torch.cat(all_images_list, dim=0)
I have understood that this comes as torch.cat
expected a tensor but all_images_list
is a list.
I tried all means but could not find an alternative way to save the generated images without getting this error.
I shall be grateful if the community could help me on this.
def p_sample_loop(model, shape):
device = next(model.parameters()).device
b = shape[0]
# start from pure noise (for each example in the batch)
img = torch.randn(shape, device=device)
imgs = []
for i in tqdm(reversed(range(0, timesteps)), desc='sampling loop time step', total=timesteps):
img = p_sample(model, img, torch.full((b,), i, device=device, dtype=torch.long), i)
return imgs
def p_sample_loop(model, shape):
device = next(model.parameters()).device
b = shape[0]
# start from pure noise (for each example in the batch)
img = torch.randn(shape, device=device)
imgs = []
for i in tqdm(reversed(range(0, timesteps)), desc='sampling loop time step', total=timesteps):
img = p_sample(model, img, torch.full((b,), i, device=device, dtype=torch.long), i)
return imgs
and also replace
all_images_list = list(map(lambda n: sample(model, batch_size=n, channels=channels), batches))
all_images = torch.cat(all_images_list, dim=0)
all_images_list = list(map(lambda n: sample(model, image_size = image_size, batch_size=n, channels=channels), batches))
all_images = torch.cat(all_images_list[0], dim=0)
You would find the generated images in the ./results