Cómo guardar y retomar entrenamientos de ML con checkpointing en Python

Llevas seis horas entrenando un modelo en una GPU alquilada. La precisión mejora, las pérdidas bajan, todo va bien. Entonces la conexión se cae, o la instancia se interrumpe por mantenimiento, o simplemente cierras la terminal sin querer. Sin checkpointing, vuelves a cero. Con checkpointing, retomas desde donde lo dejaste.

Es una de esas cosas que parece evidente hasta que la necesitas por primera vez y no la tienes.

El problema de las GPUs efímeras

Entrenar en local con tu propia GPU tiene sus ventajas, pero muchos proyectos acaban en instancias de nube como Lambda Labs, RunPod o Vast.ai. Son económicas: puedes conseguir una RTX 3090 por menos de 0,30 € la hora, o una A100 por alrededor de 1-2 €. El problema es que esas instancias son efímeras por naturaleza.

Un entrenamiento de NLP sobre un dataset grande puede durar fácilmente 12-24 horas. Un fine-tuning de un modelo de visión, varios días. Durante ese tiempo pueden pasar muchas cosas: un timeout de red, un reinicio de servidor, que se te acabe el crédito, o simplemente que quieras pausar y retomar mañana. Sin guardar estado, cualquiera de esos eventos borra todo el progreso.

La solución es guardar un checkpoint cada N épocas: un archivo con el estado del modelo, el optimizador y los metadatos suficientes para continuar el entrenamiento como si nada hubiera pasado.

Checkpointing con PyTorch

Qué guardar en el checkpoint

Un checkpoint de PyTorch útil incluye al menos estos cuatro elementos:

  • model.state_dict(): los pesos del modelo.
  • optimizer.state_dict(): el estado del optimizador, con los momentos acumulados de Adam o SGD. Sin esto, el optimizador arranca de cero y el entrenamiento puede desestabilizarse.
  • La época actual, para saber desde dónde retomar.
  • La mejor métrica registrada hasta ahora, para no sobreescribir un buen modelo con uno peor.

Con torch.save() puedes guardar un diccionario Python directamente en disco:

import torch

def save_checkpoint(model, optimizer, epoch, best_metric, path):
    checkpoint = {
        'epoch': epoch,
        'model': model.state_dict(),
        'optimizer': optimizer.state_dict(),
        'best_metric': best_metric,
    }
    torch.save(checkpoint, path)

Ejemplo completo con bucle de entrenamiento

Aquí tienes un bucle típico que guarda checkpoint al final de cada época:

import torch
import torch.nn as nn

model = MyModel().to(device)
optimizer = torch.optim.Adam(model.parameters(), lr=1e-3)
criterion = nn.CrossEntropyLoss()

start_epoch = 0
best_metric = 0.0
checkpoint_path = 'checkpoint_last.pt'

# Retomar si existe un checkpoint previo
try:
    checkpoint = torch.load(checkpoint_path, map_location=device)
    model.load_state_dict(checkpoint['model'])
    optimizer.load_state_dict(checkpoint['optimizer'])
    start_epoch = checkpoint['epoch'] + 1
    best_metric = checkpoint['best_metric']
    print(f"Retomando desde época {start_epoch}")
except FileNotFoundError:
    print("Sin checkpoint previo, empezando desde cero")

for epoch in range(start_epoch, NUM_EPOCHS):
    model.train()
    for batch in train_loader:
        inputs, labels = batch
        inputs, labels = inputs.to(device), labels.to(device)

        optimizer.zero_grad()
        outputs = model(inputs)
        loss = criterion(outputs, labels)
        loss.backward()
        optimizer.step()

    # Guardar checkpoint al final de cada época
    save_checkpoint(model, optimizer, epoch, best_metric, checkpoint_path)
    print(f"Época {epoch} completada, checkpoint guardado")

Para cargar el checkpoint y retomar, torch.load() devuelve el diccionario tal cual, y load_state_dict() aplica los pesos:

checkpoint = torch.load('checkpoint_last.pt', map_location=device)
model.load_state_dict(checkpoint['model'])
optimizer.load_state_dict(checkpoint['optimizer'])
start_epoch = checkpoint['epoch'] + 1

El parámetro map_location=device es importante: si guardaste en GPU y cargas en CPU (o al revés), PyTorch lo gestiona sin errores.

Guardar solo el mejor modelo

Guardar cada época está bien para poder retomar, pero no siempre quieres el último modelo: quieres el mejor. Si la métrica de validación empeora en las últimas épocas (overfitting), el checkpoint final no es el que te interesa.

La solución es comparar la métrica actual con la mejor registrada y sobreescribir solo si mejora:

def save_best_checkpoint(model, optimizer, epoch, current_metric, best_metric, path_best):
    if current_metric > best_metric:
        best_metric = current_metric
        checkpoint = {
            'epoch': epoch,
            'model': model.state_dict(),
            'optimizer': optimizer.state_dict(),
            'best_metric': best_metric,
        }
        torch.save(checkpoint, path_best)
        print(f"Mejor modelo guardado (métrica: {best_metric:.4f})")
    return best_metric

En el bucle de entrenamiento, llamas a esta función después de calcular la métrica de validación:

    # Al final de cada época, tras evaluar en validación
    val_accuracy = evaluate(model, val_loader, device)
    best_metric = save_best_checkpoint(
        model, optimizer, epoch,
        current_metric=val_accuracy,
        best_metric=best_metric,
        path_best='checkpoint_best.pt'
    )
    # Guardar también el último (para poder retomar)
    save_checkpoint(model, optimizer, epoch, best_metric, 'checkpoint_last.pt')

Con esto tienes dos archivos: checkpoint_last.pt para retomar el entrenamiento si se interrumpe, y checkpoint_best.pt para inferencia con el mejor modelo obtenido.

Checkpointing con Keras y TensorFlow

En Keras la gestión de checkpoints es más directa gracias al callback ModelCheckpoint, que se pasa al método fit():

import tensorflow as tf

checkpoint_cb = tf.keras.callbacks.ModelCheckpoint(
    filepath='checkpoint_best.h5',
    monitor='val_accuracy',
    save_best_only=True,
    save_weights_only=False,
    mode='max',
    verbose=1
)

model.fit(
    train_dataset,
    epochs=50,
    validation_data=val_dataset,
    callbacks=[checkpoint_cb]
)

Con save_best_only=True, el callback solo sobreescribe el archivo cuando la métrica monitorizada mejora. El parámetro mode='max' indica que "mejor" significa valor más alto (para accuracy); usa 'min' para pérdidas.

Formato HDF5 vs SavedModel

La extensión .h5 guarda en formato HDF5, que incluye arquitectura, pesos y configuración del optimizador en un solo archivo. Es cómodo pero más lento en modelos grandes. El formato SavedModel (por defecto en TF2 si usas un directorio como filepath) es más eficiente y compatible con TensorFlow Serving:

checkpoint_cb = tf.keras.callbacks.ModelCheckpoint(
    filepath='checkpoints/epoch_{epoch:02d}',
    save_freq='epoch'
)

Para retomar desde un checkpoint guardado con Keras:

model = tf.keras.models.load_model('checkpoint_best.h5')

# O si solo guardaste pesos (save_weights_only=True):
model = build_model()  # recrear la arquitectura primero
model.load_weights('checkpoint_best.h5')

Gradient checkpointing: ahorrar memoria de GPU

Este es un problema diferente al anterior. No se trata de guardar progreso entre sesiones, sino de reducir el consumo de VRAM durante el entrenamiento.

PyTorch guarda las activaciones intermedias de toda la red para calcular los gradientes en la pasada hacia atrás. Con modelos grandes (transformers de cientos de millones de parámetros), esto puede agotar la memoria de la GPU aunque el modelo en sí quepa en VRAM.

torch.utils.checkpoint.checkpoint() resuelve esto recalculando las activaciones durante la pasada hacia atrás en lugar de guardarlas. Consume más tiempo de cómputo, pero reduce la memoria significativamente:

from torch.utils.checkpoint import checkpoint

class MyTransformerLayer(nn.Module):
    def __init__(self):
        super().__init__()
        self.attention = SelfAttention()
        self.ffn = FeedForward()

    def forward(self, x):
        # Sin gradient checkpointing:
        # x = self.attention(x)
        # x = self.ffn(x)

        # Con gradient checkpointing:
        x = checkpoint(self.attention, x)
        x = checkpoint(self.ffn, x)
        return x

El trade-off es claro: reduces el uso de VRAM entre un 30 y un 60% a cambio de un 20-30% más de tiempo de entrenamiento por la recomputación. Tiene sentido cuando tienes un modelo que no cabe en GPU a batch size razonable, pero sí con gradient checkpointing.

Los modelos de Hugging Face lo soportan directamente con una línea:

model.gradient_checkpointing_enable()

Subir checkpoints a almacenamiento remoto

Guardar en disco local de la instancia tiene un problema: si destruyes la instancia para no pagar cuando no entrenas, el disco desaparece con ella. Por eso conviene subir los checkpoints a almacenamiento externo.

AWS S3 con boto3

import boto3

s3 = boto3.client('s3')

def upload_checkpoint_s3(local_path, bucket, s3_key):
    s3.upload_file(local_path, bucket, s3_key)
    print(f"Checkpoint subido a s3://{bucket}/{s3_key}")

# Llamar después de guardar localmente
upload_checkpoint_s3(
    local_path='checkpoint_best.pt',
    bucket='mi-bucket-ml',
    s3_key='experimentos/modelo-v2/checkpoint_best.pt'
)

Para bajar el checkpoint al empezar una nueva sesión:

s3.download_file('mi-bucket-ml', 'experimentos/modelo-v2/checkpoint_best.pt', 'checkpoint_best.pt')

Google Cloud Storage

from google.cloud import storage

client = storage.Client()
bucket = client.bucket('mi-bucket-ml')

def upload_to_gcs(local_path, gcs_path):
    blob = bucket.blob(gcs_path)
    blob.upload_from_filename(local_path)
    print(f"Subido a gs://mi-bucket-ml/{gcs_path}")

Hugging Face Hub

Si trabajas con modelos transformers, la opción más cómoda es subir directamente al Hub:

from transformers import AutoModelForSequenceClassification
from huggingface_hub import HfApi

# Subir el modelo completo
model.push_to_hub("mi-usuario/mi-modelo-finetuned")

# O solo los pesos
api = HfApi()
api.upload_file(
    path_or_fileobj="checkpoint_best.pt",
    path_in_repo="checkpoint_best.pt",
    repo_id="mi-usuario/mi-modelo-finetuned",
    repo_type="model"
)

El Hub guarda historial de versiones, así que cada subida queda registrada. Puedes hacer el repositorio privado si no quieres que sea público.

Estrategia práctica para un entrenamiento largo

Con todo lo anterior, una estrategia que funciona bien en la práctica combina varias piezas:

  • Checkpoint cada N épocas (por ejemplo cada 5) para poder retomar si algo falla, sin saturar el disco con un archivo por época.
  • Checkpoint del mejor modelo siempre, sobreescribiendo solo cuando mejora la métrica de validación.
  • Subida automática a S3 o GCS después de cada checkpoint, para que sobreviva a la destrucción de la instancia.
  • Limpieza de checkpoints viejos: solo conservar los últimos 2-3 para no llenar el disco.

Antes de retomar un entrenamiento desde un checkpoint, conviene hacer una comprobación mínima de integridad: verificar que el archivo tiene un tamaño razonable (no quedó a medias por un corte de luz) y que las claves esperadas están presentes:

def verify_checkpoint(path, expected_keys=('model', 'optimizer', 'epoch', 'best_metric')):
    import os
    if not os.path.exists(path):
        return False, "Archivo no encontrado"
    if os.path.getsize(path) < 1024:  # menos de 1 KB es sospechoso
        return False, "Archivo demasiado pequeño"
    try:
        ck = torch.load(path, map_location='cpu')
        missing = [k for k in expected_keys if k not in ck]
        if missing:
            return False, f"Claves ausentes: {missing}"
        return True, "OK"
    except Exception as e:
        return False, str(e)

valid, msg = verify_checkpoint('checkpoint_last.pt')
if not valid:
    print(f"Checkpoint inválido: {msg}. Empezando desde cero.")

No es necesario ser exhaustivo en la verificación, pero sí evitar cargar un archivo corrupto que haga fallar el entrenamiento horas después.

Para más sobre Python para computación científica y numérica, o si te interesa la optimización del rendimiento de Python, tienes más artículos en programacion.net.

Imagen: Pexels / Google DeepMind

COMPARTE ESTE ARTÍCULO

COMPARTIR EN FACEBOOK
COMPARTIR EN TWITTER
COMPARTIR EN LINKEDIN
COMPARTIR EN WHATSAPP