Stable Diffusion 代码 (一)

前言

在网上没找到对stable diffusion代码的解读。这里记录一下自己读代码的过程和进度作为一个备忘。

官方代码

related contents:

stable Diffusion 代码(二)

Stable Diffusion 代码 (三)

Stable Diffusion 代码(四)

模型的初始化和载入

从prompt生成图片时的命令为

python scripts/txt2img.py --prompt "a photograph of an astronaut riding a horse" --plms

就从入口txt2img.py开始阅读。跳过传入参数的parser部分

# 设定随机seed
seed_everything(opt.seed)

config=OmegaConf.load(f"{opt.config}")
model = load_model_from_config(config, f"{opt.ckpt}")

其中 opt.config= configs/stable-diffusion/v1-inference.yaml,指向一个预定义好的配置文件ckpt是预先下载好的模型

然后看load_model_from_config函数,这一函数就定义在同一个文件(txt2img.py文件)中,但是它调用了ldm.util中的两个方法。这里一起写出来

def instantiate_from_config(config):
return get_obj_from_str(config["target"])(**config.get("params", dict()))

defget_obj_from_str(string,reload=False):
module, cls = string.rsplit(".", 1)
return getattr(importlib.import_module(module, package=None), cls)

defload_model_from_config(config,ckpt):
pl_sd = torch.load(ckpt, map_location="cpu")
sd = pl_sd["state_dict"]
model = instantiate_from_config(config.model)
m, u = model.load_state_dict(sd, strict=False)
model.cuda()
model.eval()
return model

实际上等效于

from ldm.models.diffusion.ddpm import LatentDiffusion
model = LatentDiffusion(**config.model.get("params", dict()))
model.load_state_dict(torch.load(ckpt,map_location="cpu")["state_dict"],strict=False)

原code使用importlib.import_module,来读取字典中的模块名称进行灵活的import。从方便理解代码运行和算法原理的视角来看,在实际使用LatentDiffusion时,上下两种写法是完全等效的。

这里多说一句,Config字典类似于

Config = { target: path1.path2.module_1_name,
params: { para_1 : value_a,
para_2 : value_b,
module_2:{ target: path1.path2.module_2_name,
params: { para_3 : value_c,
module_3:{ target: path1.path2.module_3_name,
params : {para_4: value_d }
}}}}}

get_obj_from_str接收config字典中target对应的值来导入对应的模块,

在 instantiate_from_config 返回对应的类的实例,返回的实例是以params对应的值初始化的params对应的值是同等格式的字典。

也就是说,config中可以像上面的例子一样,设置好嵌套的各个模块,并且在模块实例化时读取传入的config,在模块的__init__中继续调用instantiate_from_config就可以实现各个模块嵌套式的实例化。具体的例子可以看第三篇。

# 初始化模型的全部逻辑:

fromldm.models.diffusion.ddpmimportLatentDiffusion
import torch
from omegaconf import OmegaConf

# 读取config
config = OmegaConf.load(f"{opt.config}")

# 初始化模型并传入config中的参数
model = LatentDiffusion(**config.model.get("params", dict()))
model.load_state_dict(torch.load(ckpt, map_location="cpu")["state_dict"], strict=False)

device=torch.device("cuda")
model = model.to(device)

图像生成的准备和图像的生成

有了model之后是sampler的初始化 (基于命令行传入的 --plms,执行判断语句的第一条)

sampler = PLMSSampler(model)

紧接着,原代码提供了两种输入prompt的方法,分别是命令行输入和从文件读取,不关键。总之最后prompt进入了data这个变量

data = [batch_size * [prompt]]

到这里,我们有了

model-[LatentDiffusion]sampler-[PLMSSampler] prompt

这样就可以开始生成图片了。

这里有两个重要的部分,一个是PLMSSampler的定义,一个是LatentDiffusion的定义。我们先将这两个模块视作黑箱,假定它们能完美的完成各自的任务,之后再详细看它们的代码。

这里先简单回忆一下classifier-free guidance的方法: ϵ(x,t)=ϵ(x,t|ϕ)+α⋅(ϵ(x,t|c)−ϵ(x,t|ϕ))epsilon(x, t)= epsilon(x,t ~| ~phi) + alphacdot (epsilon(x,t~|~ c) -epsilon(x,t~ |~ phi))

因此除了prompt,也就是上式中c所对应的条件,还需要unconditional的ϕphi 。

c = model.get_learned_conditioning(prompts)
uc=model.get_learned_conditioning(batch_size*[""])

这里可以看到model中的一个方法 get_learned_conditioning() : 输入text, 输出text的embedding 。

之后就是图像的生成了。图像的生成调用sampler实例的sample方法。这里为了直观的理解省略了几个参数,完整的参数和具体的各个参数的作用在后面sampler的代码解读部分再说。

samples_ddim, _ = sampler.sample(S=50,
conditioning=c,
batch_size=1,
shape=[4,64,64],
unconditional_guidance_scale=7.5,
unconditional_conditioning=uc,
eta=opt.ddim_eta)
x_samples_ddim=model.decode_first_stage(samples_ddim)

到这里为止,diffusion的任务已经结束了,x_samples_ddim 再经过基本的图像处理就是最终的结果。

以上就是txt2img.py文件的全部内容。这一部分绝大多数代码都是数据的读写和准备工作,核心逻辑部分比较少,还是比较好理解的。

接下来进入plms文件去看sampler的代码实现。

版权声明:ai机器人 发表于 2023年3月28日 am2:36。
转载请注明:Stable Diffusion 代码 (一) | AI工具箱

相关文章

暂无评论

您必须登录才能参与评论!
立即登录
暂无评论...