diff --git a/kandinsky2/kandinsky2_1_model.py b/kandinsky2/kandinsky2_1_model.py index 2ae2ddd..eebcc23 100644 --- a/kandinsky2/kandinsky2_1_model.py +++ b/kandinsky2/kandinsky2_1_model.py @@ -19,18 +19,22 @@ class Kandinsky2_1: - + def __init__( - self, - config, - model_path, - prior_path, - device, - task_type="text2img" + self, + config, + model_path, + prior_path, + device, + task_type="text2img" ): self.config = config - self.device = device + b_device = device + self.device = "cpu" + self.use_fp16 = self.config["model_config"]["use_fp16"] + if self.device in [torch.device('mps'), torch.device('cpu')]: + self.use_fp16 = False self.task_type = task_type self.clip_image_size = config["clip_image_size"] if task_type == "text2img": @@ -54,17 +58,17 @@ def __init__( clip_mean, clip_std, ) - self.prior.load_state_dict(torch.load(prior_path), strict=False) + self.prior.load_state_dict(torch.load(prior_path, map_location="cpu"), strict=False) if self.use_fp16: - self.prior = self.prior.half() + self.prior.half()# = self.prior.half() self.text_encoder = TextEncoder(**self.config["text_enc_params"]) if self.use_fp16: - self.text_encoder = self.text_encoder.half() + self.text_encoder.half()# = self.text_encoder.half() self.clip_model, self.preprocess = clip.load( - config["clip_name"], device=self.device, jit=False + config["clip_name"], device="cpu", jit=False ) - self.clip_model.eval() + #self.clip_model.eval() if self.config["image_enc_params"] is not None: self.use_image_enc = True @@ -80,28 +84,40 @@ def __init__( elif self.config["image_enc_params"]["name"] == "MOVQ": self.image_encoder = MOVQ(**self.config["image_enc_params"]["params"]) self.image_encoder.load_state_dict( - torch.load(self.config["image_enc_params"]["ckpt_path"]) + torch.load(self.config["image_enc_params"]["ckpt_path"], map_location='cpu') ) - self.image_encoder.eval() + #self.image_encoder.eval() else: self.use_image_enc = False - + self.config["model_config"]["cache_text_emb"] = True self.model = create_model(**self.config["model_config"]) - self.model.load_state_dict(torch.load(model_path)) + self.model.load_state_dict(torch.load(model_path, map_location="cpu")) if self.use_fp16: self.model.convert_to_fp16() - self.image_encoder = self.image_encoder.half() + self.image_encoder.half()# = self.image_encoder.half() self.model_dtype = torch.float16 else: self.model_dtype = torch.float32 - - self.image_encoder = self.image_encoder.to(self.device).eval() - self.text_encoder = self.text_encoder.to(self.device).eval() - self.prior = self.prior.to(self.device).eval() - self.model.eval() - self.model.to(self.device) + + + self.clip_model.to("cpu")# = self.clip_model.to("cpu") + self.image_encoder.to("cpu")# = self.image_encoder.eval().to("cpu")# .to(self.device).eval() + self.text_encoder.to("cpu")# = self.text_encoder.eval().to("cpu")#.to(self.device).eval() + self.prior.to("cpu")# = self.prior.eval().to("cpu")#.to(self.device).eval() + self.model.to("cpu")# = self.model.eval().to("cpu") + self.device = b_device + + + del clip_mean + del clip_std + + if torch.cuda.is_available(): + torch.cuda.empty_cache() + torch.cuda.ipc_collect() + + #self.model.to(self.device) def get_new_h_w(self, h, w): new_h = h // 64 @@ -132,12 +148,12 @@ def encode_text(self, text_encoder, tokenizer, prompt, batch_size): @torch.no_grad() def generate_clip_emb( - self, - prompt, - batch_size=1, - prior_cf_scale=4, - prior_steps="25", - negative_prior_prompt="", + self, + prompt, + batch_size=1, + prior_cf_scale=4, + prior_steps="25", + negative_prior_prompt="", ): prompts_batch = [prompt for _ in range(batch_size)] prior_cf_scales_batch = [prior_cf_scale] * len(prompts_batch) @@ -149,6 +165,7 @@ def generate_clip_emb( cf_token, cf_mask = self.tokenizer2.padded_tokens_and_mask( [negative_prior_prompt], max_txt_length ) + if not (cf_token.shape == tok.shape): cf_token = cf_token.expand(tok.shape[0], -1) cf_mask = cf_mask.expand(tok.shape[0], -1) @@ -165,6 +182,8 @@ def generate_clip_emb( txt_feat_seq = x txt_feat = (x[torch.arange(x.shape[0]), tok.argmax(dim=-1)] @ self.clip_model.text_projection) txt_feat, txt_feat_seq = txt_feat.float().to(self.device), txt_feat_seq.float().to(self.device) + self.prior = self.prior.to(self.device) + img_feat = self.prior( txt_feat, txt_feat_seq, @@ -172,6 +191,8 @@ def generate_clip_emb( prior_cf_scales_batch, timestep_respacing=prior_steps, ) + self.prior = self.prior.to("cpu") + return img_feat.to(self.model_dtype) @torch.no_grad() @@ -182,20 +203,21 @@ def encode_images(self, image, is_pil=False): @torch.no_grad() def generate_img( - self, - prompt, - img_prompt, - batch_size=1, - diffusion=None, - guidance_scale=7, - init_step=None, - noise=None, - init_img=None, - img_mask=None, - h=512, - w=512, - sampler="ddim_sampler", - num_steps=50, + self, + prompt, + img_prompt, + batch_size=1, + diffusion=None, + guidance_scale=7, + init_step=None, + noise=None, + init_img=None, + img_mask=None, + h=512, + w=512, + sampler="ddim_sampler", + num_steps=50, + callback=None, ): new_h, new_w = self.get_new_h_w(h, w) full_batch_size = batch_size * 2 @@ -205,12 +227,17 @@ def generate_img( init_img = init_img.half() if img_mask is not None and self.use_fp16: img_mask = img_mask.half() + self.text_encoder = self.text_encoder.to(self.device) model_kwargs["full_emb"], model_kwargs["pooled_emb"] = self.encode_text( text_encoder=self.text_encoder, tokenizer=self.tokenizer1, prompt=prompt, batch_size=batch_size, ) + + self.text_encoder = self.text_encoder.to("cpu") + + model_kwargs["image_emb"] = img_prompt if self.task_type == "inpainting": @@ -221,15 +248,29 @@ def generate_img( def model_fn(x_t, ts, **kwargs): half = x_t[: len(x_t) // 2] + x_t = x_t.detach().cpu() + del x_t combined = torch.cat([half, half], dim=0) + if not self.use_fp16: + combined = combined.to(dtype=torch.float32) + ts = ts.to(dtype=torch.float32) model_out = self.model(combined, ts, **kwargs) eps, rest = model_out[:, :4], model_out[:, 4:] cond_eps, uncond_eps = torch.split(eps, len(eps) // 2, dim=0) half_eps = uncond_eps + guidance_scale * (cond_eps - uncond_eps) eps = torch.cat([half_eps, half_eps], dim=0) + + half_eps = half_eps.detach().to("cpu") + cond_eps = cond_eps.detach().to("cpu") + del half_eps + del cond_eps + + if sampler == "p_sampler": return torch.cat([eps, rest], dim=1) else: + rest = rest.detach().to("cpu") + del rest return eps if noise is not None: @@ -241,7 +282,7 @@ def denoised_fun(x_start): else: def denoised_fun(x): return x.clamp(-2, 2) - + self.model = self.model.to(self.device) if sampler == "p_sampler": self.model.del_cache() samples = diffusion.p_sample_loop( @@ -253,6 +294,7 @@ def denoised_fun(x): model_kwargs=model_kwargs, init_step=init_step, denoised_fn=denoised_fun, + callback=callback )[:batch_size] self.model.del_cache() else: @@ -270,7 +312,7 @@ def denoised_fun(x): ) else: raise ValueError("Only ddim_sampler and plms_sampler is available") - + self.model.del_cache() samples, _ = sampler.sample( num_steps, @@ -279,16 +321,31 @@ def denoised_fun(x): conditioning=model_kwargs, x_T=noise, init_step=init_step, + img_callback=callback + ) self.model.del_cache() samples = samples[:batch_size] - + self.model = self.model.to("cpu") + if self.use_image_enc: if self.use_fp16: samples = samples.half() + self.image_encoder.to(self.device) samples = self.image_encoder.decode(samples / self.scale) - + self.image_encoder = self.image_encoder.to("cpu") + + samples = samples[:, :, :h, :w] + + del diffusion + del sampler + del noise + if torch.cuda.is_available(): + torch.cuda.empty_cache() + torch.cuda.ipc_collect() + + return process_images(samples) @torch.no_grad() @@ -298,20 +355,22 @@ def create_zero_img_emb(self, batch_size): @torch.no_grad() def generate_text2img( - self, - prompt, - num_steps=100, - batch_size=1, - guidance_scale=7, - h=512, - w=512, - sampler="ddim_sampler", - prior_cf_scale=4, - prior_steps="25", - negative_prior_prompt="", - negative_decoder_prompt="", + self, + prompt, + num_steps=100, + batch_size=1, + guidance_scale=7, + h=512, + w=512, + sampler="ddim_sampler", + prior_cf_scale=4, + prior_steps="25", + negative_prior_prompt="", + negative_decoder_prompt="", + callback=None, ): # generate clip embeddings + self.clip_model = self.clip_model.to(self.device) image_emb = self.generate_clip_emb( prompt, batch_size=batch_size, @@ -329,15 +388,16 @@ def generate_text2img( prior_steps=prior_steps, negative_prior_prompt=negative_prior_prompt, ) + self.clip_model = self.clip_model.to("cpu") image_emb = torch.cat([image_emb, zero_image_emb], dim=0).to(self.device) - + # load diffusion config = deepcopy(self.config) if sampler == "p_sampler": config["diffusion_config"]["timestep_respacing"] = str(num_steps) diffusion = create_gaussian_diffusion(**config["diffusion_config"]) - + return self.generate_img( prompt=prompt, img_prompt=image_emb, @@ -348,26 +408,28 @@ def generate_text2img( sampler=sampler, num_steps=num_steps, diffusion=diffusion, + callback=callback ) @torch.no_grad() def mix_images( - self, - images_texts, - weights, - num_steps=100, - batch_size=1, - guidance_scale=7, - h=512, - w=512, - sampler="ddim_sampler", - prior_cf_scale=4, - prior_steps="25", - negative_prior_prompt="", - negative_decoder_prompt="", + self, + images_texts, + weights, + num_steps=100, + batch_size=1, + guidance_scale=7, + h=512, + w=512, + sampler="ddim_sampler", + prior_cf_scale=4, + prior_steps="25", + negative_prior_prompt="", + negative_decoder_prompt="", + callback=None, ): assert len(images_texts) == len(weights) and len(images_texts) > 0 - + # generate clip embeddings image_emb = None for i in range(len(images_texts)): @@ -393,7 +455,7 @@ def mix_images( ) else: image_emb = image_emb + self.encode_images(images_texts[i], is_pil=True) * weights[i] - + image_emb = image_emb.repeat(batch_size, 1) if negative_decoder_prompt == "": zero_image_emb = self.create_zero_img_emb(batch_size=batch_size) @@ -406,7 +468,7 @@ def mix_images( negative_prior_prompt=negative_prior_prompt, ) image_emb = torch.cat([image_emb, zero_image_emb], dim=0).to(self.device) - + # load diffusion config = deepcopy(self.config) if sampler == "p_sampler": @@ -422,22 +484,24 @@ def mix_images( sampler=sampler, num_steps=num_steps, diffusion=diffusion, + callback=callback ) @torch.no_grad() def generate_img2img( - self, - prompt, - pil_img, - strength=0.7, - num_steps=100, - batch_size=1, - guidance_scale=7, - h=512, - w=512, - sampler="ddim_sampler", - prior_cf_scale=4, - prior_steps="25", + self, + prompt, + pil_img, + strength=0.7, + num_steps=100, + batch_size=1, + guidance_scale=7, + h=512, + w=512, + sampler="ddim_sampler", + prior_cf_scale=4, + prior_steps="25", + callback=None ): # generate clip embeddings image_emb = self.generate_clip_emb( @@ -448,18 +512,18 @@ def generate_img2img( ) zero_image_emb = self.create_zero_img_emb(batch_size=batch_size) image_emb = torch.cat([image_emb, zero_image_emb], dim=0).to(self.device) - + # load diffusion config = deepcopy(self.config) if sampler == "p_sampler": config["diffusion_config"]["timestep_respacing"] = str(num_steps) diffusion = create_gaussian_diffusion(**config["diffusion_config"]) - + image = prepare_image(pil_img, h=h, w=w).to(self.device) if self.use_fp16: image = image.half() image = self.image_encoder.encode(image) * self.scale - + start_step = int(diffusion.num_timesteps * (1 - strength)) image = q_sample( image, @@ -467,7 +531,7 @@ def generate_img2img( schedule_name=config["diffusion_config"]["noise_schedule"], num_steps=config["diffusion_config"]["steps"], ) - + image = image.repeat(2, 1, 1, 1) return self.generate_img( prompt=prompt, @@ -481,24 +545,25 @@ def generate_img2img( diffusion=diffusion, noise=image, init_step=start_step, + callback=callback ) @torch.no_grad() def generate_inpainting( - self, - prompt, - pil_img, - img_mask, - num_steps=100, - batch_size=1, - guidance_scale=7, - h=512, - w=512, - sampler="ddim_sampler", - prior_cf_scale=4, - prior_steps="25", - negative_prior_prompt="", - negative_decoder_prompt="", + self, + prompt, + pil_img, + img_mask, + num_steps=100, + batch_size=1, + guidance_scale=7, + h=512, + w=512, + sampler="ddim_sampler", + prior_cf_scale=4, + prior_steps="25", + negative_prior_prompt="", + negative_decoder_prompt="", ): # generate clip embeddings image_emb = self.generate_clip_emb( @@ -510,7 +575,7 @@ def generate_inpainting( ) zero_image_emb = self.create_zero_img_emb(batch_size=batch_size) image_emb = torch.cat([image_emb, zero_image_emb], dim=0).to(self.device) - + # load diffusion config = deepcopy(self.config) if sampler == "p_sampler": @@ -532,7 +597,7 @@ def generate_inpainting( img_mask = img_mask.half() image = image.repeat(2, 1, 1, 1) img_mask = img_mask.repeat(2, 1, 1, 1) - + return self.generate_img( prompt=prompt, img_prompt=image_emb, diff --git a/kandinsky2/kandinsky2_model.py b/kandinsky2/kandinsky2_model.py index 74db51b..53888f3 100644 --- a/kandinsky2/kandinsky2_model.py +++ b/kandinsky2/kandinsky2_model.py @@ -47,7 +47,8 @@ def __init__( ) self.use_fp16 = self.config["model_config"]["use_fp16"] - + if self.device in [torch.device('mps'), torch.device('cpu')]: + self.use_fp16 = False if self.config["image_enc_params"] is not None: self.use_image_enc = True self.scale = self.config["image_enc_params"]["scale"] @@ -65,7 +66,7 @@ def __init__( self.config["model_config"]["cache_text_emb"] = True self.model = create_model(**self.config["model_config"]) - self.model.load_state_dict(torch.load(model_path), strict=False) + self.model.load_state_dict(torch.load(model_path, map_location=self.device), strict=False) if self.use_fp16: self.model.convert_to_fp16() self.text_encoder1 = self.text_encoder1.half() diff --git a/kandinsky2/model/gaussian_diffusion.py b/kandinsky2/model/gaussian_diffusion.py index b5449e1..149ba6f 100644 --- a/kandinsky2/model/gaussian_diffusion.py +++ b/kandinsky2/model/gaussian_diffusion.py @@ -248,6 +248,9 @@ def p_mean_variance( B, C = x.shape[:2] assert t.shape == (B,) s_t = self._scale_timesteps(t) + + x = x.float() + model_output = model(x, s_t, **model_kwargs) if self.model_var_type in [ModelVarType.LEARNED, ModelVarType.LEARNED_RANGE]: @@ -392,6 +395,7 @@ def p_sample_loop( device=None, progress=False, init_step=None, + callback=None ): """ Generate samples from the model. @@ -420,6 +424,7 @@ def p_sample_loop( device=device, progress=progress, init_step=init_step, + callback=callback ): final = sample return final["sample"] @@ -435,6 +440,7 @@ def p_sample_loop_progressive( device=None, progress=False, init_step=None, + callback=None ): """ Generate samples from the model and yield intermediate samples from @@ -473,6 +479,8 @@ def p_sample_loop_progressive( ) yield out img = out["sample"] + if callback is not None: + callback({"i":i, "denoised":img}) def ddim_sample( self, @@ -822,7 +830,7 @@ def _extract_into_tensor(arr, timesteps, broadcast_shape): dimension equal to the length of timesteps. :return: a tensor of shape [batch_size, 1, ...] where the shape has K dims. """ - res = th.from_numpy(arr).to(device=timesteps.device)[timesteps].float() + res = th.from_numpy(arr).float().to(device=timesteps.device)[timesteps].to(dtype=th.float32) while len(res.shape) < len(broadcast_shape): res = res[..., None] return res.expand(broadcast_shape) diff --git a/kandinsky2/model/samplers.py b/kandinsky2/model/samplers.py index 0b4db1d..ed49b45 100644 --- a/kandinsky2/model/samplers.py +++ b/kandinsky2/model/samplers.py @@ -6,7 +6,21 @@ from einops import repeat from tqdm import tqdm from functools import partial +import platform +def get_torch_device(): + if "macOS" in platform.platform(): + if torch.backends.mps.is_available(): + return torch.device("mps") + else: + return torch.device("cpu") + else: + if torch.cuda.is_available(): + return torch.device(torch.cuda.current_device()) + else: + return torch.device("cpu") + +device = get_torch_device() def apply_init_step(timesteps, init_step=None): if init_step is None: @@ -75,8 +89,8 @@ def __init__(self, model, old_diffusion, schedule="linear", **kwargs): def register_buffer(self, name, attr): if type(attr) == torch.Tensor: - if attr.device != torch.device("cuda"): - attr = attr.to(torch.device("cuda")) + if attr.device != device: + attr = attr.to(device) setattr(self, name, attr) def make_schedule( @@ -98,7 +112,7 @@ def make_schedule( assert ( alphas_cumprod.shape[0] == self.ddpm_num_timesteps ), "alphas have to be defined for each timestep" - to_torch = lambda x: x.clone().detach().to(torch.float32).to("cuda") + to_torch = lambda x: x.clone().detach().to(torch.float32).to(device) self.register_buffer( "betas", to_torch(torch.from_numpy(self.old_diffusion.betas)) @@ -223,7 +237,6 @@ def ddim_sampling( unconditional_guidance_scale=1.0, unconditional_conditioning=None, ): - device = "cuda" b = shape[0] if x_T is None: img = torch.randn(shape, device=device) @@ -278,7 +291,7 @@ def ddim_sampling( if callback: callback(i) if img_callback: - img_callback(pred_x0, i) + img_callback({"i": i, "denoised": img, "x":pred_x0}) if index % log_every_t == 0 or index == total_steps - 1: intermediates["x_inter"].append(img) @@ -341,8 +354,8 @@ def __init__(self, model, old_diffusion, schedule="linear", **kwargs): def register_buffer(self, name, attr): if type(attr) == torch.Tensor: - if attr.device != torch.device("cuda"): - attr = attr.to(torch.device("cuda")) + if attr.device != device: + attr = attr.to(device) setattr(self, name, attr) def make_schedule( @@ -366,7 +379,7 @@ def make_schedule( assert ( alphas_cumprod.shape[0] == self.ddpm_num_timesteps ), "alphas have to be defined for each timestep" - to_torch = lambda x: x.clone().detach().to(torch.float32).to("cuda") + to_torch = lambda x: x.clone().detach().to(torch.float32).to(device) self.register_buffer( "betas", to_torch(torch.from_numpy(self.old_diffusion.betas)) @@ -492,7 +505,6 @@ def plms_sampling( unconditional_guidance_scale=1.0, unconditional_conditioning=None, ): - device = "cuda" b = shape[0] if x_T is None: img = torch.randn(shape, device=device) @@ -560,7 +572,7 @@ def plms_sampling( if callback: callback(i) if img_callback: - img_callback(pred_x0, i) + img_callback({"i": i, "denoised": img, "x":pred_x0}) if index % log_every_t == 0 or index == total_steps - 1: intermediates["x_inter"].append(img) diff --git a/kandinsky2/model/text_encoders.py b/kandinsky2/model/text_encoders.py index 664e8c9..1d5cbcc 100644 --- a/kandinsky2/model/text_encoders.py +++ b/kandinsky2/model/text_encoders.py @@ -12,8 +12,21 @@ ) import transformers import os - - +import platform + +def get_torch_device(): + if "macOS" in platform.platform(): + if torch.backends.mps.is_available(): + return torch.device("mps") + else: + return torch.device("cpu") + else: + if torch.cuda.is_available(): + return torch.device(torch.cuda.current_device()) + else: + return torch.device("cpu") + +device = get_torch_device() def attention(q, k, v, d_k): scores = torch.matmul(q, k.transpose(-2, -1)) / math.sqrt(d_k) scores = F.softmax(scores, dim=-1) @@ -128,20 +141,21 @@ def __init__(self, model_path, model_name, **kwargs): self.model_name = model_name if self.model_name == "clip": self.model = ImagenCLIP() - self.model.load_state_dict(torch.load(model_path)) + self.model.load_state_dict(torch.load(model_path, map_location=device)) elif self.model_name == "T5EncoderModel": - self.model = T5EncoderModel.from_pretrained(model_path) + self.model = T5EncoderModel.from_pretrained(model_path, map_location=device) elif self.model_name == "MT5EncoderModel": - self.model = MT5EncoderModel.from_pretrained(model_path) + self.model = MT5EncoderModel.from_pretrained(model_path, map_location=device) elif self.model_name == "BertModel": - self.model = BertModel.from_pretrained(model_path) + self.model = BertModel.from_pretrained(model_path, map_location=device) elif self.model_name == "multiclip": self.model = MultilingualCLIP(model_path, **kwargs) self.model.load_state_dict( - torch.load(os.path.join(model_path, "pytorch_model.bin")), strict=False + torch.load(os.path.join(model_path, "pytorch_model.bin"), map_location=device), strict=False ) elif self.model_name == "xlm_roberta": - self.model = XLMRobertaModel.from_pretrained(model_path).half() + if device not in [torch.device('mps'), torch.device('cpu')]: + self.model = XLMRobertaModel.from_pretrained(model_path).half() self.model.eval() for param in self.model.parameters(): param.requires_grad = False diff --git a/kandinsky2/model/unet.py b/kandinsky2/model/unet.py index 6a5b6cf..7d77f3a 100644 --- a/kandinsky2/model/unet.py +++ b/kandinsky2/model/unet.py @@ -14,6 +14,25 @@ timestep_embedding, zero_module, ) +import platform + +def get_torch_device(): + if "macOS" in platform.platform(): + if torch.backends.mps.is_available(): + return torch.device("mps") + else: + return torch.device("cpu") + else: + if torch.cuda.is_available(): + return torch.device(torch.cuda.current_device()) + else: + return torch.device("cpu") + +device = get_torch_device() + + +if device in [torch.device('mps') or torch.device('cpu')]: + dt = torch.float32 class TimestepBlock(nn.Module): @@ -35,6 +54,10 @@ class TimestepEmbedSequential(nn.Sequential, TimestepBlock): """ def forward(self, x, emb, encoder_out=None): + if device in [torch.device('mps') or torch.device('cpu')]: + if x.dtype != dt: + x = x.to(dt) + #x = x.float() for layer in self: if isinstance(layer, TimestepBlock): x = layer(x, emb) diff --git a/kandinsky2/model/utils.py b/kandinsky2/model/utils.py index c79aad9..9ddbd39 100644 --- a/kandinsky2/model/utils.py +++ b/kandinsky2/model/utils.py @@ -4,6 +4,7 @@ import torch import torch.nn as nn import importlib +import platform def _extract_into_tensor(arr, timesteps, broadcast_shape): @@ -36,7 +37,7 @@ def get_named_beta_schedule(schedule_name, num_diffusion_timesteps): beta_start = scale * 0.0001 beta_end = scale * 0.02 return np.linspace( - beta_start, beta_end, num_diffusion_timesteps, dtype=np.float64 + beta_start, beta_end, num_diffusion_timesteps, dtype=np.float32 ) elif schedule_name == "cosine": return betas_for_alpha_bar( diff --git a/kandinsky2/vqgan/autoencoder.py b/kandinsky2/vqgan/autoencoder.py index 6fcbb39..e69dd41 100644 --- a/kandinsky2/vqgan/autoencoder.py +++ b/kandinsky2/vqgan/autoencoder.py @@ -134,6 +134,7 @@ def init_from_ckpt(self, path, ignore_keys=list()): print("Deleting key {} from state_dict.".format(k)) del sd[k] self.load_state_dict(sd, strict=False) + del sd print(f"Restored from {path}") def encode(self, x): diff --git a/setup.py b/setup.py index d832a27..fce322e 100644 --- a/setup.py +++ b/setup.py @@ -18,7 +18,7 @@ "regex", "numpy", "blobfile", - "transformers==4.23.1", + "transformers>=4.29.2", "torchvision", "omegaconf", "pytorch_lightning",