-
Notifications
You must be signed in to change notification settings - Fork 0
Expand file tree
/
Copy pathOnnxClassifier.cpp
More file actions
308 lines (263 loc) · 10.1 KB
/
OnnxClassifier.cpp
File metadata and controls
308 lines (263 loc) · 10.1 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
/**
* @file OnnxClassifier.cpp
* @brief ONNX 模型分类器实现
*
* 核心流程:
* loadModel() -> 创建 Ort::Env + Ort::Session
* classify() -> preprocess() -> Ort::Session::Run() -> 解析输出
*/
#include "OnnxClassifier.h"
#include "ImageUtils.h"
#if __has_include(<onnxruntime_cxx_api.h>)
#include <onnxruntime_cxx_api.h>
#elif __has_include(<onnxruntime/core/session/onnxruntime_cxx_api.h>)
#include <onnxruntime/core/session/onnxruntime_cxx_api.h>
#else
#error "onnxruntime_cxx_api.h not found"
#endif
#include <QDebug>
#include <QMap>
#include <QRegularExpression>
#include <opencv2/imgproc.hpp>
#include <algorithm>
#include <array>
#include <cmath>
namespace {
QString readCustomMetadataValue(Ort::ModelMetadata &metadata, const char *key) {
try {
Ort::AllocatorWithDefaultOptions allocator;
auto value = metadata.LookupCustomMetadataMapAllocated(key, allocator);
if (value && value.get() != nullptr) {
return QString::fromUtf8(value.get());
}
} catch (const std::exception &) {
}
return {};
}
QStringList parseNamesFromMetadata(const QString &text) {
QMap<int, QString> parsed;
const QRegularExpression pattern(R"((\d+)\s*:\s*['"]([^'"]*)['"])");
auto it = pattern.globalMatch(text);
while (it.hasNext()) {
const auto match = it.next();
const int index = match.captured(1).toInt();
const QString name = match.captured(2);
parsed[index] = name;
}
if (parsed.isEmpty()) {
return {};
}
QStringList names;
names.reserve(parsed.lastKey() + 1);
for (int i = 0; i <= parsed.lastKey(); ++i) {
names << parsed.value(i, QString("class_%1").arg(i));
}
return names;
}
} // namespace
OnnxClassifier::OnnxClassifier() {
}
OnnxClassifier::~OnnxClassifier() {
}
bool OnnxClassifier::loadModel(const QString &modelPath) {
m_loaded = false;
m_session.reset();
m_modelClassNames.clear();
try {
// 创建 ONNX Runtime 运行环境
m_env = std::make_unique<Ort::Env>(ORT_LOGGING_LEVEL_WARNING, "YOLOClassifier");
// 配置推理选项:4 线程 + 全图优化
Ort::SessionOptions sessionOptions;
sessionOptions.SetIntraOpNumThreads(4);
sessionOptions.SetGraphOptimizationLevel(GraphOptimizationLevel::ORT_ENABLE_ALL);
// 加载模型文件
#ifdef _WIN32
const std::wstring nativeModelPath = modelPath.toStdWString();
m_session = std::make_unique<Ort::Session>(
*m_env, nativeModelPath.c_str(), sessionOptions
);
#else
const std::string nativeModelPath = modelPath.toStdString();
m_session = std::make_unique<Ort::Session>(
*m_env, nativeModelPath.c_str(), sessionOptions
);
#endif
// 验证模型输入
size_t inputCount = m_session->GetInputCount();
if (inputCount == 0) {
qWarning() << "No input found in ONNX model";
return false;
}
// 打印模型输入形状(调试用)
auto inputTypeInfo = m_session->GetInputTypeInfo(0);
auto tensorInfo = inputTypeInfo.GetTensorTypeAndShapeInfo();
auto shape = tensorInfo.GetShape();
if (shape.size() == 4 && shape[2] > 0 && shape[3] > 0) {
m_inputHeight = static_cast<int>(shape[2]);
m_inputWidth = static_cast<int>(shape[3]);
}
qDebug() << "Model loaded:" << modelPath;
QStringList dims;
for (const auto dim : shape) {
dims << QString::number(dim);
}
qDebug() << "Input shape:" << dims.join(" x ");
qDebug() << "Preprocess mode: OpenCV decode -> short-edge resize -> center crop -> RGB -> [0,1]";
try {
auto metadata = m_session->GetModelMetadata();
const QString task = readCustomMetadataValue(metadata, "task");
const QString imgsz = readCustomMetadataValue(metadata, "imgsz");
const QString names = readCustomMetadataValue(metadata, "names");
const QString args = readCustomMetadataValue(metadata, "args");
if (!task.isEmpty()) {
qDebug() << "Metadata task:" << task;
}
if (!imgsz.isEmpty()) {
qDebug() << "Metadata imgsz:" << imgsz;
}
if (!args.isEmpty()) {
qDebug() << "Metadata args:" << args;
}
if (!names.isEmpty()) {
m_modelClassNames = parseNamesFromMetadata(names);
qDebug() << "Metadata names parsed:" << m_modelClassNames;
}
} catch (const std::exception &e) {
qWarning() << "Failed to read ONNX metadata:" << e.what();
}
m_loaded = true;
return true;
} catch (const std::exception &e) {
qCritical() << "Failed to load model:" << e.what();
return false;
}
}
bool OnnxClassifier::isLoaded() const {
return m_loaded;
}
void OnnxClassifier::setClassNames(const QStringList &names) {
m_classNames = names;
}
QStringList OnnxClassifier::modelClassNames() const {
return m_modelClassNames;
}
OnnxClassifier::Result OnnxClassifier::classify(const QString &imagePath) {
Result result;
if (!m_loaded) {
qWarning() << "Model not loaded";
return result;
}
try {
// 1. 图像预处理(缩放短边、中心裁剪、转 RGB、[0,1])
auto input = preprocess(imagePath);
if (input.empty()) {
qWarning() << "Failed to preprocess image:" << imagePath;
return result;
}
// 2. 获取模型输入/输出节点名称
Ort::AllocatorWithDefaultOptions allocator;
auto inputName = m_session->GetInputNameAllocated(0, allocator);
auto outputName = m_session->GetOutputNameAllocated(0, allocator);
const char *inputNames[] = { inputName.get() };
const char *outputNames[] = { outputName.get() };
// 3. 创建输入 Tensor(NCHW 格式: [1, 3, H, W])
auto memoryInfo = Ort::MemoryInfo::CreateCpu(OrtArenaAllocator, OrtMemTypeDefault);
const std::array<int64_t, 4> inputShape{
1, 3, static_cast<int64_t>(m_inputHeight), static_cast<int64_t>(m_inputWidth)
};
auto inputTensor = Ort::Value::CreateTensor<float>(
memoryInfo, input.data(), input.size(),
inputShape.data(), inputShape.size()
);
// 4. 执行推理
auto outputTensors = m_session->Run(
Ort::RunOptions{nullptr},
inputNames, &inputTensor, 1,
outputNames, 1
);
// 5. 解析输出(模型输出的类别概率分布)
auto outputData = outputTensors[0].GetTensorMutableData<float>();
auto outputShape = outputTensors[0].GetTensorTypeAndShapeInfo().GetShape();
int numClasses = static_cast<int>(outputShape[outputShape.size() - 1]);
std::vector<float> scores(outputData, outputData + numClasses);
// 6. 找出最高置信度的类别
int bestIdx = 0;
float bestScore = scores[0];
for (int i = 1; i < numClasses; ++i) {
if (scores[i] > bestScore) {
bestScore = scores[i];
bestIdx = i;
}
}
const QStringList &activeNames = m_classNames.isEmpty() ? m_modelClassNames : m_classNames;
QString className = bestIdx < activeNames.size()
? activeNames[bestIdx]
: QString("class_%1").arg(bestIdx);
result.className = className;
result.confidence = bestScore;
// 7. 收集所有类别的得分
for (int i = 0; i < numClasses; ++i) {
QString name = i < activeNames.size()
? activeNames[i]
: QString("class_%1").arg(i);
result.allScores.push_back({name, scores[i]});
}
return result;
} catch (const std::exception &e) {
qCritical() << "Classification failed:" << e.what();
return result;
}
}
std::vector<float> OnnxClassifier::preprocess(const QString &imagePath) {
cv::Mat bgr = ImageUtils::loadColorImage(imagePath);
if (bgr.empty() || bgr.cols <= 0 || bgr.rows <= 0) {
return {};
}
// Ultralytics 分类推理默认流程:
// 1) 将短边缩放到目标尺寸
// 2) 从中间裁出目标大小
const float scale = std::max(
static_cast<float>(m_inputWidth) / static_cast<float>(bgr.cols),
static_cast<float>(m_inputHeight) / static_cast<float>(bgr.rows)
);
// 对齐 torchvision Resize(int) 的主逻辑:长边采用向下取整,同时保证短边 >= 目标尺寸。
const int resizedWidth = std::max(
m_inputWidth,
static_cast<int>(std::floor(static_cast<float>(bgr.cols) * scale))
);
const int resizedHeight = std::max(
m_inputHeight,
static_cast<int>(std::floor(static_cast<float>(bgr.rows) * scale))
);
cv::Mat resized;
const bool isDownsample = resizedWidth < bgr.cols || resizedHeight < bgr.rows;
cv::resize(
bgr,
resized,
cv::Size(resizedWidth, resizedHeight),
0.0,
0.0,
isDownsample ? cv::INTER_AREA : cv::INTER_LINEAR
);
const int cropX = std::max(0, static_cast<int>(std::lround((resized.cols - m_inputWidth) / 2.0)));
const int cropY = std::max(0, static_cast<int>(std::lround((resized.rows - m_inputHeight) / 2.0)));
const cv::Rect roi(cropX, cropY, m_inputWidth, m_inputHeight);
cv::Mat cropped = resized(roi);
cv::Mat rgb;
cv::cvtColor(cropped, rgb, cv::COLOR_BGR2RGB);
cv::Mat normalized;
rgb.convertTo(normalized, CV_32FC3, 1.0 / 255.0);
std::vector<float> input(3 * m_inputWidth * m_inputHeight);
const int planeSize = m_inputWidth * m_inputHeight;
for (int y = 0; y < m_inputHeight; ++y) {
const auto *row = normalized.ptr<cv::Vec3f>(y);
for (int x = 0; x < m_inputWidth; ++x) {
const cv::Vec3f &pixel = row[x];
const int offset = y * m_inputWidth + x;
input[0 * planeSize + offset] = pixel[0];
input[1 * planeSize + offset] = pixel[1];
input[2 * planeSize + offset] = pixel[2];
}
}
return input;
}