一、 环境准备与依赖配置
在 Windows 下,最省心的方式是通过 Visual Studio 的 NuGet 包管理器 来安装带有 DirectML 支持的 ONNX Runtime。
1. 过修改 NuGet.Config 设置代理(没有梯子会很卡,甚至失败)
1.1 找到配置文件
NuGet 的全局配置文件一般存放在以下路径:
C:\Users\你的用户名\AppData\Roaming\NuGet\NuGet.Config
<?xml version="1.0" encoding="utf-8"?> <configuration> <packageSources> <add key="nuget.org" value="https://api.nuget.org/v3/index.json" protocolVersion="3" /> </packageSources> <config> <add key="http_proxy" value="http://127.0.0.1:10808" /> <add key="https_proxy" value="http://127.0.0.1:10808" /> </config> </configuration>
重启 Visual Studio
保存文件后,彻底关闭并重新打开 Visual Studio。再次进入 NuGet 包管理器,搜索 Microsoft.ML.OnnxRuntime.DirectML,此时流量就会走你的代理软件,搜索和下载就能正常进行了
2. 安装 NuGet 包
打开你的 Visual Studio C++ 项目,进入 工具 -> NuGet 包管理器 -> 管理解决方案的 NuGet 程序包,搜索并安装以下包:
Microsoft.ML.OnnxRuntime.DirectML (核心:包含 ONNX Runtime 引擎和 DirectML 推理提供程序)。
注意:不要安装通用的 Microsoft.ML.OnnxRuntime(只有 CPU 版本的 C++ 接口可能没有集成 DML)。安装 DirectML 专用版后,它会自动下载 onnxruntime.dll 和 DirectML.dll 及其头文件。
安装完成后,项目根目录下会生成一个 packages 文件夹,里面包含了 onnxruntime.dll、DirectML.dll 和相关的头文件。
demo
#include <iostream>
#include <vector>
#include <string>
#include <chrono>
#include <thread>
#include <atomic>
#include <fstream>
#include <chrono>
#include <thread>
#include <filesystem>
#include <opencv2/opencv.hpp>
#include <onnxruntime_cxx_api.h>
#include <dml_provider_factory.h>
#include "ThreadSafeQueue.hpp" // 引入刚才创建的头文件
// 保持 RTDETR 类定义结构,包装核心算法
class RTDETRDirectMLDetector {
private:
Ort::Env env;
Ort::SessionOptions session_options;
std::unique_ptr<Ort::Session> session;
float conf_threshold;
int input_size = 640;
std::string input_name;
std::vector<std::string> output_names;
std::vector<const char*> input_node_names;
std::vector<const char*> output_node_names;
std::vector<std::string> class_names;
public:
struct DetectionResult {
cv::Rect bbox;
float confidence;
int class_id;
std::string class_name;
};
RTDETRDirectMLDetector(const std::wstring& onnx_path, float conf_thresh = 0.35f)
: env(ORT_LOGGING_LEVEL_WARNING, "RT-DETR_Live"), conf_threshold(conf_thresh)
{
class_names = {
"person", "bicycle", "car", "motorcycle", "airplane", "bus", "train", "truck",
"boat", "traffic light", "fire hydrant", "stop sign", "parking meter", "bench",
"bird", "cat", "dog", "horse", "sheep", "cow", "elephant", "bear", "zebra",
"giraffe", "backpack", "umbrella", "handbag", "tie", "suitcase", "frisbee",
"skis", "snowboard", "sports ball", "kite", "baseball bat", "baseball glove",
"skateboard", "surfboard", "tennis racket", "bottle", "wine glass", "cup",
"fork", "knife", "spoon", "bowl", "banana", "apple", "sandwich", "orange",
"broccoli", "carrot", "hot dog", "pizza", "donut", "cake", "chair", "couch",
"potted plant", "bed", "dining table", "toilet", "tv", "laptop", "mouse",
"remote", "keyboard", "cell phone", "microwave", "oven", "toaster", "sink",
"refrigerator", "book", "clock", "vase", "scissors", "teddy bear", "hair drier",
"toothbrush"
};
session_options.SetGraphOptimizationLevel(GraphOptimizationLevel::ORT_ENABLE_ALL);
// 绑定 DirectML
int device_id = 0;
auto statusPtr = OrtSessionOptionsAppendExecutionProvider_DML(session_options, device_id);
session = std::make_unique<Ort::Session>(env, onnx_path.c_str(), session_options);
Ort::AllocatorWithDefaultOptions allocator;
auto in_name_alloc = session->GetInputNameAllocated(0, allocator);
input_name = std::string(in_name_alloc.get());
input_node_names.push_back(input_name.c_str());
size_t num_outputs = session->GetOutputCount();
for (size_t i = 0; i < num_outputs; ++i) {
auto out_name_alloc = session->GetOutputNameAllocated(i, allocator);
output_names.push_back(std::string(out_name_alloc.get()));
}
for (const auto& name : output_names) {
output_node_names.push_back(name.c_str());
}
}
std::vector<float> preprocess(const cv::Mat& img) {
cv::Mat resized, rgb;
cv::resize(img, resized, cv::Size(input_size, input_size));
cv::cvtColor(resized, rgb, cv::COLOR_BGR2RGB);
std::vector<float> input_tensor_values(1 * 3 * input_size * input_size);
int channel_size = input_size * input_size;
for (int c = 0; c < 3; ++c) {
for (int h = 0; h < input_size; ++h) {
for (int w = 0; w < input_size; ++w) {
input_tensor_values[c * channel_size + h * input_size + w] = rgb.at<cv::Vec3b>(h, w)[c] / 255.0f;
}
}
}
return input_tensor_values;
}
std::vector<DetectionResult> postprocess(const Ort::Value& output_tensor, int original_w, int original_h) {
const float* raw_output = output_tensor.GetTensorData<float>();
auto shape = output_tensor.GetTensorTypeAndShapeInfo().GetShape();
int num_dets = static_cast<int>(shape[1]);
int num_elements = static_cast<int>(shape[2]);
int num_classes = num_elements - 4;
std::vector<DetectionResult> results;
for (int i = 0; i < num_dets; ++i) {
const float* det = raw_output + (i * num_elements);
float cx = det[0], cy = det[1], w = det[2], h = det[3];
const float* class_scores = det + 4;
auto max_elem = std::max_element(class_scores, class_scores + num_classes);
float confidence = *max_elem;
int class_id = static_cast<int>(std::distance(class_scores, max_elem));
if (confidence < conf_threshold) continue;
std::string class_name = (class_id < class_names.size()) ? class_names[class_id] : "unknown";
if (class_name != "person") continue; // 业务需求:只要人
int x1 = static_cast<int>((cx - w / 2.0f) * original_w);
int y1 = static_cast<int>((cy - h / 2.0f) * original_h);
int x2 = static_cast<int>((cx + w / 2.0f) * original_w);
int y2 = static_cast<int>((cy + h / 2.0f) * original_h);
DetectionResult res{ cv::Rect(cv::Point(x1, y1), cv::Point(x2, y2)), confidence, class_id, class_name };
results.push_back(res);
}
return results;
}
std::vector<DetectionResult> detect(const cv::Mat& img) {
std::vector<float> input_tensor_values = preprocess(img);
std::vector<int64_t> input_shape = { 1, 3, input_size, input_size };
auto memory_info = Ort::MemoryInfo::CreateCpu(OrtArenaAllocator, OrtMemTypeDefault);
Ort::Value input_tensor = Ort::Value::CreateTensor<float>(
memory_info, input_tensor_values.data(), input_tensor_values.size(), input_shape.data(), input_shape.size()
);
auto outputs = session->Run(Ort::RunOptions{ nullptr }, input_node_names.data(), &input_tensor, 1, output_node_names.data(), output_node_names.size());
return postprocess(outputs[0], img.cols, img.rows);
}
};
// ==========================================
// 全局多线程控制变量
// ==========================================
std::atomic<bool> g_running(true);
ThreadSafeQueue<cv::Mat> g_capture_queue(3); // 限制缓冲大小,防止产生累积延迟
struct RenderFrame {
cv::Mat frame;
std::vector<RTDETRDirectMLDetector::DetectionResult> results;
double fps;
};
ThreadSafeQueue<RenderFrame> g_render_queue(3);
// 1. 采集线程:只管以最快速度抓取 RTMP 帧
void capture_thread_func(const std::string& rtmp_url) {
cv::VideoCapture cap(rtmp_url, rtmp_url == "0" ? cv::CAP_DSHOW : cv::CAP_FFMPEG);
// 针对网络直播流的重大优化参数
cap.set(cv::CAP_PROP_BUFFERSIZE, 1);
if (!cap.isOpened()) {
std::cerr << " 错误: 无法打开 RTMP 视频流: " << rtmp_url << std::endl;
g_running = false;
return;
}
std::cout << " 成功连接 RTMP 视频流,开始采集..." << std::endl;
cv::Mat tmp_frame;
while (g_running) {
if (!cap.read(tmp_frame) || tmp_frame.empty()) {
std::cerr << "警告: 视频流断开或读取空帧" << std::endl;
std::this_thread::sleep_for(std::chrono::milliseconds(10));
continue;
}
g_capture_queue.push(tmp_frame.clone()); // 深度拷贝,安全送入队列
}
cap.release();
}
// 2. 识别/推理线程:全力运行 GPU 进行检测
void inference_thread_func(const std::wstring& model_path) {
std::cout << " 初始化推理线程与 DirectML 引擎..." << std::endl;
RTDETRDirectMLDetector detector(model_path, 0.45f);
// 预热一帧,激发 DML 算子编译
cv::Mat dummy(640, 640, CV_8UC3, cv::Scalar(0));
detector.detect(dummy);
std::cout << " DirectML 引擎预热完毕,推理线程就绪。" << std::endl;
cv::Mat local_frame;
auto last_time = std::chrono::high_resolution_clock::now();
while (g_running) {
if (!g_capture_queue.pop(local_frame)) break;
// 执行 GPU 加速推理
auto results = detector.detect(local_frame);
// 计算当前推理实际的 FPS
auto now = std::chrono::high_resolution_clock::now();
double duration = std::chrono::duration<double>(now - last_time).count();
last_time = now;
double current_fps = 1.0 / (duration > 0 ? duration : 0.01);
// 将结果包装,打包送给渲染队列
RenderFrame r_frame{ local_frame, results, current_fps };
g_render_queue.push(r_frame);
}
}
// ==========================================
// 主线程/渲染线程:负责绘制和最终的 UI 刷新
// ==========================================
int main() {
std::string line;
// 读取url.txt
std::ifstream url_file("url.txt");
if (url_file.is_open()) {
std::getline(url_file, line);
url_file.close();
}
std::string RTMP_URL = line.empty() ? "rtsp://username:password@192.168.10.2:554/stream?rtsp_transport=tcp" : line; // 替换为你的 RTMP 直播流地址
std::wstring MODEL_PATH = L"rtdetr-l.onnx"; // 替换为你的模型路径
// 启动多线程工作流
std::thread capture_thread(capture_thread_func, RTMP_URL);
std::thread inference_thread(inference_thread_func, MODEL_PATH);
cv::namedWindow("DirectML + RT-DETR 多线程高实时检测", cv::WINDOW_NORMAL);
RenderFrame display_data;
while (g_running) {
// 从结果队列拿到已经识别完的帧和结果
if (!g_render_queue.pop(display_data)) break;
cv::Mat canvas = display_data.frame;
int person_count = 0;
// 计算单帧推理耗时(毫秒)
double latency_ms = 1000.0 / display_data.fps;
// 遍历并绘制检测到的人
for (const auto& res : display_data.results) {
if (res.class_name == "person") person_count++;
// 加大粗细:边界框线宽改为 3 (原来是 2)
cv::rectangle(canvas, res.bbox, cv::Scalar(0, 0, 255), 3);
// 加大字体:格式化标签文本
char label_buf[64];
sprintf_s(label_buf, "%s: %.2f", res.class_name.c_str(), res.confidence);
std::string label(label_buf);
int base_line;
// 加大字体:字号(fontScale)改为 1.5,线宽改为 2
cv::Size text_size = cv::getTextSize(label, cv::FONT_HERSHEY_SIMPLEX, 1.5, 2, &base_line);
// 绘制标签背景实心矩形
cv::rectangle(canvas,
cv::Point(res.bbox.x, res.bbox.y - text_size.height - 8),
cv::Point(res.bbox.x + text_size.width, res.bbox.y),
cv::Scalar(0, 0, 255),
cv::FILLED);
// 绘制标签文字
cv::putText(canvas, label,
cv::Point(res.bbox.x, res.bbox.y - 4),
cv::FONT_HERSHEY_SIMPLEX, 1.5, cv::Scalar(255, 255, 255), 2);
}
// 左上角状态栏绘制 (多行大字 OSD)
// 提升视觉可读性:字号改为 1.0 (大字),线宽改为 3,使用显眼的绿色/黄色
float font_scale = 1.0;
std::string fps_str = "FPS: " + std::to_string(display_data.fps).substr(0, 5);
std::string latency_str = "Latency: " + std::to_string(latency_ms).substr(0, 5) + " ms";
std::string count_str = "Count: " + std::to_string(person_count);
// 逐行绘制到左上角,每行留出 40 像素的间隔
cv::putText(canvas, fps_str, cv::Point(20, 40), cv::FONT_HERSHEY_SIMPLEX, font_scale, cv::Scalar(0, 255, 0), 3);
cv::putText(canvas, latency_str, cv::Point(20, 80), cv::FONT_HERSHEY_SIMPLEX, font_scale, cv::Scalar(0, 255, 255), 3);
cv::putText(canvas, count_str, cv::Point(20, 120), cv::FONT_HERSHEY_SIMPLEX, font_scale, cv::Scalar(0, 255, 0), 3);
// 3. 展现画布
cv::imshow("DirectML + RT-DETR 多线程高实时检测", canvas);
// 检测退出事件(按 ESC 键退出)
if (cv::waitKey(1) == 27) {
std::cout << "接收到退出指令,正在释放线程..." << std::endl;
g_running = false;
break;
}
}
// 善后:停止并回收所有子线程
g_capture_queue.stop();
g_render_queue.stop();
if (capture_thread.joinable()) capture_thread.join();
if (inference_thread.joinable()) inference_thread.join();
cv::destroyAllWindows();
std::cout << "程序安全退出。" << std::endl;
return 0;
}下载模型并导出onnx格式
from ultralytics import RTDETR
model = RTDETR("rtdetr-l.pt")
# 导出 ONNX(关键参数)
model.export(
format="onnx",
imgsz=640,
opset=16,
nms=False, # 不包含 NMS,在 C++ 中手动处理
simplify=True,
dynamic=False
) 收藏的用户(0) X
正在加载信息~
推荐阅读
最新回复 (0)
站点信息
- 文章2324
- 用户1338
- 访客12461775
每日一句
Wheat awns pierce the dawn. Farmers bend to pick up gold.
麦芒刺破晨光,农人弯腰拾起金黄。
麦芒刺破晨光,农人弯腰拾起金黄。