Fast Gradient Sign Attack

3 minute read

Published:

This post covers MNIST implementation of ICLR 2015 paper “Explaining and Harnessing Adversarial Examples

Source: PyTorch

Import Libraries

import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim

from torchvision import datasets, transforms

import numpy as np
import matplotlib.pyplot as plt

%matplotlib inline

device = 'cuda' if torch.cuda.is_available() else 'cpu'
print(device)

Pretrained Model

pretrained_model = 'model/lenet_mnist_model.pth'
epsilons = [0, 0.05, .1, .15, .2, .25, .3]

LeNet Model

class Net(nn.Module):
    def __init__(self):
        super(Net, self).__init__()
        self.conv1 = nn.Conv2d(1, 10, kernel_size=5)
        self.conv2 = nn.Conv2d(10, 20, kernel_size=5)
        self.conv2_drop = nn.Dropout2d()
        self.fc1 = nn.Linear(320, 50)
        self.fc2 = nn.Linear(50, 10)

    def forward(self, x):
        x = F.relu(F.max_pool2d(self.conv1(x), 2))
        x = F.relu(F.max_pool2d(self.conv2_drop(self.conv2(x)), 2))
        x = x.view(-1, 320)
        x = F.relu(self.fc1(x))
        x = F.dropout(x, training=self.training)
        x = self.fc2(x)
        return F.log_softmax(x, dim=1)

MNIST Test Dataset

tfs = [transforms.ToTensor()]
transform = transforms.Compose(tfs)

ds_test = datasets.MNIST(root='data', download=True, train=False, transform=transform)

loader_test = torch.utils.data.DataLoader(ds_test, batch_size=1, shuffle=True)

Load Pretrained Model

# Initialize the network
model = Net().to(device)

model.load_state_dict(torch.load(pretrained_model, map_location=device))

model.eval()

FSGM Attack

  • Rather than working to minimize the loss by adjusting the weights based on the backpropagated gradients, the attack adjusts the input data to maximize the loss based on the same backpropagated gradients.
  • In other words, the attack uses the gradient of the loss w.r.t the input data, then adjusts the input data to maximize the loss.
def fgsm_attack(image, epsilon, data_grad):
    sign_data_grad = data_grad.sign()
    
    perturbed_image = image + epsilon * sign_data_grad
    
    perturbed_image = torch.clamp(perturbed_image, 0, 1)
    
    return perturbed_image

Test Module

def test(model, loader_test, epsilon):
    correct = 0
    adv_examples = []
    
    for data, target in loader_test:
      
        data, target = data.to(device), target.to(device)
      
      	data.requires_grad = True # For attack
        
        output = model(data)
        pred_init = output.max(1, keepdim=True)[1]
        
        if pred_init.item() != target.item():
            continue # No need to attack if model pred is not correct
        
        loss = F.nll_loss(output, target)
        loss.backward()
        model.zero_grad()
        
        data_grad = data.grad.data
        
        data = fgsm_attack(data, epsilon, data_grad) # perturbed image
        
        output = model(data)
        pred_final = output.max(1, keepdim=True)[1]
        
        if pred_final.item() != target.item(): # Attack successful
            if len(adv_examples) < 5:
                adv_ex = data.squeeze().detach().cpu().numpy()
                adv_examples.append((pred_init.item(), pred_final.item(), adv_ex))
        else:
            correct += 1 # Attack Unsuccessful
            if len(adv_examples) < 5 and epsilon == 0: # save only when eps is 0 when attack not successful
                adv_ex = data.squeeze().detach().cpu().numpy()
                adv_examples.append((pred_init.item(), pred_final.item(), adv_ex))
    acc = correct*1.0 / len(loader_test)
    print('Eps: {} \t Accuracy = {}/{}={:.3f}'.format(epsilon, correct, len(loader_test), acc))
    return acc, adv_examples

Run Test

accuracies = []
examples = []

for eps in epsilons:
    acc, ex = test(model, loader_test, eps)
    
    accuracies.append(acc)
    examples.append(ex)

Eps: 0 Accuracy = 9810/10000=0.981 Eps: 0.05 Accuracy = 9426/10000=0.943 Eps: 0.1 Accuracy = 8510/10000=0.851 Eps: 0.15 Accuracy = 6826/10000=0.683 Eps: 0.2 Accuracy = 4301/10000=0.430 Eps: 0.25 Accuracy = 2082/10000=0.208 Eps: 0.3 Accuracy = 869/10000=0.087

print(epsilons)
print(accuracies)

[0, 0.05, 0.1, 0.15, 0.2, 0.25, 0.3] [0.981, 0.9426, 0.851, 0.6826, 0.4301, 0.2082, 0.0869]

Visualization

plt.figure(figsize=(5,5))

plt.plot(epsilons, accuracies, '*-')

plt.yticks(np.arange(0, 1.1, step=0.1))
plt.xticks(np.arange(0, .35, step=0.05))
plt.title("Accuracy vs Epsilon")
plt.xlabel("Epsilon")
plt.ylabel("Accuracy");

plt.figure(figsize=(8,10))

cnt = 0
for i in range(len(epsilons)):
    for j in range(len(examples[i])):
        cnt += 1
        plt.subplot(len(epsilons), len(examples[i]), cnt)
        orig, adv, ex = examples[i][j]

        plt.imshow(ex, cmap='gray')
        
        plt.xticks([], [])
        plt.yticks([], [])
        plt.title("{} -> {}".format(orig, adv))
        
        if j == 0:
            plt.ylabel("Eps: {}".format(epsilons[i]), fontsize=14)
        
plt.tight_layout()

  • All correct for Eps 0 row