nndeploy C++ API  0.2.0
nndeploy C++ API
recognizer.h
Go to the documentation of this file.
1 #ifndef _NNDEPLOY_DETECT_RECOGNIZER_OCR_OCR_H_
2 #define _NNDEPLOY_DETECT_RECOGNIZER_OCR_OCR_H_
3 
4 #include "nndeploy/base/any.h"
5 #include "nndeploy/base/common.h"
7 #include "nndeploy/base/log.h"
8 #include "nndeploy/base/macro.h"
9 #include "nndeploy/base/object.h"
11 #include "nndeploy/base/status.h"
12 #include "nndeploy/base/string.h"
13 #include "nndeploy/dag/edge.h"
14 #include "nndeploy/dag/graph.h"
15 #include "nndeploy/dag/node.h"
16 #include "nndeploy/device/buffer.h"
17 #include "nndeploy/device/device.h"
19 #include "nndeploy/device/tensor.h"
20 #include "nndeploy/infer/infer.h"
22 #include "nndeploy/ocr/result.h"
25 
26 namespace nndeploy {
27 namespace ocr {
28 
30  public:
31  int version_ = -1;
32  // using base::Param::serialize;
33  // virtual base::Status serialize(rapidjson::Value &json,
34  // rapidjson::Document::AllocatorType &allocator);
35  // using base::Param::deserialize;
36  // virtual base::Status deserialize(rapidjson::Value &json);
37 };
38 
40  public:
46  int h_ = -1;
47  int w_ = -1;
48  bool normalize_ = true;
49  int rec_batch_size_ = 6;
50  std::vector<int> rec_image_shape_ = {3, 48, 320};
51 
52  float scale_[3] = {1.0f / 255.0f, 1.0f / 255.0f, 1.0f / 255.0f};
53  float mean_[3] = {0.5f, 0.5f, 0.5f};
54  float std_[3] = {0.5f, 0.5f, 0.5f};
55 
57  int top_ = 0;
58  int bottom_ = 0;
59  int left_ = 0;
60  int right_ = 0;
61  base::Scalar2d border_val_ = 0.0;
62 
65  rapidjson::Value &json,
66  rapidjson::Document::AllocatorType &allocator) override {
67  std::string src_pixel_type_str = base::pixelTypeToString(src_pixel_type_);
68  json.AddMember("src_pixel_type_",
69  rapidjson::Value(src_pixel_type_str.c_str(), allocator),
70  allocator);
71  std::string dst_pixel_type_str = base::pixelTypeToString(dst_pixel_type_);
72  json.AddMember("dst_pixel_type_",
73  rapidjson::Value(dst_pixel_type_str.c_str(), allocator),
74  allocator);
75  std::string interp_type_str = base::interpTypeToString(interp_type_);
76  json.AddMember("interp_type_",
77  rapidjson::Value(interp_type_str.c_str(), allocator),
78  allocator);
79  std::string data_type_str = base::dataTypeToString(data_type_);
80  json.AddMember("data_type_",
81  rapidjson::Value(data_type_str.c_str(), allocator),
82  allocator);
83  std::string data_format_str = base::dataFormatToString(data_format_);
84  json.AddMember("data_format_",
85  rapidjson::Value(data_format_str.c_str(), allocator),
86  allocator);
87  json.AddMember("h_", h_, allocator);
88  json.AddMember("w_", w_, allocator);
89  json.AddMember("rec_batch_size_", rec_batch_size_, allocator);
90  rapidjson::Value rec_image_shape(rapidjson::kArrayType);
91  for (int i = 0; i < 3; i++) {
92  rec_image_shape.PushBack(rec_image_shape_[i], allocator);
93  }
94  json.AddMember("rec_image_shape_", rec_image_shape, allocator);
95  json.AddMember("normalize_", normalize_, allocator);
96 
97  rapidjson::Value scale_array(rapidjson::kArrayType);
98  rapidjson::Value mean_array(rapidjson::kArrayType);
99  rapidjson::Value std_array(rapidjson::kArrayType);
100  for (int i = 0; i < 3; i++) {
101  scale_array.PushBack(scale_[i], allocator);
102  mean_array.PushBack(mean_[i], allocator);
103  std_array.PushBack(std_[i], allocator);
104  }
105  json.AddMember("scale_", scale_array, allocator);
106  json.AddMember("mean_", mean_array, allocator);
107  json.AddMember("std_", std_array, allocator);
108 
109  std::string border_type_str = base::borderTypeToString(border_type_);
110  json.AddMember("border_type_",
111  rapidjson::Value(border_type_str.c_str(), allocator),
112  allocator);
113  json.AddMember("top_", top_, allocator);
114  json.AddMember("bottom_", bottom_, allocator);
115  json.AddMember("left_", left_, allocator);
116  json.AddMember("right_", right_, allocator);
117 
118  rapidjson::Value border_val_array(rapidjson::kArrayType);
119  for (int i = 0; i < 4; i++) {
120  border_val_array.PushBack(border_val_.val_[i], allocator);
121  }
122  json.AddMember("border_val_", border_val_array, allocator);
123 
124  return base::kStatusCodeOk;
125  }
126 
128  virtual base::Status deserialize(rapidjson::Value &json) override {
129  if (json.HasMember("src_pixel_type_") &&
130  json["src_pixel_type_"].IsString()) {
131  src_pixel_type_ =
132  base::stringToPixelType(json["src_pixel_type_"].GetString());
133  }
134  if (json.HasMember("dst_pixel_type_") &&
135  json["dst_pixel_type_"].IsString()) {
136  dst_pixel_type_ =
137  base::stringToPixelType(json["dst_pixel_type_"].GetString());
138  }
139  if (json.HasMember("interp_type_") && json["interp_type_"].IsString()) {
140  interp_type_ = base::stringToInterpType(json["interp_type_"].GetString());
141  }
142  if (json.HasMember("data_type_") && json["data_type_"].IsString()) {
143  data_type_ = base::stringToDataType(json["data_type_"].GetString());
144  }
145  if (json.HasMember("data_format_") && json["data_format_"].IsString()) {
146  data_format_ = base::stringToDataFormat(json["data_format_"].GetString());
147  }
148  if (json.HasMember("h_") && json["h_"].IsInt()) {
149  h_ = json["h_"].GetInt();
150  }
151  if (json.HasMember("w_") && json["w_"].IsInt()) {
152  w_ = json["w_"].GetInt();
153  }
154 
155  if (json.HasMember("rec_batch_size_") && json["rec_batch_size_"].IsInt()) {
156  rec_batch_size_ = json["rec_batch_size_"].GetInt();
157  }
158  if (json.HasMember("rec_image_shape_") &&
159  json["rec_image_shape_"].IsArray()) {
160  const rapidjson::Value &rec_image_shape_array = json["rec_image_shape_"];
161  for (int i = 0; i < 3 && i < rec_image_shape_array.Size(); i++) {
162  if (rec_image_shape_array[i].IsInt()) {
163  rec_image_shape_[i] = rec_image_shape_array[i].GetInt();
164  }
165  }
166  }
167 
168  if (json.HasMember("normalize_") && json["normalize_"].IsBool()) {
169  normalize_ = json["normalize_"].GetBool();
170  }
171 
172  if (json.HasMember("scale_") && json["scale_"].IsArray()) {
173  const rapidjson::Value &scale_array = json["scale_"];
174  for (int i = 0; i < 3 && i < scale_array.Size(); i++) {
175  if (scale_array[i].IsFloat()) {
176  scale_[i] = scale_array[i].GetFloat();
177  }
178  }
179  }
180  if (json.HasMember("mean_") && json["mean_"].IsArray()) {
181  const rapidjson::Value &mean_array = json["mean_"];
182  for (int i = 0; i < 3 && i < mean_array.Size(); i++) {
183  if (mean_array[i].IsFloat()) {
184  mean_[i] = mean_array[i].GetFloat();
185  }
186  }
187  }
188  if (json.HasMember("std_") && json["std_"].IsArray()) {
189  const rapidjson::Value &std_array = json["std_"];
190  for (int i = 0; i < 3 && i < std_array.Size(); i++) {
191  if (std_array[i].IsFloat()) {
192  std_[i] = std_array[i].GetFloat();
193  }
194  }
195  }
196 
197  if (json.HasMember("border_type_") && json["border_type_"].IsString()) {
198  border_type_ = base::stringToBorderType(json["border_type_"].GetString());
199  }
200  if (json.HasMember("top_") && json["top_"].IsInt()) {
201  top_ = json["top_"].GetInt();
202  }
203  if (json.HasMember("bottom_") && json["bottom_"].IsInt()) {
204  bottom_ = json["bottom_"].GetInt();
205  }
206  if (json.HasMember("left_") && json["left_"].IsInt()) {
207  left_ = json["left_"].GetInt();
208  }
209  if (json.HasMember("right_") && json["right_"].IsInt()) {
210  right_ = json["right_"].GetInt();
211  }
212 
213  if (json.HasMember("border_val_") && json["border_val_"].IsArray()) {
214  const rapidjson::Value &border_val_array = json["border_val_"];
215  for (int i = 0; i < 4 && i < border_val_array.Size(); i++) {
216  if (border_val_array[i].IsFloat()) {
217  border_val_.val_[i] = border_val_array[i].GetFloat();
218  }
219  }
220  }
221 
222  return base::kStatusCodeOk;
223  }
224 };
225 
227  public:
228  RecognizerPreProcess(const std::string &name) : dag::Node(name) {
229  key_ = "nndeploy::ocr::RecognizerPreProcess";
230  desc_ =
231  "ocr recognizer preprocess cv::Mat to "
232  "device::Tensor[resize->pad->normalize->transpose]";
233  param_ = std::make_shared<RecognizerPreProcessParam>();
234  this->setInputTypeInfo<OCRResult>();
235  this->setOutputTypeInfo<device::Tensor>();
236  }
237  RecognizerPreProcess(const std::string &name, std::vector<dag::Edge *> inputs,
238  std::vector<dag::Edge *> outputs)
239  : dag::Node(name, inputs, outputs) {
240  key_ = "nndeploy::ocr::RecognizerPreProcess";
241  desc_ =
242  "ocr recognizer preprocess cv::Mat to "
243  "device::Tensor[resize->pad->normalize->transpose]";
244  param_ = std::make_shared<RecognizerPreProcessParam>();
245  this->setInputTypeInfo<OCRResult>();
246  this->setOutputTypeInfo<device::Tensor>();
247  }
249 
250  virtual base::Status run();
251 };
252 
254  public:
255  int version_ = 5;
256  double rec_thresh_ = 0.2;
257  std::string character_path_;
258 
260  virtual base::Status serialize(rapidjson::Value &json,
261  rapidjson::Document::AllocatorType &allocator);
263  virtual base::Status deserialize(rapidjson::Value &json);
264 };
265 
267  public:
268  RecognizerPostProcess(const std::string &name) : dag::Node(name) {
269  key_ = "nndeploy::ocr::RecognizerPostProcess";
270  desc_ = "PPOcrRecv3/v4/v5 postprocess[device::Tensor->DetectResult]";
271  param_ = std::make_shared<RecognizerPostParam>();
272  this->setInputTypeInfo<device::Tensor>();
273  this->setOutputTypeInfo<OCRResult>();
274  }
275  RecognizerPostProcess(const std::string &name,
276  std::vector<dag::Edge *> inputs,
277  std::vector<dag::Edge *> outputs)
278  : dag::Node(name, inputs, outputs) {
279  key_ = "nndeploy::ocr::RecognizerPostProcess";
280  desc_ = "PPOcrRecv3/v4/v5 postprocess[device::Tensor->DetectResult]";
281  param_ = std::make_shared<RecognizerPostParam>();
282  this->setInputTypeInfo<device::Tensor>();
283  this->setOutputTypeInfo<OCRResult>();
284  }
287  virtual base::Status run();
288 };
289 
291  public:
292  RecognizerGraph(const std::string &name) : dag::Graph(name) {
293  key_ = "nndeploy::ocr::RecognizerGraph";
294  desc_ =
295  "PPOcrRecv3/v4/v5 "
296  "graph[cv::Mat->preprocess->infer->postprocess->OcrResult]";
297  this->setInputTypeInfo<OCRResult>();
298  this->setOutputTypeInfo<OCRResult>();
299  pre_ = dynamic_cast<RecognizerPreProcess *>(
300  this->createNode<RecognizerPreProcess>("preprocess"));
301  infer_ =
302  dynamic_cast<infer::Infer *>(this->createNode<infer::Infer>("infer"));
303  post_ = dynamic_cast<RecognizerPostProcess *>(
304  this->createNode<RecognizerPostProcess>("postprocess"));
305  }
306 
307  RecognizerGraph(const std::string &name, std::vector<dag::Edge *> inputs,
308  std::vector<dag::Edge *> outputs)
309  : dag::Graph(name, inputs, outputs) {
310  key_ = "nndeploy::ocr::RecognizerGraph";
311  desc_ =
312  "PPOcrRecv3/v4/v5 "
313  "graph[cv::Mat->preprocess->infer->postprocess->DetectResult]";
314  this->setInputTypeInfo<OCRResult>();
315  this->setOutputTypeInfo<OCRResult>();
316  pre_ = dynamic_cast<RecognizerPreProcess *>(
317  this->createNode<RecognizerPreProcess>("preprocess"));
318  infer_ =
319  dynamic_cast<infer::Infer *>(this->createNode<infer::Infer>("infer"));
320  post_ = dynamic_cast<RecognizerPostProcess *>(
321  this->createNode<RecognizerPostProcess>("postprocess"));
322  }
323  virtual ~RecognizerGraph() {}
324 
326  RecognizerPreProcessParam *pre_param =
327  dynamic_cast<RecognizerPreProcessParam *>(pre_->getParam());
331  pre_param->rec_batch_size_ = 6;
332  pre_param->rec_image_shape_ = {3, 48, 320};
333  RecognizerPostParam *post_param =
334  dynamic_cast<RecognizerPostParam *>(post_->getParam());
335  post_param->rec_thresh_ = 0.6;
336  post_param->version_ = 5;
337 
338  return base::kStatusCodeOk;
339  }
340 
341  base::Status make(const dag::NodeDesc &pre_desc,
342  const dag::NodeDesc &infer_desc,
343  base::InferenceType inference_type,
344  const dag::NodeDesc &post_desc) {
345  this->setNodeDesc(pre_, pre_desc);
346  this->setNodeDesc(infer_, infer_desc);
347  this->setNodeDesc(post_, post_desc);
348  this->defaultParam();
349  base::Status status = infer_->setInferenceType(inference_type);
350  if (status != base::kStatusCodeOk) {
351  NNDEPLOY_LOGE("Failed to set inference type");
352  return status;
353  }
354 
355  return base::kStatusCodeOk;
356  }
357 
359  base::Status status = infer_->setInferenceType(inference_type);
360  if (status != base::kStatusCodeOk) {
361  NNDEPLOY_LOGE("Failed to set inference type");
362  return status;
363  }
364  return base::kStatusCodeOk;
365  }
367  base::ModelType model_type, bool is_path,
368  std::vector<std::string> &model_value) {
369  auto param = dynamic_cast<inference::InferenceParam *>(infer_->getParam());
370  param->device_type_ = device_type;
371  param->model_type_ = model_type;
372  param->is_path_ = is_path;
373  param->model_value_ = model_value;
374  return base::kStatusCodeOk;
375  }
376 
377  base::Status setCharacterPath(const std::string &character_path) {
378  if (character_path.empty()) {
379  return base::kStatusCodeErrorInvalidParam; // 可以加一些校验
380  }
381  auto param = dynamic_cast<RecognizerPostParam *>(post_->getParam());
382  param->character_path_ = character_path;
383  return base::kStatusCodeOk;
384  }
385 
386  base::Status setRecThresh(float threshold) {
387  RecognizerPostParam *param =
388  dynamic_cast<RecognizerPostParam *>(post_->getParam());
389  param->rec_thresh_ = threshold;
390  return base::kStatusCodeOk;
391  }
392  base::Status setVersion(int version) {
393  RecognizerPostParam *param =
394  dynamic_cast<RecognizerPostParam *>(post_->getParam());
395  param->version_ = version;
396  return base::kStatusCodeOk;
397  }
398 
401  dynamic_cast<RecognizerPreProcessParam *>(pre_->getParam());
402  param->src_pixel_type_ = pixel_type;
403  return base::kStatusCodeOk;
404  }
405  std::vector<dag::Edge *> forward(std::vector<dag::Edge *> inputs) {
406  inputs = (*pre_)(inputs);
407  inputs = (*infer_)(inputs);
408  std::vector<dag::Edge *> outputs = (*post_)(inputs);
409  return outputs;
410  }
411 
412  private:
413  dag::Node *pre_ = nullptr;
414  infer::Infer *infer_ = nullptr;
415  dag::Node *post_ = nullptr;
416 };
417 } // namespace ocr
418 } // namespace nndeploy
419 
420 #endif
virtual base::Status deserialize(rapidjson::Value &json)
virtual std::string serialize()
Template class for a 4-element vector. Scalar_ and Scalar can be used just as typical 4-element vecto...
Definition: type.h:421
Directed Acyclic Graph Node.
Definition: graph.h:31
Node description class.
Definition: node.h:35
Node base class.
Definition: node.h:171
InferenceParam is the base class of all inference param.
RecognizerGraph(const std::string &name, std::vector< dag::Edge * > inputs, std::vector< dag::Edge * > outputs)
Definition: recognizer.h:307
base::Status make(const dag::NodeDesc &pre_desc, const dag::NodeDesc &infer_desc, base::InferenceType inference_type, const dag::NodeDesc &post_desc)
Definition: recognizer.h:341
base::Status setCharacterPath(const std::string &character_path)
Definition: recognizer.h:377
virtual base::Status defaultParam()
Set default parameters.
Definition: recognizer.h:325
base::Status setInferParam(base::DeviceType device_type, base::ModelType model_type, bool is_path, std::vector< std::string > &model_value)
Definition: recognizer.h:366
RecognizerGraph(const std::string &name)
Definition: recognizer.h:292
base::Status setVersion(int version)
Definition: recognizer.h:392
base::Status setRecThresh(float threshold)
Definition: recognizer.h:386
base::Status setInferenceType(base::InferenceType inference_type)
Definition: recognizer.h:358
std::vector< dag::Edge * > forward(std::vector< dag::Edge * > inputs)
Definition: recognizer.h:405
base::Status setSrcPixelType(base::PixelType pixel_type)
Definition: recognizer.h:399
virtual base::Status serialize(rapidjson::Value &json, rapidjson::Document::AllocatorType &allocator)
virtual base::Status deserialize(rapidjson::Value &json)
RecognizerPostProcess(const std::string &name, std::vector< dag::Edge * > inputs, std::vector< dag::Edge * > outputs)
Definition: recognizer.h:275
virtual base::Status run()
Run node (pure virtual function)
RecognizerPostProcess(const std::string &name)
Definition: recognizer.h:268
virtual base::Status deserialize(rapidjson::Value &json) override
Definition: recognizer.h:128
virtual base::Status serialize(rapidjson::Value &json, rapidjson::Document::AllocatorType &allocator) override
Definition: recognizer.h:64
RecognizerPreProcess(const std::string &name, std::vector< dag::Edge * > inputs, std::vector< dag::Edge * > outputs)
Definition: recognizer.h:237
virtual base::Status run()
Run node (pure virtual function)
RecognizerPreProcess(const std::string &name)
Definition: recognizer.h:228
#define NNDEPLOY_LOGE(fmt,...)
Definition: log.h:59
#define NNDEPLOY_CC_API
api
Definition: macro.h:29
std::string dataFormatToString(DataFormat data_format)
PixelType stringToPixelType(const std::string &pixel_type_str)
@ kStatusCodeOk
Definition: status.h:13
@ kStatusCodeErrorInvalidParam
Definition: status.h:21
BorderType stringToBorderType(const std::string &border_type_str)
DataType stringToDataType(const std::string &str)
@ kInterpTypeLinear
Definition: type.h:57
DataFormat stringToDataFormat(const std::string &str)
std::string borderTypeToString(BorderType border_type)
std::string dataTypeToString(DataType data_type)
DataType dataTypeOf< float >()
std::string pixelTypeToString(PixelType pixel_type)
std::string interpTypeToString(InterpType interp_type)
@ kBorderTypeConstant
Definition: type.h:71
InterpType stringToInterpType(const std::string &interp_type_str)
@ kDataFormatNCHW
Definition: common.h:146
@ kPixelTypeBGR
Definition: type.h:15