Implementation Pipeline of Stable Diffusion with ControlNet
@zilla0717
本文梳理了用ControlNet控制Stable Diffusion輸出的實現(xiàn)思路驼鹅。
分析對象
StableDiffusion WebUI
ControlNet Extension for StableDiffusion WebUI
ControlNet作為StableDiffusion WebUI的擴(kuò)展擦酌,遵照其擴(kuò)展開發(fā)規(guī)則寥殖。
參考資料
【StableDiffusion WebUI源碼分析 — 知乎】
1. Gradio的基本用法
2. txt2img的實現(xiàn)
3. 模型加載的過程
4. 啟動流程
5. 多語言的實現(xiàn)方式
6. 腳本的實現(xiàn)方式
7. 擴(kuò)展的實現(xiàn)方式
8. Lora功能的實現(xiàn)方式
【StableDiffusion WebUI的Wiki】
【gradio UI component】
1. 實現(xiàn)擴(kuò)展的一般流程
插件目錄下,各文件、子目錄作用如下:
-
install.py
:若有則自動執(zhí)行柱衔,用于完成依賴庫的安裝堪唐。 - 子目錄
scripts
:放py腳本,插件目錄會被追加到sys.path
缅疟。建議腳本中用scripts.basedir()
來獲取當(dāng)前插件目錄分别,因為用戶可能重命名插件。 -
style.css
和子目錄javascript
中的js文件會被加載到頁面上存淫。 -
preload.py
:若有耘斩,則在程序解析命令之前加載。在該文件里的preload
函數(shù)中追加與該擴(kuò)展有關(guān)的命令行參數(shù)纫雁。如:
def preload(parser):
parser.add_argument("--wildcards-dir", type=str, default=None)
下面說明如何編寫一個py腳本煌往,以“旋轉(zhuǎn)生成的圖片”這一腳本為例(分析見注釋)。
- import必要的包和函數(shù)(這部分不需要改動)
import modules.scripts as scripts
import gradio as gr
import os
from modules import images
from modules.processing import process_images, Processed
from modules.processing import Processed
from modules.shared import opts, cmd_opts, state
- 定義
Script
類轧邪,后續(xù)的title()
刽脖、show()
、ui()
忌愚、run()
都是該類的函數(shù)
class Script(scripts.Script)
-
title()
:定義腳本名稱(顯示在該插件的下拉菜單里)
def title(self):
return "Rotate Output"
-
show()
:其返回值控制該選項何時出現(xiàn)在下拉菜單
def show(self, is_img2img):
# 只有在img2img 界面才在下拉菜單顯示該功能
return is_img2img
-
ui()
:定義這個腳本在UI上怎么展示曲管,其返回值被用作參數(shù)
多數(shù)UI組件返回的是boolean。
def ui(self, is_img2img):
angle = gr.Slider(minimum=0.0, maximum=360.0, step=1, value=0,
label="Angle")
overwrite = gr.Checkbox(False, label="Overwrite existing files")
return [angle, overwrite]
-
run()
:獲取UI傳回的參數(shù)硕糊,做額外的計算過程
該函數(shù)在這個腳本在下拉菜單中被選中時被調(diào)用院水,它必須進(jìn)行所有處理并返回帶有結(jié)果的Processed
對象(與processing.process_images()
返回的結(jié)果相同)。
通常處理過程是調(diào)用process_images()
完成的简十。- 入?yún)?
-
p
(類型為StableDiffusionProcessing
的對象實例)
StableDiffusionProcessing
定義參見module/processing.py
檬某,定義了它以及子類StableDiffusionProcessingTxt2Img
和StableDiffusionProcessingImg2Img
。 -
ui()
返回的參數(shù)
-
-
run()
內(nèi)部可以自定義函數(shù)和引入額外的包螟蝙。 - 對圖片執(zhí)行運算的函數(shù)以由
process_images()
返回的Processed
對象proc
和ui()
獲取的參數(shù) 為入?yún)⒒帜眨紙D片在proc.images
,返回處理后的proc
胰默。
- 入?yún)?
def run(self, p, angle, overwrite):
def rotate(im, angle):
from PIL import Image
raf = im
if angle != 0:
raf = raf.rotate(angle, expand=True)
return raf
basename = ""
if(not overwrite):
if angle != 0:
basename += "rotated_" + str(angle)
else:
p.do_not_save_samples = True
proc = process_images(p)
for i in range(len(proc.images)):
proc.images[i] = rotate(proc.images[i], angle)
images.save_image(proc.images[i], p.outpath_samples, basename, proc.seed + i, proc.prompt, opts.samples_format, info= proc.info, p=p)
return proc
-
process()
:獲取UI傳回的參數(shù)场斑,做額外的計算過程
該函數(shù)類似run()
漓踢,區(qū)別是它在開始執(zhí)行總是可見的腳本前被調(diào)用,即在圖像處理前被調(diào)用漏隐。
before_process_batch()
喧半、process_batch()
、postprocess_batch()
等函數(shù)的作用見modules/scripts.py
青责。
2. ControlNet擴(kuò)展的UI實現(xiàn)和回調(diào)方法
controlnet.py
的寫法類似上面的例子挺据,其ui()
實現(xiàn)如下:
def ui(self, is_img2img):
self.infotext_fields = []
self.paste_field_names = []
controls = ()
max_models = shared.opts.data.get("control_net_max_models_num", 1)
elem_id_tabname = ("img2img" if is_img2img else "txt2img") + "_controlnet"
with gr.Group(elem_id=elem_id_tabname):
with gr.Accordion(f"ControlNet {controlnet_version.version_flag}", open = False, elem_id="controlnet"):
if max_models > 1:
with gr.Tabs(elem_id=f"{elem_id_tabname}_tabs"):
for i in range(max_models):
with gr.Tab(f"ControlNet Unit {i}",
elem_classes=['cnet-unit-tab']):
controls += (self.uigroup(f"ControlNet-{i}", is_img2img, elem_id_tabname),)
else:
with gr.Column():
controls += (self.uigroup(f"ControlNet", is_img2img, elem_id_tabname),)
if shared.opts.data.get("control_net_sync_field_args", False):
for _, field_name in self.infotext_fields:
self.paste_field_names.append(field_name)
return controls
在api.py
中,可以看到 在web app啟動(on_app_started
)時就會調(diào)用controlnet_api()
方法爽柒。
try:
import modules.script_callbacks as script_callbacks
script_callbacks.on_app_started(controlnet_api)
except:
pass
controlnet_api()
中定義了一些異步的方法(其中獲取插件模型列表吴菠、版本、設(shè)置等信息的方法由GET請求調(diào)用浩村,detect()
由POST請求調(diào)用)做葵,實現(xiàn)如下:
def controlnet_api(_: gr.Blocks, app: FastAPI):
@app.get("/controlnet/version")
async def version():
return {"version": external_code.get_api_version()}
@app.get("/controlnet/model_list")
async def model_list():
up_to_date_model_list = external_code.get_models(update=True)
logger.debug(up_to_date_model_list)
return {"model_list": up_to_date_model_list}
@app.get("/controlnet/module_list")
async def module_list(alias_names: bool = False):
_module_list = external_code.get_modules(alias_names)
logger.debug(_module_list)
return {
"module_list": _module_list,
"module_detail": external_code.get_modules_detail(alias_names)
}
@app.get("/controlnet/settings")
async def settings():
max_models_num = external_code.get_max_models_num()
return {"control_net_max_models_num":max_models_num}
cached_cn_preprocessors = global_state.cache_preprocessors(global_state.cn_preprocessor_modules)
@app.post("/controlnet/detect")
async def detect(
controlnet_module: str = Body("none", title='Controlnet Module'),
controlnet_input_images: List[str] = Body([], title='Controlnet Input Images'),
controlnet_processor_res: int = Body(512, title='Controlnet Processor Resolution'),
controlnet_threshold_a: float = Body(64, title='Controlnet Threshold a'),
controlnet_threshold_b: float = Body(64, title='Controlnet Threshold b')
):
controlnet_module = global_state.reverse_preprocessor_aliases.get(controlnet_module, controlnet_module)
if controlnet_module not in cached_cn_preprocessors:
raise HTTPException(
status_code=422, detail="Module not available")
if len(controlnet_input_images) == 0:
raise HTTPException(
status_code=422, detail="No image selected")
logger.info(f"Detecting {str(len(controlnet_input_images))} images with the {controlnet_module} module.")
results = []
processor_module = cached_cn_preprocessors[controlnet_module]
for input_image in controlnet_input_images:
img = external_code.to_base64_nparray(input_image)
results.append(processor_module(img, res=controlnet_processor_res, thr_a=controlnet_threshold_a, thr_b=controlnet_threshold_b)[0])
global_state.cn_preprocessor_unloadable.get(controlnet_module, lambda: None)()
results64 = list(map(encode_to_base64, results))
return {"images": results64, "info": "Success"}
3. ControlNet擴(kuò)展的功能實現(xiàn)
原始的Stable Diffusion 由三個模型構(gòu)成:text encoder模型(CLIPTextModel)、UNet模型和VAE 模型心墅。ControlNet是在UNet網(wǎng)絡(luò)上新增的旁路酿矢,用于增加額外的條件控制Stable Diffusion的輸出。
在
controlnet.py
的Script
類的process()
中怎燥,實現(xiàn)了網(wǎng)絡(luò)結(jié)構(gòu)的注入瘫筐。process()
在圖像處理前被調(diào)用,此處unet
為原先網(wǎng)絡(luò)的結(jié)構(gòu)铐姚,UnetHook
為新定義的結(jié)構(gòu)策肝,通過UnetHook.hook()
改變原始的UNet。
sd_ldm = p.sd_model
unet = sd_ldm.model.diffusion_model
......
self.latest_network = UnetHook(lowvram=hook_lowvram)
self.latest_network.hook(model=unet, sd_ldm=sd_ldm, control_params=forward_params, process=p)
self.detected_map = detected_maps
self.post_processors = post_processors
UnetHook.hook方法隐绵,model即是原先的網(wǎng)絡(luò)之众,hook方法先將原先的模型的forward方法保存起來(model._original_forward = model.forward),然后給它重新賦值依许,賦值為自行實現(xiàn)的forward2棺禾。
- 文本生成圖片
text2img流程
text_embedding = text_encoder(prompt)
for i in steps:
predict_noise = unet(text_embedding, timestamp,latent)
latent_new = DDPM(latent, timestamp) # 求解器
img = vae_decoder(latent)
- img2img的流程
原始的img2img
如圖片卡通風(fēng)格轉(zhuǎn)換
img_info = vae_encoder(img)
latent_init = handle(img_info)
其他類似text2img
unet 我們可以拆開為 uencoder和udecoder。
controlnet_information = contorlnet(controlnet_img, timestamp, latent,text_embedding )
encoder_info = uencoder(timestamp, latent,text_embedding)
信息融合:
decoder_input = controlnet_information * rate + encoder_info
predict_noise = decoder(decoder_input, timestamp, latent,text_embedding )
其他流程與text2img相同
img2paint(with mask)
要梳理什么:
- controlnet的pipeline具體實現(xiàn)峭跳,參考:onnxweb(一個repo)的diffusion 和 diffusers 的 controlnet
需要考慮的是? - controlnet的根據(jù)參數(shù)功能和實現(xiàn)(我有一版本膘婶,晚點發(fā))