PyTorch 2.0 e torch.compile: compilazione del grafo senza riscrivere il codice
In una frase PyTorch 2.0 introduce torch.compile basato su TorchDynamo e il backend Inductor, offrendo fino a 2x di speedup su transformer senza modifiche al codice, rendendo PyTorch competitivo con XLA/JAX in produzione.
PyTorch ha sempre avuto un grande vantaggio: il codice gira riga per riga come normale Python, il che lo rende facilissimo da debuggare. Il problema è che questo approccio lascia molta performance sul tavolo rispetto ai framework che compilano tutto prima di eseguire.
PyTorch 2.0 introduce torch.compile, una sola funzione che puoi applicare al tuo modello. Sotto la superficie, un sistema chiamato TorchDynamo analizza il tuo codice Python al volo, cattura il grafo delle operazioni e lo passa a un backend di ottimizzazione chiamato Inductor, che genera codice kernel ottimizzato per CPU o GPU.
Il risultato pratico è notevole: sugli stessi modelli e hardware, torch.compile ottiene in media tra il 30% e il 200% di speedup rispetto al codice non compilato, senza che tu debba cambiare nemmeno una riga del tuo modello. Per chi addestra grandi transformer, questo si traduce direttamente in meno ore di compute e meno costi. PyTorch finalmente può competere con JAX, che aveva questo vantaggio da sempre, ma mantenendo la facilità d'uso che aveva reso PyTorch dominante nella ricerca.
Aziende
Meta
Tool
—
Tag
Fonti