使用 ONNX Runtime 與 Qualcomm QNN 加速影像分類模型推理

Posted on Wed 24 January 2024 in Machine Learning

2024-05-24 更新:由於 ONNX QNN EP 文件更新很多,因此文章因應而大幅更新。

2024-08-15 更新:更正 MobileNet 相關的內文錯誤

2024-08-19 更新:更新模型量化的相關內容與結果

背景

由於最近生成式 AI 的興起,如 ChatGPT 與 DALL·E,微軟攜手廠商們開始推廣 AI PC,Intel、AMD 與高通等廠商,開始直接在處理器當中內建 Neural Processing Unit (NPU) 來加速 AI 模型的推理,且相較於在 GPU 推理有著更低的功耗。 在之前的測試中,實現了使用 ONNX Runtime 與 DirectML 實現對影像分類器的加速推理,這裡想嘗試使用基於 NPU 的推理加速。

ONNX Runtime

ONNX Runtime 是一種跨平台的 AI 加速框架,支援主流的作業系統與程式語言,以及主流的深度學習框架所訓練出來的模型,如 PyTorch 與 TensorFlow。 ONNX 支援多種的加速方式 (Execution Providers, EP),常見的像是 CPU, CUDA, DirectML 等等,最近也加入各家廠商的 NPU 支援,其餘的 EP 與細節可以參考官方的文件。 這裡以 Qualcomm - QNN 來嘗試實作,它使用 Qualcomm AI Engine Direct SDK (QNN SDK) 作為 ONNX Runtime 的後端。 根據文件,目前在 Qualcomm SC8280X 與 SM8350 兩款晶片 建置與測試,對應的行銷用型號應為 Snapdragon 8cx Gen 3 與 Snapdragon 888,這裡的測試平台是 Microsoft SQ3 ( 微軟客製化的 Snapdragon 8cx Gen 3)。

實作

這裡嘗試用官方的 ResNet50 C# 範例為基礎來實作。 由於範例的文件比較舊,我做了一些修改以適用最新版本的 ONNX 與執行環境。

環境準備

  1. 安裝 Visual Studio 2022 (VS) 或更新的版本。
  2. 在 VS 中創建一個 C# Console App,.NET 版本選擇 6.0 或更高。
  3. 在方案的相依性安裝支援 QNN 的 ONNX Runtime
  4. 安裝前處理用的相依性
  5. 下載 ONNX 模型檔案( 提醒:為了要能用 NPU 來推理,模型檔案的 Opset version 需要大於等於 11)
  6. 下載一張待推理的影像,或其他影像。
  7. 準備一個有安裝 onnxruntime-qnn 的 Python 環境。

前處理

固定輸入的形狀

由於要使用 QNN EP 存在一些限制,第一個是 QNN 不支援動態的 batch size 輸入,因此需要將模型的輸入 shape 固定成 1,ONNX 有提供工具來將輸入的形狀改成固定。 首先利用 Netron 找出模型的輸入名稱與形狀,以這裡用的模型為例:

Netron

輸入名稱是 data,batch size 是 N,我們要將 N 改成 1,轉換工具的用法如下:

python -m onnxruntime.tools.make_dynamic_shape_fixed --dim_param N --dim_value 1 resnet50-v1-12.onnx resnet50-v1-12.fixed.onnx

轉換完成之後可以再用 Netron 確認是否已經成功修改。

模型量化 (Model quantization)

由於 NPU 只接受量化過的模型,可能需要將模型做量化,參考網頁的範例,需要準備校正用的資料來做模型量化,為了讀圖方便而經過修改之後,在這個例子的讀檔案實現如下:

# resnet50_data_reader.py

import numpy
import onnxruntime
import os
from onnxruntime.quantization import CalibrationDataReader
from PIL import Image


def _preprocess_images(images_folder: str, height: int, width: int, size_limit=0):
    """
    Loads a batch of images and preprocess them
    parameter images_folder: path to folder storing images
    parameter height: image height in pixels
    parameter width: image width in pixels
    parameter size_limit: number of images to load. Default is 0 which means all images are picked.
    return: list of matrices characterizing multiple images
    """
    batch_filenames = []
    for root, dirs, files in os.walk(images_folder):
        for file in files:
            if size_limit > 0 and len(batch_filenames) >= size_limit:
                break

            batch_filenames.append(os.path.join(root, file))

    unconcatenated_batch_data = []

    for image_filepath in batch_filenames:
        pillow_img = Image.new("RGB", (width, height))
        pillow_img.paste(Image.open(image_filepath).resize((width, height)))
        input_data = numpy.float32(pillow_img) - numpy.array(
            [123.68, 116.78, 103.94], dtype=numpy.float32
        )
        nhwc_data = numpy.expand_dims(input_data, axis=0)
        nchw_data = nhwc_data.transpose(0, 3, 1, 2)  # ONNX Runtime standard
        unconcatenated_batch_data.append(nchw_data)
    batch_data = numpy.concatenate(
        numpy.expand_dims(unconcatenated_batch_data, axis=0), axis=0
    )
    return batch_data

class ResNet50DataReader(CalibrationDataReader):
    def __init__(self, calibration_image_folder: str, model_path: str):
        self.enum_data = None

        # Use inference session to get input shape.
        session = onnxruntime.InferenceSession(model_path, providers=['CPUExecutionProvider'])
        (_, _, height, width) = session.get_inputs()[0].shape

        # Convert image to input data
        self.nhwc_data_list = _preprocess_images(
            calibration_image_folder, height, width, size_limit=0
        )
        self.input_name = session.get_inputs()[0].name
        self.datasize = len(self.nhwc_data_list)

    def get_next(self):
        if self.enum_data is None:
            self.enum_data = iter(
                [{self.input_name: nhwc_data} for nhwc_data in self.nhwc_data_list]
            )
        return next(self.enum_data, None)

    def rewind(self):
        self.enum_data = None

量化模型用的程式碼如下:

# quantize_model.py

import numpy as np
import onnx
from resnet50_data_reader import ResNet50DataReader
from onnxruntime.quantization import QuantType, quantize
from onnxruntime.quantization.execution_providers.qnn import get_qnn_qdq_config, qnn_preprocess_model

if __name__ == "__main__":
    input_model_path = "resnet50-v1-12.fixed.onnx"  # Replace with your actual model
    output_model_path = "resnet50-v1-12.qdq.onnx"  # Name of final quantized model
    calibration_image_folder = "test_images" # Path to calibration data
    my_data_reader = ResNet50DataReader(calibration_image_folder, input_model_path)

    # Pre-process the original float32 model.
    preproc_model_path = "model.preproc.onnx"
    model_changed = qnn_preprocess_model(input_model_path, preproc_model_path)
    model_to_quantize = preproc_model_path if model_changed else input_model_path

    # Generate a suitable quantization configuration for this model.
    # Note that we're choosing to use uint16 activations and uint8 weights.
    qnn_config = get_qnn_qdq_config(model_to_quantize,
                                    my_data_reader,
                                    activation_type=QuantType.QUInt16,  # uint16 activations
                                    weight_type=QuantType.QUInt8)       # uint8 weights

    # Quantize the model.
    quantize(model_to_quantize, output_model_path, qnn_config)

這裡準備了 10000 張來自 ImageNet dataset 的影像來當作校正資料,其中每一類別隨機挑選了 10 張影像。執行 quantize_model.py 後就可以得到量化後的模型 resnet50-v1-12.qdq.onnx

推理程式碼

  1. 先準備分類標籤
  2. 主程式 Program.cs
using System;
using System.Collections.Generic;
using System.Linq;
using Microsoft.ML.OnnxRuntime.Tensors;
using SixLabors.ImageSharp;
using SixLabors.ImageSharp.PixelFormats;
using SixLabors.ImageSharp.Processing;
using static System.Net.Mime.MediaTypeNames;

namespace Microsoft.ML.OnnxRuntime.ResNet50v2Sample
{   
    class Program
    {
        public static void Main(string[] args)
        {
            // Read paths
            string modelFilePath = "Path\\to\\model\\resnet50-v1-12.qdq.onnx";
            string imageFilePath = "Path\\to\\image\\dog.jpeg";

            // Read image
            using Image<Rgb24> image = SixLabors.ImageSharp.Image.Load<Rgb24>(imageFilePath);

            // Resize image
            image.Mutate(x =>
            {
                x.Resize(new ResizeOptions
                {
                    Size = new Size(224, 224),
                    Mode = ResizeMode.Crop
                });
            });

            // Preprocess image
            Tensor<float> input = new DenseTensor<float>(new[] { 1, 3, 224, 224 });
            var mean = new[] { 0.485f, 0.456f, 0.406f };
            var stddev = new[] { 0.229f, 0.224f, 0.225f };
            image.ProcessPixelRows(accessor =>
            {
                for (int y = 0; y < accessor.Height; y++)
                {
                    Span<Rgb24> pixelSpan = accessor.GetRowSpan(y);
                    for (int x = 0; x < accessor.Width; x++)
                    {
                        input[0, 0, y, x] = ((pixelSpan[x].R / 255f) - mean[0]) / stddev[0];
                        input[0, 1, y, x] = ((pixelSpan[x].G / 255f) - mean[1]) / stddev[1];
                        input[0, 2, y, x] = ((pixelSpan[x].B / 255f) - mean[2]) / stddev[2];
                    }
                }
            });

            // Setup inputs
            var inputs = new List<NamedOnnxValue>
            {
                NamedOnnxValue.CreateFromTensor("data", input)
            };

            // Run inference
            Dictionary<string, string> qnn_options = new Dictionary<string, string>();
            qnn_options.Add("backend_path", "QnnHtp.dll"); //"QnnHtp.dll" for NPU or "QnnCpu.dll" for CPU
            using var session_options = new SessionOptions();
            session_options.AppendExecutionProvider("QNN", qnn_options);
;
            using var session = new InferenceSession(modelFilePath, session_options);
            using IDisposableReadOnlyCollection<DisposableNamedOnnxValue> results = session.Run(inputs);

            // Postprocess to get softmax vector
            IEnumerable<float> output = results.First().AsEnumerable<float>();
            float sum = output.Sum(x => (float)Math.Exp(x));
            IEnumerable<float> softmax = output.Select(x => (float)Math.Exp(x) / sum);

            // Extract top 10 predicted classes
            IEnumerable<Prediction> top10 = softmax.Select((x, i) => new Prediction { Label = LabelMap.Labels[i], Confidence = x })
                               .OrderByDescending(x => x.Confidence)
                               .Take(10);

            // Print results to console
            Console.WriteLine("Top 10 predictions...");
            Console.WriteLine("--------------------------------------------------------------");
            foreach (var t in top10)
            {
                Console.WriteLine($"Label: {t.Label}, Confidence: {t.Confidence}");
            }
        }
    }
}

首先使用 QnnCpu.dll 來測試推理 resnet50-v1-12.fixed.onnx,輸出結果範例:

Error in cpuinfo: Unknown chip model name 'Microsoft SQ3'.
Please add new Windows on Arm SoC/chip support to arm/windows/init.c!
unknown ARM CPU part 0xd4b ignored
unknown ARM CPU part 0xd4b ignored
unknown ARM CPU part 0xd4b ignored
unknown ARM CPU part 0xd4b ignored
unknown ARM CPU part 0xd4c ignored
unknown ARM CPU part 0xd4c ignored
unknown ARM CPU part 0xd4c ignored
unknown ARM CPU part 0xd4c ignored
Top 10 predictions...
--------------------------------------------------------------
Label: Golden Retriever, Confidence: 0.8330527
Label: Kuvasz, Confidence: 0.058261674
Label: Saluki, Confidence: 0.052930113
Label: Flat-Coated Retriever, Confidence: 0.005992945
Label: English Setter, Confidence: 0.0042325333
Label: Afghan Hound, Confidence: 0.0037076203
Label: Irish Setter, Confidence: 0.0036869524
Label: Clumber Spaniel, Confidence: 0.003038115
Label: Curly-coated Retriever, Confidence: 0.003019722
Label: Sussex Spaniel, Confidence: 0.0028180785

ONNX Runtime 提示無法辨識處理器型號,但依然可以正確地執行。

不幸地,當後端指定為 QnnHtp.dll 時,卻回報 runtime error,無法正常執行,看起來是跑 BatchNormalization 運算時出現錯誤:

Microsoft.ML.OnnxRuntime.OnnxRuntimeException
  HResult=0x80131500
  Message=[ErrorCode:Fail] Node 'BatchNormalization' OpType:BatchNormalization with domain:com.ms.internal.nhwc was inserted using the NHWC format as requested by QNNExecutionProvider, but was not selected by that EP. This means the graph is now invalid as there will not be an EP able to run the node. This could be a bug in layout transformer, or in the GetCapability implementation of the EP.
  Source=Microsoft.ML.OnnxRuntime
  StackTrace:
   at Microsoft.ML.OnnxRuntime.NativeApiStatus.VerifySuccess(IntPtr nativeStatus)
   at Microsoft.ML.OnnxRuntime.InferenceSession.Init(String modelPath, SessionOptions options, PrePackedWeightsContainer prepackedWeightsContainer)
   at Microsoft.ML.OnnxRuntime.InferenceSession..ctor(String modelPath, SessionOptions options)
   at Microsoft.ML.OnnxRuntime.ResNet50v2Sample.Program.Main(String[] args) in C:\Users\ya-ti\Repos\ONNX_ResNet\ONNX_ResNet\Program.cs:line 63

由於 ResNet 無法順利地用 NPU 來推理,這裡嘗試換成另一個範例模型 MobileNet V2 來測試。 經過前面所述的前處理後,將模型量化為 mobilenetv2-12.qdq.onnx,使用 QnnHtp.dll 執行的結果如下:

Error in cpuinfo: Unknown chip model name 'Microsoft SQ3'.
Please add new Windows on Arm SoC/chip support to arm/windows/init.c!
unknown ARM CPU part 0xd4b ignored
unknown ARM CPU part 0xd4b ignored
unknown ARM CPU part 0xd4b ignored
unknown ARM CPU part 0xd4b ignored
unknown ARM CPU part 0xd4c ignored
unknown ARM CPU part 0xd4c ignored
unknown ARM CPU part 0xd4c ignored
unknown ARM CPU part 0xd4c ignored
Starting stage: Graph Preparation Initializing
Completed stage: Graph Preparation Initializing (1436 us)
Starting stage: Graph Transformations and Optimizations
Completed stage: Graph Transformations and Optimizations (191067 us)
Starting stage: Graph Sequencing for Target
Completed stage: Graph Sequencing for Target (31089 us)
Starting stage: VTCM Allocation
Completed stage: VTCM Allocation (21108 us)
Starting stage: Parallelization Optimization
Completed stage: Parallelization Optimization (3288 us)
Starting stage: Finalizing Graph Sequence
Completed stage: Finalizing Graph Sequence (3288 us)
Starting stage: Completion
Completed stage: Completion (217 us)
Top 10 predictions...
--------------------------------------------------------------
Label: Golden Retriever, Confidence: 0.59277284
Label: Kuvasz, Confidence: 0.08672298
Label: Clumber Spaniel, Confidence: 0.08043996
Label: Saluki, Confidence: 0.06905809
Label: Otterhound, Confidence: 0.058527183
Label: Sussex Spaniel, Confidence: 0.02276213
Label: English Setter, Confidence: 0.02189854
Label: Tibetan Terrier, Confidence: 0.011933949
Label: Pyrenean Mountain Dog, Confidence: 0.0105469
Label: Afghan Hound, Confidence: 0.010070729

成功地使用 NPU 來加速推理,且分類的結果正確 (CPU: Label: Golden Retriever, Confidence: 0.43511143)。 有趣的是量化後的模型得到的分數比未量化的模型還要高,但這裡只測試了一張影像,可能沒有代表性,也許是 overfitting。

ONNX 的 GitHub 也有上傳用標準 ONNX 量化的 MobileNet V2,這裡也嘗試直接用 NPU 推理這個模型,部分輸出如下:

2024-05-24 13:59:17.9089996 [W:onnxruntime:, qnn_model_wrapper.cc:240 onnxruntime::qnn::QnnModelWrapper::CreateQnnNode] QNN.backendValidateOpConfig() failed for node `Add_26` of type `ElementWiseAdd` with error code 3110

2024-05-24 13:59:17.9167208 [W:onnxruntime:, qnn_execution_provider.cc:364 onnxruntime::QNNExecutionProvider::IsNodeSupported] Add node `Add_26` is not supported: base_op_builder.cc:162 onnxruntime::qnn::BaseOpBuilder::ProcessOutputs Failed to add node.

2024-05-24 13:59:18.2300141 [W:onnxruntime:, qnn_execution_provider.cc:364 onnxruntime::QNNExecutionProvider::IsNodeSupported] Add node `Gemm_104_Add` is not supported: base_op_builder.cc:162 onnxruntime::qnn::BaseOpBuilder::ProcessOutputs Failed to add node.
Starting stage: Graph Preparation Initializing
Completed stage: Graph Preparation Initializing (742 us)
Starting stage: Graph Transformations and Optimizations
Completed stage: Graph Transformations and Optimizations (30250 us)
Starting stage: Graph Sequencing for Target
Completed stage: Graph Sequencing for Target (6511 us)
Starting stage: VTCM Allocation
Completed stage: VTCM Allocation (4765 us)
Starting stage: Parallelization Optimization
Completed stage: Parallelization Optimization (697 us)
Starting stage: Finalizing Graph Sequence
Completed stage: Finalizing Graph Sequence (1064 us)
Starting stage: Completion
Completed stage: Completion (144 us)
2024-05-24 13:59:18.6360876 [W:onnxruntime:, session_state.cc:1166 onnxruntime::VerifyEachNodeIsAssignedToAnEp] Some nodes were not assigned to the preferred execution providers which may or may not have an negative impact on performance. e.g. ORT explicitly assigns shape related ops to CPU to improve perf.
2024-05-24 13:59:18.6589396 [W:onnxruntime:, session_state.cc:1168 onnxruntime::VerifyEachNodeIsAssignedToAnEp] Rerunning with verbose output on a non-minimal build will show node assignments.
Top 10 predictions...
--------------------------------------------------------------
Label: Golden Retriever, Confidence: 0.2638858
Label: Kuvasz, Confidence: 0.22160825
Label: Saluki, Confidence: 0.1875134
Label: Clumber Spaniel, Confidence: 0.07285915
Label: Otterhound, Confidence: 0.031498097
Label: borzoi, Confidence: 0.027598232
Label: English Setter, Confidence: 0.02686022
Label: Sussex Spaniel, Confidence: 0.019354017
Label: Afghan Hound, Confidence: 0.017119078
Label: Pyrenean Mountain Dog, Confidence: 0.010476385

雖然出現了額外的警告訊息,提示某些運算不支援,還是有成功跑出結果。雖然分類結果正確,但是數值跟未經過量化模型的數值有明顯的差異,可能體現出了量化誤差。 從分數可看出兩種量化模型的表現差異很大,雖然不知道官方是用什麼資料做模型量化,但可以看出校正資料的影響很大,要準備對預期目標有代表性的資料才會有較好的效果。

心得

經過測試,ONNX Runtime 與 QNN EP 的組合,要使用 NPU 需要做不少額外的前置處理,且不支援部分運算或是存在 bug。 此外,量化模型需要使用 QNN 版 ONNX 專用的量化 API,似乎不能直接相容用一般 ONNX API 量化的模型,否則推理可能會有問題,是一個比較麻煩的地方。希望微軟與高通可以讓 QNN EP 在 ONNX 的開發體驗更友好。

分享到: DiasporaTwitterFacebookLinkedInHackerNewsEmailReddit