Perfeziona i modelli Whisper su Amazon SageMaker con LoRA

Perfeziona i modelli di Whisper su Amazon SageMaker con LoRA

Whisper è un modello di riconoscimento automatico del parlato (ASR) che è stato addestrato utilizzando 680.000 ore di dati supervisionati provenienti da internet, che comprendono una gamma di lingue e compiti. Una delle sue limitazioni è la bassa performance nelle lingue a bassa risorsa come il marathi e le lingue dravidiche, che possono essere risolte con il fine-tuning. Tuttavia, il fine-tuning di un modello Whisper è diventato una sfida considerevole, sia in termini di risorse di calcolo che di requisiti di archiviazione. Cinque-dieci esecuzioni complete di fine-tuning per i modelli Whisper richiedono circa 100 ore di GPU A100 (40 GB SXM4) (varia in base alle dimensioni del modello e ai parametri del modello) e ogni checkpoint fine-tuned necessita circa 7 GB di spazio di archiviazione. Questa combinazione di elevate richieste di calcolo e archiviazione può rappresentare ostacoli significativi, soprattutto in ambienti con risorse limitate, rendendo spesso estremamente difficile ottenere risultati significativi.

L’Adattamento a basso rango, noto anche come LoRA, adotta un approccio unico al fine-tuning del modello. Mantiene i pesi del modello pre-addestrato in uno stato statico e introduce matrici di decomposizione del rango addestrabili in ogni livello della struttura del Transformer. Questo metodo può ridurre il numero di parametri addestrabili necessari per i compiti successivi di un fattore di 10.000 e ridurre il requisito di memoria GPU del 3. Per quanto riguarda la qualità del modello, LoRA ha dimostrato di eguagliare o addirittura superare le prestazioni dei metodi tradizionali di fine-tuning, nonostante operi con meno parametri addestrabili (vedi i risultati dall’originale paper di LoRA). Offre anche il vantaggio di un aumento della velocità di addestramento. A differenza dei metodi adapter, LoRA non introduce latenza aggiuntiva durante l’inferenza, mantenendo così l’efficienza del modello durante la fase di implementazione. Il fine-tuning di Whisper utilizzando LoRA ha mostrato risultati promettenti. Prendiamo ad esempio Whisper-Large-v2: eseguire 3 epoche con un dataset di 12 ore di Common Voice su una GPU con 8 GB di memoria richiede 6-8 ore, che è 5 volte più veloce del fine-tuning completo con prestazioni comparabili.

Amazon SageMaker è una piattaforma ideale per implementare il fine-tuning di Whisper utilizzando LoRA. Amazon SageMaker consente di creare, addestrare e implementare modelli di machine learning per qualsiasi caso d’uso con un’infrastruttura, strumenti e flussi di lavoro completamente gestiti. Ulteriori vantaggi dell’addestramento del modello possono includere costi di addestramento inferiori con Managed Spot Training, librerie di addestramento distribuito per dividere i modelli e i dataset di addestramento tra le istanze GPU di AWS e altro. I modelli SageMaker addestrati possono essere facilmente implementati per l’inferenza direttamente su SageMaker. In questo articolo, presentiamo una guida passo-passo per implementare il fine-tuning con LoRA su SageMaker. Il codice sorgente associato a questa implementazione può essere trovato su GitHub.

Prepara il dataset per il fine-tuning

Utilizziamo la lingua a bassa risorsa marathi per il compito di fine-tuning. Utilizzando la libreria “Hugging Face datasets”, è possibile scaricare e dividere il dataset Common Voice in dataset di addestramento e di test. Ecco il codice:

from datasets import load_dataset, DatasetDictlanguage = "Marathi"language_abbr = "mr"task = "transcribe"dataset_name = "mozilla-foundation/common_voice_11_0"common_voice = DatasetDict()common_voice["train"] = load_dataset(dataset_name, language_abbr, split="train+validation", use_auth_token=True)common_voice["test"] = load_dataset(dataset_name, language_abbr, split="test", use_auth_token=True)

Il modello di riconoscimento del parlato Whisper richiede che gli input audio siano file WAV a 16 kHz, mono, a 16 bit, con interi firmati. Poiché il dataset Common Voice ha un tasso di campionamento di 48 kHz, è necessario ridurre il campionamento dei file audio. Quindi è necessario applicare l’estrattore di caratteristiche di Whisper all’audio per estrarre le caratteristiche dello spettrogramma log-mel, e applicare il tokenizer di Whisper alle caratteristiche frammentate per convertire ogni frase nella trascrizione in un ID di token. Ecco il codice:

from transformers import WhisperFeatureExtractor
from transformers import WhisperTokenizer

feature_extractor = WhisperFeatureExtractor.from_pretrained(model_name_or_path)
tokenizer = WhisperTokenizer.from_pretrained(model_name_or_path, language=language, task=task)

def prepare_dataset(batch):
    # load and resample audio data from 48 to 16kHz
    audio = batch["audio"]
    
    # compute log-Mel input features from input audio array
    batch["input_features"] = feature_extractor(audio["array"], sampling_rate=audio["sampling_rate"]).input_features[0]
    
    # encode target text to label ids
    batch["labels"] = tokenizer(batch["sentence"]).input_ids
    
    return batch

# apply the data preparation function to all of our fine-tuning dataset samples using dataset's .map method
common_voice = common_voice.map(prepare_dataset, remove_columns=common_voice.column_names["train"], num_proc=2)
common_voice.save_to_disk("marathi-common-voice-processed")

!aws s3 cp --recursive "marathi-common-voice-processed" s3://<Your-S3-Bucket>

Dopo aver elaborato tutti i campioni di addestramento, carica i dati elaborati su Amazon S3, in modo che durante l’utilizzo dei dati di addestramento elaborati nella fase di raffinamento, puoi utilizzare FastFile per montare direttamente il file S3 anziché copiarlo nel disco locale:

from sagemaker.inputs import TrainingInput

training_input_path=s3uri
training = TrainingInput(s3_data_type='S3Prefix', # Opzioni disponibili: S3Prefix | ManifestFile | AugmentedManifestFile
                         s3_data=training_input_path,
                         distribution='FullyReplicated', # Opzioni disponibili: FullyReplicated | ShardedByS3Key
                         input_mode='FastFile')

Allenare il modello

A titolo di esempio, utilizziamo whisper-large-v2 come modello pre-addestrato (whisper v3 è ora disponibile), che può essere importato attraverso la libreria transformers di Hugging Face. Puoi utilizzare la quantizzazione a 8 bit per migliorare ulteriormente l’efficienza di addestramento. La quantizzazione a 8 bit offre un’ottimizzazione della memoria arrotondando da floating point a interi a 8 bit. È una tecnica di compressione del modello comunemente utilizzata per ottenere risparmi di memoria ridotti senza sacrificare troppo la precisione durante l’infereza.

Per caricare il modello pre-addestrato in formato quantizzato a 8 bit, aggiungiamo semplicemente l’argomento load_in_8bit=True durante l’istanziazione del modello, come mostrato nel seguente codice. Ciò caricherà i pesi del modello quantizzati a 8 bit, riducendo l’impronta di memoria.

from transformers import WhisperForConditionalGeneration

model = WhisperForConditionalGeneration.from_pretrained(model_name_or_path, load_in_8bit=True, device_map="auto")

Utilizziamo l’implementazione LoRA dal pacchetto peft di Hugging Face. Ci sono quattro passaggi per raffinare un modello utilizzando LoRA:

  1. Istanziare un modello di base (come abbiamo fatto nell’ultimo passaggio).
  2. Creare una configurazione (LoraConfig) in cui sono definiti i parametri specifici di LoRA.
  3. Avvolgere il modello di base con get_peft_model() per ottenere un PeftModel trainabile.
  4. Allenare il PeftModel come modello di base.

Guarda il seguente codice:

from peft import LoraConfig, get_peft_model

config = LoraConfig(r=32, lora_alpha=64, target_modules=["q_proj", "v_proj"], lora_dropout=0.05, bias="none")
model = get_peft_model(model, config)

training_args = Seq2SeqTrainingArguments(
    output_dir=args.model_dir,
    per_device_train_batch_size=int(args.train_batch_size),
    gradient_accumulation_steps=1,
    learning_rate=float(args.learning_rate),
    warmup_steps=args.warmup_steps,
    num_train_epochs=args.num_train_epochs,
    evaluation_strategy="epoch",
    fp16=True,
    per_device_eval_batch_size=args.eval_batch_size,
    generation_max_length=128,
    logging_steps=25,
    remove_unused_columns=False,
    label_names=["labels"],
)

trainer = Seq2SeqTrainer(
    args=training_args,
    model=model,
    train_dataset=train_dataset["train"],
    eval_dataset=train_dataset.get("test", train_dataset["test"]),
    data_collator=data_collator,
    tokenizer=processor.feature_extractor,
)

Per eseguire un job di addestramento di SageMaker, utilizziamo il nostro proprio container Docker. Puoi scaricare l’immagine Docker da GitHub, dove ffmpeg4 e git-lfs sono confezionati insieme ad altri requisiti Python. Per saperne di più su come adattare il proprio container Docker per funzionare con SageMaker, consulta Adattare il proprio container di addestramento. Quindi puoi utilizzare l’Estimator di Hugging Face per avviare un job di addestramento di SageMaker:

OUTPUT_PATH= f's3://{BUCKET}/{PREFIX}/{TRAINING_JOB_NAME}/output/'
huggingface_estimator = HuggingFace(entry_point='train.sh',source_dir='./src',output_path= OUTPUT_PATH,instance_type=instance_type,instance_count=1,# transformers_version='4.17.0',# pytorch_version='1.10.2',py_version='py310',image_uri=<ECR-PATH>,role=ROLE,metric_definitions = metric_definitions,volume_size=200,distribution=distribution,keep_alive_period_in_seconds=1800,environment=environment,)
huggingface_estimator.fit(job_name=TRAINING_JOB_NAME, wait=False)

L’implementazione di LoRA ci ha permesso di eseguire il compito di fine-tuning di Whisper su un’unica istanza GPU (ad esempio, ml.g5.2xlarge). In confronto, il compito di fine-tuning completo di Whisper richiede più GPU (ad esempio, ml.p4d.24xlarge) e un tempo di formazione molto più lungo. In particolare, il nostro esperimento ha dimostrato che il compito di fine-tuning completo richiede 24 volte più ore di GPU rispetto all’approccio LoRA.

Valutare le prestazioni del modello

Per valutare le prestazioni del modello Whisper sintonizzato, calcoliamo il tasso di errore sulle parole (WER) su un set di test separato. Il WER misura la differenza tra la trascrizione prevista e la trascrizione di riferimento. Un WER più basso indica migliori prestazioni. Puoi eseguire lo script seguente sul modello pre-addestrato e sul modello sintonizzato e confrontare la differenza di WER:

metric = evaluate.load("wer")
eval_dataloader = DataLoader(common_voice["test"], batch_size=8, collate_fn=data_collator)
model.eval()
for step, batch in enumerate(tqdm(eval_dataloader)):
    with torch.cuda.amp.autocast():
        with torch.no_grad():
            generated_tokens = (model.generate(input_features=batch["input_features"].to("cuda"),decoder_input_ids=batch["labels"][:, :4].to("cuda"),max_new_tokens=255,).cpu().numpy())
            labels = batch["labels"].cpu().numpy()
            labels = np.where(labels != -100, labels, tokenizer.pad_token_id)
            decoded_preds = tokenizer.batch_decode(generated_tokens, skip_special_tokens=True)
            decoded_labels = tokenizer.batch_decode(labels, skip_special_tokens=True)
            metric.add_batch(predictions=decoded_preds,references=decoded_labels,)
            del generated_tokens, labels, batch
            gc.collect()
wer = 100 * metric.compute()
print(f"{wer=}")

Conclusioni

In questo post, abbiamo dimostrato il fine-tuning di Whisper, un modello di riconoscimento del parlato all’avanguardia. In particolare, abbiamo usato PEFT LoRA di Hugging Face e abilitato la quantizzazione a 8 bit per un’addestramento efficiente. Abbiamo anche dimostrato come eseguire il lavoro di addestramento su SageMaker.

Anche se questo è un primo passo importante, ci sono diverse modalità in cui puoi sviluppare ulteriormente il modello Whisper per migliorarlo. In futuro, considera di utilizzare l’addestramento distribuito di SageMaker per scalare l’addestramento su un dataset molto più grande. Ciò permetterà al modello di addestrarsi su dati più variati e completi, migliorando l’accuratezza. Puoi anche ottimizzare la latenza durante il servizio del modello Whisper, per abilitare il riconoscimento del parlato in tempo reale. Inoltre, potresti estendere il lavoro per gestire trascrizioni audio più lunghe, che richiede cambiamenti all’architettura del modello e agli schemi di addestramento.

Ringraziamenti

Gli autori estendono la loro gratitudine a Paras Mehra, John Sol ed Evandro Franco per il loro prezioso feedback e la revisione del post.