Servire un Modello PyTorch con FastAPI e Docker

Servire un Modello PyTorch utilizzando FastAPI e Docker

Foto di SpaceX su Unsplash

Scopri come sviluppare un completo servizio di Machine Learning

Introduzione

Lavorare su progetti personali di Machine/Deep Learning è molto piacevole. Di sera, di fronte al computer portatile, si scrivono cose che ci piacciono, si leggono articoli interessanti e non c’è nessuna scadenza da rispettare. Tutti sappiamo che la programmazione è buona solo se non si fa per lavoro! 😂

Comunque, anche se è solo un progetto personale, una delle più grandi soddisfazioni arriva quando quello che hai fatto inizia ad essere usato anche da altre persone. Quindi, c’è bisogno che tu renda il tuo modello disponibile agli altri e impari gli strumenti giusti per farlo. In questo articolo ti mostrerò come servire un modello di Deep Learning sviluppato in PyTorch usando FastAPI e Docker.

Immagine di Autore

Imposta il server con FastAPI

Prima di tutto, creiamo un modello di visione artificiale. Questo modello sarà in grado di riconoscere immagini di gatti e immagini di pesci. Per fare ciò prendiamo una rete pre-addestrata di tipo ResNet50 e cambiamo l’ultimo livello di classificazione in modo che l’output sia binario.

Ho messo il seguente codice in un file chiamato model.py

from torchvision import modelsimport torch.nn as nnCatfishClasses = ["gatto", "pesce"]CatfishModel = models.resnet50()CatfishModel.fc = nn.Sequential(    nn.Linear(CatfishModel.fc.in_features, 500),    nn.ReLU(),    nn.Dropout(),    nn.Linear(500, 2))

Adesso istanziamo il server con FastAPI al quale i client possono connettersi per chiedere al nostro modello di fare previsioni.

from PIL import Imagefrom torchvision import transformsimport torchimport osfrom fastapi import FastAPIfrom fastapi.responses import JSONResponsefrom .model import CatfishModel, CatfishClassesfrom io import BytesIOfrom fastapi import HTTPExceptionimport requestsapp = FastAPI()def open_image(image_path):    # Aggiungi qui qualsiasi logica necessaria di preprocessing delle immagini    image = Image.open(image_path)    return image  def load_model():    return [email protected]("/")def status():    return {"status"…