-
Notifications
You must be signed in to change notification settings - Fork 20
Open
Description
How to solve this problem?The function is not correct.
class FBPFunctionGPU(torch.autograd.Function):
@staticmethod
def forward(ctx, input, proj, vol, param_id): # input: projection (sinogram), output: image
for batch in range(input.shape[0]):
f = vol[batch]
g = input[batch]
lct.fbp_gpu(g, f) # compute input (f) from proj (g)
ctx.save_for_backward(input, proj, param_id)
return vol
@staticmethod
def backward(ctx, grad_output): # grad_output: image, grad_input: projection (sinogram)
input, proj, param_id = ctx.saved_tensors
for batch in range(input.shape[0]):
f = grad_output[batch]
g = proj[batch]
lct.fbp_adjoint_gpu(g, f) # compute proj (g) from input (f) -> needs to be replaced!!!
return proj, None, None, None
What do I need to change to make this code correct?
Metadata
Metadata
Assignees
Labels
No labels