Atlas registration progress

At low resolution

Since last time, found that I could not train ICON_atlas loss to match performance of ICON on OAI knees: second step of multiscale training failed as folds ran away, lambda increased to infinity.

Outreach:

Presenting same powerpoint from group meeting last week to funky bunch on wednesday. Several people from Kitware's AI team coming: anything I should make sure to include?

Follow up to "New approach to patchwise registration"

Wanted to setup for applying Instant Neural Graphics Primitives to neural registration fields. To do this, I needed a valid regularization proceedure. Started with this approach:

Which then doesn't work: produces very irregular mappings or identity map:

colab notebook

New approach: Use torch.autograd to compute spatial gradients of node "Approximation of original position in Image A, - Position in Image A" with respect to node "Position in Image A"

Square these partial derivatives and minimize the result. This is inspired by the gradient penalty for GANS espoused in When do GANS actually converge which we discussed a while back: Each sample forces a neighborhood around it to be near zero, instead of just a single point.

FOR A SINGLE PAIR, THIS WORKS FOR REGISTRATION!

notebook

Image A

Image B

Grid

Warped B

GradientICON

While getting the above to work, I was impressed enough with the performance of the Jacobian penalty on the Inverse Consistency term to try it back on the standard convolutional ICON.

I computed the jacobian using finite differences instead of torch.autograd since that was more convenient, and it's only through linear interpolations, so finite differences are usually exact anyways.

delta = .001

if len(self.identityMap.shape) == 4:
    dx = torch.Tensor([[[[delta]], [[0.]]]]).to(config.device)
    dy = torch.Tensor([[[[0.]], [[delta]]]]).to(config.device)
    direction_vectors = (dx, dy)

elif len(self.identityMap.shape) == 5:
    dx = torch.Tensor([[[[[delta]]], [[[0.]]], [[[0.]]]]]).to(config.device)
    dy = torch.Tensor([[[[[0.]]], [[[delta]]], [[[0.]]]]]).to(config.device)
    dz = torch.Tensor([[[[0.]]], [[[0.]]], [[[delta]]]]).to(config.device)
    direction_vectors = (dx, dy, dz)

for d in direction_vectors:
    approximate_Iepsilon_d = self.phi_AB(self.phi_BA(Iepsilon + d))
    inverse_consistency_error_d = Iepsilon + d - approximate_Iepsilon_d
    grad_d_icon_error = (inverse_consistency_error - inverse_consistency_error_d) / delta
    direction_losses.append(torch.mean(grad_d_icon_error**2))

inverse_consistency_loss = sum(direction_losses)

This works great in 2d, solving the hollow triangles circles benchmark in 2 minutes instead of ~ an hour, and with more reliable and higher quality final results:

notebook

GradientICON in 3D

Trains like a dream with not much fussing even at batch size 1, 160 x 384 x 384.

Step 1 notebook

Step 1 (40 x 96 x 96): DICE 66

Step 2 notebook

Step 2 (80 x 192 x 192): DICE 71.3

Step 3 notebook

Step 3 (160 x 384 x 384): DICE 73.3

Back to Reports
subdavis.com forrestli.com