使用 ONNX Runtime 與 DirectML 加速 Stable Diffusing 模型推理
Posted on Sun 14 January 2024 in Machine Learning
背景
在先前的實測中,已經利用 ONNX Runtime 與 DirectML 來加速 ResNet50 影像分類器的推理,所以接下來就測試比較複雜也比較實用的生成式模型。 這裡嘗試使用 ONNX Runtime 與 DirectML 來實現對 Stable Diffusion 的加速推理。
Stable Diffusion
Stable Diffusion 是一個最近比較熱門的生成式 AI 模型,目標是利用一段文字的提示詞 (prompt),來產生一張符合提示詞的影像。
實現方法是藉由 CLIP 模型的文字編碼器來引導擴散模型的輸出,擴散模型會將一個隨機雜訊影像 ( 空白畫布 ) 去噪來產生一張人類可解讀的影像。
擴散過程會由一個 unet
網路來實現,文字編碼起產生的特徵會在擴散過程中加入,來控制輸出的結果。
擴散模型的輸出最後會在通過一個 vae
的解碼器來產生最後的影像。
實作
這裡以 ONNX 官方的 C# 範例為基礎來實作。 由於範例的文件比較舊,我做了一些修改以適用最新版本的 ONNX 與執行環境。
環境準備
- 安裝 Visual Studio 2022 (VS) 或更新的版本。
- 從 GitHub 複製倉庫:
git clone https://github.com/cassiebreviu/StableDiffusion.git
- 從 Hugging Face 下載 Stable Diffusion 模型檔案:
git lfs install git clone https://huggingface.co/CompVis/stable-diffusion-v1-4 -b onnx
- 複製包含
model.onnx
模型檔案的unet
,vae_decoder
,text_encoder
,safety_checker
資料夾到 C# 方案的資料夾\StableDiffusion\StableDiffusion
內。 - 切換到
direct-ML-EP
分支,更新方案中的相依性Microsoft.ML
,Microsoft.ML.OnnxRuntime.DirectML
,Microsoft.ML.OnnxRuntime.Extensions
,Microsoft.ML.OnnxRuntime.Managed
到最新版本。 - 若開發環境是 X86-64 處理器,則目標平台設為
X64
,若目標平台為 ARM 處理器,則必須新增ARM64
的設定檔,且由於Microsoft.ML
似乎預設不支援 ARM 處理器,但可在專案檔案中加入屬性來通過編譯:xml <EnableMLUnsupportedPlatformTargetCheck>false</EnableMLUnsupportedPlatformTargetCheck>
程式碼
由於範例比較舊,直接編譯會有問題,需要對程式碼做一些修改,要修改的部分是使用 Microsoft.ML.OnnxRuntime.Extensions
實現將文字轉化成特徵 (Tokenization) 的程式碼,在最新版改變了 API 使用方式。
在 TextProcessing.cs
中的第 32 行將
sessionOptions.RegisterCustomOpLibraryV2(config.OrtExtensionsPath, out var ibraryHandle);
替換成
sessionOptions.RegisterOrtExtensions();
建議再將 ortextensions.dll
從方案裝移除。
在主程式 Program.cs
修改擴散模型提示詞與設定:
using StableDiffusion.ML.OnnxRuntime;
namespace StableDiffusion
{
public class Program
{
static void Main(string[] args)
{
//test how long this takes to execute
var watch = System.Diagnostics.Stopwatch.StartNew();
//Default args
var prompt = "a fireplace in an old cabin in the woods";
Console.WriteLine(prompt);
var config = new StableDiffusionConfig
{
// Number of denoising steps
NumInferenceSteps = 15,
// Scale for classifier-free guidance
GuidanceScale = 7.5,
// Set your preferred Execution Provider. Currently (GPU, DirectML, CPU) are supported in this project.
// ONNX Runtime supports many more than this. Learn more here: https://onnxruntime.ai/docs/execution-providers/
// The config is defaulted to CUDA. You can override it here if needed.
// To use DirectML EP intall the Microsoft.ML.OnnxRuntime.DirectML and uninstall Microsoft.ML.OnnxRuntime.GPU
ExecutionProviderTarget = StableDiffusionConfig.ExecutionProvider.DirectML,
// Set GPU Device ID.
DeviceId = 0,
// Update paths to your models
TextEncoderOnnxPath = @"Path\to\text_encoder\model.onnx",
UnetOnnxPath = @"Path\to\unet\model.onnx",
VaeDecoderOnnxPath = @"Path\to\vae_decoder\model.onnx",
SafetyModelPath = @"Path\to\safety_checker\model.onnx",
};
// Inference Stable Diff
var image = UNet.Inference(prompt, config);
// If image failed or was unsafe it will return null.
if (image == null)
{
Console.WriteLine("Unable to create image, please try again.");
}
// Stop the timer
watch.Stop();
var elapsedMs = watch.ElapsedMilliseconds;
Console.WriteLine("Time taken: " + elapsedMs + "ms");
}
}
}
輸出結果範例 (RTX 4070):
a fireplace in an old cabin in the woods
49406 320 23050 530 550 896 14178 530 518 6267 49407
Seed generated: 68791426
scaled model input 2.1315572 at step 0. Max 4.6911726 Min-3.6267598
20.710197
latents result after step 0 min -35.111282 max 46.04094
scaled model input 2.1276011 at step 1. Max 4.7298803 Min-3.6070542
14.298233
latents result after step 1 min -24.175587 max 33.47158
scaled model input 2.1174793 at step 2. Max 4.9569325 Min-3.5802538
10.329241
latents result after step 2 min -17.471031 max 25.68005
scaled model input 2.117433 at step 3. Max 5.264258 Min-3.5814576
7.299632
latents result after step 3 min -13.020624 max 19.017113
scaled model input 1.9937395 at step 4. Max 5.1941204 Min-3.5563068
5.682288
latents result after step 4 min -10.168077 max 15.260122
scaled model input 1.9952072 at step 5. Max 5.3582473 Min-3.5702908
4.252401
latents result after step 5 min -8.037247 max 12.389794
scaled model input 1.8565629 at step 6. Max 5.4092813 Min-3.508995
3.4030435
latents result after step 6 min -6.6028404 max 10.449694
scaled model input 1.7910485 at step 7. Max 5.499756 Min-3.475127
2.6942093
latents result after step 7 min -5.608853 max 9.02561
scaled model input 1.6612248 at step 8. Max 5.5651083 Min-3.4583673
2.1611075
latents result after step 8 min -4.8157053 max 7.9344597
scaled model input 1.5207505 at step 9. Max 5.5834026 Min-3.3887653
1.8328636
latents result after step 9 min -4.338108 max 7.0801086
scaled model input 1.4373326 at step 10. Max 5.552225 Min-3.4019465
1.4414997
latents result after step 10 min -4.011564 max 6.3761377
scaled model input 1.2330464 at step 11. Max 5.4540935 Min-3.4314573
1.2209218
latents result after step 11 min -3.704286 max 5.804214
scaled model input 1.1176652 at step 12. Max 5.313336 Min-3.3910048
1.020186
latents result after step 12 min -3.4659088 max 5.3420415
scaled model input 0.9829046 at step 13. Max 5.1468234 Min-3.3392518
0.815398
latents result after step 13 min -3.1990318 max 4.928313
scaled model input 0.8150513 at step 14. Max 4.9262176 Min-3.1976717
0.81261224
latents result after step 14 min -3.2003698 max 4.921097
Image saved to: C:\Users\ya-ti\Projects\StableDiffusion\StableDiffusion\bin\x64\Debug\net6.0\sd_image_20240115145501437.png
Time taken: 22118ms
心得
經過實測,此範例在 NVIDIA RTX 4070、Intel Iris Xe Graphics (Intel Core i5-11300H),都可以順利執行加速推理。 然而,這次在高通的 Adreno 8CX Gen 3 (Microsoft SQ 3),會回報 runtime error,而且每一次跑回報的錯誤都有些差異,主要有兩類的錯誤:
- Fatal error
Fatal error. System.AccessViolationException: Attempted to read or write protected memory. This is often an indication that other memory is corrupt.
- ONNX Runtime error
1200890 [E:onnxruntime:, sequential_executor.cc:514 onnxruntime::ExecuteKernel] Non-zero status code returned while running Conv node. Name:'Conv_3482' Status Message:
我猜也許是高通的驅動還不夠完善。
在這個 Stable Diffusion 的範例,DirectML 的相容性在 X86 處理器平台不錯,但是目前在 ARM 處理器上的相容性還不及 X86 處理器。 希望微軟與高通持續改善 Windows on ARM ,盡早追上 X86 處理器的開發體驗。