I have two losses: one the usual L1 loss and second one involving torch.rfft()
def dft_amp(img):
fft_im = torch.rfft( img, signal_ndim=2, onesided=False )
fft_amp = fft_im[:,:,:,:,0]**2 + fft_im[:,:,:,:,1]**2
return torch.sqrt(fft_amp)
l1_loss = torch.nn.L1Loss()
loss = l1_loss(pred,gt) + l1_loss(dft_amp(pred),dft(gt_amp))
This runs for the 1st iteration with both losses not bein being nan
but the loss from 2nd iteration onwards becomes nan
If however only the simple L1 loss is kept and l1_loss(dft_amp(pred),dft(gt_amp))
is omitted, the training proceeds normally.
Does torch.rfft()
supports backpropagation? I am using pytorch 1.4.0
Any suggestions would be appreciated
question from:https://stackoverflow.com/questions/65932593/unable-to-backpropagate-through-torch-rfft