Skip to content

Commit

Permalink
Merge branch 'main' into feature/add-mps-support
Browse files Browse the repository at this point in the history
  • Loading branch information
erip authored Sep 17, 2023
2 parents b2b0d2f + f5d2cd5 commit eb85ece
Show file tree
Hide file tree
Showing 4 changed files with 6 additions and 3 deletions.
3 changes: 2 additions & 1 deletion nougat/_version.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,4 +4,5 @@
This source code is licensed under the MIT license found in the
LICENSE file in the root directory of this source tree.
"""
__version__ = "0.1.7"

__version__ = "0.1.8"
2 changes: 2 additions & 0 deletions nougat/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -571,10 +571,12 @@ def inference(

if image_tensors is None:
image_tensors = self.encoder.prepare_input(image).unsqueeze(0)

if self.device.type != "cpu":
image_tensors = image_tensors.to(self.device)
if self.device.type != "mps":
image_tensors = image_tensors.to(torch.bfloat16)

last_hidden_state = self.encoder(image_tensors)

encoder_outputs = ModelOutput(
Expand Down
2 changes: 1 addition & 1 deletion nougat/utils/checkpoint.py
Original file line number Diff line number Diff line change
Expand Up @@ -111,7 +111,7 @@ def get_checkpoint(
checkpoint = checkpoint.parent
if download and (not checkpoint.exists() or len(os.listdir(checkpoint)) < 5):
checkpoint.mkdir(parents=True, exist_ok=True)
download_checkpoint(checkpoint, model_tag=model_tag)
download_checkpoint(checkpoint, model_tag=model_tag or MODEL_TAG)
return checkpoint


Expand Down
2 changes: 1 addition & 1 deletion predict.py
Original file line number Diff line number Diff line change
Expand Up @@ -58,7 +58,7 @@ def get_args():
"--model",
"-m",
type=str,
default=None,
default='0.1.0-small',
help=f"Model tag to use.",
)
parser.add_argument("--out", "-o", type=Path, help="Output directory.")
Expand Down

0 comments on commit eb85ece

Please sign in to comment.