使用 ONNX Runtime 與 DirectML 加速 Stable Diffusing 模型推理

Posted on Sun 14 January 2024 in Machine Learning

背景

在先前的實測中,已經利用 ONNX Runtime 與 DirectML 來加速 ResNet50 影像分類器的推理,所以接下來就測試比較複雜也比較實用的生成式模型。 這裡嘗試使用 ONNX Runtime 與 DirectML 來實現對 Stable Diffusion 的加速推理。

Stable Diffusion

Stable Diffusion 是一個最近比較熱門的生成式 AI 模型,目標是利用一段文字的提示詞 (prompt),來產生一張符合提示詞的影像。 實現方法是藉由 CLIP 模型的文字編碼器來引導擴散模型的輸出,擴散模型會將一個隨機雜訊影像 ( 空白畫布 ) 去噪來產生一張人類可解讀的影像。 擴散過程會由一個 unet 網路來實現,文字編碼起產生的特徵會在擴散過程中加入,來控制輸出的結果。 擴散模型的輸出最後會在通過一個 vae 的解碼器來產生最後的影像。

實作

這裡以 ONNX 官方的 C# 範例為基礎來實作。 由於範例的文件比較舊,我做了一些修改以適用最新版本的 ONNX 與執行環境。

環境準備

  1. 安裝 Visual Studio 2022 (VS) 或更新的版本。
  2. 從 GitHub 複製倉庫: git clone https://github.com/cassiebreviu/StableDiffusion.git
  3. 從 Hugging Face 下載 Stable Diffusion 模型檔案git lfs install git clone https://huggingface.co/CompVis/stable-diffusion-v1-4 -b onnx
  4. 複製包含 model.onnx 模型檔案的 unet, vae_decoder, text_encoder, safety_checker 資料夾到 C# 方案的資料夾 \StableDiffusion\StableDiffusion 內。
  5. 切換到 direct-ML-EP 分支,更新方案中的相依性 Microsoft.ML, Microsoft.ML.OnnxRuntime.DirectML, Microsoft.ML.OnnxRuntime.Extensions, Microsoft.ML.OnnxRuntime.Managed 到最新版本。
  6. 若開發環境是 X86-64 處理器,則目標平台設為 X64,若目標平台為 ARM 處理器,則必須新增 ARM64 的設定檔,且由於 Microsoft.ML 似乎預設不支援 ARM 處理器,但可在專案檔案中加入屬性來通過編譯: xml <EnableMLUnsupportedPlatformTargetCheck>false</EnableMLUnsupportedPlatformTargetCheck>

程式碼

由於範例比較舊,直接編譯會有問題,需要對程式碼做一些修改,要修改的部分是使用 Microsoft.ML.OnnxRuntime.Extensions 實現將文字轉化成特徵 (Tokenization) 的程式碼,在最新版改變了 API 使用方式。 在 TextProcessing.cs 中的第 32 行將

sessionOptions.RegisterCustomOpLibraryV2(config.OrtExtensionsPath, out var ibraryHandle);

替換成

sessionOptions.RegisterOrtExtensions();

建議再將 ortextensions.dll 從方案裝移除。

在主程式 Program.cs 修改擴散模型提示詞與設定:

using StableDiffusion.ML.OnnxRuntime;

namespace StableDiffusion
{
    public class Program
    {
        static void Main(string[] args)
        {
            //test how long this takes to execute
            var watch = System.Diagnostics.Stopwatch.StartNew();

            //Default args
            var prompt = "a fireplace in an old cabin in the woods";
            Console.WriteLine(prompt);

            var config = new StableDiffusionConfig
            {
                // Number of denoising steps
                NumInferenceSteps = 15,
                // Scale for classifier-free guidance
                GuidanceScale = 7.5,
                // Set your preferred Execution Provider. Currently (GPU, DirectML, CPU) are supported in this project.
                // ONNX Runtime supports many more than this. Learn more here: https://onnxruntime.ai/docs/execution-providers/
                // The config is defaulted to CUDA. You can override it here if needed.
                // To use DirectML EP intall the Microsoft.ML.OnnxRuntime.DirectML and uninstall Microsoft.ML.OnnxRuntime.GPU
                ExecutionProviderTarget = StableDiffusionConfig.ExecutionProvider.DirectML,
                // Set GPU Device ID.
                DeviceId = 0,
                // Update paths to your models
                TextEncoderOnnxPath = @"Path\to\text_encoder\model.onnx",
                UnetOnnxPath = @"Path\to\unet\model.onnx",
                VaeDecoderOnnxPath = @"Path\to\vae_decoder\model.onnx",
                SafetyModelPath = @"Path\to\safety_checker\model.onnx",
            };

            // Inference Stable Diff
            var image = UNet.Inference(prompt, config);

            // If image failed or was unsafe it will return null.
            if (image == null)
            {
                Console.WriteLine("Unable to create image, please try again.");
            }
            // Stop the timer
            watch.Stop();
            var elapsedMs = watch.ElapsedMilliseconds;
            Console.WriteLine("Time taken: " + elapsedMs + "ms");

        }

    }
}

輸出結果範例 (RTX 4070):

a fireplace in an old cabin in the woods
49406 320 23050 530 550 896 14178 530 518 6267 49407
Seed generated: 68791426
scaled model input 2.1315572 at step 0. Max 4.6911726 Min-3.6267598
20.710197
latents result after step 0 min -35.111282 max 46.04094
scaled model input 2.1276011 at step 1. Max 4.7298803 Min-3.6070542
14.298233
latents result after step 1 min -24.175587 max 33.47158
scaled model input 2.1174793 at step 2. Max 4.9569325 Min-3.5802538
10.329241
latents result after step 2 min -17.471031 max 25.68005
scaled model input 2.117433 at step 3. Max 5.264258 Min-3.5814576
7.299632
latents result after step 3 min -13.020624 max 19.017113
scaled model input 1.9937395 at step 4. Max 5.1941204 Min-3.5563068
5.682288
latents result after step 4 min -10.168077 max 15.260122
scaled model input 1.9952072 at step 5. Max 5.3582473 Min-3.5702908
4.252401
latents result after step 5 min -8.037247 max 12.389794
scaled model input 1.8565629 at step 6. Max 5.4092813 Min-3.508995
3.4030435
latents result after step 6 min -6.6028404 max 10.449694
scaled model input 1.7910485 at step 7. Max 5.499756 Min-3.475127
2.6942093
latents result after step 7 min -5.608853 max 9.02561
scaled model input 1.6612248 at step 8. Max 5.5651083 Min-3.4583673
2.1611075
latents result after step 8 min -4.8157053 max 7.9344597
scaled model input 1.5207505 at step 9. Max 5.5834026 Min-3.3887653
1.8328636
latents result after step 9 min -4.338108 max 7.0801086
scaled model input 1.4373326 at step 10. Max 5.552225 Min-3.4019465
1.4414997
latents result after step 10 min -4.011564 max 6.3761377
scaled model input 1.2330464 at step 11. Max 5.4540935 Min-3.4314573
1.2209218
latents result after step 11 min -3.704286 max 5.804214
scaled model input 1.1176652 at step 12. Max 5.313336 Min-3.3910048
1.020186
latents result after step 12 min -3.4659088 max 5.3420415
scaled model input 0.9829046 at step 13. Max 5.1468234 Min-3.3392518
0.815398
latents result after step 13 min -3.1990318 max 4.928313
scaled model input 0.8150513 at step 14. Max 4.9262176 Min-3.1976717
0.81261224
latents result after step 14 min -3.2003698 max 4.921097
Image saved to: C:\Users\ya-ti\Projects\StableDiffusion\StableDiffusion\bin\x64\Debug\net6.0\sd_image_20240115145501437.png
Time taken: 22118ms

sd_image

心得

經過實測,此範例在 NVIDIA RTX 4070、Intel Iris Xe Graphics (Intel Core i5-11300H),都可以順利執行加速推理。 然而,這次在高通的 Adreno 8CX Gen 3 (Microsoft SQ 3),會回報 runtime error,而且每一次跑回報的錯誤都有些差異,主要有兩類的錯誤:

  • Fatal error
Fatal error. System.AccessViolationException: Attempted to read or write protected memory. This is often an indication that other memory is corrupt.
  • ONNX Runtime error
1200890 [E:onnxruntime:, sequential_executor.cc:514 onnxruntime::ExecuteKernel] Non-zero status code returned while running Conv node. Name:'Conv_3482' Status Message:

我猜也許是高通的驅動還不夠完善。

在這個 Stable Diffusion 的範例,DirectML 的相容性在 X86 處理器平台不錯,但是目前在 ARM 處理器上的相容性還不及 X86 處理器。 希望微軟與高通持續改善 Windows on ARM ,盡早追上 X86 處理器的開發體驗。

分享到: DiasporaTwitterFacebookLinkedInHackerNewsEmailReddit