Llevas seis horas entrenando un modelo en una GPU alquilada. La precisión mejora, las pérdidas bajan, todo va bien. Entonces la conexión se cae, o la instancia se interrumpe por mantenimiento, o simplemente cierras la terminal sin querer. Sin checkpointing, vuelves a cero. Con checkpointing, retomas desde donde lo dejaste.
Es una de esas cosas que parece evidente hasta que la necesitas por primera vez y no la tienes.
El problema de las GPUs efímeras
Entrenar en local con tu propia GPU tiene sus ventajas, pero muchos proyectos acaban en instancias de nube como Lambda Labs, RunPod o Vast.ai. Son económicas: puedes conseguir una RTX 3090 por menos de 0,30 la hora, o una A100 por alrededor de 1-2 . El problema es que esas instancias son efímeras por naturaleza.
Un entrenamiento de NLP sobre un dataset grande puede durar fácilmente 12-24 horas. Un fine-tuning de un modelo de visión, varios días. Durante ese tiempo pueden pasar muchas cosas: un timeout de red, un reinicio de servidor, que se te acabe el crédito, o simplemente que quieras pausar y retomar mañana. Sin guardar estado, cualquiera de esos eventos borra todo el progreso.
La solución es guardar un checkpoint cada N épocas: un archivo con el estado del modelo, el optimizador y los metadatos suficientes para continuar el entrenamiento como si nada hubiera pasado.
Checkpointing con PyTorch
Qué guardar en el checkpoint
Un checkpoint de PyTorch útil incluye al menos estos cuatro elementos:
- model.state_dict(): los pesos del modelo.
- optimizer.state_dict(): el estado del optimizador, con los momentos acumulados de Adam o SGD. Sin esto, el optimizador arranca de cero y el entrenamiento puede desestabilizarse.
- La época actual, para saber desde dónde retomar.
- La mejor métrica registrada hasta ahora, para no sobreescribir un buen modelo con uno peor.
Con torch.save() puedes guardar un diccionario Python directamente en disco:
import torch
def save_checkpoint(model, optimizer, epoch, best_metric, path):
checkpoint = {
'epoch': epoch,
'model': model.state_dict(),
'optimizer': optimizer.state_dict(),
'best_metric': best_metric,
}
torch.save(checkpoint, path)
Ejemplo completo con bucle de entrenamiento
Aquí tienes un bucle típico que guarda checkpoint al final de cada época:
import torch
import torch.nn as nn
model = MyModel().to(device)
optimizer = torch.optim.Adam(model.parameters(), lr=1e-3)
criterion = nn.CrossEntropyLoss()
start_epoch = 0
best_metric = 0.0
checkpoint_path = 'checkpoint_last.pt'
# Retomar si existe un checkpoint previo
try:
checkpoint = torch.load(checkpoint_path, map_location=device)
model.load_state_dict(checkpoint['model'])
optimizer.load_state_dict(checkpoint['optimizer'])
start_epoch = checkpoint['epoch'] + 1
best_metric = checkpoint['best_metric']
print(f"Retomando desde época {start_epoch}")
except FileNotFoundError:
print("Sin checkpoint previo, empezando desde cero")
for epoch in range(start_epoch, NUM_EPOCHS):
model.train()
for batch in train_loader:
inputs, labels = batch
inputs, labels = inputs.to(device), labels.to(device)
optimizer.zero_grad()
outputs = model(inputs)
loss = criterion(outputs, labels)
loss.backward()
optimizer.step()
# Guardar checkpoint al final de cada época
save_checkpoint(model, optimizer, epoch, best_metric, checkpoint_path)
print(f"Época {epoch} completada, checkpoint guardado")
Para cargar el checkpoint y retomar, torch.load() devuelve el diccionario tal cual, y load_state_dict() aplica los pesos:
checkpoint = torch.load('checkpoint_last.pt', map_location=device)
model.load_state_dict(checkpoint['model'])
optimizer.load_state_dict(checkpoint['optimizer'])
start_epoch = checkpoint['epoch'] + 1
El parámetro map_location=device es importante: si guardaste en GPU y cargas en CPU (o al revés), PyTorch lo gestiona sin errores.
Guardar solo el mejor modelo
Guardar cada época está bien para poder retomar, pero no siempre quieres el último modelo: quieres el mejor. Si la métrica de validación empeora en las últimas épocas (overfitting), el checkpoint final no es el que te interesa.
La solución es comparar la métrica actual con la mejor registrada y sobreescribir solo si mejora:
def save_best_checkpoint(model, optimizer, epoch, current_metric, best_metric, path_best):
if current_metric > best_metric:
best_metric = current_metric
checkpoint = {
'epoch': epoch,
'model': model.state_dict(),
'optimizer': optimizer.state_dict(),
'best_metric': best_metric,
}
torch.save(checkpoint, path_best)
print(f"Mejor modelo guardado (métrica: {best_metric:.4f})")
return best_metric
En el bucle de entrenamiento, llamas a esta función después de calcular la métrica de validación:
# Al final de cada época, tras evaluar en validación
val_accuracy = evaluate(model, val_loader, device)
best_metric = save_best_checkpoint(
model, optimizer, epoch,
current_metric=val_accuracy,
best_metric=best_metric,
path_best='checkpoint_best.pt'
)
# Guardar también el último (para poder retomar)
save_checkpoint(model, optimizer, epoch, best_metric, 'checkpoint_last.pt')
Con esto tienes dos archivos: checkpoint_last.pt para retomar el entrenamiento si se interrumpe, y checkpoint_best.pt para inferencia con el mejor modelo obtenido.
Checkpointing con Keras y TensorFlow
En Keras la gestión de checkpoints es más directa gracias al callback ModelCheckpoint, que se pasa al método fit():
import tensorflow as tf
checkpoint_cb = tf.keras.callbacks.ModelCheckpoint(
filepath='checkpoint_best.h5',
monitor='val_accuracy',
save_best_only=True,
save_weights_only=False,
mode='max',
verbose=1
)
model.fit(
train_dataset,
epochs=50,
validation_data=val_dataset,
callbacks=[checkpoint_cb]
)
Con save_best_only=True, el callback solo sobreescribe el archivo cuando la métrica monitorizada mejora. El parámetro mode='max' indica que "mejor" significa valor más alto (para accuracy); usa 'min' para pérdidas.
Formato HDF5 vs SavedModel
La extensión .h5 guarda en formato HDF5, que incluye arquitectura, pesos y configuración del optimizador en un solo archivo. Es cómodo pero más lento en modelos grandes. El formato SavedModel (por defecto en TF2 si usas un directorio como filepath) es más eficiente y compatible con TensorFlow Serving:
checkpoint_cb = tf.keras.callbacks.ModelCheckpoint(
filepath='checkpoints/epoch_{epoch:02d}',
save_freq='epoch'
)
Para retomar desde un checkpoint guardado con Keras:
model = tf.keras.models.load_model('checkpoint_best.h5')
# O si solo guardaste pesos (save_weights_only=True):
model = build_model() # recrear la arquitectura primero
model.load_weights('checkpoint_best.h5')
Gradient checkpointing: ahorrar memoria de GPU
Este es un problema diferente al anterior. No se trata de guardar progreso entre sesiones, sino de reducir el consumo de VRAM durante el entrenamiento.
PyTorch guarda las activaciones intermedias de toda la red para calcular los gradientes en la pasada hacia atrás. Con modelos grandes (transformers de cientos de millones de parámetros), esto puede agotar la memoria de la GPU aunque el modelo en sí quepa en VRAM.
torch.utils.checkpoint.checkpoint() resuelve esto recalculando las activaciones durante la pasada hacia atrás en lugar de guardarlas. Consume más tiempo de cómputo, pero reduce la memoria significativamente:
from torch.utils.checkpoint import checkpoint
class MyTransformerLayer(nn.Module):
def __init__(self):
super().__init__()
self.attention = SelfAttention()
self.ffn = FeedForward()
def forward(self, x):
# Sin gradient checkpointing:
# x = self.attention(x)
# x = self.ffn(x)
# Con gradient checkpointing:
x = checkpoint(self.attention, x)
x = checkpoint(self.ffn, x)
return x
El trade-off es claro: reduces el uso de VRAM entre un 30 y un 60% a cambio de un 20-30% más de tiempo de entrenamiento por la recomputación. Tiene sentido cuando tienes un modelo que no cabe en GPU a batch size razonable, pero sí con gradient checkpointing.
Los modelos de Hugging Face lo soportan directamente con una línea:
model.gradient_checkpointing_enable()
Subir checkpoints a almacenamiento remoto
Guardar en disco local de la instancia tiene un problema: si destruyes la instancia para no pagar cuando no entrenas, el disco desaparece con ella. Por eso conviene subir los checkpoints a almacenamiento externo.
AWS S3 con boto3
import boto3
s3 = boto3.client('s3')
def upload_checkpoint_s3(local_path, bucket, s3_key):
s3.upload_file(local_path, bucket, s3_key)
print(f"Checkpoint subido a s3://{bucket}/{s3_key}")
# Llamar después de guardar localmente
upload_checkpoint_s3(
local_path='checkpoint_best.pt',
bucket='mi-bucket-ml',
s3_key='experimentos/modelo-v2/checkpoint_best.pt'
)
Para bajar el checkpoint al empezar una nueva sesión:
s3.download_file('mi-bucket-ml', 'experimentos/modelo-v2/checkpoint_best.pt', 'checkpoint_best.pt')
Google Cloud Storage
from google.cloud import storage
client = storage.Client()
bucket = client.bucket('mi-bucket-ml')
def upload_to_gcs(local_path, gcs_path):
blob = bucket.blob(gcs_path)
blob.upload_from_filename(local_path)
print(f"Subido a gs://mi-bucket-ml/{gcs_path}")
Hugging Face Hub
Si trabajas con modelos transformers, la opción más cómoda es subir directamente al Hub:
from transformers import AutoModelForSequenceClassification
from huggingface_hub import HfApi
# Subir el modelo completo
model.push_to_hub("mi-usuario/mi-modelo-finetuned")
# O solo los pesos
api = HfApi()
api.upload_file(
path_or_fileobj="checkpoint_best.pt",
path_in_repo="checkpoint_best.pt",
repo_id="mi-usuario/mi-modelo-finetuned",
repo_type="model"
)
El Hub guarda historial de versiones, así que cada subida queda registrada. Puedes hacer el repositorio privado si no quieres que sea público.
Estrategia práctica para un entrenamiento largo
Con todo lo anterior, una estrategia que funciona bien en la práctica combina varias piezas:
- Checkpoint cada N épocas (por ejemplo cada 5) para poder retomar si algo falla, sin saturar el disco con un archivo por época.
- Checkpoint del mejor modelo siempre, sobreescribiendo solo cuando mejora la métrica de validación.
- Subida automática a S3 o GCS después de cada checkpoint, para que sobreviva a la destrucción de la instancia.
- Limpieza de checkpoints viejos: solo conservar los últimos 2-3 para no llenar el disco.
Antes de retomar un entrenamiento desde un checkpoint, conviene hacer una comprobación mínima de integridad: verificar que el archivo tiene un tamaño razonable (no quedó a medias por un corte de luz) y que las claves esperadas están presentes:
def verify_checkpoint(path, expected_keys=('model', 'optimizer', 'epoch', 'best_metric')):
import os
if not os.path.exists(path):
return False, "Archivo no encontrado"
if os.path.getsize(path) < 1024: # menos de 1 KB es sospechoso
return False, "Archivo demasiado pequeño"
try:
ck = torch.load(path, map_location='cpu')
missing = [k for k in expected_keys if k not in ck]
if missing:
return False, f"Claves ausentes: {missing}"
return True, "OK"
except Exception as e:
return False, str(e)
valid, msg = verify_checkpoint('checkpoint_last.pt')
if not valid:
print(f"Checkpoint inválido: {msg}. Empezando desde cero.")
No es necesario ser exhaustivo en la verificación, pero sí evitar cargar un archivo corrupto que haga fallar el entrenamiento horas después.
Para más sobre Python para computación científica y numérica, o si te interesa la optimización del rendimiento de Python, tienes más artículos en programacion.net.
Imagen: Pexels / Google DeepMind
