nndeploy C++ API  0.2.0
nndeploy C++ API
detector.h
Go to the documentation of this file.
1 #ifndef _NNDEPLOY_DETECT_DETECTER_OCR_OCR_H_
2 #define _NNDEPLOY_DETECT_DETECTER_OCR_OCR_H_
3 
4 #include "nndeploy/base/any.h"
5 #include "nndeploy/base/common.h"
7 #include "nndeploy/base/log.h"
8 #include "nndeploy/base/macro.h"
9 #include "nndeploy/base/object.h"
11 #include "nndeploy/base/status.h"
12 #include "nndeploy/base/string.h"
13 #include "nndeploy/dag/edge.h"
14 #include "nndeploy/dag/graph.h"
15 #include "nndeploy/dag/node.h"
16 #include "nndeploy/detect/result.h"
17 #include "nndeploy/device/buffer.h"
18 #include "nndeploy/device/device.h"
20 #include "nndeploy/device/tensor.h"
21 #include "nndeploy/infer/infer.h"
23 #include "nndeploy/ocr/result.h"
26 
27 namespace nndeploy {
28 namespace ocr {
29 
31  public:
32  int version_ = -1;
33  // using base::Param::serialize;
34  // virtual base::Status serialize(rapidjson::Value &json,
35  // rapidjson::Document::AllocatorType &allocator);
36  // using base::Param::deserialize;
37  // virtual base::Status deserialize(rapidjson::Value &json);
38 };
39 
41  public:
47  int h_ = -1;
48  int w_ = -1;
49  int max_side_len_ = 960;
50  bool normalize_ = true;
51  float scale_[3] = {1.0f / 255.0f, 1.0f / 255.0f, 1.0f / 255.0f};
52  float mean_[3] = {0.485f, 0.456f, 0.406f};
53  float std_[3] = {0.229f, 0.224f, 0.225f};
54 
56  int top_ = 0;
57  int bottom_ = 0;
58  int left_ = 0;
59  int right_ = 0;
60  base::Scalar2d border_val_ = 0.0;
61 
64  rapidjson::Value &json,
65  rapidjson::Document::AllocatorType &allocator) override {
66  std::string src_pixel_type_str = base::pixelTypeToString(src_pixel_type_);
67  json.AddMember("src_pixel_type_",
68  rapidjson::Value(src_pixel_type_str.c_str(), allocator),
69  allocator);
70  std::string dst_pixel_type_str = base::pixelTypeToString(dst_pixel_type_);
71  json.AddMember("dst_pixel_type_",
72  rapidjson::Value(dst_pixel_type_str.c_str(), allocator),
73  allocator);
74  std::string interp_type_str = base::interpTypeToString(interp_type_);
75  json.AddMember("interp_type_",
76  rapidjson::Value(interp_type_str.c_str(), allocator),
77  allocator);
78  std::string data_type_str = base::dataTypeToString(data_type_);
79  json.AddMember("data_type_",
80  rapidjson::Value(data_type_str.c_str(), allocator),
81  allocator);
82  std::string data_format_str = base::dataFormatToString(data_format_);
83  json.AddMember("data_format_",
84  rapidjson::Value(data_format_str.c_str(), allocator),
85  allocator);
86  json.AddMember("h_", h_, allocator);
87  json.AddMember("w_", w_, allocator);
88 
89  json.AddMember("max_side_len_", max_side_len_, allocator);
90  json.AddMember("normalize_", normalize_, allocator);
91 
92  rapidjson::Value scale_array(rapidjson::kArrayType);
93  rapidjson::Value mean_array(rapidjson::kArrayType);
94  rapidjson::Value std_array(rapidjson::kArrayType);
95  for (int i = 0; i < 3; i++) {
96  scale_array.PushBack(scale_[i], allocator);
97  mean_array.PushBack(mean_[i], allocator);
98  std_array.PushBack(std_[i], allocator);
99  }
100  json.AddMember("scale_", scale_array, allocator);
101  json.AddMember("mean_", mean_array, allocator);
102  json.AddMember("std_", std_array, allocator);
103 
104  std::string border_type_str = base::borderTypeToString(border_type_);
105  json.AddMember("border_type_",
106  rapidjson::Value(border_type_str.c_str(), allocator),
107  allocator);
108  json.AddMember("top_", top_, allocator);
109  json.AddMember("bottom_", bottom_, allocator);
110  json.AddMember("left_", left_, allocator);
111  json.AddMember("right_", right_, allocator);
112 
113  rapidjson::Value border_val_array(rapidjson::kArrayType);
114  for (int i = 0; i < 4; i++) {
115  border_val_array.PushBack(border_val_.val_[i], allocator);
116  }
117  json.AddMember("border_val_", border_val_array, allocator);
118 
119  return base::kStatusCodeOk;
120  }
121 
123  virtual base::Status deserialize(rapidjson::Value &json) override {
124  if (json.HasMember("src_pixel_type_") &&
125  json["src_pixel_type_"].IsString()) {
126  src_pixel_type_ =
127  base::stringToPixelType(json["src_pixel_type_"].GetString());
128  }
129  if (json.HasMember("dst_pixel_type_") &&
130  json["dst_pixel_type_"].IsString()) {
131  dst_pixel_type_ =
132  base::stringToPixelType(json["dst_pixel_type_"].GetString());
133  }
134  if (json.HasMember("interp_type_") && json["interp_type_"].IsString()) {
135  interp_type_ = base::stringToInterpType(json["interp_type_"].GetString());
136  }
137  if (json.HasMember("data_type_") && json["data_type_"].IsString()) {
138  data_type_ = base::stringToDataType(json["data_type_"].GetString());
139  }
140  if (json.HasMember("data_format_") && json["data_format_"].IsString()) {
141  data_format_ = base::stringToDataFormat(json["data_format_"].GetString());
142  }
143  if (json.HasMember("h_") && json["h_"].IsInt()) {
144  h_ = json["h_"].GetInt();
145  }
146  if (json.HasMember("w_") && json["w_"].IsInt()) {
147  w_ = json["w_"].GetInt();
148  }
149 
150  if (json.HasMember("max_side_len_") && json["max_side_len_"].IsInt()) {
151  max_side_len_ = json["max_side_len_"].GetInt();
152  }
153  if (json.HasMember("normalize_") && json["normalize_"].IsBool()) {
154  normalize_ = json["normalize_"].GetBool();
155  }
156 
157  if (json.HasMember("scale_") && json["scale_"].IsArray()) {
158  const rapidjson::Value &scale_array = json["scale_"];
159  for (int i = 0; i < 3 && i < scale_array.Size(); i++) {
160  if (scale_array[i].IsFloat()) {
161  scale_[i] = scale_array[i].GetFloat();
162  }
163  }
164  }
165  if (json.HasMember("mean_") && json["mean_"].IsArray()) {
166  const rapidjson::Value &mean_array = json["mean_"];
167  for (int i = 0; i < 3 && i < mean_array.Size(); i++) {
168  if (mean_array[i].IsFloat()) {
169  mean_[i] = mean_array[i].GetFloat();
170  }
171  }
172  }
173  if (json.HasMember("std_") && json["std_"].IsArray()) {
174  const rapidjson::Value &std_array = json["std_"];
175  for (int i = 0; i < 3 && i < std_array.Size(); i++) {
176  if (std_array[i].IsFloat()) {
177  std_[i] = std_array[i].GetFloat();
178  }
179  }
180  }
181 
182  if (json.HasMember("border_type_") && json["border_type_"].IsString()) {
183  border_type_ = base::stringToBorderType(json["border_type_"].GetString());
184  }
185  if (json.HasMember("top_") && json["top_"].IsInt()) {
186  top_ = json["top_"].GetInt();
187  }
188  if (json.HasMember("bottom_") && json["bottom_"].IsInt()) {
189  bottom_ = json["bottom_"].GetInt();
190  }
191  if (json.HasMember("left_") && json["left_"].IsInt()) {
192  left_ = json["left_"].GetInt();
193  }
194  if (json.HasMember("right_") && json["right_"].IsInt()) {
195  right_ = json["right_"].GetInt();
196  }
197 
198  if (json.HasMember("border_val_") && json["border_val_"].IsArray()) {
199  const rapidjson::Value &border_val_array = json["border_val_"];
200  for (int i = 0; i < 4 && i < border_val_array.Size(); i++) {
201  if (border_val_array[i].IsFloat()) {
202  border_val_.val_[i] = border_val_array[i].GetFloat();
203  }
204  }
205  }
206 
207  return base::kStatusCodeOk;
208  }
209 };
210 
212  public:
213  DetectorPreProcess(const std::string &name) : dag::Node(name) {
214  key_ = "nndeploy::ocr::DetectorPreProcess";
215  desc_ =
216  "ocr detectorpreprocess cv::Mat to "
217  "device::Tensor[resize->pad->normalize->transpose]";
218  param_ = std::make_shared<DetectorPreProcessParam>();
219  this->setInputTypeInfo<cv::Mat>();
220  this->setOutputTypeInfo<device::Tensor>();
221  }
222  DetectorPreProcess(const std::string &name, std::vector<dag::Edge *> inputs,
223  std::vector<dag::Edge *> outputs)
224  : dag::Node(name, inputs, outputs) {
225  key_ = "nndeploy::ocr::DetectorPreProcess";
226  desc_ =
227  "ocr detectorpreprocess cv::Mat to "
228  "device::Tensor[resize->pad->normalize->transpose]";
229  param_ = std::make_shared<DetectorPreProcessParam>();
230  this->setInputTypeInfo<cv::Mat>();
231  this->setOutputTypeInfo<device::Tensor>();
232  }
233  virtual ~DetectorPreProcess() {}
234 
235  virtual base::Status run();
236 };
237 
239  public:
240  int version_ = 3;
241  double det_db_thresh_ = 0.3;
242  double det_db_box_thresh_ = 0.6;
243  double det_db_unclip_ratio_ = 1.5;
244  std::string det_db_score_mode_ = "slow";
245  bool use_dilation_ = false;
246 
248  virtual base::Status serialize(rapidjson::Value &json,
249  rapidjson::Document::AllocatorType &allocator);
251  virtual base::Status deserialize(rapidjson::Value &json);
252 };
253 
255  public:
256  DetectorPostProcess(const std::string &name) : dag::Node(name) {
257  key_ = "nndeploy::ocr::DetectorPostProcess";
258  desc_ = "PPOcrDetv3/v4/v5 postprocess[device::Tensor->OcrResult]";
259  param_ = std::make_shared<DetectorPostParam>();
260  this->setInputTypeInfo<device::Tensor>();
261  this->setOutputTypeInfo<OCRResult>();
262  }
263  DetectorPostProcess(const std::string &name, std::vector<dag::Edge *> inputs,
264  std::vector<dag::Edge *> outputs)
265  : dag::Node(name, inputs, outputs) {
266  key_ = "nndeploy::ocr::DetectorPostProcess";
267  desc_ = "PPOcrDetv3/v4/v5 postprocess[device::Tensor->OcrResult]";
268  param_ = std::make_shared<DetectorPostParam>();
269  this->setInputTypeInfo<device::Tensor>();
270  this->setOutputTypeInfo<OCRResult>();
271  }
272  virtual ~DetectorPostProcess() {}
274  virtual base::Status run();
275 };
276 
277 // class NNDEPLOY_CC_API DetectBBoxResult : public base::Param {
278 // public:
279 // DetectBBoxResult(){};
280 // virtual ~DetectBBoxResult() {
281 // if (mask_ != nullptr) {
282 // delete mask_;
283 // mask_ = nullptr;
284 // }
285 // };
286 // int index_;
287 // int label_id_;
288 // float score_;
289 // std::array<float, 4> bbox_; // xmin, ymin, xmax, ymax
290 // device::Tensor *mask_ = nullptr;
291 // };
292 
293 // class NNDEPLOY_CC_API DetectResult : public base::Param {
294 // public:
295 // DetectResult(){};
296 // virtual ~DetectResult(){};
297 // std::vector<DetectBBoxResult> bboxs_;
298 // };
299 
301  public:
302  DetectorGraph(const std::string &name) : dag::Graph(name) {
303  key_ = "nndeploy::ocr::DetectorGraph";
304  desc_ =
305  "PPOcrDetv3/v4/v5 "
306  "graph[cv::Mat->preprocess->infer->postprocess->OcrResult]";
307  this->setInputTypeInfo<cv::Mat>();
308  this->setOutputTypeInfo<OCRResult>();
309  pre_ = dynamic_cast<DetectorPreProcess *>(
310  this->createNode<DetectorPreProcess>("preprocess"));
311  infer_ =
312  dynamic_cast<infer::Infer *>(this->createNode<infer::Infer>("infer"));
313  post_ = dynamic_cast<DetectorPostProcess *>(
314  this->createNode<DetectorPostProcess>("postprocess"));
315  }
316 
317  DetectorGraph(const std::string &name, std::vector<dag::Edge *> inputs,
318  std::vector<dag::Edge *> outputs)
319  : dag::Graph(name, inputs, outputs) {
320  key_ = "nndeploy::ocr::DetectorGraph";
321  desc_ =
322  "PPOcrDetv3/v4/v5 "
323  "graph[cv::Mat->preprocess->infer->postprocess->OcrResult]";
324  this->setInputTypeInfo<cv::Mat>();
325  this->setOutputTypeInfo<OCRResult>();
326  pre_ = dynamic_cast<DetectorPreProcess *>(
327  this->createNode<DetectorPreProcess>("preprocess"));
328  infer_ =
329  dynamic_cast<infer::Infer *>(this->createNode<infer::Infer>("infer"));
330  post_ = dynamic_cast<DetectorPostProcess *>(
331  this->createNode<DetectorPostProcess>("postprocess"));
332  }
333 
334  virtual ~DetectorGraph() {}
335 
336  base::Status make(const dag::NodeDesc &pre_desc,
337  const dag::NodeDesc &infer_desc,
338  base::InferenceType inference_type,
339  const dag::NodeDesc &post_desc) {
340  this->setNodeDesc(pre_, pre_desc);
341  this->setNodeDesc(infer_, infer_desc);
342  this->setNodeDesc(post_, post_desc);
343  this->defaultParam();
344  base::Status status = infer_->setInferenceType(inference_type);
345  if (status != base::kStatusCodeOk) {
346  NNDEPLOY_LOGE("Failed to set inference type");
347  return status;
348  }
349  return base::kStatusCodeOk;
350  }
351 
353  DetectorPreProcessParam *pre_param =
354  dynamic_cast<DetectorPreProcessParam *>(pre_->getParam());
358 
359  DetectorPostParam *post_param =
360  dynamic_cast<DetectorPostParam *>(post_->getParam());
361  post_param->det_db_thresh_ = 0.3;
362  post_param->det_db_box_thresh_ = 0.6;
363  post_param->det_db_unclip_ratio_ = 1.5;
364  post_param->det_db_score_mode_ = "slow";
365  post_param->use_dilation_ = false;
366  post_param->version_ = 5;
367 
368  return base::kStatusCodeOk;
369  }
370 
372  base::Status status = infer_->setInferenceType(inference_type);
373  if (status != base::kStatusCodeOk) {
374  NNDEPLOY_LOGE("Failed to set inference type");
375  return status;
376  }
377  return base::kStatusCodeOk;
378  }
380  base::ModelType model_type, bool is_path,
381  std::vector<std::string> &model_value) {
382  auto param = dynamic_cast<inference::InferenceParam *>(infer_->getParam());
383  param->device_type_ = device_type;
384  param->model_type_ = model_type;
385  param->is_path_ = is_path;
386  param->model_value_ = model_value;
387  return base::kStatusCodeOk;
388  }
389 
391  DetectorPreProcessParam *param =
392  dynamic_cast<DetectorPreProcessParam *>(pre_->getParam());
393  param->src_pixel_type_ = pixel_type;
394  return base::kStatusCodeOk;
395  }
396 
397  base::Status setDbThresh(float threshold) {
398  DetectorPostParam *param =
399  dynamic_cast<DetectorPostParam *>(post_->getParam());
400  param->det_db_thresh_ = threshold;
401  return base::kStatusCodeOk;
402  }
403 
404  base::Status setDbBoxThresh(float threshold) {
405  DetectorPostParam *param =
406  dynamic_cast<DetectorPostParam *>(post_->getParam());
407  param->det_db_box_thresh_ = threshold;
408  return base::kStatusCodeOk;
409  }
410 
412  DetectorPostParam *param =
413  dynamic_cast<DetectorPostParam *>(post_->getParam());
414  param->det_db_unclip_ratio_ = ratio;
415  return base::kStatusCodeOk;
416  }
417 
418  base::Status setDbScoreMode(const std::string &mode) {
419  DetectorPostParam *param =
420  dynamic_cast<DetectorPostParam *>(post_->getParam());
421  param->det_db_score_mode_ = mode;
422  return base::kStatusCodeOk;
423  }
424 
426  DetectorPostParam *param =
427  dynamic_cast<DetectorPostParam *>(post_->getParam());
428  param->use_dilation_ = value;
429  return base::kStatusCodeOk;
430  }
431 
432  base::Status setVersion(int version) {
433  DetectorPostParam *param =
434  dynamic_cast<DetectorPostParam *>(post_->getParam());
435  param->version_ = version;
436  return base::kStatusCodeOk;
437  }
438 
439  std::vector<dag::Edge *> forward(std::vector<dag::Edge *> inputs) {
440  std::vector<dag::Edge *> pre_outputs = (*pre_)(inputs);
441  std::vector<dag::Edge *> infer_outputs = (*infer_)(pre_outputs);
442  std::vector<dag::Edge *> post_outputs = (*post_)(infer_outputs);
443  return post_outputs;
444  }
445 
446  private:
447  dag::Node *pre_ = nullptr;
448  infer::Infer *infer_ = nullptr;
449  dag::Node *post_ = nullptr;
450 };
451 } // namespace ocr
452 } // namespace nndeploy
453 
454 #endif
virtual base::Status deserialize(rapidjson::Value &json)
virtual std::string serialize()
Template class for a 4-element vector. Scalar_ and Scalar can be used just as typical 4-element vecto...
Definition: type.h:421
Directed Acyclic Graph Node.
Definition: graph.h:31
Node description class.
Definition: node.h:35
Node base class.
Definition: node.h:171
InferenceParam is the base class of all inference param.
base::Status setVersion(int version)
Definition: detector.h:432
DetectorGraph(const std::string &name, std::vector< dag::Edge * > inputs, std::vector< dag::Edge * > outputs)
Definition: detector.h:317
std::vector< dag::Edge * > forward(std::vector< dag::Edge * > inputs)
Definition: detector.h:439
base::Status setDbUnclipRatio(float ratio)
Definition: detector.h:411
DetectorGraph(const std::string &name)
Definition: detector.h:302
base::Status make(const dag::NodeDesc &pre_desc, const dag::NodeDesc &infer_desc, base::InferenceType inference_type, const dag::NodeDesc &post_desc)
Definition: detector.h:336
base::Status setDbScoreMode(const std::string &mode)
Definition: detector.h:418
base::Status setDbUseDilation(bool value)
Definition: detector.h:425
virtual base::Status defaultParam()
Set default parameters.
Definition: detector.h:352
base::Status setSrcPixelType(base::PixelType pixel_type)
Definition: detector.h:390
base::Status setInferenceType(base::InferenceType inference_type)
Definition: detector.h:371
base::Status setDbBoxThresh(float threshold)
Definition: detector.h:404
base::Status setDbThresh(float threshold)
Definition: detector.h:397
base::Status setInferParam(base::DeviceType device_type, base::ModelType model_type, bool is_path, std::vector< std::string > &model_value)
Definition: detector.h:379
virtual base::Status deserialize(rapidjson::Value &json)
virtual base::Status serialize(rapidjson::Value &json, rapidjson::Document::AllocatorType &allocator)
virtual base::Status run()
Run node (pure virtual function)
DetectorPostProcess(const std::string &name, std::vector< dag::Edge * > inputs, std::vector< dag::Edge * > outputs)
Definition: detector.h:263
DetectorPostProcess(const std::string &name)
Definition: detector.h:256
virtual base::Status serialize(rapidjson::Value &json, rapidjson::Document::AllocatorType &allocator) override
Definition: detector.h:63
virtual base::Status deserialize(rapidjson::Value &json) override
Definition: detector.h:123
DetectorPreProcess(const std::string &name)
Definition: detector.h:213
virtual base::Status run()
Run node (pure virtual function)
DetectorPreProcess(const std::string &name, std::vector< dag::Edge * > inputs, std::vector< dag::Edge * > outputs)
Definition: detector.h:222
#define NNDEPLOY_LOGE(fmt,...)
Definition: log.h:59
#define NNDEPLOY_CC_API
api
Definition: macro.h:29
std::string dataFormatToString(DataFormat data_format)
PixelType stringToPixelType(const std::string &pixel_type_str)
@ kStatusCodeOk
Definition: status.h:13
BorderType stringToBorderType(const std::string &border_type_str)
DataType stringToDataType(const std::string &str)
@ kInterpTypeLinear
Definition: type.h:57
DataFormat stringToDataFormat(const std::string &str)
std::string borderTypeToString(BorderType border_type)
std::string dataTypeToString(DataType data_type)
DataType dataTypeOf< float >()
std::string pixelTypeToString(PixelType pixel_type)
std::string interpTypeToString(InterpType interp_type)
@ kBorderTypeConstant
Definition: type.h:71
InterpType stringToInterpType(const std::string &interp_type_str)
@ kDataFormatNCHW
Definition: common.h:146
@ kPixelTypeBGR
Definition: type.h:15