nndeploy C++ API  0.2.0
nndeploy C++ API
classification.h
Go to the documentation of this file.
1 
2 #ifndef _NNDEPLOY_CLASSIFICATION_CLASSIFICATION_H_
3 #define _NNDEPLOY_CLASSIFICATION_CLASSIFICATION_H_
4 
5 #include "nndeploy/base/any.h"
6 #include "nndeploy/base/common.h"
8 #include "nndeploy/base/log.h"
9 #include "nndeploy/base/macro.h"
10 #include "nndeploy/base/object.h"
12 #include "nndeploy/base/param.h"
13 #include "nndeploy/base/status.h"
14 #include "nndeploy/base/string.h"
16 #include "nndeploy/dag/edge.h"
17 #include "nndeploy/dag/graph.h"
18 #include "nndeploy/dag/node.h"
19 #include "nndeploy/device/buffer.h"
20 #include "nndeploy/device/device.h"
22 #include "nndeploy/device/tensor.h"
23 #include "nndeploy/infer/infer.h"
27 
28 namespace nndeploy {
29 namespace classification {
30 
32  public:
33  int topk_ = 1;
34  bool is_softmax_ = true;
35  int version_ = -1;
36 
38  virtual base::Status serialize(rapidjson::Value &json,
39  rapidjson::Document::AllocatorType &allocator);
41  virtual base::Status deserialize(rapidjson::Value &json);
42 };
43 
45  public:
46  ClassificationPostProcess(const std::string &name) : dag::Node(name) {
47  key_ = "nndeploy::classification::ClassificationPostProcess";
48  desc_ = "Classification postprocess[device::Tensor->ClassificationResult]";
49  param_ = std::make_shared<ClassificationPostParam>();
50  this->setInputTypeInfo<device::Tensor>();
51  this->setOutputTypeInfo<ClassificationResult>();
52  }
53  ClassificationPostProcess(const std::string &name,
54  std::vector<dag::Edge *> inputs,
55  std::vector<dag::Edge *> outputs)
56  : dag::Node(name, inputs, outputs) {
57  key_ = "nndeploy::classification::ClassificationPostProcess";
58  desc_ = "Classification postprocess[device::Tensor->ClassificationResult]";
59  param_ = std::make_shared<ClassificationPostParam>();
60  this->setInputTypeInfo<device::Tensor>();
61  this->setOutputTypeInfo<ClassificationResult>();
62  }
64 
65  virtual base::Status run();
66 };
67 
77  public:
78  ClassificationGraph(const std::string &name) : dag::Graph(name) {
79  key_ = "nndeploy::classification::ClassificationGraph";
80  desc_ =
81  "Classification "
82  "graph[cv::Mat->preprocess->infer->postprocess->ClassificationResult]";
83  this->setInputTypeInfo<cv::Mat>();
84  this->setOutputTypeInfo<ClassificationResult>();
85  pre_ = dynamic_cast<preprocess::CvtResizeCropNormTrans *>(
86  this->createNode<preprocess::CvtResizeCropNormTrans>("preprocess"));
87  infer_ =
88  dynamic_cast<infer::Infer *>(this->createNode<infer::Infer>("infer"));
89  post_ = dynamic_cast<ClassificationPostProcess *>(
90  this->createNode<ClassificationPostProcess>("postprocess"));
91  }
92  ClassificationGraph(const std::string &name, std::vector<dag::Edge *> inputs,
93  std::vector<dag::Edge *> outputs)
94  : dag::Graph(name, inputs, outputs) {
95  key_ = "nndeploy::classification::ClassificationGraph";
96  desc_ =
97  "Classification "
98  "graph[cv::Mat->preprocess->infer->postprocess->ClassificationResult]";
99  this->setInputTypeInfo<cv::Mat>();
100  this->setOutputTypeInfo<ClassificationResult>();
101  pre_ = dynamic_cast<preprocess::CvtResizeCropNormTrans *>(
102  this->createNode<preprocess::CvtResizeCropNormTrans>("preprocess"));
103  infer_ =
104  dynamic_cast<infer::Infer *>(this->createNode<infer::Infer>("infer"));
105  post_ = dynamic_cast<ClassificationPostProcess *>(
106  this->createNode<ClassificationPostProcess>("postprocess"));
107  }
108 
109  virtual ~ClassificationGraph() {}
110 
112  // preprocess::CvtResizeNormTransParam *pre_param =
113  // dynamic_cast<preprocess::CvtResizeNormTransParam
114  // *>(pre_->getParam());
115  // pre_param->src_pixel_type_ = base::kPixelTypeBGR;
116  // pre_param->dst_pixel_type_ = base::kPixelTypeRGB;
117  // pre_param->interp_type_ = base::kInterpTypeLinear;
118  // pre_param->h_ = 224;
119  // pre_param->w_ = 224;
120  // pre_param->mean_[0] = 0.485;
121  // pre_param->mean_[1] = 0.456;
122  // pre_param->mean_[2] = 0.406;
123  // pre_param->std_[0] = 0.229;
124  // pre_param->std_[1] = 0.224;
125  // pre_param->std_[2] = 0.225;
128  pre_->getParam());
132  pre_param->resize_h_ = 256;
133  pre_param->resize_w_ = 256;
134  // pre_param->resize_h_ = 224;
135  // pre_param->resize_w_ = 224;
136  pre_param->mean_[0] = 0.485;
137  pre_param->mean_[1] = 0.456;
138  pre_param->mean_[2] = 0.406;
139  pre_param->std_[0] = 0.229;
140  pre_param->std_[1] = 0.224;
141  pre_param->std_[2] = 0.225;
142  pre_param->width_ = 224;
143  pre_param->height_ = 224;
144 
145  ClassificationPostParam *post_param =
146  dynamic_cast<ClassificationPostParam *>(post_->getParam());
147  post_param->topk_ = 1;
148 
149  return base::kStatusCodeOk;
150  }
151 
152  base::Status make(const dag::NodeDesc &pre_desc,
153  const dag::NodeDesc &infer_desc,
154  base::InferenceType inference_type,
155  const dag::NodeDesc &post_desc) {
156  this->setNodeDesc(pre_, pre_desc);
157  this->setNodeDesc(infer_, infer_desc);
158  this->setNodeDesc(post_, post_desc);
159  this->defaultParam();
160  base::Status status = infer_->setInferenceType(inference_type);
161  if (status != base::kStatusCodeOk) {
162  NNDEPLOY_LOGE("Failed to set inference type");
163  return status;
164  }
165  return base::kStatusCodeOk;
166  }
167 
169  base::Status status = infer_->setInferenceType(inference_type);
170  if (status != base::kStatusCodeOk) {
171  NNDEPLOY_LOGE("Failed to set inference type");
172  return status;
173  }
174  return base::kStatusCodeOk;
175  }
176 
178  base::ModelType model_type, bool is_path,
179  std::vector<std::string> &model_value) {
180  // auto infer = dynamic_cast<infer::Infer *>(infer_);
181  auto param = dynamic_cast<inference::InferenceParam *>(infer_->getParam());
182  param->device_type_ = device_type;
183  param->model_type_ = model_type;
184  param->is_path_ = is_path;
185  param->model_value_ = model_value;
186  return base::kStatusCodeOk;
187  }
188 
197  pre_->getParam());
198  param->src_pixel_type_ = pixel_type;
199  return base::kStatusCodeOk;
200  }
201 
202  base::Status setTopk(int topk) {
203  ClassificationPostParam *param =
204  dynamic_cast<ClassificationPostParam *>(post_->getParam());
205  param->topk_ = topk;
206  return base::kStatusCodeOk;
207  }
208 
209  base::Status setSoftmax(bool is_softmax) {
210  ClassificationPostParam *param =
211  dynamic_cast<ClassificationPostParam *>(post_->getParam());
212  param->is_softmax_ = is_softmax;
213  return base::kStatusCodeOk;
214  }
215 
216  std::vector<dag::Edge *> forward(std::vector<dag::Edge *> inputs) {
217  inputs = (*pre_)(inputs);
218  inputs = (*infer_)(inputs);
219  std::vector<dag::Edge *> outputs = (*post_)(inputs);
220  return outputs;
221  }
222 
223  private:
224  dag::Node *pre_ = nullptr;
225  infer::Infer *infer_ = nullptr;
226  dag::Node *post_ = nullptr;
227 };
228 
238  public:
239  ResnetGraph(const std::string &name) : dag::Graph(name) {
240  key_ = "nndeploy::classification::ResnetGraph";
241  desc_ =
242  "Resnet "
243  "graph[cv::Mat->preprocess->infer->postprocess->ClassificationResult]";
244  this->setInputTypeInfo<cv::Mat>();
245  this->setOutputTypeInfo<ClassificationResult>();
246  pre_ = dynamic_cast<preprocess::CvtResizeNormTrans *>(
247  this->createNode<preprocess::CvtResizeNormTrans>("preprocess"));
248  infer_ =
249  dynamic_cast<infer::Infer *>(this->createNode<infer::Infer>("infer"));
250  post_ = dynamic_cast<ClassificationPostProcess *>(
251  this->createNode<ClassificationPostProcess>("postprocess"));
252  }
253  ResnetGraph(const std::string &name, std::vector<dag::Edge *> inputs,
254  std::vector<dag::Edge *> outputs)
255  : dag::Graph(name, inputs, outputs) {
256  key_ = "nndeploy::classification::ClassificationGraph";
257  desc_ =
258  "Classification "
259  "graph[cv::Mat->preprocess->infer->postprocess->ClassificationResult]";
260  this->setInputTypeInfo<cv::Mat>();
261  this->setOutputTypeInfo<ClassificationResult>();
262  pre_ = dynamic_cast<preprocess::CvtResizeNormTrans *>(
263  this->createNode<preprocess::CvtResizeNormTrans>("preprocess"));
264  infer_ =
265  dynamic_cast<infer::Infer *>(this->createNode<infer::Infer>("infer"));
266  post_ = dynamic_cast<ClassificationPostProcess *>(
267  this->createNode<ClassificationPostProcess>("postprocess"));
268  }
269 
270  virtual ~ResnetGraph() {}
271 
274  dynamic_cast<preprocess::CvtResizeNormTransParam *>(pre_->getParam());
278  pre_param->h_ = 224;
279  pre_param->w_ = 224;
280  pre_param->mean_[0] = 0.485;
281  pre_param->mean_[1] = 0.456;
282  pre_param->mean_[2] = 0.406;
283  pre_param->std_[0] = 0.229;
284  pre_param->std_[1] = 0.224;
285  pre_param->std_[2] = 0.225;
286 
287  ClassificationPostParam *post_param =
288  dynamic_cast<ClassificationPostParam *>(post_->getParam());
289  post_param->topk_ = 1;
290 
291  return base::kStatusCodeOk;
292  }
293 
294  base::Status make(const dag::NodeDesc &pre_desc,
295  const dag::NodeDesc &infer_desc,
296  base::InferenceType inference_type,
297  const dag::NodeDesc &post_desc) {
298  this->setNodeDesc(pre_, pre_desc);
299  this->setNodeDesc(infer_, infer_desc);
300  this->setNodeDesc(post_, post_desc);
301  this->defaultParam();
302  base::Status status = infer_->setInferenceType(inference_type);
303  if (status != base::kStatusCodeOk) {
304  NNDEPLOY_LOGE("Failed to set inference type");
305  return status;
306  }
307  return base::kStatusCodeOk;
308  }
309 
311  base::Status status = infer_->setInferenceType(inference_type);
312  if (status != base::kStatusCodeOk) {
313  NNDEPLOY_LOGE("Failed to set inference type");
314  return status;
315  }
316  return base::kStatusCodeOk;
317  }
318 
320  base::ModelType model_type, bool is_path,
321  std::vector<std::string> &model_value) {
322  // auto infer = dynamic_cast<infer::Infer *>(infer_);
323  auto param = dynamic_cast<inference::InferenceParam *>(infer_->getParam());
324  param->device_type_ = device_type;
325  param->model_type_ = model_type;
326  param->is_path_ = is_path;
327  param->model_value_ = model_value;
328  return base::kStatusCodeOk;
329  }
330 
338  dynamic_cast<preprocess::CvtResizeNormTransParam *>(pre_->getParam());
339  param->src_pixel_type_ = pixel_type;
340  return base::kStatusCodeOk;
341  }
342 
343  base::Status setTopk(int topk) {
344  ClassificationPostParam *param =
345  dynamic_cast<ClassificationPostParam *>(post_->getParam());
346  param->topk_ = topk;
347  return base::kStatusCodeOk;
348  }
349 
350  base::Status setSoftmax(bool is_softmax) {
351  ClassificationPostParam *param =
352  dynamic_cast<ClassificationPostParam *>(post_->getParam());
353  param->is_softmax_ = is_softmax;
354  return base::kStatusCodeOk;
355  }
356 
357  std::vector<dag::Edge *> forward(std::vector<dag::Edge *> inputs) {
358  inputs = (*pre_)(inputs);
359  inputs = (*infer_)(inputs);
360  std::vector<dag::Edge *> outputs = (*post_)(inputs);
361  return outputs;
362  }
363 
364  private:
365  dag::Node *pre_ = nullptr;
366  infer::Infer *infer_ = nullptr;
367  dag::Node *post_ = nullptr;
368 };
369 
370 } // namespace classification
371 } // namespace nndeploy
372 
373 #endif /* _NNDEPLOY_CLASSIFICATION_CLASSIFICATION_H_ */
virtual base::Status deserialize(rapidjson::Value &json)
virtual std::string serialize()
Implementation of ResNet classification network graph structure.
virtual base::Status defaultParam()
Set default parameters.
base::Status setInferenceType(base::InferenceType inference_type)
base::Status setInferParam(base::DeviceType device_type, base::ModelType model_type, bool is_path, std::vector< std::string > &model_value)
base::Status setSoftmax(bool is_softmax)
std::vector< dag::Edge * > forward(std::vector< dag::Edge * > inputs)
base::Status make(const dag::NodeDesc &pre_desc, const dag::NodeDesc &infer_desc, base::InferenceType inference_type, const dag::NodeDesc &post_desc)
base::Status setSrcPixelType(base::PixelType pixel_type)
Set preprocessing parameters.
ClassificationGraph(const std::string &name, std::vector< dag::Edge * > inputs, std::vector< dag::Edge * > outputs)
virtual base::Status deserialize(rapidjson::Value &json)
virtual base::Status serialize(rapidjson::Value &json, rapidjson::Document::AllocatorType &allocator)
virtual base::Status run()
Run node (pure virtual function)
ClassificationPostProcess(const std::string &name, std::vector< dag::Edge * > inputs, std::vector< dag::Edge * > outputs)
Implementation of ResNet classification network graph structure.
base::Status setSoftmax(bool is_softmax)
base::Status setInferenceType(base::InferenceType inference_type)
virtual base::Status defaultParam()
Set default parameters.
std::vector< dag::Edge * > forward(std::vector< dag::Edge * > inputs)
ResnetGraph(const std::string &name)
base::Status setSrcPixelType(base::PixelType pixel_type)
Set preprocessing parameters.
base::Status make(const dag::NodeDesc &pre_desc, const dag::NodeDesc &infer_desc, base::InferenceType inference_type, const dag::NodeDesc &post_desc)
ResnetGraph(const std::string &name, std::vector< dag::Edge * > inputs, std::vector< dag::Edge * > outputs)
base::Status setInferParam(base::DeviceType device_type, base::ModelType model_type, bool is_path, std::vector< std::string > &model_value)
Directed Acyclic Graph Node.
Definition: graph.h:31
Node description class.
Definition: node.h:35
Node base class.
Definition: node.h:171
InferenceParam is the base class of all inference param.
#define NNDEPLOY_LOGE(fmt,...)
Definition: log.h:59
#define NNDEPLOY_CC_API
api
Definition: macro.h:29
@ kStatusCodeOk
Definition: status.h:13
@ kInterpTypeLinear
Definition: type.h:57
@ kPixelTypeBGR
Definition: type.h:15
@ kPixelTypeRGB
Definition: type.h:14