Welcome! Share code as fast as possible.

# import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.autograd import grad

# Setting seed
def set_seed(seed):
    torch.manual_seed(seed)
    torch.cuda.manual_seed(seed)
    torch.backends.cudnn.deterministic = True
    torch.backends.cudnn.benchmark = False

# Defining the simplest possible CNN with masked weights
class SimpleModel(nn.Module):
    def __init__(self):
        super(SimpleModel, self).__init__()
        self.conv1 = nn.Conv2d(1, 1, kernel_size=3, stride=1, padding=1)
        self.mask = None  # Initialize the mask as None
    
    def forward(self, x):
        if self.mask is not None:
            # Apply the mask to the weights
            masked_weight = self.conv1.weight * self.mask
        else:
            masked_weight = self.conv1.weight
        return F.conv2d(x, masked_weight, self.conv1.bias, stride=1, padding=1)

# Function to compute the Hessian
def compute_Hessian(W, grad_W, mask):
    # Initialize Hessian as a zero matrix
    total_elements = W.numel()
    Hessian_w = torch.zeros((total_elements, total_elements), dtype=W.dtype, device=W.device)
    
    # Get non-zero indices
    non_zero_indices = mask.nonzero(as_tuple=False)
    
    # Function to flatten 4D index to 1D
    def flatten_index(i, j, k, l):
        return i * (W.shape[1] * W.shape[2] * W.shape[3]) + j * (W.shape[2] * W.shape[3]) + k * W.shape[3] + l
    
    # Loop only over non-zero elements
    for idx in range(len(non_zero_indices)):
        i, j, k, l = non_zero_indices[idx]
        
        # Extract the gradient element (i, j, k, l)
        grad_element = grad_W[i, j, k, l]
        
        # Compute the second derivative (Hessian) for the element. 
        # Exploits the fact that H[i,j] = ∂²L / (∂W[i] ∂W[j]) = ∂(∂L/∂W[i]) / ∂W[j]
        second_grad = torch.autograd.grad(grad_element, W, retain_graph=True)[0]
        
        # Flatten the second_grad and update the corresponding row in Hessian_w
        flat_idx = flatten_index(i, j, k, l)
        Hessian_w[flat_idx] = second_grad.flatten()
    
    return Hessian_w.cpu().numpy()

# Running
if __name__ == "__main__":
    set_seed(42)
    
    # Create model and input
    model = SimpleModel()
    input_data = torch.randn(1, 1, 5, 5) # Batch size 1, 1 I/O channel, 5x5 image
    target = torch.randn(1, 1, 5, 5) # Same shape as input
    
    # Forward pass to initialize model parameters
    output = model(input_data)
    
    # Define a loss function
    loss = F.mse_loss(output, target)
    
    # Compute the gradient with respect to the weights
    W = model.conv1.weight
    
    # Make W sparse to only keep top 3 elements
    with torch.no_grad():
        num_elements = W.numel()
        top_k = 3 # Keeping the top 3 elements
        topk_result = torch.topk(W.abs().flatten(), top_k, largest=True)
        threshold = topk_result.values[-1] if topk_result.values.numel() > 0 else W.abs().max()
        W.data[W.abs() < threshold] = 0
    
    # Create a mask that matches the sparsity of W
    mask = (W != 0).float()

    # Set the mask in the model
    model.mask = mask

    # Forward pass again with the mask applied
    output = model(input_data)
    
    # Compute the gradient with respect to the masked weights
    grad_W = torch.autograd.grad(F.mse_loss(output, target), W, create_graph=True)[0]

    # Compute the Hessian matrix
    Hessian_w = compute_Hessian(W, grad_W, mask)
    
    print("Hessian shape:", Hessian_w.shape)
    print(Hessian_w)