Inspirado en el estilo de un "HRM" (High-level Reasoning Model), diseñado específicamente para texto en español.
Define un modelo de lenguaje basado en Transformer con un enfoque de dos niveles: un Planner (planificador) de alto nivel que genera un embedding de "plan" a partir de un prompt, y un Executor (ejecutor) de bajo nivel que genera texto condicionado por ese plan.
Está diseñado para:
- Preentrenamiento: Entrenar el modelo como un modelo de lenguaje puro en un corpus de texto en español.
- Fine-tuning: Ajustar el modelo con prompts y respuestas específicas (formato TSV: prompt \t target) o como modelo de lenguaje si no hay respuestas.
- Generación: Generar texto a partir de un prompt utilizando técnicas de muestreo como top-k, top-p y temperatura.
El modelo utiliza un tokenizador byte-level (UTF-8) para manejar texto en español, lo que permite una representación robusta de caracteres y es especialmente útil para idiomas con caracteres especiales.
Función: Convierte texto en español a una secuencia de enteros (tokens) y viceversa.
Detalles:
- Usa codificación UTF-8 a nivel de bytes (0 a 255) más 4 tokens especiales: PAD (256), BOS (257, inicio de secuencia), EOS (258, fin de secuencia) y PLAN (259, para inyectar el embedding del plan).
- encode: Convierte un string a una lista de enteros (bytes UTF-8, con BOS y EOS opcionales).
- decode: Convierte una lista de enteros a un string, ignorando tokens especiales.
- Ventaja: Es simple, no depende de un vocabulario predefinido y maneja bien caracteres no estándar.
Función: Carga y preprocesa datos para entrenamiento.
Detalles:
- load_corpus_lines: Lee un archivo de texto línea por línea (UTF-8, ignorando errores).
- dataset_lm: Convierte un archivo de texto en una lista de secuencias tokenizadas para preentrenamiento (modelo de lenguaje puro).
- dataset_tsv: Lee un archivo TSV (prompt \t target) para fine-tuning. Si no hay target, usa el propio prompt como objetivo (modo LM puro). Los datos se tokenizan con ByteTokenizer y se preparan para el modelo.
El modelo utiliza bloques estándar de Transformer, adaptados para su arquitectura de dos niveles:
- RMSNorm: Normalización de capa basada en la raíz cuadrada media (RMS), alternativa a LayerNorm, más eficiente.
- MHA (Multi-Head Attention): Atención multi-cabeza, con soporte para atención causal (para el decodificador) y atención cruzada (para incorporar el plan, aunque no se usa en este caso).
- FFN (Feed-Forward Network): Red feed-forward con activación GEGLU (Gated Linear Unit), que combina una activación GELU con una puerta lineal.
- DecoderBlock: Bloque Transformer causal con atención propia y FFN.
- EncoderBlock: Bloque Transformer no causal (para el Planner).
- Entrada: Prompt tokenizado (prompt_ids).
- Función: Genera un embedding de "plan" que resume el prompt.
Estructura Embedding de tokens y posiciones.
- n_layers_enc bloques Transformer no causales (EncoderBlock).
- Mean pooling sobre la secuencia para obtener un vector resumen.
- Proyección a plan_dim (dimensión del plan). Opcionalmente, usa Gumbel-Softmax para discretizar el plan en un codebook de n_plan_codes vectores.
Salida: Un vector de plan (plan_vec) y, si se usa Gumbel-Softmax, logits para el codebook.
- Entrada: Secuencia de entrada (y_in) y el vector de plan (plan_vec).
- Función: Genera texto de manera autoregresiva, condicionado por el plan.
Estructura:
Embedding de tokens y posiciones.
- Proyección del plan_vec a un "token" sintético ([PLAN]) que se concatena como prefijo.
- n_layers_dec bloques Transformer causales (DecoderBlock).
- Capa final (lm_head) para predecir el siguiente token.
Salida: Logits para la distribución de probabilidad sobre el vocabulario.
Preentrenamiento (train_lm):
- Usa un corpus de texto en español (archivo TXT).
- Entrena el modelo como un modelo de lenguaje puro, prediciendo el siguiente token en secuencias de longitud max_len.
- El Planner ve la misma secuencia que el Executor (condición débil).
- Calcula la pérdida de cross-entropy, con una pequeña penalización opcional para el codebook (si se usa Gumbel-Softmax).
- Guarda checkpoints por época y un checkpoint final.
Fine-tuning (train_ft):
- Carga un modelo preentrenado desde un checkpoint.
- Usa un archivo TSV (prompt \t target) para ajustar el modelo a tareas específicas.
- Si no hay target, entrena como LM puro.
- Similar al preentrenamiento, pero con prompts y targets explícitos.
Optimización:
- Usa AdamW con learning rate configurable.
- Soporta mixed precision (AMP) en CUDA para mayor eficiencia.
- Aplica dropout para regularización.
Detalles:
Tokeniza el prompt y genera el plan_vec con el Planner.
Usa el Executor para generar tokens autoregresivamente, comenzando con BOS.
Aplica muestreo con:
- Top-k: Selecciona los k tokens más probables.
- Top-p (nucleus sampling): Selecciona un conjunto de tokens cuya probabilidad acumulada supera p.
Temperatura:
- Controla la aleatoriedad (valores bajos hacen el muestreo más determinista).
- Detiene la generación al alcanzar max_new_tokens o el token EOS.
- Decodifica la salida a texto usando el tokenizador.
- pretrain: Entrena el modelo en un corpus de texto.
- finetune: Ajusta un modelo preentrenado con un archivo TSV.
- generate: Genera texto a partir de un prompt y un checkpoint.
Argumentos:
- Configuraciones como batch_size, max_len, lr, d_model, n_heads, epochs, etc.
- Opciones para desactivar CUDA (--cpu) o Gumbel-Softmax (--no_gumbel).
- Parámetros de generación como temperature, top_k, top_p.