Skip to content

Commit

Permalink
add custom page support
Browse files Browse the repository at this point in the history
  • Loading branch information
lukas-blecher committed Sep 22, 2023
1 parent 0299bc5 commit 748360b
Show file tree
Hide file tree
Showing 7 changed files with 35 additions and 11 deletions.
1 change: 0 additions & 1 deletion lightning_module.py
Original file line number Diff line number Diff line change
Expand Up @@ -108,7 +108,6 @@ def on_validation_epoch_end(self):
self.validation_step_outputs.clear()

def configure_optimizers(self):

def _get_device_count():
if torch.cuda.is_available():
return torch.cuda.device_count()
Expand Down
2 changes: 1 addition & 1 deletion nougat/_version.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,4 +5,4 @@
LICENSE file in the root directory of this source tree.
"""

__version__ = "0.1.11"
__version__ = "0.1.12"
1 change: 1 addition & 0 deletions nougat/dataset/rasterize.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@

logging.getLogger("pypdfium2").setLevel(logging.WARNING)


def rasterize_paper(
pdf: Union[Path, bytes],
outpath: Optional[Path] = None,
Expand Down
6 changes: 4 additions & 2 deletions nougat/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -549,7 +549,7 @@ def inference(
image: Image.Image = None,
image_tensors: Optional[torch.Tensor] = None,
return_attentions: bool = False,
early_stopping: bool = True
early_stopping: bool = True,
):
"""
Generate a token sequence in an auto-regressive manner.
Expand Down Expand Up @@ -602,7 +602,9 @@ def inference(
output_scores=True,
output_attentions=return_attentions,
do_sample=False,
stopping_criteria=StoppingCriteriaList([StoppingCriteriaScores()] if early_stopping else []),
stopping_criteria=StoppingCriteriaList(
[StoppingCriteriaScores()] if early_stopping else []
),
)
output["repetitions"] = decoder_output.sequences.clone()
output["sequences"] = decoder_output.sequences.clone()
Expand Down
7 changes: 4 additions & 3 deletions nougat/utils/dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@
import random
from typing import Dict, Tuple, Callable
from PIL import Image, UnidentifiedImageError
from typing import List, Optional

import torch
import pypdf
Expand Down Expand Up @@ -79,13 +80,13 @@ class LazyDataset(Dataset):
name (str): Name of the PDF document.
"""

def __init__(self, pdf, prepare: Callable):
def __init__(self, pdf, prepare: Callable, pages: Optional[List[int]] = None):
super().__init__()
self.prepare = prepare
self.name = str(pdf)
self.init_fn = partial(rasterize_paper, pdf)
self.init_fn = partial(rasterize_paper, pdf, pages=pages)
self.dataset = None
self.size = len(pypdf.PdfReader(pdf).pages)
self.size = len(pypdf.PdfReader(pdf).pages) if pages is None else len(pages)

def __len__(self):
return self.size
Expand Down
27 changes: 24 additions & 3 deletions predict.py
Original file line number Diff line number Diff line change
Expand Up @@ -58,7 +58,7 @@ def get_args():
"--model",
"-m",
type=str,
default='0.1.0-small',
default="0.1.0-small",
help=f"Model tag to use.",
)
parser.add_argument("--out", "-o", type=Path, help="Output directory.")
Expand All @@ -78,6 +78,12 @@ def get_args():
action="store_false",
help="Don't apply failure detection heuristic.",
)
parser.add_argument(
"--pages",
"-p",
type=str,
help="Provide page numbers like '1-4,7' for pages 1 through 4 and page 7. Only works for single PDF input.",
)
parser.add_argument("pdf", nargs="+", type=Path, help="PDF(s) to process.")
args = parser.parse_args()
if args.checkpoint is None or not args.checkpoint.exists():
Expand All @@ -99,6 +105,17 @@ def get_args():
]
except:
pass
if args.pages and len(args.pdf) == 1:
pages = []
for p in args.pages.split(","):
if "-" in p:
start, end = p.split("-")
pages.extend(range(int(start)-1, int(end)))
else:
pages.append(int(p)-1)
args.pages = pages
else:
args.pages = None
return args


Expand All @@ -124,7 +141,9 @@ def main():
continue
try:
dataset = LazyDataset(
pdf, partial(model.encoder.prepare_input, random_padding=False)
pdf,
partial(model.encoder.prepare_input, random_padding=False),
args.pages,
)
except pypdf.errors.PdfStreamError:
logging.info(f"Could not load file {str(pdf)}.")
Expand All @@ -143,7 +162,9 @@ def main():
file_index = 0
page_num = 0
for i, (sample, is_last_page) in enumerate(tqdm(dataloader)):
model_output = model.inference(image_tensors=sample, early_stopping=args.skipping)
model_output = model.inference(
image_tensors=sample, early_stopping=args.skipping
)
# check if model output is faulty
for j, output in enumerate(model_output["predictions"]):
if page_num == 0:
Expand Down
2 changes: 1 addition & 1 deletion setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -72,7 +72,7 @@ def read_long_description():
"fuzzysearch",
"unidecode",
"htmlmin",
"pdfminer.six>=20221105"
"pdfminer.six>=20221105",
],
},
entry_points={
Expand Down

0 comments on commit 748360b

Please sign in to comment.