Skip to content

Commit

Permalink
import func + backward compatibility
Browse files Browse the repository at this point in the history
  • Loading branch information
lukas-blecher committed Sep 18, 2023
1 parent eb85ece commit b8edda9
Show file tree
Hide file tree
Showing 4 changed files with 17 additions and 12 deletions.
1 change: 1 addition & 0 deletions app.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@
from nougat.utils.dataset import ImageDataset
from nougat.utils.checkpoint import get_checkpoint
from nougat.dataset.rasterize import rasterize_paper
from nougat.utils.device import move_to_device
from tqdm import tqdm


Expand Down
3 changes: 1 addition & 2 deletions nougat/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -572,10 +572,9 @@ def inference(
if image_tensors is None:
image_tensors = self.encoder.prepare_input(image).unsqueeze(0)

if self.device.type != "cpu":
image_tensors = image_tensors.to(self.device)
if self.device.type != "mps":
image_tensors = image_tensors.to(torch.bfloat16)
image_tensors = image_tensors.to(self.device)

last_hidden_state = self.encoder(image_tensors)

Expand Down
16 changes: 13 additions & 3 deletions nougat/utils/device.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,18 @@
"""
Copyright (c) Meta Platforms, Inc. and affiliates.
This source code is licensed under the MIT license found in the
LICENSE file in the root directory of this source tree.
"""
import torch


def move_to_device(model):
if torch.cuda.is_available():
return model.to("cuda").to(torch.bfloat16)
elif torch.backends.mps.is_available():
return model.to("mps")
return model.to(torch.bfloat16)
try:
if torch.backends.mps.is_available():
return model.to("mps")
except AttributeError:
pass
return model.to(torch.bfloat16)
9 changes: 2 additions & 7 deletions test.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,17 +20,12 @@
from nougat.metrics import compute_metrics
from nougat.utils.checkpoint import get_checkpoint
from nougat.utils.dataset import NougatDataset
from nougat.utils.device import move_to_device
from lightning_module import NougatDataPLModule

def move_to_device(model):
if torch.cuda.is_available():
return model.to("cuda")
elif torch.backends.mps.is_available():
return model.to("mps")
return model

def test(args):
pretrained_model = NougatModel.from_pretrained(args.checkpoint).to(torch.bfloat16)
pretrained_model = NougatModel.from_pretrained(args.checkpoint)
pretrained_model = move_to_device(pretrained_model)

pretrained_model.eval()
Expand Down

0 comments on commit b8edda9

Please sign in to comment.