In [2]:
import icon_registration as icon
import icon_registration.data
import icon_registration.networks as networks
from icon_registration.config import device

import numpy as np
import torch
import torchvision.utils
import matplotlib.pyplot as plt
In [3]:
ds, _ = icon_registration.data.get_dataset_mnist(split="train", number=2)

sample_batch = next(iter(ds))[0]
plt.imshow(torchvision.utils.make_grid(sample_batch[:12], nrow=4)[0])
Out[3]:
<matplotlib.image.AxesImage at 0x7f0451808d90>
No description has been provided for this image
In [9]:
import icon_registration.carl as carl


unet = carl.NoDownsampleNet(dimension=2)
ar = carl.AttentionRegistration(unet, dimension=2)
ts = icon.FunctionFromVectorField(ar)



inner_net = icon.FunctionFromVectorField(networks.tallUNet2(dimension=2))

for _ in range(2):
     inner_net = icon.TwoStepRegistration(
         icon.DownsampleRegistration(inner_net, dimension=2),
         icon.FunctionFromVectorField(networks.tallUNet2(dimension=2))
     )
inner_net = icon.TwoStepRegistration(ts, inner_net)

net = icon.losses.DiffusionRegularizedNet(inner_net, icon.LNCC(sigma=4), lmbda=.5)
In [10]:
net.assign_identity_map(sample_batch.shape)
In [11]:
net.train()
net.to(device)

optim = torch.optim.Adam(net.parameters(), lr=0.001)
curves = icon.train_datasets(net, optim, ds, ds, epochs=5)
plt.close()
plt.plot(np.array(curves)[:, :3])
/usr/lib/python3.10/contextlib.py:103: FutureWarning: `torch.backends.cuda.sdp_kernel()` is deprecated. In the future, this context manager will be removed. Please see `torch.nn.attention.sdpa_kernel()` for the new context manager, with updated signature.
  self.gen = func(*args, **kwds)
100%|█████████████████████████████████████████████████████████████████████████████████████| 5/5 [04:40<00:00, 56.15s/it]
Out[11]:
[<matplotlib.lines.Line2D at 0x7f04485a6860>,
 <matplotlib.lines.Line2D at 0x7f04485a69b0>,
 <matplotlib.lines.Line2D at 0x7f04485a6b00>]
No description has been provided for this image
In [12]:
plt.close()

def show(tensor):
    plt.imshow(torchvision.utils.make_grid(tensor[:6], nrow=3)[0].cpu().detach())
    plt.xticks([])
    plt.yticks([])
image_A = next(iter(ds))[0].to(device)
image_B = next(iter(ds))[0].to(device)
net(image_A, image_B)
plt.subplot(2, 2, 1)
show(image_A)
plt.subplot(2, 2, 2)
show(image_B)
plt.subplot(2, 2, 3)
show(net.warped_image_A)
plt.contour(torchvision.utils.make_grid(net.phi_AB_vectorfield[:6], nrow=3)[0].cpu().detach())
plt.contour(torchvision.utils.make_grid(net.phi_AB_vectorfield[:6], nrow=3)[1].cpu().detach())
plt.subplot(2, 2, 4)
show(net.warped_image_A - image_B)
plt.tight_layout()
No description has been provided for this image
In [15]:
ds, _ = icon_registration.data.get_dataset_mnist(split="train", number=5)
curves = icon.train_datasets(net, optim, ds, ds, epochs=1)
plt.close()
plt.plot(np.array(curves)[:, :3])
ds, _ = icon_registration.data.get_dataset_mnist(split="train", number=8)
curves = icon.train_datasets(net, optim, ds, ds, epochs=1)
plt.close()
plt.plot(np.array(curves)[:, :3])
  0%|                                                                                             | 0/1 [00:00<?, ?it/s]
---------------------------------------------------------------------------
OutOfMemoryError                          Traceback (most recent call last)
Cell In[15], line 2
      1 ds, _ = icon_registration.data.get_dataset_mnist(split="train", number=5)
----> 2 curves = icon.train_datasets(net, optim, ds, ds, epochs=1)
      3 plt.close()
      4 plt.plot(np.array(curves)[:, :3])

File ~/projects/affine_generalization/venv/lib/python3.10/site-packages/icon_registration/train.py:112, in train_datasets(net, optimizer, d1, d2, epochs)
    109 image_B = B[0].to(icon_registration.config.device)
    110 optimizer.zero_grad()
--> 112 loss_object = net(image_A, image_B)
    114 loss_object.all_loss.backward()
    115 optimizer.step()

File ~/projects/affine_generalization/venv/lib/python3.10/site-packages/torch/nn/modules/module.py:1739, in Module._wrapped_call_impl(self, *args, **kwargs)
   1737     return self._compiled_call_impl(*args, **kwargs)  # type: ignore[misc]
   1738 else:
-> 1739     return self._call_impl(*args, **kwargs)

File ~/projects/affine_generalization/venv/lib/python3.10/site-packages/torch/nn/modules/module.py:1750, in Module._call_impl(self, *args, **kwargs)
   1745 # If we don't have any hooks, we want to skip the rest of the logic in
   1746 # this function, and just call forward.
   1747 if not (self._backward_hooks or self._backward_pre_hooks or self._forward_hooks or self._forward_pre_hooks
   1748         or _global_backward_pre_hooks or _global_backward_hooks
   1749         or _global_forward_hooks or _global_forward_pre_hooks):
-> 1750     return forward_call(*args, **kwargs)
   1752 result = None
   1753 called_always_called_hooks = set()

File ~/projects/affine_generalization/venv/lib/python3.10/site-packages/icon_registration/losses.py:459, in BendingEnergyNet.forward(self, image_A, image_B)
    455 # Tag used elsewhere for optimization.
    456 # Must be set at beginning of forward b/c not preserved by .cuda() etc
    457 self.identity_map.isIdentity = True
--> 459 self.phi_AB = self.regis_net(image_A, image_B)
    460 self.phi_AB_vectorfield = self.phi_AB(self.identity_map)
    462 similarity_loss = 2 * self.compute_similarity_measure(
    463     self.phi_AB_vectorfield, image_A, image_B
    464 )

File ~/projects/affine_generalization/venv/lib/python3.10/site-packages/torch/nn/modules/module.py:1739, in Module._wrapped_call_impl(self, *args, **kwargs)
   1737     return self._compiled_call_impl(*args, **kwargs)  # type: ignore[misc]
   1738 else:
-> 1739     return self._call_impl(*args, **kwargs)

File ~/projects/affine_generalization/venv/lib/python3.10/site-packages/torch/nn/modules/module.py:1750, in Module._call_impl(self, *args, **kwargs)
   1745 # If we don't have any hooks, we want to skip the rest of the logic in
   1746 # this function, and just call forward.
   1747 if not (self._backward_hooks or self._backward_pre_hooks or self._forward_hooks or self._forward_pre_hooks
   1748         or _global_backward_pre_hooks or _global_backward_hooks
   1749         or _global_forward_hooks or _global_forward_pre_hooks):
-> 1750     return forward_call(*args, **kwargs)
   1752 result = None
   1753 called_always_called_hooks = set()

File ~/projects/affine_generalization/venv/lib/python3.10/site-packages/icon_registration/network_wrappers.py:212, in TwoStepRegistration.forward(self, image_A, image_B)
    206 def forward(self, image_A, image_B):
    207     
    208     # Tag for shortcutting hack. Must be set at the beginning of 
    209     # forward because it is not preserved by .to(config.device)
    210     self.identity_map.isIdentity = True
--> 212     phi = self.netPhi(image_A, image_B)
    213     psi = self.netPsi(
    214         self.as_function(image_A)(phi(self.identity_map)), 
    215         image_B
    216     )
    217     return lambda tensor_of_coordinates: phi(psi(tensor_of_coordinates))

File ~/projects/affine_generalization/venv/lib/python3.10/site-packages/torch/nn/modules/module.py:1739, in Module._wrapped_call_impl(self, *args, **kwargs)
   1737     return self._compiled_call_impl(*args, **kwargs)  # type: ignore[misc]
   1738 else:
-> 1739     return self._call_impl(*args, **kwargs)

File ~/projects/affine_generalization/venv/lib/python3.10/site-packages/torch/nn/modules/module.py:1750, in Module._call_impl(self, *args, **kwargs)
   1745 # If we don't have any hooks, we want to skip the rest of the logic in
   1746 # this function, and just call forward.
   1747 if not (self._backward_hooks or self._backward_pre_hooks or self._forward_hooks or self._forward_pre_hooks
   1748         or _global_backward_pre_hooks or _global_backward_hooks
   1749         or _global_forward_hooks or _global_forward_pre_hooks):
-> 1750     return forward_call(*args, **kwargs)
   1752 result = None
   1753 called_always_called_hooks = set()

File ~/projects/affine_generalization/venv/lib/python3.10/site-packages/icon_registration/network_wrappers.py:112, in FunctionFromVectorField.forward(self, image_A, image_B)
    111 def forward(self, image_A, image_B):
--> 112     tensor_of_displacements = self.net(image_A, image_B)
    113     displacement_field = self.as_function(tensor_of_displacements)
    115     def transform(coordinates):

File ~/projects/affine_generalization/venv/lib/python3.10/site-packages/torch/nn/modules/module.py:1739, in Module._wrapped_call_impl(self, *args, **kwargs)
   1737     return self._compiled_call_impl(*args, **kwargs)  # type: ignore[misc]
   1738 else:
-> 1739     return self._call_impl(*args, **kwargs)

File ~/projects/affine_generalization/venv/lib/python3.10/site-packages/torch/nn/modules/module.py:1750, in Module._call_impl(self, *args, **kwargs)
   1745 # If we don't have any hooks, we want to skip the rest of the logic in
   1746 # this function, and just call forward.
   1747 if not (self._backward_hooks or self._backward_pre_hooks or self._forward_hooks or self._forward_pre_hooks
   1748         or _global_backward_pre_hooks or _global_backward_hooks
   1749         or _global_forward_hooks or _global_forward_pre_hooks):
-> 1750     return forward_call(*args, **kwargs)
   1752 result = None
   1753 called_always_called_hooks = set()

File ~/projects/affine_generalization/venv/lib/python3.10/site-packages/icon_registration/carl.py:397, in AttentionRegistration.forward(self, A, B)
    395 def forward(self, A, B):
    396     ft_A = self.featurize(A, recrop=False)
--> 397     ft_B = self.featurize(B)
    398     output = self.torch_attention(ft_A, ft_B)
    399     output = output - self.identity_map

File ~/projects/affine_generalization/venv/lib/python3.10/site-packages/icon_registration/carl.py:324, in AttentionRegistration.featurize(self, values, recrop)
    322     x = torch.nn.functional.pad(values, [padding, padding, padding, padding])
    323 x = self.net(x)
--> 324 x = 4 * x / (0.001 + torch.sqrt(torch.sum(x**2, dim=1, keepdims=True)))
    325 if recrop:
    326     x = self.crop(x)

File ~/projects/affine_generalization/venv/lib/python3.10/site-packages/torch/_tensor.py:39, in _handle_torch_function_and_wrap_type_error_to_not_implemented.<locals>.wrapped(*args, **kwargs)
     37     if has_torch_function(args):
     38         return handle_torch_function(wrapped, args, *args, **kwargs)
---> 39     return f(*args, **kwargs)
     40 except TypeError:
     41     return NotImplemented

OutOfMemoryError: CUDA out of memory. Tried to allocate 134.00 MiB. GPU 0 has a total capacity of 7.92 GiB of which 17.06 MiB is free. Including non-PyTorch memory, this process has 6.94 GiB memory in use. Of the allocated memory 5.98 GiB is allocated by PyTorch, and 815.89 MiB is reserved by PyTorch but unallocated. If reserved but unallocated memory is large try setting PYTORCH_CUDA_ALLOC_CONF=expandable_segments:True to avoid fragmentation.  See documentation for Memory Management  (https://pytorch.org/docs/stable/notes/cuda.html#environment-variables)
In [14]:
ds, _ = icon_registration.data.get_dataset_mnist(split="train", number=6)
image_A = next(iter(ds))[0].to(device)[:12]
image_B = next(iter(ds))[0].to(device)[:12]
net(image_A, image_B)
plt.subplot(2, 2, 1)
show(image_A)
plt.subplot(2, 2, 2)
show(image_B)
plt.subplot(2, 2, 3)
show(net.warped_image_A)
plt.contour(torchvision.utils.make_grid(net.phi_AB_vectorfield[:6], nrow=3)[0].cpu().detach())
plt.contour(torchvision.utils.make_grid(net.phi_AB_vectorfield[:6], nrow=3)[1].cpu().detach())
plt.subplot(2, 2, 4)
show(net.warped_image_A - image_B)
plt.tight_layout()
plt.show()
ds, _ = icon_registration.data.get_dataset_mnist(split="train", number=1)
image_A = next(iter(ds))[0].to(device)[:12]
image_B = next(iter(ds))[0].to(device)[:12]
net(image_A, image_B)
plt.subplot(2, 2, 1)
show(image_A)
plt.subplot(2, 2, 2)
show(image_B)
plt.subplot(2, 2, 3)
show(net.warped_image_A)
plt.contour(torchvision.utils.make_grid(net.phi_AB_vectorfield[:6], nrow=3)[0].cpu().detach())
plt.contour(torchvision.utils.make_grid(net.phi_AB_vectorfield[:6], nrow=3)[1].cpu().detach())
plt.subplot(2, 2, 4)
show(net.warped_image_A - image_B)
plt.tight_layout()
No description has been provided for this image
No description has been provided for this image
In [ ]:
 
In [ ]: