nndeploy C++ API  0.2.0
nndeploy C++ API
super_resolution.h
Go to the documentation of this file.
1 
2 #ifndef _NNDEPLOY_SUPER_RESOLUTION_SUPER_RESOLUTION_H_
3 #define _NNDEPLOY_SUPER_RESOLUTION_SUPER_RESOLUTION_H_
4 
5 #include "nndeploy/base/any.h"
6 #include "nndeploy/base/common.h"
8 #include "nndeploy/base/log.h"
9 #include "nndeploy/base/macro.h"
10 #include "nndeploy/base/object.h"
12 #include "nndeploy/base/param.h"
13 #include "nndeploy/base/status.h"
14 #include "nndeploy/base/string.h"
15 #include "nndeploy/dag/edge.h"
16 #include "nndeploy/dag/graph.h"
17 #include "nndeploy/dag/node.h"
18 #include "nndeploy/device/buffer.h"
19 #include "nndeploy/device/device.h"
21 #include "nndeploy/device/tensor.h"
22 #include "nndeploy/infer/infer.h"
26 
27 namespace nndeploy {
28 namespace super_resolution {
29 
31  public:
32  SuperResolutionPostProcess(const std::string &name) : dag::Node(name) {
33  key_ = "nndeploy::super_resolution::SuperResolutionPostProcess";
34  this->setInputTypeInfo<device::Tensor>();
35  this->setOutputTypeInfo<std::vector<cv::Mat>>();
36  }
37  SuperResolutionPostProcess(const std::string &name,
38  std::vector<dag::Edge *> inputs,
39  std::vector<dag::Edge *> outputs)
40  : dag::Node(name, inputs, outputs) {
41  key_ = "nndeploy::super_resolution::SuperResolutionPostProcess";
42  this->setInputTypeInfo<device::Tensor>();
43  this->setOutputTypeInfo<std::vector<cv::Mat>>();
44  }
46 
47  virtual base::Status run();
48 };
49 
59  public:
60  SuperResolutionGraph(const std::string &name) : dag::Graph(name) {
61  key_ = "nndeploy::super_resolution::SuperResolutionGraph";
62  this->setInputTypeInfo<cv::Mat>();
63  this->setOutputTypeInfo<std::vector<cv::Mat>>();
64  }
65  SuperResolutionGraph(const std::string &name,
66  std::vector<dag::Edge *> inputs,
67  std::vector<dag::Edge *> outputs)
68  : dag::Graph(name, inputs, outputs) {
69  key_ = "nndeploy::super_resolution::SuperResolutionGraph";
70  this->setInputTypeInfo<cv::Mat>();
71  this->setOutputTypeInfo<std::vector<cv::Mat>>();
72  }
73 
74  virtual ~SuperResolutionGraph() {}
75 
76  base::Status make(const dag::NodeDesc &pre_desc,
77  const dag::NodeDesc &infer_desc,
78  base::InferenceType inference_type,
79  const dag::NodeDesc &post_desc) {
80  // Create preprocessing node for image preprocessing
81  pre_ = (preprocess::BatchPreprocess *)this->createNode<preprocess::BatchPreprocess>(pre_desc);
82  if (pre_ == nullptr) {
83  NNDEPLOY_LOGE("Failed to create preprocessing node");
85  }
86  pre_->setNodeKey("nndeploy::preprocess::CvtNormTrans");
88  dynamic_cast<preprocess::CvtNormTransParam *>(pre_->getParam());
91  pre_param->mean_[0] = 0.0;
92  pre_param->mean_[1] = 0.0;
93  pre_param->mean_[2] = 0.0;
94  pre_param->std_[0] = 1.0;
95  pre_param->std_[1] = 1.0;
96  pre_param->std_[2] = 1.0;
97 
98  // Create inference node for ResNet model execution
99  infer_ = dynamic_cast<infer::Infer *>(
100  this->createNode<infer::Infer>(infer_desc));
101  if (infer_ == nullptr) {
102  NNDEPLOY_LOGE("Failed to create inference node");
104  }
105  infer_->setInferenceType(inference_type);
106 
107  // Create postprocessing node for SuperResolution results
108  post_ = this->createNode<SuperResolutionPostProcess>(post_desc);
109  if (post_ == nullptr) {
110  NNDEPLOY_LOGE("Failed to create postprocessing node");
112  }
113 
114  return base::kStatusCodeOk;
115  }
116 
118  // Create preprocessing node for image preprocessing
119  pre_ = (preprocess::BatchPreprocess *)this->createNode<preprocess::BatchPreprocess>(
120  "preprocess::BatchPreprocess");
121  if (pre_ == nullptr) {
122  NNDEPLOY_LOGE("Failed to create preprocessing node");
124  }
125  pre_->setGraph(this);
126  pre_->setNodeKey("nndeploy::preprocess::CvtNormTrans");
127  preprocess::CvtNormTransParam *pre_param =
128  dynamic_cast<preprocess::CvtNormTransParam *>(pre_->getParam());
129  if (pre_param == nullptr) {
130  NNDEPLOY_LOGE("Failed to get preprocessing node parameter.\n");
132  }
135  pre_param->mean_[0] = 0.485;
136  pre_param->mean_[1] = 0.456;
137  pre_param->mean_[2] = 0.406;
138  pre_param->std_[0] = 0.229;
139  pre_param->std_[1] = 0.224;
140  pre_param->std_[2] = 0.225;
141 
142  // Create inference node for ResNet model execution
143  infer_ = dynamic_cast<infer::Infer *>(
144  this->createNode<infer::Infer>("infer::Infer"));
145  if (infer_ == nullptr) {
146  NNDEPLOY_LOGE("Failed to create inference node");
148  }
149  infer_->setGraph(this);
150  infer_->setInferenceType(inference_type);
151 
152  // Create postprocessing node for SuperResolution results
153  post_ = this->createNode<SuperResolutionPostProcess>(
154  "SuperResolutionPostProcess");
155  if (post_ == nullptr) {
156  NNDEPLOY_LOGE("Failed to create postprocessing node");
158  }
159  post_->setGraph(this);
160 
161  return base::kStatusCodeOk;
162  }
163 
165  base::ModelType model_type, bool is_path,
166  std::vector<std::string> &model_value) {
167  // auto infer = dynamic_cast<infer::Infer *>(infer_);
168  auto param = dynamic_cast<inference::InferenceParam *>(infer_->getParam());
169  param->device_type_ = device_type;
170  param->model_type_ = model_type;
171  param->is_path_ = is_path;
172  param->model_value_ = model_value;
173  return base::kStatusCodeOk;
174  }
175 
183  dynamic_cast<preprocess::CvtNormTransParam *>(pre_->getParam());
184  param->src_pixel_type_ = pixel_type;
185  return base::kStatusCodeOk;
186  }
187 
188  std::vector<dag::Edge *> forward(std::vector<dag::Edge *> inputs) {
189  inputs = (*pre_)(inputs);
190  inputs = (*infer_)(inputs);
191  std::vector<dag::Edge *> outputs = (*post_)(inputs);
192  return outputs;
193  }
194 
195  private:
196  preprocess::BatchPreprocess *pre_ = nullptr;
197  infer::Infer *infer_ = nullptr;
198  dag::Node *post_ = nullptr;
199 };
200 
201 } // namespace SuperResolution
202 } // namespace nndeploy
203 
204 #endif /* _NNDEPLOY_SUPER_RESOLUTION_SUPER_RESOLUTION_H_ */
Directed Acyclic Graph Node.
Definition: graph.h:31
Node description class.
Definition: node.h:35
Node base class.
Definition: node.h:171
InferenceParam is the base class of all inference param.
Implementation of ResNet SuperResolution network graph structure.
base::Status setInferParam(base::DeviceType device_type, base::ModelType model_type, bool is_path, std::vector< std::string > &model_value)
SuperResolutionGraph(const std::string &name, std::vector< dag::Edge * > inputs, std::vector< dag::Edge * > outputs)
base::Status setSrcPixelType(base::PixelType pixel_type)
Set preprocessing parameters.
base::Status make(const dag::NodeDesc &pre_desc, const dag::NodeDesc &infer_desc, base::InferenceType inference_type, const dag::NodeDesc &post_desc)
std::vector< dag::Edge * > forward(std::vector< dag::Edge * > inputs)
base::Status make(base::InferenceType inference_type)
SuperResolutionPostProcess(const std::string &name, std::vector< dag::Edge * > inputs, std::vector< dag::Edge * > outputs)
virtual base::Status run()
Run node (pure virtual function)
#define NNDEPLOY_LOGE(fmt,...)
Definition: log.h:59
#define NNDEPLOY_CC_API
api
Definition: macro.h:29
@ kStatusCodeOk
Definition: status.h:13
@ kStatusCodeErrorInvalidParam
Definition: status.h:21
@ kPixelTypeBGR
Definition: type.h:15
@ kPixelTypeRGB
Definition: type.h:14