Блог школы DeepSchool

CVAT SDK PyTorch Adapter

CVAT SDK PyTorch Adapter

Введение

В предыдущей статье я описывал ускоренный процесс разметки с использованием CVAT и Fiftyone. Мы развернули сервис разметки и получили предварительно размеченные данные.

Сейчас мы рассмотрим некоторые возможности новой SDK PyTorch от команды CVAT. Основная идея — использовать CVAT для получения ваших данных в формате PyTorch Dataset object. Это помогает решать следующие проблемы:

  • ручное скачивание каждого проекта или задания в один из форматов CVAT;
  • написание кода для конвертации в нужный формат для обучения сети;
  • объединение новых размеченных данных с текущими.

Все перечисленные выше проблемы отнимают много времени и сил. Поэтому давайте разберемся с тем, как мы можем их автоматизировать.

Датасеты

В этой статье мы рассмотрим два важных класса:
  1. ProjectVisionDataset — он позволяет загрузить все Tasks в конкретном проекте в единый датасет;
  2. TaskVisionDataset — он позволяет загрузить определенный Task.

Оба класса наследуются от torch.utils.data.Dataset и возвращают датасет, состоящий из:

  • sample[0] — изображение в формате PIL.Image.Image ;
  • sample[1] — аннотации и другие вспомогательные данные в формате cvat_sdk.pytorch.Target.
  • Что представляет из себя sample[1]? По сути, это просто все поля, доступные в CVAT.
Если вам нужны данные в других форматах, то для этого можно использовать трансформации (transforms).

Transforms

Для получения датасета в необходимом формате можно обратиться к transforms с целью изменения возвращаемых объектов.

Для этого доступны следующие варианты:

  • передача функций трансформации изображений и аннотаций в виде tuple в параметр transforms;
  • передача функций трансформации изображений в transform(для изображений) и в target_transform(для аннотаций).

Для преобразования изображений можно использовать библиотеки albumentations, torchvision.transforms, kornia и т.д.

Для преобразования аннотаций доступны классы ExtractBoundingBoxes и ExtractSingleLabelIndex .

Предположим, что нам нужны изображения в формате torch.Tensor, а для аннотаций необходимы только метки классов и bboxes.

Для этого мы можем использовать следующие классы трансформации:

import torchvision.transforms as transforms
from cvat_sdk.pytorch import TaskVisionDataset, ExtractBoundingBoxes
dataset = TaskVisionDataset(client, TASK_ID, transform=transforms.ToTensor(),target_transform=ExtractBoundingBoxes())
И тогда мы получим желаемый результат в виде torch.Tensor
>>> dataset[0]

(tensor([[[1.0000, 1.0000, 0.9922,  ..., 0.3059, 0.3059, 0.3059],
          [1.0000, 1.0000, 0.9922,  ..., 0.3098, 0.3137, 0.3137],
          [1.0000, 0.9961, 0.9961,  ..., 0.3176, 0.3176, 0.3216],
          ...,
          [0.0784, 0.0784, 0.0745,  ..., 0.2039, 0.2078, 0.2078],
          [0.0784, 0.0784, 0.0745,  ..., 0.2039, 0.2078, 0.2078],
          [0.0784, 0.0784, 0.0745,  ..., 0.2039, 0.2039, 0.2078]]]),
 {'boxes': tensor([[391.9600, 194.4400, 503.5600, 429.6400],
          [ 70.3600, 129.6400, 334.3600, 494.4400]]),
  'labels': tensor([0, 0])})
Важно отметить: данные в CVAT сохраняются в кэш. Если данные на сервисе разметки обновляются (например, были добавлены новые метки или bbox), то при повторной загрузке данных из CVAT будут скачаны только измененные части. Это позволит сократить время загрузки и улучшить эффективность работы с платформой.

Теперь давайте рассмотрим пример получения датасета с CVAT с использованием transforms.Compose:

import torchvision.transforms as transforms
from cvat_sdk import make_client
from cvat_sdk.pytorch import ProjectVisionDataset, ExtractSingleLabelIndex


# Создаем клиента
with make_client(host=CVAT_HOST, credentials=(CVAT_USER, CVAT_PASS)) as client:
    # Получаем все задания проекта с номером 52521 по фильтру 'Train'
    train_set = ProjectVisionDataset(client, project_id=52521,
				# Указываем для скачивания только Train 
        include_subsets=['Train'],
				# Применяем к изображениям трансформации
        transform=transforms.Compose([
                transforms.RandomResizedCrop((224, 224)),
                transforms.ToTensor(),
                transforms.Normalize(
                    mean=(0.485, 0.456, 0.406), std=(0.229, 0.224, 0.225)),
            ]),
				# Применяем к аннотациям трансформации
        target_transform=ExtractSingleLabelIndex())


    image, target = train_set[0]
Таким способом мы получили объект torch.utils.data.Dataset, который может быть передан в torch.utils.data.Dataloader для начала обучения. Кстати, в репозитории блога CVAT можно найти пример встроенного скачивания данных с помощью CVAT в пайплайн обучения и инференса на PyTorch.

Выводы

Хоть CVAT SDK PyTorch еще не обзавелся большим набором фичей, он уже предлагает нам удобно работать с данными: освобождает от рутины ручного скачивания, конвертирования и обновления данных.

Среди недостатков можно отметить:

  • недавний релиз (следовательно, нужно быть готовым к тому, что может что-то пойти не так);
  • пока не поддерживаемый видео-формат и формат трекинга объектов;
  • совместимость только с версией CVAT 2.3.0 и выше.

Ссылки

Раздел 3