Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
285 changes: 175 additions & 110 deletions kandinsky2/kandinsky2_1_model.py

Large diffs are not rendered by default.

5 changes: 3 additions & 2 deletions kandinsky2/kandinsky2_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -47,7 +47,8 @@ def __init__(
)

self.use_fp16 = self.config["model_config"]["use_fp16"]

if self.device in [torch.device('mps'), torch.device('cpu')]:
self.use_fp16 = False
if self.config["image_enc_params"] is not None:
self.use_image_enc = True
self.scale = self.config["image_enc_params"]["scale"]
Expand All @@ -65,7 +66,7 @@ def __init__(

self.config["model_config"]["cache_text_emb"] = True
self.model = create_model(**self.config["model_config"])
self.model.load_state_dict(torch.load(model_path), strict=False)
self.model.load_state_dict(torch.load(model_path, map_location=self.device), strict=False)
if self.use_fp16:
self.model.convert_to_fp16()
self.text_encoder1 = self.text_encoder1.half()
Expand Down
10 changes: 9 additions & 1 deletion kandinsky2/model/gaussian_diffusion.py
Original file line number Diff line number Diff line change
Expand Up @@ -248,6 +248,9 @@ def p_mean_variance(
B, C = x.shape[:2]
assert t.shape == (B,)
s_t = self._scale_timesteps(t)

x = x.float()

model_output = model(x, s_t, **model_kwargs)

if self.model_var_type in [ModelVarType.LEARNED, ModelVarType.LEARNED_RANGE]:
Expand Down Expand Up @@ -392,6 +395,7 @@ def p_sample_loop(
device=None,
progress=False,
init_step=None,
callback=None
):
"""
Generate samples from the model.
Expand Down Expand Up @@ -420,6 +424,7 @@ def p_sample_loop(
device=device,
progress=progress,
init_step=init_step,
callback=callback
):
final = sample
return final["sample"]
Expand All @@ -435,6 +440,7 @@ def p_sample_loop_progressive(
device=None,
progress=False,
init_step=None,
callback=None
):
"""
Generate samples from the model and yield intermediate samples from
Expand Down Expand Up @@ -473,6 +479,8 @@ def p_sample_loop_progressive(
)
yield out
img = out["sample"]
if callback is not None:
callback({"i":i, "denoised":img})

def ddim_sample(
self,
Expand Down Expand Up @@ -822,7 +830,7 @@ def _extract_into_tensor(arr, timesteps, broadcast_shape):
dimension equal to the length of timesteps.
:return: a tensor of shape [batch_size, 1, ...] where the shape has K dims.
"""
res = th.from_numpy(arr).to(device=timesteps.device)[timesteps].float()
res = th.from_numpy(arr).float().to(device=timesteps.device)[timesteps].to(dtype=th.float32)
while len(res.shape) < len(broadcast_shape):
res = res[..., None]
return res.expand(broadcast_shape)
32 changes: 22 additions & 10 deletions kandinsky2/model/samplers.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,21 @@
from einops import repeat
from tqdm import tqdm
from functools import partial
import platform

def get_torch_device():
if "macOS" in platform.platform():
if torch.backends.mps.is_available():
return torch.device("mps")
else:
return torch.device("cpu")
else:
if torch.cuda.is_available():
return torch.device(torch.cuda.current_device())
else:
return torch.device("cpu")

device = get_torch_device()

def apply_init_step(timesteps, init_step=None):
if init_step is None:
Expand Down Expand Up @@ -75,8 +89,8 @@ def __init__(self, model, old_diffusion, schedule="linear", **kwargs):

def register_buffer(self, name, attr):
if type(attr) == torch.Tensor:
if attr.device != torch.device("cuda"):
attr = attr.to(torch.device("cuda"))
if attr.device != device:
attr = attr.to(device)
setattr(self, name, attr)

def make_schedule(
Expand All @@ -98,7 +112,7 @@ def make_schedule(
assert (
alphas_cumprod.shape[0] == self.ddpm_num_timesteps
), "alphas have to be defined for each timestep"
to_torch = lambda x: x.clone().detach().to(torch.float32).to("cuda")
to_torch = lambda x: x.clone().detach().to(torch.float32).to(device)

self.register_buffer(
"betas", to_torch(torch.from_numpy(self.old_diffusion.betas))
Expand Down Expand Up @@ -223,7 +237,6 @@ def ddim_sampling(
unconditional_guidance_scale=1.0,
unconditional_conditioning=None,
):
device = "cuda"
b = shape[0]
if x_T is None:
img = torch.randn(shape, device=device)
Expand Down Expand Up @@ -278,7 +291,7 @@ def ddim_sampling(
if callback:
callback(i)
if img_callback:
img_callback(pred_x0, i)
img_callback({"i": i, "denoised": img, "x":pred_x0})

if index % log_every_t == 0 or index == total_steps - 1:
intermediates["x_inter"].append(img)
Expand Down Expand Up @@ -341,8 +354,8 @@ def __init__(self, model, old_diffusion, schedule="linear", **kwargs):

def register_buffer(self, name, attr):
if type(attr) == torch.Tensor:
if attr.device != torch.device("cuda"):
attr = attr.to(torch.device("cuda"))
if attr.device != device:
attr = attr.to(device)
setattr(self, name, attr)

def make_schedule(
Expand All @@ -366,7 +379,7 @@ def make_schedule(
assert (
alphas_cumprod.shape[0] == self.ddpm_num_timesteps
), "alphas have to be defined for each timestep"
to_torch = lambda x: x.clone().detach().to(torch.float32).to("cuda")
to_torch = lambda x: x.clone().detach().to(torch.float32).to(device)

self.register_buffer(
"betas", to_torch(torch.from_numpy(self.old_diffusion.betas))
Expand Down Expand Up @@ -492,7 +505,6 @@ def plms_sampling(
unconditional_guidance_scale=1.0,
unconditional_conditioning=None,
):
device = "cuda"
b = shape[0]
if x_T is None:
img = torch.randn(shape, device=device)
Expand Down Expand Up @@ -560,7 +572,7 @@ def plms_sampling(
if callback:
callback(i)
if img_callback:
img_callback(pred_x0, i)
img_callback({"i": i, "denoised": img, "x":pred_x0})

if index % log_every_t == 0 or index == total_steps - 1:
intermediates["x_inter"].append(img)
Expand Down
30 changes: 22 additions & 8 deletions kandinsky2/model/text_encoders.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,8 +12,21 @@
)
import transformers
import os


import platform

def get_torch_device():
if "macOS" in platform.platform():
if torch.backends.mps.is_available():
return torch.device("mps")
else:
return torch.device("cpu")
else:
if torch.cuda.is_available():
return torch.device(torch.cuda.current_device())
else:
return torch.device("cpu")

device = get_torch_device()
def attention(q, k, v, d_k):
scores = torch.matmul(q, k.transpose(-2, -1)) / math.sqrt(d_k)
scores = F.softmax(scores, dim=-1)
Expand Down Expand Up @@ -128,20 +141,21 @@ def __init__(self, model_path, model_name, **kwargs):
self.model_name = model_name
if self.model_name == "clip":
self.model = ImagenCLIP()
self.model.load_state_dict(torch.load(model_path))
self.model.load_state_dict(torch.load(model_path, map_location=device))
elif self.model_name == "T5EncoderModel":
self.model = T5EncoderModel.from_pretrained(model_path)
self.model = T5EncoderModel.from_pretrained(model_path, map_location=device)
elif self.model_name == "MT5EncoderModel":
self.model = MT5EncoderModel.from_pretrained(model_path)
self.model = MT5EncoderModel.from_pretrained(model_path, map_location=device)
elif self.model_name == "BertModel":
self.model = BertModel.from_pretrained(model_path)
self.model = BertModel.from_pretrained(model_path, map_location=device)
elif self.model_name == "multiclip":
self.model = MultilingualCLIP(model_path, **kwargs)
self.model.load_state_dict(
torch.load(os.path.join(model_path, "pytorch_model.bin")), strict=False
torch.load(os.path.join(model_path, "pytorch_model.bin"), map_location=device), strict=False
)
elif self.model_name == "xlm_roberta":
self.model = XLMRobertaModel.from_pretrained(model_path).half()
if device not in [torch.device('mps'), torch.device('cpu')]:
self.model = XLMRobertaModel.from_pretrained(model_path).half()
self.model.eval()
for param in self.model.parameters():
param.requires_grad = False
Expand Down
23 changes: 23 additions & 0 deletions kandinsky2/model/unet.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,25 @@
timestep_embedding,
zero_module,
)
import platform

def get_torch_device():
if "macOS" in platform.platform():
if torch.backends.mps.is_available():
return torch.device("mps")
else:
return torch.device("cpu")
else:
if torch.cuda.is_available():
return torch.device(torch.cuda.current_device())
else:
return torch.device("cpu")

device = get_torch_device()


if device in [torch.device('mps') or torch.device('cpu')]:
dt = torch.float32


class TimestepBlock(nn.Module):
Expand All @@ -35,6 +54,10 @@ class TimestepEmbedSequential(nn.Sequential, TimestepBlock):
"""

def forward(self, x, emb, encoder_out=None):
if device in [torch.device('mps') or torch.device('cpu')]:
if x.dtype != dt:
x = x.to(dt)
#x = x.float()
for layer in self:
if isinstance(layer, TimestepBlock):
x = layer(x, emb)
Expand Down
3 changes: 2 additions & 1 deletion kandinsky2/model/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@
import torch
import torch.nn as nn
import importlib
import platform


def _extract_into_tensor(arr, timesteps, broadcast_shape):
Expand Down Expand Up @@ -36,7 +37,7 @@ def get_named_beta_schedule(schedule_name, num_diffusion_timesteps):
beta_start = scale * 0.0001
beta_end = scale * 0.02
return np.linspace(
beta_start, beta_end, num_diffusion_timesteps, dtype=np.float64
beta_start, beta_end, num_diffusion_timesteps, dtype=np.float32
)
elif schedule_name == "cosine":
return betas_for_alpha_bar(
Expand Down
1 change: 1 addition & 0 deletions kandinsky2/vqgan/autoencoder.py
Original file line number Diff line number Diff line change
Expand Up @@ -134,6 +134,7 @@ def init_from_ckpt(self, path, ignore_keys=list()):
print("Deleting key {} from state_dict.".format(k))
del sd[k]
self.load_state_dict(sd, strict=False)
del sd
print(f"Restored from {path}")

def encode(self, x):
Expand Down
2 changes: 1 addition & 1 deletion setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,7 @@
"regex",
"numpy",
"blobfile",
"transformers==4.23.1",
"transformers>=4.29.2",
"torchvision",
"omegaconf",
"pytorch_lightning",
Expand Down