maemaewaterの日記

エンジニア兼ゲーマーの人の日記です。PHP/Python/JavaScript/C#/C++などによるプログラムに関することを主に書いています。

UWPアプリケーションでONNXを使用する

UWPアプリケーション(Universal Windows Platform)での機械学習で使用されるONNX形式のモデルの使い方についてです。いくつか問題になるところがありましたので、その問題点と解決方法についてになります。

モデルファイルの読み込み

UWPアプリケーションでは、ファイルの読み込みを行う際にはユーザーの操作(ファイルやフォルダの選択ダイアログからなど)によってなど読み込む場所が限定されています。モデルなどはAssetに入れて置き読み込むのが良いと思われますが、ML.NET(Microsoft.ML.OnnxRuntime)の場合はモデルのファイルをパスで指定するため簡単にAssetから読み込みができるようではなさそうでした。

どうしようかと思っていたところ調べてみると、なんとUWPで利用できるWindows AI Platformがあらかじめ用意されていました。こちらを使うのが素直な感じなのですね。

docs.microsoft.com

そのようなわけで、Windows AI PlatformのWindows.AI.MachineLearningを使っていくのが良さそうです。

モデルファイルのバージョン

Windows AI PlatformでもONNX形式のファイルを取り扱えますが、実は問題が発生していまいました。読み込もうとしていたONNXファイルで定義されているオペレーターのバージョンが対応していないというエラーになります。以下のページにバージョンについて書かれております。

github.com

docs.microsoft.com

Windows.AIの方では、バージョン8までの対応となっているため、このバージョンより新しい場合には変換することで対応できます。MicrosoftWindows.AIのページにも書かれているのですが、ONNX(https://github.com/onnx/onnx)の方で変換のライブラリが用意されています。Pythonをあらかじめインストールしておく必要がありますが、次の手順で変換が行えます(Windowsでももちろん動きます)。

pip install onnx

変換のプログラム

import onnx
from onnx import version_converter

model = onnx.load("model.onnx")
converted_model = version_converter.convert_version(model, 8)

onnx.save(converted_model, "model_8.onnx")

ONNXを使用する例

UWPのプロジェクトでAssetに変換したONNXのファイルを追加します。ここで、プロパティでビルドアクションがコンテンツになっていない場合は変更します(変更しないとファイルが見つからないというエラーになります)。

次がONNXを利用するコードになります。

using Windows.AI.MachineLearning;
float data = new float[16 * 16]; // 入力

// dataの初期化をします

LearningModel model;
LearningModelSession session;

var modelFile = await StorageFile.GetFileFromApplicationUriAsync(new Uri("ms-appx:///Assets/model_8.onnx"));
model = await LearningModel.LoadFromStorageFileAsync(modelFile);

session = new LearningModelSession(model, new LearningModelDevice(LearningModelDeviceKind.Default));

LearningModelBinding binding = new LearningModelBinding(session);


TensorFloat tensor = TensorFloat.CreateFromArray(new long[]{ 1, 1, 16, 16 }, data);
binding.Bind("0", tensor);

var modelOutput = await session.EvaluateAsync(binding, "run");
List<float> v = new List<float>();
foreach (var item in modelOutput.Outputs)
{
   TensorFloat outTensor = (TensorFloat)item.Value;
   v = outTensor.GetAsVectorView().ToList();
}

// v.IndexOf(v.Max()) で結果を取得