nndeploy C++ API  0.2.0
nndeploy C++ API
drawlabel.h
Go to the documentation of this file.
1 #ifndef _NNDEPLOY_CLASSIFICATION_DRAWLABEL_H_
2 #define _NNDEPLOY_CLASSIFICATION_DRAWLABEL_H_
3 
6 #include "nndeploy/dag/node.h"
10 
11 namespace nndeploy {
12 namespace classification {
13 
14 class DrawLable : public dag::Node {
15  public:
16  DrawLable(const std::string &name) : Node(name) {
17  key_ = "nndeploy::classification::DrawLable";
18  desc_ = "Draw classification labels on input cv::Mat image based on classification results[cv::Mat->cv::Mat]";
19  this->setInputTypeInfo<cv::Mat>();
20  this->setInputTypeInfo<ClassificationResult>();
21  this->setOutputTypeInfo<cv::Mat>();
22  }
23  DrawLable(const std::string &name, std::vector<dag::Edge *> inputs,
24  std::vector<dag::Edge *> outputs)
25  : Node(name, inputs, outputs) {
26  key_ = "nndeploy::classification::DrawLable";
27  desc_ = "Draw classification labels on input cv::Mat image based on classification results[cv::Mat->cv::Mat]";
28  this->setInputTypeInfo<cv::Mat>();
29  this->setInputTypeInfo<ClassificationResult>();
30  this->setOutputTypeInfo<cv::Mat>();
31  }
32  virtual ~DrawLable() {}
33 
34  virtual base::Status run() {
35  cv::Mat *input_mat = inputs_[0]->getCvMat(this);
38  // 遍历每个分类结果
39  cv::Mat *output_mat = new cv::Mat();
40  input_mat->copyTo(*output_mat);
41  for (int i = 0; i < result->labels_.size(); i++) {
42  auto label = result->labels_[i];
43 
44  // 将分类结果和置信度转为字符串
45  std::string text = "class: " + std::to_string(label.label_ids_) +
46  " score: " + std::to_string(label.scores_);
47 
48  // 在图像左上角绘制文本
49  // 计算文本大小以确保不会被截断,使用更大的字体
50  int baseline = 0;
51  cv::Size text_size = cv::getTextSize(text, cv::FONT_HERSHEY_SIMPLEX, 2.0, 4, &baseline);
52 
53  // 确保文本在图像顶部有足够的边距,避免被截断
54  int y_position = std::max(text_size.height + 10, 50 + i * (text_size.height + 10));
55 
56  // 添加文本背景矩形,提高可读性
57  cv::Point text_origin(30, y_position);
58  cv::Rect background_rect(text_origin.x - 5, text_origin.y - text_size.height - 5,
59  text_size.width + 10, text_size.height + baseline + 10);
60  cv::rectangle(*output_mat, background_rect, cv::Scalar(0, 0, 0), -1);
61 
62  // 绘制文本,使用更大更粗的字体以提高可见性
63  cv::putText(*output_mat, text, text_origin,
64  cv::FONT_HERSHEY_SIMPLEX, 2.0, cv::Scalar(0, 255, 0), 4);
65  }
66  // cv::imwrite("draw_label_node.jpg", *input_mat);
67  outputs_[0]->set(output_mat, false);
68  return base::kStatusCodeOk;
69  }
70 };
71 
72 } // namespace detect
73 } // namespace nndeploy
74 
75 #endif
std::vector< ClassificationLableResult > labels_
Definition: result.h:29
DrawLable(const std::string &name, std::vector< dag::Edge * > inputs, std::vector< dag::Edge * > outputs)
Definition: drawlabel.h:23
virtual base::Status run()
Run node (pure virtual function)
Definition: drawlabel.h:34
DrawLable(const std::string &name)
Definition: drawlabel.h:16
Node base class.
Definition: node.h:171
std::string desc_
Node description.
Definition: node.h:1294
virtual base::Param * getParam()
Get parameter.
std::vector< Edge * > outputs_
Output edge list.
Definition: node.h:1318
Node(const std::string &name)
Constructor.
std::string key_
Node key.
Definition: node.h:1290
std::vector< Edge * > inputs_
Input edge list.
Definition: node.h:1317
@ kStatusCodeOk
Definition: status.h:13