Paper sulle Graph Attention Networks spiegato con illustrazioni e implementazione in PyTorch

Paper sulle Graph Attention Networks con illustrazioni e implementazione in PyTorch

Un esauriente e illustrato percorso nel paper delle “Graph Attention Networks” di Veličković et al. con l’implementazione in PyTorch del modello proposto.

Illustrazione del livello di passaggio dei messaggi in una Graph Attention Network — immagine dell'autore

Introduzione

Le reti neurali per i grafi (GNN, Graph Neural Networks) sono una potente classe di reti neurali che operano su dati strutturati come grafi. Apprendono le rappresentazioni dei nodi (embedding) aggregando informazioni dal vicinato locale di un nodo. Questo concetto è noto come “message passing” nella letteratura sull’apprendimento delle rappresentazioni dei grafi.

I messaggi (embedding) vengono passati tra i nodi del grafo attraverso strati multipli della GNN. Ogni nodo aggrega i messaggi dai suoi vicini per aggiornare la sua rappresentazione. Questo processo viene ripetuto attraverso gli strati, consentendo ai nodi di ottenere rappresentazioni che codificano informazioni più ricche sul grafo. Alcune delle varianti importanti delle GNN includono GraphSAGE [2], Graph Convolution Network [3], ecc. Puoi esplorare ulteriori varianti delle GNN qui.

Illustrazione semplice di un singolo passaggio di messaggi — immagine dell'autore

Graph Attention Networks (GAT) [1] sono una classe speciale di GNN proposta per migliorare questo schema di passaggio dei messaggi. Hanno introdotto un meccanismo di attenzione apprendibile che consente a un nodo di decidere quali nodi vicini sono più importanti quando aggrega i messaggi dal loro vicinato locale assegnando un peso tra ogni nodo sorgente e destinazione invece di aggregare informazioni da tutti i vicini con pesi uguali.

Empiricamente, le Graph Attention Networks hanno dimostrato di superare molti altri modelli GNN in compiti come la classificazione dei nodi, la previsione dei collegamenti e la classificazione dei grafi. Hanno dimostrato prestazioni all’avanguardia su diversi dataset di grafi di benchmark.

In questo post, esamineremo la parte cruciale del paper originale delle “Graph Attention Networks” di Veličković et al. [1], spiegheremo queste parti e implementeremo contemporaneamente le nozioni proposte nel paper utilizzando il framework PyTorch per comprendere meglio l’intuizione del metodo GAT.

Puoi anche accedere al codice completo utilizzato in questo post, che contiene il codice di addestramento e validazione in questo repository GitHub.

Andando Attraverso il Paper

Sezione 1 — Introduzione

Dopo una revisione generale dei metodi esistenti nella letteratura sull’apprendimento delle rappresentazioni dei grafi nella Sezione 1, “Introduzione”, viene introdotta la Graph Attention Network (GAT). Gli autori menzionano:

  1. Una visione generale del meccanismo di attenzione incorporato.
  2. Tre proprietà delle GAT, ovvero un calcolo efficiente, una generalizzabilità a tutti i nodi e una utilità nell’apprendimento induttivo.
  3. Confronti e dataset su cui hanno valutato le prestazioni delle GAT.
Sezione selezionata del paper originale delle GAT

Dopo aver confrontato il loro approccio con alcuni metodi esistenti e menzionando le somiglianze e le differenze generali tra di essi, passano alla sezione successiva del paper.

Sezione 2 — Architettura GAT

In questa sezione, che costituisce la parte principale del paper, viene presentata in dettaglio l’architettura delle Graph Attention Network. Per procedere con la spiegazione, si assume che l’architettura proposta funzioni su un grafo con N nodi (V = {vᵢ}; i=1,…,N) e che ogni nodo sia rappresentato da un vettore hᵢ di F elementi, con impostazioni arbitrarie di collegamenti tra i nodi.

Esempio di grafo di input - immagine dell'autore

Gli autori iniziano caratterizzando un singolo Graph Attention Layer e come funziona, che diventa il blocco di costruzione di una Graph Attention Network. In generale, uno strato GAT singolo dovrebbe prendere in input un grafo con i suoi nodi e incorporamenti (rappresentazioni) dati, propagare le informazioni ai nodi vicini locali e produrre una rappresentazione aggiornata dei nodi.

sezione selezionata dal documento originale GAT

Come evidenziato sopra, per farlo, prima di tutto, affermano che tutti i vettori delle caratteristiche di input del nodo (hᵢ) allo strato GA vengono trasformati linearmente (cioè moltiplicati da una matrice di pesi W), in PyTorch, di solito viene fatto come segue:

Trasformazione lineare delle caratteristiche del nodo - immagine dell'autore
import torchfrom torch import nn# in_features -> F e out_feature -> F'in_features = ...out_feature = ...# istanzia la matrice di pesi apprendibile W (FxF')W = nn.Parameter(torch.empty(size=(in_features, out_feature)))# Inizializza la matrice di pesi Wnn.init.xavier_normal_(W)# moltiplica W e h (h è una matrice NxF delle caratteristiche di input di tutti i nodi)h_transformed = torch.mm(h, W)

Ora, tenendo presente che abbiamo ottenuto una versione trasformata delle nostre caratteristiche (incorporamenti) di input dei nodi, saltiamo avanti di qualche passo per osservare e capire qual è il nostro obiettivo finale in uno strato GAT.

Come descritto nel documento, alla fine di uno strato di attenzione del grafo, per ogni nodo i, dobbiamo ottenere un nuovo vettore di caratteristiche che sia più consapevole della struttura e del contesto del suo vicinato.

Ciò avviene calcolando una somma ponderata delle caratteristiche dei nodi vicini seguita da una funzione di attivazione non lineare σ. Questa somma ponderata è anche nota come “Fase di Aggregazione” nelle operazioni generali degli strati GNN, secondo la letteratura di Machine Learning su grafi.

Questi pesi αᵢⱼ ∈ [0, 1] sono appresi e calcolati da un meccanismo di attenzione che indica l’importanza delle caratteristiche del vicino j per il nodo i durante la propagazione del messaggio e l’aggregazione.

sezione selezionata dal documento originale GAT

Ora vediamo come questi pesi di attenzione αᵢⱼ vengono calcolati per ogni coppia di nodi i e il suo vicino j:

In breve, i pesi di attenzione αᵢⱼ vengono calcolati come segue:

sezione selezionata dal documento originale GAT

Dove gli eᵢⱼ sono punteggi di attenzione e la funzione Softmax viene applicata in modo che tutti i pesi siano nell’intervallo [0, 1] e sommino a 1.

Le valutazioni di attenzione eᵢⱼ sono ora calcolate tra ogni nodo i e i suoi vicini j ∈ N attraverso la funzione di attenzione a(…) come segue:

Sezione selezionata dall'articolo originale di GAT

Dove || indica la concatenazione di due embedding di nodi trasformati, e a è un vettore di parametri apprendibili (cioè parametri di attenzione) di dimensione 2 * F’ (due volte la dimensione degli embedding trasformati).

E (aᵀ) è la trasposta del vettore a, risultando nell’espressione completa aᵀ [Whᵢ|| Whⱼ] che rappresenta il prodotto (interno) tra “a” e la concatenazione degli embedding trasformati.

L’intera operazione è illustrata di seguito:

Calcolo delle valutazioni di attenzione in GAT - immagine dell'autore

In PyTorch, per ottenere queste valutazioni, adottiamo un approccio leggermente diverso. Poiché è più efficiente calcolare eᵢⱼ tra tutte le coppie di nodi e quindi selezionare solo quelle che rappresentano i bordi esistenti tra i nodi. Per calcolare tutti eᵢⱼ:

# istanzia il vettore di parametri di attenzione apprendibili `a`a = nn.Parameter(torch.empty(size=(2 * out_feature, 1)))# Inizializza il vettore di parametri `a`nn.init.xavier_normal_(a)# abbiamo ottenuto `h_transformed` nello snippet di codice precedente# calcoliamo il prodotto scalare di tutti gli embedding dei nodi# e la prima metà dei parametri del vettore di attenzione (corrispondenti ai messaggi dei vicini)source_scores = torch.matmul(h_transformed, self.a[:out_feature, :])# calcoliamo il prodotto scalare di tutti gli embedding dei nodi# e la seconda metà dei parametri del vettore di attenzione (corrispondenti al nodo target)target_scores = torch.matmul(h_transformed, self.a[out_feature:, :])# addizione broadcast e = source_scores + target_scores.Te = self.leakyrelu(e)

Nell’ultima parte dello snippet di codice (# addizione broadcast), vengono sommati tutti i punteggi uno a uno di sorgente e destinazione, ottenendo una matrice NxN contenente tutti i punteggi di eᵢⱼ. (illustrato di seguito)

Calcolo parallelo vettorializzato dei punteggi di attenzione tra tutti i nodi in GAT - immagine dell'autore

Fino ad ora, abbiamo supposto che il grafo sia completamente connesso e abbiamo calcolato i punteggi di attenzione tra tutte le possibili coppie di nodi. Per affrontare questo problema, dopo che l’attivazione LeakyReLU viene applicata ai punteggi di attenzione, i punteggi di attenzione vengono mascherati in base ai bordi esistenti nel grafo, il che significa che vengono mantenuti solo i punteggi che corrispondono ai bordi esistenti.

Ciò può essere fatto assegnando un punteggio negativo elevato (per approssimare -∞) agli elementi della matrice dei punteggi tra nodi con bordi non esistenti in modo che i relativi pesi di attenzione diventino zero dopo softmax.

Possiamo ottenere ciò utilizzando la matrice di adiacenza del grafo. La matrice di adiacenza è una matrice NxN con 1 nella riga i e nella colonna j se c’è un bordo tra il nodo i e j e 0 altrove. Quindi creiamo la maschera assegnando -∞ agli elementi nulli della matrice di adiacenza e assegnando 0 altrove. Successivamente, aggiungiamo la maschera alla nostra matrice di punteggi e applichiamo la funzione softmax sulle sue righe.

connectivity_mask = -9e16 * torch.ones_like(e) # Maschera di connettività
e = torch.where(adj_mat > 0, e, connectivity_mask) # Punteggi di attenzione mascherati
# I coefficienti di attenzione vengono calcolati come softmax sulle righe
attention = F.softmax(e, dim=-1)

Infine, secondo il paper, dopo aver ottenuto i punteggi di attenzione e averli mascherati con gli archi esistenti, otteniamo i pesi di attenzione αᵢⱼ eseguendo il softmax sulle righe della matrice dei punteggi.

selected section from the original GAT paper
Illustration of applying connectivity mask and softmax to attention scores to attain attention coefficients — image by author.

E come discusso in precedenza, calcoliamo la somma pesata delle rappresentazioni dei nodi:

# Le rappresentazioni finali dei nodi vengono calcolate come una media pesata delle caratteristiche dei suoi vicini
h_prime = torch.matmul(attention, h_transformed)

Infine, il paper introduce il concetto di attenzione multi-head, in cui tutte le operazioni discusse vengono eseguite attraverso flussi paralleli di operazioni, dove i risultati finali delle attenzioni sono o mediati o concatenati.

selected section from the original GAT paper

Il processo di attenzione e aggregazione multi-head è illustrato di seguito:

An illustration of multi-head attention (with K = 3 heads) by node 1 in its neighborhood. Different arrow styles and colors denote independent attention computations. The aggregated features from each head are concatenated or averaged to obtain h’. — Image from the original paper

Per riassumere l’implementazione in una forma modulare più pulita (come un modulo PyTorch) e per incorporare la funzionalità di attenzione multi-head, l’intera implementazione del Graph Attention Layer viene fatta come segue:

import torch
from torch import nn
import torch.nn.functional as F

###################################  DEFINIZIONE DELLO STRATO DI ATTENZIONE GRAFICA (GAT)    ###################################

class GraphAttentionLayer(nn.Module):
    def __init__(self, in_features: int, out_features: int,
                 n_heads: int, concat: bool = False, dropout: float = 0.4,
                 leaky_relu_slope: float = 0.2):
        super(GraphAttentionLayer, self).__init__()
        self.n_heads = n_heads # Numero di teste di attenzione
        self.concat = concat # se concatenare le teste di attenzione finali
        self.dropout = dropout # Tasso di dropout
        if concat: # concatenare le teste di attenzione
            self.out_features = out_features # Numero di caratteristiche di output per nodo
            assert out_features % n_heads == 0 # Assicurarsi che out_features sia un multiplo di n_heads
            self.n_hidden = out_features // n_heads
        else: # mediare l'output sulle teste di attenzione (Usato nel paper principale)
            self.n_hidden = out_features
        # Applicare una trasformazione lineare condivisa, parametrizzata da una matrice di pesi W, ad ogni nodo
        # Inizializza la matrice di pesi W
        self.W = nn.Parameter(torch.empty(size=(in_features, self.n_hidden * n_heads)))
        # Inizializza i pesi di attenzione a
        self.a = nn.Parameter(torch.empty(size=(n_heads, 2 * self.n_hidden, 1)))
        self.leakyrelu = nn.LeakyReLU(leaky_relu_slope) # Funzione di attivazione LeakyReLU
        self.softmax = nn.Softmax(dim=1) # Funzione di attivazione softmax per i coefficienti di attenzione
        self.reset_parameters() # Reimposta i parametri

    def reset_parameters(self):
        nn.init.xavier_normal_(self.W)
        nn.init.xavier_normal_(self.a)

    def _get_attention_scores(self, h_transformed: torch.Tensor):
        source_scores = torch.matmul(h_transformed, self.a[:, :self.n_hidden, :])
        target_scores = torch.matmul(h_transformed, self.a[:, self.n_hidden:, :])
        # Addizione broadcast
        # (n_heads, n_nodes, 1) + (n_heads, 1, n_nodes) = (n_heads, n_nodes, n_nodes)
        e = source_scores + target_scores.transpose(1, 2)
        return self.leakyrelu(e)

    def forward(self,  h: torch.Tensor, adj_mat: torch.Tensor):
        n_nodes = h.shape[0]
        # Applica la trasformazione lineare alla caratteristica del nodo -> W h
        # Forma di output (n_nodes, n_hidden * n_heads)
        h_transformed = torch.mm(h, self.W)
        h_transformed = F.dropout(h_transformed, self.dropout, training=self.training)
        # Suddividere le teste di attenzione ridimensionando il tensore e mettendo la dimensione delle teste per prima
        # Forma di output (n_heads, n_nodes, n_hidden)
        h_transformed = h_transformed.view(n_nodes, self.n_heads, self.n_hidden).permute(1, 0, 2)

        # Ottenere i punteggi di attenzione
        # Forma di output (n_heads, n_nodes, n_nodes)
        e = self._get_attention_scores(h_transformed)

        # Imposta il punteggio di attenzione per gli archi non esistenti a -9e15 (MASCHERAMENTO DEGLI ARCHI NON ESISTENTI)
        connectivity_mask = -9e16 * torch.ones_like(e)
        e = torch.where(adj_mat > 0, e, connectivity_mask) # Punteggi di attenzione mascherati

        # I coefficienti di attenzione vengono calcolati come softmax sulle righe
        # per ogni colonna j della matrice dei punteggi di attenzione e
        attention = F.softmax(e, dim=-1)
        attention = F.dropout(attention, self.dropout, training=self.training)
        # Le rappresentazioni finali dei nodi vengono calcolate come una media pesata delle caratteristiche dei suoi vicini
        h_prime = torch.matmul(attention, h_transformed)

        # Concatenazione / media delle teste di attenzione
        # Forma di output (n_nodes, out_features)
        if self.concat:
            h_prime = h_prime.permute(1, 0, 2).contiguous().view(n_nodes, self.out_features)
        else:
            h_prime = h_prime.mean(dim=0)
        return h_prime

In seguito, gli autori fanno un confronto tra GAT e alcune delle altre metodologie/architetture GNN esistenti. Argomentano che:

  1. GAT è computazionalmente più efficiente rispetto a alcuni metodi esistenti grazie alla capacità di calcolare i pesi di attenzione e di eseguire l’aggregazione locale in parallelo.
  2. GAT può assegnare differenti importanze ai vicini di un nodo durante l’aggregazione dei messaggi, il che può consentire un aumento della capacità del modello e aumentare l’interpretabilità.
  3. GAT considera l’intero vicinato dei nodi (non richiede il campionamento dai vicini) e non assume alcun ordinamento tra i nodi.
  4. GAT può essere riformulato come un’istanza particolare di MoNet (Monti et al., 2016) impostando la funzione pseudo-coordinata come u(x, y) = f(x)||f(y), dove f(x) rappresenta le caratteristiche del nodo x (potenzialmente trasformate da MLP) e || è la concatenazione; e la funzione di peso come wj(u) = softmax(MLP(u))

Sezione 3 — Valutazione

Nella terza sezione del paper, innanzitutto, gli autori descrivono i benchmark, i dataset e i compiti su cui è valutato il GAT. Successivamente presentano i risultati della valutazione del modello.

Apprendimento transduttivo vs. Apprendimento induttivoI dataset utilizzati come benchmark in questo paper sono differenziati in due tipi di compiti, Transduttivo e Induttivo.

  • Apprendimento induttivo: È un tipo di compito di apprendimento supervisionato in cui un modello viene addestrato solo su un insieme di esempi di addestramento etichettati e il modello addestrato viene valutato e testato su esempi che non sono stati osservati durante l’addestramento. È il tipo di apprendimento noto come apprendimento supervisionato comune.
  • Apprendimento transduttivo: In questo tipo di compito, tutti i dati, inclusi gli esempi di addestramento, di convalida e di test, vengono utilizzati durante l’addestramento. Ma in ogni fase, solo l’insieme corrispondente di etichette viene accesso dal modello. Significa che durante l’addestramento, il modello viene addestrato solo utilizzando la loss risultante dagli esempi e dalle etichette di addestramento, ma le caratteristiche di test e di convalida vengono utilizzate per il passaggio dei messaggi. Questo avviene principalmente a causa delle informazioni strutturali e contestuali esistenti negli esempi.

DatasetNel paper, vengono utilizzati quattro dataset di benchmark per valutare i GAT, tre dei quali corrispondono all’apprendimento transduttivo, e un altro viene utilizzato come compito di apprendimento induttivo.

I dataset di apprendimento transduttivo, ovvero i dataset Cora, Citeseer e Pubmed (Sen et al., 2008), sono tutti grafo delle citazioni in cui i nodi sono documenti pubblicati e gli archi (connessioni) sono citazioni tra di essi, e le caratteristiche dei nodi sono elementi di una rappresentazione bag-of-words di un documento. Il dataset di apprendimento induttivo è un dataset di interazione proteina-proteina (PPI) che contiene grafi di diversi tegumenti umani (Zitnik & Leskovec, 2017). I dataset sono descritti più in dettaglio di seguito:

Riassunto dei dataset utilizzati nei nostri esperimenti - tratto dal paper originale.

Setup & Risultati

  • Per i tre compiti transduttivi, l’impostazione utilizzata per l’addestramento è:Utilizzano 2 strati GAT — lo strato 1 utilizza- K = 8 attenzione head- F’ = 8 dimensione delle caratteristiche di output per head- attivazione ELUe per il secondo strato [Cora & Citeseer | Pubmed]- [1 | 8] attenzione head con C numero di classi dimensione di output- attivazione Softmax per la probabilità di classificazione di outpute per l’intera rete- Dropout con p = 0.6– regolarizzazione L2 con λ = [0.0005 | 0.001]
  • Per i tre compiti transduttivi, l’impostazione utilizzata per l’addestramento è:Tre strati — – Strato 1 & 2: K = 4 | F’ = 256 | ELU – Strato 3: K = 6 | F’ = C classi | Sigmoid (multi-label)senza regolarizzazione e dropout

L’implementazione della prima configurazione in PyTorch è riportata di seguito utilizzando il layer che abbiamo definito in precedenza:

class GAT(nn.Module):    def __init__(self,        in_features,        n_hidden,        n_heads,        num_classes,        concat=False,        dropout=0.4,        leaky_relu_slope=0.2):        super(GAT, self).__init__()        # Definizione dei layer Graph Attention        self.gat1 = GraphAttentionLayer(            in_features=in_features, out_features=n_hidden, n_heads=n_heads,            concat=concat, dropout=dropout, leaky_relu_slope=leaky_relu_slope            )                self.gat2 = GraphAttentionLayer(            in_features=n_hidden, out_features=num_classes, n_heads=1,            concat=False, dropout=dropout, leaky_relu_slope=leaky_relu_slope            )    def forward(self, input_tensor: torch.Tensor , adj_mat: torch.Tensor):        # Applicazione del primo Graph Attention layer        x = self.gat1(input_tensor, adj_mat)        x = F.elu(x) # Applicazione della funzione di attivazione ELU all'output del primo layer        # Applicazione del secondo Graph Attention layer        x = self.gat2(x, adj_mat)        return F.softmax(x, dim=1) # Applicazione della funzione di attivazione softmax

Dopo i test, gli autori riportano le seguenti performance per le quattro valutazioni mostrando i risultati comparabili dei GAT rispetto ai metodi GNN esistenti.

Riepilogo dei risultati in termini di accuratezza di classificazione per Cora, Citeseer e Pubmed - dal paper originale.
Riepilogo dei risultati in termini di punteggi F1 micro-averaged, per il dataset PPI - dal paper originale.

Conclusioni

In conclusione, in questo post ho cercato di seguire un approccio dettagliato e facile da seguire per spiegare il paper “Graph Attention Networks” di Veličković et al. utilizzando illustrazioni per aiutare i lettori a comprendere le idee principali di questi network e perché sono importanti per lavorare con dati strutturati a grafo complessi (ad esempio, reti sociali o molecole). Inoltre, il post include un’implementazione pratica del modello utilizzando PyTorch, un framework di programmazione popolare. Leggendo il post e provando il codice, spero che i lettori possano acquisire una solida comprensione di come funzionano i GAT e come possono essere applicati in scenari reali. Spero che questo post sia stato utile e stimoli a esplorare ulteriormente questa entusiasmante area di ricerca.

Inoltre, è possibile accedere al codice completo utilizzato in questo post, contenente il codice di addestramento e di convalida in questo repository GitHub.

Sarei felice di sentire eventuali idee o suggerimenti/cambiamenti sul post.

Riferimenti

[1] — Graph Attention Networks (2017), Petar Veličković, Guillem Cucurull, Arantxa Casanova, Adriana Romero, Pietro Liò, Yoshua Bengio. arXiv:1710.10903v3

[2] — Inductive Representation Learning on Large Graphs (2017), William L. Hamilton, Rex Ying, Jure Leskovec. arXiv:1706.02216v4

[3] — Semi-Supervised Classification with Graph Convolutional Networks (2016), Thomas N. Kipf, Max Welling. arXiv:1609.02907v4