nndeploy C++ API  0.2.0
nndeploy C++ API
classifier.h
Go to the documentation of this file.
1 #ifndef _NNDEPLOY_DETECT_CLASSIFIER_OCR_OCR_H_
2 #define _NNDEPLOY_DETECT_CLASSIFIER_OCR_OCR_H_
3 
4 #include "nndeploy/base/any.h"
5 #include "nndeploy/base/common.h"
7 #include "nndeploy/base/log.h"
8 #include "nndeploy/base/macro.h"
9 #include "nndeploy/base/object.h"
11 #include "nndeploy/base/status.h"
12 #include "nndeploy/base/string.h"
13 #include "nndeploy/dag/edge.h"
14 #include "nndeploy/dag/graph.h"
15 #include "nndeploy/dag/node.h"
16 #include "nndeploy/detect/result.h"
17 #include "nndeploy/device/buffer.h"
18 #include "nndeploy/device/device.h"
20 #include "nndeploy/device/tensor.h"
21 #include "nndeploy/infer/infer.h"
23 #include "nndeploy/ocr/result.h"
26 
27 namespace nndeploy {
28 namespace ocr {
29 
31  public:
32  int version_ = -1;
34  // virtual base::Status serialize(rapidjson::Value &json,
35  // rapidjson::Document::AllocatorType &allocator);
37  // virtual base::Status deserialize(rapidjson::Value &json);
38 };
39 
41  public:
47  std::vector<int> cls_image_shape_ = {3, 48, 192};
48  int h_ = -1;
49  int w_ = -1;
50  bool normalize_ = true;
51  float scale_[3] = {1.0f / 255.0f, 1.0f / 255.0f, 1.0f / 255.0f};
52  float mean_[3] = {0.0f, 0.0f, 0.0f};
53  float std_[3] = {1.0f, 1.0f, 1.0f};
54 
56  int top_ = 0;
57  int bottom_ = 0;
58  int left_ = 0;
59  int right_ = 0;
60  base::Scalar2d border_val_ = 0.0;
61 
64  rapidjson::Value &json,
65  rapidjson::Document::AllocatorType &allocator) override {
66  std::string src_pixel_type_str = base::pixelTypeToString(src_pixel_type_);
67  json.AddMember("src_pixel_type_",
68  rapidjson::Value(src_pixel_type_str.c_str(), allocator),
69  allocator);
70  std::string dst_pixel_type_str = base::pixelTypeToString(dst_pixel_type_);
71  json.AddMember("dst_pixel_type_",
72  rapidjson::Value(dst_pixel_type_str.c_str(), allocator),
73  allocator);
74  std::string interp_type_str = base::interpTypeToString(interp_type_);
75  json.AddMember("interp_type_",
76  rapidjson::Value(interp_type_str.c_str(), allocator),
77  allocator);
78  std::string data_type_str = base::dataTypeToString(data_type_);
79  json.AddMember("data_type_",
80  rapidjson::Value(data_type_str.c_str(), allocator),
81  allocator);
82  std::string data_format_str = base::dataFormatToString(data_format_);
83  json.AddMember("data_format_",
84  rapidjson::Value(data_format_str.c_str(), allocator),
85  allocator);
86  json.AddMember("h_", h_, allocator);
87  json.AddMember("w_", w_, allocator);
88  json.AddMember("normalize_", normalize_, allocator);
89 
90  rapidjson::Value cls_image_shape(rapidjson::kArrayType);
91  for (int i = 0; i < 3; i++) {
92  cls_image_shape.PushBack(cls_image_shape_[i], allocator);
93  }
94  json.AddMember("cls_image_shape_", cls_image_shape, allocator);
95 
96  rapidjson::Value scale_array(rapidjson::kArrayType);
97  rapidjson::Value mean_array(rapidjson::kArrayType);
98  rapidjson::Value std_array(rapidjson::kArrayType);
99  for (int i = 0; i < 3; i++) {
100  scale_array.PushBack(scale_[i], allocator);
101  mean_array.PushBack(mean_[i], allocator);
102  std_array.PushBack(std_[i], allocator);
103  }
104  json.AddMember("scale_", scale_array, allocator);
105  json.AddMember("mean_", mean_array, allocator);
106  json.AddMember("std_", std_array, allocator);
107 
108  std::string border_type_str = base::borderTypeToString(border_type_);
109  json.AddMember("border_type_",
110  rapidjson::Value(border_type_str.c_str(), allocator),
111  allocator);
112  json.AddMember("top_", top_, allocator);
113  json.AddMember("bottom_", bottom_, allocator);
114  json.AddMember("left_", left_, allocator);
115  json.AddMember("right_", right_, allocator);
116 
117  rapidjson::Value border_val_array(rapidjson::kArrayType);
118  for (int i = 0; i < 4; i++) {
119  border_val_array.PushBack(border_val_.val_[i], allocator);
120  }
121  json.AddMember("border_val_", border_val_array, allocator);
122 
123  return base::kStatusCodeOk;
124  }
125 
127  virtual base::Status deserialize(rapidjson::Value &json) override {
128  if (json.HasMember("src_pixel_type_") &&
129  json["src_pixel_type_"].IsString()) {
130  src_pixel_type_ =
131  base::stringToPixelType(json["src_pixel_type_"].GetString());
132  }
133  if (json.HasMember("dst_pixel_type_") &&
134  json["dst_pixel_type_"].IsString()) {
135  dst_pixel_type_ =
136  base::stringToPixelType(json["dst_pixel_type_"].GetString());
137  }
138  if (json.HasMember("interp_type_") && json["interp_type_"].IsString()) {
139  interp_type_ = base::stringToInterpType(json["interp_type_"].GetString());
140  }
141  if (json.HasMember("data_type_") && json["data_type_"].IsString()) {
142  data_type_ = base::stringToDataType(json["data_type_"].GetString());
143  }
144  if (json.HasMember("data_format_") && json["data_format_"].IsString()) {
145  data_format_ = base::stringToDataFormat(json["data_format_"].GetString());
146  }
147  if (json.HasMember("h_") && json["h_"].IsInt()) {
148  h_ = json["h_"].GetInt();
149  }
150  if (json.HasMember("w_") && json["w_"].IsInt()) {
151  w_ = json["w_"].GetInt();
152  }
153  if (json.HasMember("normalize_") && json["normalize_"].IsBool()) {
154  normalize_ = json["normalize_"].GetBool();
155  }
156 
157  if (json.HasMember("cls_image_shape_") &&
158  json["cls_image_shape_"].IsArray()) {
159  const rapidjson::Value &cls_image_shape_array = json["cls_image_shape_"];
160  for (int i = 0; i < 3 && i < cls_image_shape_array.Size(); i++) {
161  if (cls_image_shape_array[i].IsInt()) {
162  cls_image_shape_[i] = cls_image_shape_array[i].GetInt();
163  }
164  }
165  }
166 
167  if (json.HasMember("scale_") && json["scale_"].IsArray()) {
168  const rapidjson::Value &scale_array = json["scale_"];
169  for (int i = 0; i < 3 && i < scale_array.Size(); i++) {
170  if (scale_array[i].IsFloat()) {
171  scale_[i] = scale_array[i].GetFloat();
172  }
173  }
174  }
175  if (json.HasMember("mean_") && json["mean_"].IsArray()) {
176  const rapidjson::Value &mean_array = json["mean_"];
177  for (int i = 0; i < 3 && i < mean_array.Size(); i++) {
178  if (mean_array[i].IsFloat()) {
179  mean_[i] = mean_array[i].GetFloat();
180  }
181  }
182  }
183  if (json.HasMember("std_") && json["std_"].IsArray()) {
184  const rapidjson::Value &std_array = json["std_"];
185  for (int i = 0; i < 3 && i < std_array.Size(); i++) {
186  if (std_array[i].IsFloat()) {
187  std_[i] = std_array[i].GetFloat();
188  }
189  }
190  }
191 
192  if (json.HasMember("border_type_") && json["border_type_"].IsString()) {
193  border_type_ = base::stringToBorderType(json["border_type_"].GetString());
194  }
195  if (json.HasMember("top_") && json["top_"].IsInt()) {
196  top_ = json["top_"].GetInt();
197  }
198  if (json.HasMember("bottom_") && json["bottom_"].IsInt()) {
199  bottom_ = json["bottom_"].GetInt();
200  }
201  if (json.HasMember("left_") && json["left_"].IsInt()) {
202  left_ = json["left_"].GetInt();
203  }
204  if (json.HasMember("right_") && json["right_"].IsInt()) {
205  right_ = json["right_"].GetInt();
206  }
207 
208  if (json.HasMember("border_val_") && json["border_val_"].IsArray()) {
209  const rapidjson::Value &border_val_array = json["border_val_"];
210  for (int i = 0; i < 4 && i < border_val_array.Size(); i++) {
211  if (border_val_array[i].IsFloat()) {
212  border_val_.val_[i] = border_val_array[i].GetFloat();
213  }
214  }
215  }
216 
217  return base::kStatusCodeOk;
218  }
219 };
220 
222  public:
223  ClassifierPreProcess(const std::string &name) : dag::Node(name) {
224  key_ = "nndeploy::ocr::ClassifierPreProcess";
225  desc_ =
226  "ocr classify preprocess cv::Mat to "
227  "device::Tensor[resize->pad->normalize->transpose]";
228  param_ = std::make_shared<ClassifierPreProcessParam>();
229  this->setInputTypeInfo<OCRResult>();
230  this->setOutputTypeInfo<device::Tensor>();
231  }
232  ClassifierPreProcess(const std::string &name, std::vector<dag::Edge *> inputs,
233  std::vector<dag::Edge *> outputs)
234  : dag::Node(name, inputs, outputs) {
235  key_ = "nndeploy::ocr::ClassifierPreProcess";
236  desc_ =
237  "ocr classify preprocess cv::Mat to "
238  "device::Tensor[resize->pad->normalize->transpose]";
239  param_ = std::make_shared<ClassifierPreProcessParam>();
240  this->setInputTypeInfo<OCRResult>();
241  this->setOutputTypeInfo<device::Tensor>();
242  }
244 
245  virtual base::Status run();
246 };
247 
249  public:
250  int version_ = 2;
251  float cls_thresh_ = 0.9;
252 
254  virtual base::Status serialize(rapidjson::Value &json,
255  rapidjson::Document::AllocatorType &allocator);
257  virtual base::Status deserialize(rapidjson::Value &json);
258 };
259 
261  public:
262  ClassifierPostProcess(const std::string &name) : dag::Node(name) {
263  key_ = "nndeploy::ocr::ClassifierPostProcess";
264  desc_ = "PPOcrClsv2 postprocess[device::Tensor->OcrResult]";
265  param_ = std::make_shared<ClassifierPostParam>();
266  this->setInputTypeInfo<device::Tensor>();
267  this->setOutputTypeInfo<OCRResult>();
268  }
269  ClassifierPostProcess(const std::string &name,
270  std::vector<dag::Edge *> inputs,
271  std::vector<dag::Edge *> outputs)
272  : dag::Node(name, inputs, outputs) {
273  key_ = "nndeploy::ocr::ClassifierPostProcess";
274  desc_ = "PPOcrClsv2 postprocess[device::Tensor->OcrResult]";
275  param_ = std::make_shared<ClassifierPostParam>();
276  this->setInputTypeInfo<device::Tensor>();
277  this->setOutputTypeInfo<OCRResult>();
278  }
281  virtual base::Status run();
282 };
283 
284 // class NNDEPLOY_CC_API ClassifyBBoxResult : public base::Param {
285 // public:
286 // ClassifyBBoxResult(){};
287 // virtual ~ClassifyBBoxResult() {
288 // if (mask_ != nullptr) {
289 // delete mask_;
290 // mask_ = nullptr;
291 // }
292 // };
293 // int index_;
294 // int label_id_;
295 // float score_;
296 // std::array<float, 4> bbox_; // xmin, ymin, xmax, ymax
297 // device::Tensor *mask_ = nullptr;
298 // };
299 
300 // class NNDEPLOY_CC_API ClassifyResult : public base::Param {
301 // public:
302 // ClassifyResult(){};
303 // virtual ~ClassifyResult(){};
304 // std::vector<ClassifyBBoxResult> bboxs_;
305 // };
306 
308  public:
309  ClassifierGraph(const std::string &name) : dag::Graph(name) {
310  key_ = "nndeploy::ocr::ClassifierGraph";
311  desc_ =
312  "PPOcrClsV2 graph[cv::Mat->preprocess->infer->postprocess->OcrResult]";
313  this->setInputTypeInfo<OCRResult>();
314  this->setOutputTypeInfo<OCRResult>();
315  pre_ = dynamic_cast<ClassifierPreProcess *>(
316  this->createNode<ClassifierPreProcess>("preprocess"));
317  infer_ =
318  dynamic_cast<infer::Infer *>(this->createNode<infer::Infer>("infer"));
319  post_ = dynamic_cast<ClassifierPostProcess *>(
320  this->createNode<ClassifierPostProcess>("postprocess"));
321  }
322 
323  ClassifierGraph(const std::string &name, std::vector<dag::Edge *> inputs,
324  std::vector<dag::Edge *> outputs)
325  : dag::Graph(name, inputs, outputs) {
326  key_ = "nndeploy::ocr::ClassifierGraph";
327  desc_ =
328  "PPOcrClsV2 graph[cv::Mat->preprocess->infer->postprocess->OcrResult]";
329  this->setInputTypeInfo<OCRResult>();
330  this->setOutputTypeInfo<OCRResult>();
331  pre_ = dynamic_cast<ClassifierPreProcess *>(
332  this->createNode<ClassifierPreProcess>("preprocess"));
333  infer_ =
334  dynamic_cast<infer::Infer *>(this->createNode<infer::Infer>("infer"));
335  post_ = dynamic_cast<ClassifierPostProcess *>(
336  this->createNode<ClassifierPostProcess>("postprocess"));
337  }
338 
339  virtual ~ClassifierGraph() {}
340 
341  base::Status make(const dag::NodeDesc &pre_desc,
342  const dag::NodeDesc &infer_desc,
343  base::InferenceType inference_type,
344  const dag::NodeDesc &post_desc) {
345  this->setNodeDesc(pre_, pre_desc);
346  this->setNodeDesc(infer_, infer_desc);
347  this->setNodeDesc(post_, post_desc);
348  this->defaultParam();
349  base::Status status = infer_->setInferenceType(inference_type);
350  if (status != base::kStatusCodeOk) {
351  NNDEPLOY_LOGE("Failed to set inference type");
352  return status;
353  }
354  return base::kStatusCodeOk;
355  }
356 
358  ClassifierPreProcessParam *pre_param =
359  dynamic_cast<ClassifierPreProcessParam *>(pre_->getParam());
363  pre_param->cls_image_shape_ = {3, 48, 192};
364  ClassifierPostParam *post_param =
365  dynamic_cast<ClassifierPostParam *>(post_->getParam());
366  post_param->cls_thresh_ = 0.6;
367  post_param->version_ = 5;
368 
369  return base::kStatusCodeOk;
370  }
371 
373  base::Status status = infer_->setInferenceType(inference_type);
374  if (status != base::kStatusCodeOk) {
375  NNDEPLOY_LOGE("Failed to set inference type");
376  return status;
377  }
378  return base::kStatusCodeOk;
379  }
381  base::ModelType model_type, bool is_path,
382  std::vector<std::string> &model_value) {
383  auto param = dynamic_cast<inference::InferenceParam *>(infer_->getParam());
384  param->device_type_ = device_type;
385  param->model_type_ = model_type;
386  param->is_path_ = is_path;
387  param->model_value_ = model_value;
388  return base::kStatusCodeOk;
389  }
390  base::Status setVersion(int version) {
391  ClassifierPostParam *param =
392  dynamic_cast<ClassifierPostParam *>(post_->getParam());
393  param->version_ = version;
394  return base::kStatusCodeOk;
395  }
396  base::Status setClsThresh(float threshold) {
397  ClassifierPostParam *param =
398  dynamic_cast<ClassifierPostParam *>(post_->getParam());
399  param->cls_thresh_ = threshold;
400  return base::kStatusCodeOk;
401  }
404  dynamic_cast<ClassifierPreProcessParam *>(pre_->getParam());
405  param->src_pixel_type_ = pixel_type;
406  return base::kStatusCodeOk;
407  }
408 
409  std::vector<dag::Edge *> forward(std::vector<dag::Edge *> inputs) {
410  inputs = (*pre_)(inputs);
411  inputs = (*infer_)(inputs);
412  std::vector<dag::Edge *> outputs = (*post_)(inputs);
413  return outputs;
414  }
415 
416  private:
417  dag::Node *pre_ = nullptr;
418  infer::Infer *infer_ = nullptr;
419  dag::Node *post_ = nullptr;
420 };
421 } // namespace ocr
422 } // namespace nndeploy
423 
424 #endif
virtual base::Status deserialize(rapidjson::Value &json)
virtual std::string serialize()
Template class for a 4-element vector. Scalar_ and Scalar can be used just as typical 4-element vecto...
Definition: type.h:421
Directed Acyclic Graph Node.
Definition: graph.h:31
Node description class.
Definition: node.h:35
Node base class.
Definition: node.h:171
InferenceParam is the base class of all inference param.
base::Status make(const dag::NodeDesc &pre_desc, const dag::NodeDesc &infer_desc, base::InferenceType inference_type, const dag::NodeDesc &post_desc)
Definition: classifier.h:341
std::vector< dag::Edge * > forward(std::vector< dag::Edge * > inputs)
Definition: classifier.h:409
base::Status setVersion(int version)
Definition: classifier.h:390
base::Status setInferParam(base::DeviceType device_type, base::ModelType model_type, bool is_path, std::vector< std::string > &model_value)
Definition: classifier.h:380
base::Status setInferenceType(base::InferenceType inference_type)
Definition: classifier.h:372
virtual base::Status defaultParam()
Set default parameters.
Definition: classifier.h:357
base::Status setClsThresh(float threshold)
Definition: classifier.h:396
ClassifierGraph(const std::string &name)
Definition: classifier.h:309
base::Status setSrcPixelType(base::PixelType pixel_type)
Definition: classifier.h:402
ClassifierGraph(const std::string &name, std::vector< dag::Edge * > inputs, std::vector< dag::Edge * > outputs)
Definition: classifier.h:323
virtual base::Status serialize(rapidjson::Value &json, rapidjson::Document::AllocatorType &allocator)
virtual base::Status deserialize(rapidjson::Value &json)
ClassifierPostProcess(const std::string &name)
Definition: classifier.h:262
ClassifierPostProcess(const std::string &name, std::vector< dag::Edge * > inputs, std::vector< dag::Edge * > outputs)
Definition: classifier.h:269
virtual base::Status run()
Run node (pure virtual function)
virtual base::Status serialize(rapidjson::Value &json, rapidjson::Document::AllocatorType &allocator) override
Definition: classifier.h:63
virtual base::Status deserialize(rapidjson::Value &json) override
Definition: classifier.h:127
ClassifierPreProcess(const std::string &name)
Definition: classifier.h:223
virtual base::Status run()
Run node (pure virtual function)
ClassifierPreProcess(const std::string &name, std::vector< dag::Edge * > inputs, std::vector< dag::Edge * > outputs)
Definition: classifier.h:232
#define NNDEPLOY_LOGE(fmt,...)
Definition: log.h:59
#define NNDEPLOY_CC_API
api
Definition: macro.h:29
std::string dataFormatToString(DataFormat data_format)
PixelType stringToPixelType(const std::string &pixel_type_str)
@ kStatusCodeOk
Definition: status.h:13
BorderType stringToBorderType(const std::string &border_type_str)
DataType stringToDataType(const std::string &str)
@ kInterpTypeLinear
Definition: type.h:57
DataFormat stringToDataFormat(const std::string &str)
std::string borderTypeToString(BorderType border_type)
std::string dataTypeToString(DataType data_type)
DataType dataTypeOf< float >()
std::string pixelTypeToString(PixelType pixel_type)
std::string interpTypeToString(InterpType interp_type)
@ kBorderTypeConstant
Definition: type.h:71
InterpType stringToInterpType(const std::string &interp_type_str)
@ kDataFormatNCHW
Definition: common.h:146
@ kPixelTypeBGR
Definition: type.h:15