In [2]:
import icon_registration as icon
import icon_registration.data
import icon_registration.networks as networks
from icon_registration.config import device

import numpy as np
import torch
import torchvision.utils
import matplotlib.pyplot as plt
In [3]:
ds, _ = icon_registration.data.get_dataset_mnist(split="train", number=2)

sample_batch = next(iter(ds))[0]
plt.imshow(torchvision.utils.make_grid(sample_batch[:12], nrow=4)[0])
Out[3]:
<matplotlib.image.AxesImage at 0x7fabc477fc70>
No description has been provided for this image
In [4]:
# model.py

import icon_registration.constricon as constricon

input_shape = [1, 1, 128, 128, 128]

def make_network():
  net = constricon.FirstTransform(
    constricon.TwoStepInverseConsistent(
        constricon.ConsistentFromMatrix(
          networks.ConvolutionalMatrixNet(dimension=2)
      ),
      constricon.TwoStepInverseConsistent(
          constricon.ConsistentFromMatrix(
              networks.ConvolutionalMatrixNet(dimension=2)
          ),
          constricon.TwoStepInverseConsistent(
              constricon.ConsistentFromMatrix(
                  networks.ConvolutionalMatrixNet(dimension=2)
              ),
              constricon.ConsistentFromMatrix(
                  networks.ConvolutionalMatrixNet(dimension=2)
              ),
          ),
      ),
    )
  )
  net = icon.losses.BendingEnergyNet(net, icon.LNCC(5), lmbda=.03)
  net.assign_identity_map(input_shape)
  return net
net = make_network()
In [5]:
net.assign_identity_map(sample_batch.shape)
In [6]:
net.train()
net.to(device)

optim = torch.optim.Adam(net.parameters(), lr=0.001)
curves = icon.train_datasets(net, optim, ds, ds, epochs=5)
plt.close()
plt.plot(np.array(curves)[:, :3])
100%|█████████████████████████████████████████████████████████████████████████████████████| 5/5 [00:34<00:00,  6.83s/it]
Out[6]:
[<matplotlib.lines.Line2D at 0x7fabc0c87c70>,
 <matplotlib.lines.Line2D at 0x7fabc0c87dc0>,
 <matplotlib.lines.Line2D at 0x7fabc0c87f10>]
No description has been provided for this image
In [7]:
plt.close()

def show(tensor):
    plt.imshow(torchvision.utils.make_grid(tensor[:6], nrow=3)[0].cpu().detach())
    plt.xticks([])
    plt.yticks([])
image_A = next(iter(ds))[0].to(device)
image_B = next(iter(ds))[0].to(device)
net(image_A, image_B)
plt.subplot(2, 2, 1)
show(image_A)
plt.subplot(2, 2, 2)
show(image_B)
plt.subplot(2, 2, 3)
show(net.warped_image_A)
plt.contour(torchvision.utils.make_grid(net.phi_AB_vectorfield[:6], nrow=3)[0].cpu().detach())
plt.contour(torchvision.utils.make_grid(net.phi_AB_vectorfield[:6], nrow=3)[1].cpu().detach())
plt.subplot(2, 2, 4)
show(net.warped_image_A - image_B)
plt.tight_layout()
No description has been provided for this image
In [8]:
ds, _ = icon_registration.data.get_dataset_mnist(split="train", number=5)
curves = icon.train_datasets(net, optim, ds, ds, epochs=5)
plt.close()
plt.plot(np.array(curves)[:, :3])
ds, _ = icon_registration.data.get_dataset_mnist(split="train", number=8)
curves = icon.train_datasets(net, optim, ds, ds, epochs=5)
plt.close()
plt.plot(np.array(curves)[:, :3])
100%|█████████████████████████████████████████████████████████████████████████████████████| 5/5 [00:30<00:00,  6.07s/it]
100%|█████████████████████████████████████████████████████████████████████████████████████| 5/5 [00:32<00:00,  6.51s/it]
Out[8]:
[<matplotlib.lines.Line2D at 0x7fabc1e275b0>,
 <matplotlib.lines.Line2D at 0x7fabc1e27700>,
 <matplotlib.lines.Line2D at 0x7fabc1e27850>]
No description has been provided for this image
In [9]:
ds, _ = icon_registration.data.get_dataset_mnist(split="train", number=6)
image_A = next(iter(ds))[0].to(device)
image_B = next(iter(ds))[0].to(device)
net(image_A, image_B)
plt.subplot(2, 2, 1)
show(image_A)
plt.subplot(2, 2, 2)
show(image_B)
plt.subplot(2, 2, 3)
show(net.warped_image_A)
plt.contour(torchvision.utils.make_grid(net.phi_AB_vectorfield[:6], nrow=3)[0].cpu().detach())
plt.contour(torchvision.utils.make_grid(net.phi_AB_vectorfield[:6], nrow=3)[1].cpu().detach())
plt.subplot(2, 2, 4)
show(net.warped_image_A - image_B)
plt.tight_layout()
plt.show()
ds, _ = icon_registration.data.get_dataset_mnist(split="train", number=1)
image_A = next(iter(ds))[0].to(device)
image_B = next(iter(ds))[0].to(device)
net(image_A, image_B)
plt.subplot(2, 2, 1)
show(image_A)
plt.subplot(2, 2, 2)
show(image_B)
plt.subplot(2, 2, 3)
show(net.warped_image_A)
plt.contour(torchvision.utils.make_grid(net.phi_AB_vectorfield[:6], nrow=3)[0].cpu().detach())
plt.contour(torchvision.utils.make_grid(net.phi_AB_vectorfield[:6], nrow=3)[1].cpu().detach())
plt.subplot(2, 2, 4)
show(net.warped_image_A - image_B)
plt.tight_layout()
No description has been provided for this image
No description has been provided for this image
In [ ]: