# import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.autograd import grad
# Setting seed
def set_seed(seed):
torch.manual_seed(seed)
torch.cuda.manual_seed(seed)
torch.backends.cudnn.deterministic = True
torch.backends.cudnn.benchmark = False
# Defining the simplest possible CNN with masked weights
class SimpleModel(nn.Module):
def __init__(self):
super(SimpleModel, self).__init__()
self.conv1 = nn.Conv2d(1, 1, kernel_size=3, stride=1, padding=1)
self.mask = None # Initialize the mask as None
def forward(self, x):
if self.mask is not None:
# Apply the mask to the weights
masked_weight = self.conv1.weight * self.mask
else:
masked_weight = self.conv1.weight
return F.conv2d(x, masked_weight, self.conv1.bias, stride=1, padding=1)
# Function to compute the Hessian
def compute_Hessian(W, grad_W, mask):
# Initialize Hessian as a zero matrix
total_elements = W.numel()
Hessian_w = torch.zeros((total_elements, total_elements), dtype=W.dtype, device=W.device)
# Get non-zero indices
non_zero_indices = mask.nonzero(as_tuple=False)
# Function to flatten 4D index to 1D
def flatten_index(i, j, k, l):
return i * (W.shape[1] * W.shape[2] * W.shape[3]) + j * (W.shape[2] * W.shape[3]) + k * W.shape[3] + l
# Loop only over non-zero elements
for idx in range(len(non_zero_indices)):
i, j, k, l = non_zero_indices[idx]
# Extract the gradient element (i, j, k, l)
grad_element = grad_W[i, j, k, l]
# Compute the second derivative (Hessian) for the element.
# Exploits the fact that H[i,j] = ∂²L / (∂W[i] ∂W[j]) = ∂(∂L/∂W[i]) / ∂W[j]
second_grad = torch.autograd.grad(grad_element, W, retain_graph=True)[0]
# Flatten the second_grad and update the corresponding row in Hessian_w
flat_idx = flatten_index(i, j, k, l)
Hessian_w[flat_idx] = second_grad.flatten()
return Hessian_w.cpu().numpy()
# Running
if __name__ == "__main__":
set_seed(42)
# Create model and input
model = SimpleModel()
input_data = torch.randn(1, 1, 5, 5) # Batch size 1, 1 I/O channel, 5x5 image
target = torch.randn(1, 1, 5, 5) # Same shape as input
# Forward pass to initialize model parameters
output = model(input_data)
# Define a loss function
loss = F.mse_loss(output, target)
# Compute the gradient with respect to the weights
W = model.conv1.weight
# Make W sparse to only keep top 3 elements
with torch.no_grad():
num_elements = W.numel()
top_k = 3 # Keeping the top 3 elements
topk_result = torch.topk(W.abs().flatten(), top_k, largest=True)
threshold = topk_result.values[-1] if topk_result.values.numel() > 0 else W.abs().max()
W.data[W.abs() < threshold] = 0
# Create a mask that matches the sparsity of W
mask = (W != 0).float()
# Set the mask in the model
model.mask = mask
# Forward pass again with the mask applied
output = model(input_data)
# Compute the gradient with respect to the masked weights
grad_W = torch.autograd.grad(F.mse_loss(output, target), W, create_graph=True)[0]
# Compute the Hessian matrix
Hessian_w = compute_Hessian(W, grad_W, mask)
print("Hessian shape:", Hessian_w.shape)
print(Hessian_w)