Implémentation : Prototypical Network for few-shot Learning¶

Setup¶

In [ ]:
#@Install
!pip install easyfsl
In [2]:
#Import
import torch
from torch import nn, optim
from torch.utils.data import DataLoader
from torchvision import transforms
from torchvision.datasets import Omniglot
from torchvision.models import resnet18
from tqdm import tqdm

from easyfsl.samplers import TaskSampler
from easyfsl.utils import plot_images, sliding_average
In [ ]:
#Preprocess and split dataset
image_size = 28

train_set = Omniglot(
    root="./data",
    background=True, #Train data
    transform=transforms.Compose(
        [
            transforms.Grayscale(num_output_channels=3),
            transforms.RandomResizedCrop(image_size),
            transforms.RandomHorizontalFlip(),
            transforms.ToTensor(),
        ]
    ),
    download=True,
)

test_set = Omniglot(
    root="./data",
    background=False, #Test data
    transform=transforms.Compose(
        [
            transforms.Grayscale(num_output_channels=3),
            transforms.Resize([int(image_size * 1.15), int(image_size * 1.15)]),
            transforms.CenterCrop(image_size),
            transforms.ToTensor(),
        ]
    ),
    download=True,
)

Prototypical Networks (3 ways - 2 shots, no training, euclidian distance)¶

Accuracy: 87.44%

In [ ]:
#Prototypical Networks
class PrototypicalNetworks(nn.Module):
    def __init__(self, backbone: nn.Module):
        super(PrototypicalNetworks, self).__init__()
        self.backbone = backbone

    def forward(
        self,
        support_images: torch.Tensor,
        support_labels: torch.Tensor,
        query_images: torch.Tensor,
    ) -> torch.Tensor:
        """
        Predict query labels using labeled support images.
        """
        # Extract the features of support and query images
        z_support = self.backbone.forward(support_images)
        z_query = self.backbone.forward(query_images)

        # Infer the number of different classes from the labels of the support set
        n_way = len(torch.unique(support_labels))
        # Prototype i is the mean of all instances of features corresponding to labels == i
        z_proto = torch.cat(
            [
                z_support[torch.nonzero(support_labels == label)].mean(0)
                for label in range(n_way)
            ]
        )

        # Compute the euclidean distance from queries to prototypes
        dists = torch.cdist(z_query, z_proto)

        # And here is the super complicated operation to transform those distances into classification scores!
        scores = -dists
        return scores


convolutional_network = resnet18(pretrained=True)
convolutional_network.fc = nn.Flatten()
print(convolutional_network)

model = PrototypicalNetworks(convolutional_network).cuda()
In [29]:
#@title Hyperparameters
N_WAY = 3 # Number of classes per task
N_SHOT = 5 # Number of images per class (support set)
N_QUERY = 10 # Number of images per class (query set)
N_EVALUATION_TASKS = 1000

# The sampler needs a dataset with a "get_labels" method. Check the code if you have any doubt!
test_set.get_labels = lambda: [instance[1] for instance in test_set._flat_character_images]
test_sampler = TaskSampler(
    test_set, n_way=N_WAY, n_shot=N_SHOT, n_query=N_QUERY, n_tasks=N_EVALUATION_TASKS
)

test_loader = DataLoader(
    test_set,
    batch_sampler=test_sampler,
    num_workers=12,
    pin_memory=True,
    collate_fn=test_sampler.episodic_collate_fn,
)
In [30]:
(
    example_support_images,
    example_support_labels,
    example_query_images,
    example_query_labels,
    example_class_ids,
) = next(iter(test_loader))

plot_images(example_support_images, "support images", images_per_row=N_SHOT)
plot_images(example_query_images, "query images", images_per_row=N_QUERY)
In [31]:
model.eval()
example_scores = model(
    example_support_images.cuda(),
    example_support_labels.cuda(),
    example_query_images.cuda(),
).detach()

_, example_predicted_labels = torch.max(example_scores.data, 1)

print("Ground Truth / Predicted")
for i in range(len(example_query_labels)):
    print(
        f"{test_set._characters[example_class_ids[example_query_labels[i]]]} / {test_set._characters[example_class_ids[example_predicted_labels[i]]]}"
    )
Ground Truth / Predicted
Syriac_(Serto)/character03 / Syriac_(Serto)/character03
Syriac_(Serto)/character03 / Syriac_(Serto)/character03
Syriac_(Serto)/character03 / Syriac_(Serto)/character03
Syriac_(Serto)/character03 / Syriac_(Serto)/character03
Syriac_(Serto)/character03 / Syriac_(Serto)/character03
Syriac_(Serto)/character03 / Syriac_(Serto)/character03
Syriac_(Serto)/character03 / Syriac_(Serto)/character03
Syriac_(Serto)/character03 / Syriac_(Serto)/character03
Syriac_(Serto)/character03 / Syriac_(Serto)/character03
Syriac_(Serto)/character03 / Syriac_(Serto)/character03
Manipuri/character03 / Tengwar/character02
Manipuri/character03 / Tengwar/character02
Manipuri/character03 / Tengwar/character02
Manipuri/character03 / Syriac_(Serto)/character03
Manipuri/character03 / Tengwar/character02
Manipuri/character03 / Tengwar/character02
Manipuri/character03 / Tengwar/character02
Manipuri/character03 / Manipuri/character03
Manipuri/character03 / Tengwar/character02
Manipuri/character03 / Manipuri/character03
Tengwar/character02 / Tengwar/character02
Tengwar/character02 / Tengwar/character02
Tengwar/character02 / Tengwar/character02
Tengwar/character02 / Tengwar/character02
Tengwar/character02 / Tengwar/character02
Tengwar/character02 / Tengwar/character02
Tengwar/character02 / Tengwar/character02
Tengwar/character02 / Tengwar/character02
Tengwar/character02 / Tengwar/character02
Tengwar/character02 / Tengwar/character02
In [32]:
def evaluate_on_one_task(
    support_images: torch.Tensor,
    support_labels: torch.Tensor,
    query_images: torch.Tensor,
    query_labels: torch.Tensor,
):
    """
    Returns the number of correct predictions of query labels, and the total number of predictions.
    """
    return (
        torch.max(
            model(support_images.cuda(), support_labels.cuda(), query_images.cuda())
            .detach()
            .data,
            1,
        )[1]
        == query_labels.cuda()
    ).sum().item(), len(query_labels)


def evaluate(data_loader: DataLoader):
    # We'll count everything and compute the ratio at the end
    total_predictions = 0
    correct_predictions = 0

    # eval mode affects the behaviour of some layers (such as batch normalization or dropout)
    # no_grad() tells torch not to keep in memory the whole computational graph (it's more lightweight this way)
    model.eval()
    with torch.no_grad():
        for episode_index, (
            support_images,
            support_labels,
            query_images,
            query_labels,
            class_ids,
        ) in tqdm(enumerate(data_loader), total=len(data_loader)):

            correct, total = evaluate_on_one_task(
                support_images, support_labels, query_images, query_labels
            )

            total_predictions += total
            correct_predictions += correct

    print(
        f"Model tested on {len(data_loader)} tasks. Accuracy: {(100 * correct_predictions/total_predictions):.2f}%"
    )


evaluate(test_loader)
100%|██████████| 1000/1000 [00:34<00:00, 28.77it/s]
Model tested on 1000 tasks. Accuracy: 86.85%

Meta-Learning¶

In [33]:
N_TRAINING_EPISODES = 40000
N_VALIDATION_TASKS = 1000

# train_set.labels = [instance[1] for instance in train_set._flat_character_images]   #incorrect, ignore this line; the following line is correct.
train_set.get_labels = lambda: [instance[1] for instance in train_set._flat_character_images]

train_sampler = TaskSampler(
    train_set, n_way=N_WAY, n_shot=N_SHOT, n_query=N_QUERY, n_tasks=N_TRAINING_EPISODES
)
train_loader = DataLoader(
    train_set,
    batch_sampler=train_sampler,
    num_workers=12,
    pin_memory=True,
    collate_fn=train_sampler.episodic_collate_fn,
)
In [34]:
criterion = nn.CrossEntropyLoss()
optimizer = optim.Adam(model.parameters(), lr=0.001)


def fit(
    support_images: torch.Tensor,
    support_labels: torch.Tensor,
    query_images: torch.Tensor,
    query_labels: torch.Tensor,
) -> float:
    optimizer.zero_grad()
    classification_scores = model(
        support_images.cuda(), support_labels.cuda(), query_images.cuda()
    )

    loss = criterion(classification_scores, query_labels.cuda())
    loss.backward()
    optimizer.step()

    return loss.item()
In [35]:
# Train the model yourself with this cell

log_update_frequency = 10

all_loss = []
model.train()
with tqdm(enumerate(train_loader), total=len(train_loader)) as tqdm_train:
    for episode_index, (
        support_images,
        support_labels,
        query_images,
        query_labels,
        _,
    ) in tqdm_train:
        loss_value = fit(support_images, support_labels, query_images, query_labels)
        all_loss.append(loss_value)

        if episode_index % log_update_frequency == 0:
            tqdm_train.set_postfix(loss=sliding_average(all_loss, log_update_frequency))
100%|██████████| 40000/40000 [38:20<00:00, 17.39it/s, loss=0.119]
In [ ]:
model.load_state_dict(torch.load("/content/drive/MyDrive/Projets/FewShot/resnet18_with_pretraining_3w_2s_40kt.pth", map_location="cuda"))
In [44]:
evaluate(test_loader)
100%|██████████| 1000/1000 [00:35<00:00, 28.13it/s]
Model tested on 1000 tasks. Accuracy: 98.83%