diff --git a/segmentation_models_pytorch/losses/focal.py b/segmentation_models_pytorch/losses/focal.py index 3beb9f34..6a9150c8 100644 --- a/segmentation_models_pytorch/losses/focal.py +++ b/segmentation_models_pytorch/losses/focal.py @@ -70,20 +70,21 @@ def forward(self, y_pred: torch.Tensor, y_true: torch.Tensor) -> torch.Tensor: elif self.mode == MULTICLASS_MODE: num_classes = y_pred.size(1) - loss = 0 - # Filter anchors with -1 label from loss computation + # If ignore_index parameter is passed, treat it as an extra class during one-hot encoding and remove it later + # One-hot encoding the labels allows us to vectorise the focal loss computation if self.ignore_index is not None: - not_ignored = y_true != self.ignore_index + y_true[y_true == self.ignore_index] = num_classes + y_true_one_hot = torch.nn.functional.one_hot(y_true,num_classes = num_classes + 1) + y_true_one_hot = y_true_one_hot[ ... , : -1] - for cls in range(num_classes): - cls_y_true = (y_true == cls).long() - cls_y_pred = y_pred[:, cls, ...] + else: + y_true_one_hot = torch.nn.functional.one_hot(y_true,num_classes = num_classes) - if self.ignore_index is not None: - cls_y_true = cls_y_true[not_ignored] - cls_y_pred = cls_y_pred[not_ignored] + y_true_one_hot = torch.permute(y_true_one_hot,(0,3,1,2)) - loss += self.focal_loss_fn(cls_y_pred, cls_y_true) + # Multiplying the loss by num_classes in order to stay consistent with the earlier loss computation which did not + # take a classwise average of the loss + loss = num_classes * self.focal_loss_fn(y_pred, y_true_one_hot) return loss diff --git a/segmentation_models_pytorch/losses/focal_loss_optimisation_benchmarking.ipynb b/segmentation_models_pytorch/losses/focal_loss_optimisation_benchmarking.ipynb new file mode 100644 index 00000000..e47ef9e0 --- /dev/null +++ b/segmentation_models_pytorch/losses/focal_loss_optimisation_benchmarking.ipynb @@ -0,0 +1,195 @@ +{ + "cells": [ + { + "cell_type": "code", + "execution_count": null, + "id": "245a88c9", + "metadata": {}, + "outputs": [], + "source": [ + "from segmentation_models_pytorch.losses import BINARY_MODE, MULTICLASS_MODE, MULTILABEL_MODE\n", + "from time import time\n", + "from typing import Optional\n", + "import torch\n", + "import segmentation_models_pytorch\n", + "\n", + "class FocalLossVectorised(segmentation_models_pytorch.losses.FocalLoss):\n", + " def __init__(\n", + " self,\n", + " mode: str,\n", + " alpha: Optional[float] = None,\n", + " gamma: Optional[float] = 2.0,\n", + " ignore_index: Optional[int] = None,\n", + " reduction: Optional[str] = \"mean\",\n", + " normalized: bool = False,\n", + " reduced_threshold: Optional[float] = None,\n", + " ):\n", + " \n", + " super().__init__(mode = mode,alpha = alpha,gamma = gamma, ignore_index = ignore_index,reduction = reduction,\n", + " normalized = normalized,reduced_threshold = reduced_threshold)\n", + " \n", + " def forward(self, y_pred: torch.Tensor, y_true: torch.Tensor) -> torch.Tensor:\n", + " if self.mode in {BINARY_MODE, MULTILABEL_MODE}:\n", + " y_true = y_true.view(-1)\n", + " y_pred = y_pred.view(-1)\n", + "\n", + " if self.ignore_index is not None:\n", + " # Filter predictions with ignore label from loss computation\n", + " not_ignored = y_true != self.ignore_index\n", + " y_pred = y_pred[not_ignored]\n", + " y_true = y_true[not_ignored]\n", + "\n", + " loss = self.focal_loss_fn(y_pred, y_true)\n", + "\n", + " elif self.mode == MULTICLASS_MODE:\n", + " num_classes = y_pred.size(1)\n", + "\n", + " if self.ignore_index is not None:\n", + " y_true[y_true == self.ignore_index] = num_classes\n", + " y_true_one_hot = torch.nn.functional.one_hot(y_true,num_classes = num_classes + 1)\n", + " y_true_one_hot = y_true_one_hot[ : , : , : , : -1]\n", + "\n", + " else: \n", + " y_true_one_hot = torch.nn.functional.one_hot(y_true,num_classes = num_classes)\n", + "\n", + " y_true_one_hot = torch.permute(y_true_one_hot,(0,3,1,2))\n", + " loss = num_classes * self.focal_loss_fn(y_pred, y_true_one_hot)\n", + "\n", + " return loss" + ] + }, + { + "cell_type": "code", + "execution_count": 2, + "id": "9c64c3ea", + "metadata": {}, + "outputs": [], + "source": [ + "num_classes = 20\n", + "batch_size = 128\n", + "resolution = 512\n", + "device = 'cuda:1' if torch.cuda.is_available() else 'cpu'" + ] + }, + { + "cell_type": "code", + "execution_count": 3, + "id": "d4d3a5f5", + "metadata": {}, + "outputs": [], + "source": [ + "vectorised_loss_fn = FocalLossVectorised(mode = 'multiclass',ignore_index = num_classes)\n", + "loss_fn = segmentation_models_pytorch.losses.FocalLoss(mode = 'multiclass',ignore_index = num_classes)" + ] + }, + { + "cell_type": "code", + "execution_count": 4, + "id": "5e49a5b8", + "metadata": {}, + "outputs": [], + "source": [ + "predictions = torch.randn((batch_size,num_classes,resolution,resolution)).to(device = device)\n", + "labels = torch.randint(low = 0,high = num_classes+1,size = (batch_size,resolution,resolution)).to(device = device)" + ] + }, + { + "cell_type": "code", + "execution_count": 5, + "id": "36a0f89b", + "metadata": {}, + "outputs": [], + "source": [ + "def benchmark(function,predictions,labels,benchmark_iterations = 100):\n", + " start_time = time()\n", + "\n", + " for _ in range(benchmark_iterations):\n", + " loss = function(predictions,labels)\n", + "\n", + " end_time = time()\n", + "\n", + " average_time_taken = (end_time - start_time) / (benchmark_iterations)\n", + "\n", + " print(f\"Average time taken by function {function} is {average_time_taken} seconds\")" + ] + }, + { + "cell_type": "code", + "execution_count": 6, + "id": "8de20667", + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Average time taken by function FocalLoss() is 0.3390256547927856 seconds\n" + ] + } + ], + "source": [ + "benchmark(loss_fn,predictions,labels)" + ] + }, + { + "cell_type": "code", + "execution_count": 7, + "id": "6da16fc2", + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Average time taken by function FocalLossVectorised() is 0.11771584510803222 seconds\n" + ] + } + ], + "source": [ + "benchmark(vectorised_loss_fn,predictions,labels)" + ] + }, + { + "cell_type": "markdown", + "id": "a1083cd4", + "metadata": {}, + "source": [ + "##### CHECKING THAT OUTPUT OF NEW CLASS IS CONSISTENT WITH THE OLD ONE" + ] + }, + { + "cell_type": "code", + "execution_count": 8, + "id": "8fba3182", + "metadata": {}, + "outputs": [], + "source": [ + "output_from_vectorised_fn = vectorised_loss_fn(predictions,labels)\n", + "output_from_old_fn = loss_fn(predictions,labels)\n", + "\n", + "assert torch.allclose(output_from_vectorised_fn,output_from_old_fn)" + ] + } + ], + "metadata": { + "kernelspec": { + "display_name": "torch", + "language": "python", + "name": "python3" + }, + "language_info": { + "codemirror_mode": { + "name": "ipython", + "version": 3 + }, + "file_extension": ".py", + "mimetype": "text/x-python", + "name": "python", + "nbconvert_exporter": "python", + "pygments_lexer": "ipython3", + "version": "3.9.19" + } + }, + "nbformat": 4, + "nbformat_minor": 5 +}