A high-performance table recognition system using an End-to-End Multi-Task Learning (MTL) architecture that transitions from an "all-token" generative foundation to a precision-optimized hybrid model.
- Unified Sequence Paradigm: Serializes document images into autoregressive sequences
y={c,b,t,<Sep>}where:c= cell content (character-level tokens)b= bounding box tokens (discrete spatial coordinates)t= structural HTML tags (<table>,<tr>,<td>, etc.)
- Coordinate Discretization: Scales all coordinates to a fixed 1024 × 1280 grid
- Token Quartet: Represents cell bounding boxes as discrete spatial tokens:
{<Xmin>, <Ymin>, <Xmax>, <Ymax>} - Right-Shifted Tokens: Synchronized training via right-shifted target tokens
- Unified Cross-Entropy Loss: Treats spatial tokens and text characters with equal priority
- Auxiliary Regression Head: Linear layer with Sigmoid activation for continuous normalized coordinates (x, y, w, h)
- L1 + IoU Loss: Optimizes spatial heads using L1 loss for coordinate distance and IoU loss for overlap precision
- Column Consistency Loss: Minimizes prediction variance across tokens in the same logical column
- DREAM Parallel Decoding: Uses N element queries in a feature aggregator to generate sequences for multiple elements simultaneously
- Multi-Token Inference: Decoder predicts n tokens simultaneously via additional linear layers
- Multi-Cell Parallelism: For tables, concatenates all cell contents and predicts next-tokens for every cell simultaneously
- ConvStem: Convolutional stem using stride-2, 3x3 convolutions to balance receptive field and sequence length
- NoPE (No Positional Encoding): Removes explicit 1D positional embeddings; relies on causal attention mask for relative positioning
- Token Compression: Pixel-shuffle and compression to reduce vision token length by up to 20%
- HTML Refiner: Non-causal attention module between structure and content decoders allowing cells to share dense structural features
- B-I-IB Tagging: Beginning-Inside-InsideBelow tagging for semantic continuity (ready for implementation)
- Global Context Attention (GCAttention): Multi-aspect global context attention after residual blocks in encoder
- Encoder: Swin-B, ResNet-31, or ConvStem backbone with optional GCAttention
- Decoder: Transformer decoder with NoPE, HTML refiner, and optional parallel decoding
- Loss Function:
L = λ₁ CE_struc + λ₂ CE_cont + λ₃ L1_bbox + λ₄ IoU + λ₅ Consistency
pip install -r requirements.txtThe system expects data in JSON format:
{
"image_path": "path/to/image.jpg",
"table": {
"cells": [
{
"content": "Cell text",
"bbox": [xmin, ymin, xmax, ymax],
"is_header": false
}
],
"image_width": 1024,
"image_height": 1280
}
}- Prepare your data in the expected JSON format
- Update
config.yamlwith your data paths and hyperparameters - Run training:
python train.py --config config.yamlTo resume from a checkpoint:
python train.py --config config.yaml --resume checkpoints/latest.pthKey configuration options in config.yaml:
encoder_backbone: Choose between "swin_b", "resnet31", or "convstem"use_hybrid_regression: Enable hybrid regression heads for continuous coordinatesuse_parallel_decoder: Enable DREAM parallel decoding for faster inferencetoken_compression: Set to 0.8 for 20% token reduction- Loss weights: Adjust
lambda_*values to balance different loss components
tsr.data.serialization: Sequence serialization and coordinate discretizationtsr.data.dataset: Dataset classes for loading table datatsr.models.encoder: Visual encoders (Swin-B, ResNet-31, ConvStem)tsr.models.decoder: Transformer decoder with NoPE and parallel decodingtsr.models.model: Main E2E MTL modeltsr.losses.losses: Multi-task loss functionstsr.training.trainer: Training utilities
- Architecture: Swin-B or ResNet-31 Encoder + Transformer Decoder (NoPE)
- Trigger Mechanism: MTL-TabNet
- Parallelization: DREAM Parallel Decoder
- Loss Weights:
L = λ₁ CE_struc + λ₂ CE_cont + λ₃ L1_bbox + λ₄ IoU + λ₅ Consistency
[Add your license here]