I have two losses: one the usual L1 loss and second one involving torch.rfft()
def dft_amp(img):
fft_im = torch.rfft( img, signal_ndim=2, onesided=False )
fft_amp = fft_im[:,:,:,:,0]**2 + fft_im[:,:,:,:,1]**2
return torch.sqrt(fft_amp)
l1_loss = torch.nn.L1Loss()
loss = l1_loss(pred,gt) + l1_loss(dft_amp(pred),dft(gt_amp))
loss.backward()
This runs for the 1st iteration with both losses not bein being nan
but the loss from 2nd iteration onwards becomes nan
.
If however only the simple L1 loss is kept and l1_loss(dft_amp(pred),dft(gt_amp))
is omitted, the training proceeds normally.
Does torch.rfft()
supports backpropagation? I am using pytorch 1.4.0
Any suggestions would be appreciated
question from:https://stackoverflow.com/questions/65932593/unable-to-backpropagate-through-torch-rfft