nndeploy C++ API  0.2.0
nndeploy C++ API
op.h
Go to the documentation of this file.
1 
2 #ifndef _NNDEPLOY_OP_OP_H_
3 #define _NNDEPLOY_OP_OP_H_
4 
5 #include "nndeploy/base/any.h"
6 #include "nndeploy/base/common.h"
8 #include "nndeploy/base/log.h"
9 #include "nndeploy/base/macro.h"
10 #include "nndeploy/base/object.h"
11 #include "nndeploy/base/param.h"
12 #include "nndeploy/base/shape.h"
13 #include "nndeploy/base/status.h"
14 #include "nndeploy/base/string.h"
16 #include "nndeploy/device/buffer.h"
17 #include "nndeploy/device/device.h"
19 #include "nndeploy/device/tensor.h"
20 #include "nndeploy/ir/ir.h"
21 
22 namespace nndeploy {
23 namespace op {
24 
43  public:
44  Op();
45 
46  virtual ~Op();
47 
48  base::Status setName(std::string name);
49  std::string getName();
50 
53 
54  virtual base::Status setParam(std::shared_ptr<base::Param> param);
55  virtual std::shared_ptr<base::Param> getParam();
56 
59 
60  void setStream(device::Stream *stream);
62 
73 
74  std::string getInputName(int index = 0);
75  std::string getOutputName(int index = 0);
76 
77  device::Tensor *getInput(int index = 0);
78  device::Tensor *getOutput(int index = 0);
79 
82 
83  base::Status setInput(device::Tensor *input, int index);
84  base::Status setOutput(device::Tensor *output, int index);
85 
86  base::Status setAllInputName(std::initializer_list<std::string>);
87  base::Status setAllOutputName(std::initializer_list<std::string>);
88 
89  base::Status setAllInputName(std::vector<std::string> &);
90  base::Status setAllOutputName(std::vector<std::string> &);
91 
92  std::vector<std::string> getAllInputName();
93  std::vector<std::string> getAllOutputName();
94 
95  std::vector<device::Tensor *> getAllInput();
96  std::vector<device::Tensor *> getAllOutput();
97 
98  device::Tensor *getInputTensor(const std::string &name);
99  device::Tensor *getOutputTensor(const std::string &name);
100  base::Status replaceInputTensor(const std::string &name,
101  device::Tensor *tensor);
102  base::Status replaceOutputTensor(const std::string &name,
103  device::Tensor *tensor);
104 
106 
107  base::Status setAllInput(std::vector<device::Tensor *> inputs);
108  base::Status setAllOutput(std::vector<device::Tensor *> outputs);
109 
111 
114 
115  void setInnerFlag(bool flag);
116 
117  void setInitializedFlag(bool flag);
119 
120  void setTimeProfileFlag(bool flag);
122 
123  void setDebugFlag(bool flag);
124  bool getDebugFlag();
125 
126  void setRunningFlag(bool flag);
127  bool isRunning();
128 
150 
159  virtual base::Status init();
160  virtual base::Status deinit();
161 
168  virtual base::Status reshape(base::ShapeMap &shape_map);
169 
170  virtual base::Status preRun();
176  virtual uint64_t getWorkspaceSize();
177  virtual void setWorkspace(void *workspace);
184  virtual uint64_t getFlops();
185 
194 
195  virtual base::Status run() = 0;
197 
198  protected:
204 
213  bool is_external_stream_ = false;
214  device::Stream *stream_ = nullptr;
215 
224 
237  std::vector<device::Tensor *> inputs_;
256  std::vector<device::Tensor *> outputs_;
257 
263  bool workspace_is_external_ = false; // workspace是否是外部传入
264  uint64_t workspace_size_ = 0; // workspace大小
265  void *workspace_ = nullptr; // op的workspace
266  uint64_t flops_ = 0; // op的flops
267 
268  // 是否是图中内部节点
269  bool is_inner_ = false;
270  // 并行类型
272  // 是否 可以是inplace op
273  bool is_inplace_ = false;
274  // 参数&输入是否发生变化
275  bool is_changed_ = false;
276 
277  bool constructed_ = false;
278  bool initialized_ = false;
279  bool is_running_ = false;
280  bool is_time_profile_ = false;
281  bool is_debug_ = false;
282 };
283 
289  public:
290  virtual ~OpCreator(){};
291 
292  virtual Op *createOp(base::DeviceType device_type, const std::string &name,
293  ir::OpType op_type, std::vector<std::string> &inputs,
294  std::vector<std::string> &outputs) = 0;
295 
296  virtual std::shared_ptr<Op> createOpSharedPtr(
297  base::DeviceType device_type, const std::string &name, ir::OpType op_type,
298  std::vector<std::string> &inputs, std::vector<std::string> &outputs) = 0;
299 };
300 
306 template <typename T>
307 class TypeOpCreator : public OpCreator {
308  virtual Op *createOp(base::DeviceType device_type, const std::string &name,
309  ir::OpType op_type, std::vector<std::string> &inputs,
310  std::vector<std::string> &outputs) {
311  auto op = new T();
312  op->setDeviceType(device_type);
313  op->setName(name);
314  op->setOpType(op_type);
315  op->setAllInputName(inputs);
316  op->setAllOutputName(outputs);
317  return op;
318  }
319 
320  virtual std::shared_ptr<Op> createOpSharedPtr(
321  base::DeviceType device_type, const std::string &name, ir::OpType op_type,
322  std::vector<std::string> &inputs, std::vector<std::string> &outputs) {
323  auto op = std::make_shared<T>();
324  op->setDeviceType(device_type);
325  op->setName(name);
326  op->setOpType(op_type);
327  op->setAllInputName(inputs);
328  op->setAllOutputName(outputs);
329  return op;
330  }
331 };
332 
339 extern NNDEPLOY_CC_API std::map<base::DeviceTypeCode, std::map<ir::OpType, std::shared_ptr<OpCreator>>>
341 
347 template <typename T>
349  public:
350  explicit TypeOpRegister(base::DeviceTypeCode device_type_code,
351  ir::OpType op_type) {
352  getGlobalOpCreatorMap()[device_type_code][op_type] =
353  std::shared_ptr<T>(new T());
354  }
355 };
356 
357 extern NNDEPLOY_CC_API Op *createOp(base::DeviceType device_type, const std::string &name,
358  ir::OpType op_type);
359 
360 extern NNDEPLOY_CC_API Op *createOp(base::DeviceType device_type, const std::string &name,
361  ir::OpType op_type, std::initializer_list<std::string> inputs,
362  std::initializer_list<std::string> outputs);
363 
364 extern NNDEPLOY_CC_API Op *createOp(base::DeviceType device_type, const std::string &name,
365  ir::OpType op_type, std::vector<std::string> &inputs,
366  std::vector<std::string> &outputs);
367 
368 extern NNDEPLOY_CC_API Op *createOp(base::DeviceType device_type, const std::string &name,
369  ir::OpType op_type, std::vector<std::string> &inputs,
370  std::vector<std::string> &outputs,
371  std::shared_ptr<base::Param> param);
372 
373 extern NNDEPLOY_CC_API Op *createOp(base::DeviceType device_type, std::shared_ptr<ir::OpDesc> op_desc);
374 
375 extern NNDEPLOY_CC_API std::shared_ptr<Op> createOpSharedPtr(base::DeviceType device_type,
376  const std::string &name,
377  ir::OpType op_type);
378 
379 extern NNDEPLOY_CC_API std::shared_ptr<Op> createOpSharedPtr(
380  base::DeviceType device_type, const std::string &name, ir::OpType op_type,
381  std::initializer_list<std::string> inputs,
382  std::initializer_list<std::string> outputs);
383 
384 extern NNDEPLOY_CC_API std::shared_ptr<Op> createOpSharedPtr(base::DeviceType device_type,
385  const std::string &name,
386  ir::OpType op_type,
387  std::vector<std::string> &inputs,
388  std::vector<std::string> &outputs);
389 
390 extern NNDEPLOY_CC_API std::shared_ptr<Op> createOpSharedPtr(base::DeviceType device_type,
391  const std::string &name,
392  ir::OpType op_type,
393  std::vector<std::string> &inputs,
394  std::vector<std::string> &outputs,
395  std::shared_ptr<base::Param> param);
396 
397 extern NNDEPLOY_CC_API std::shared_ptr<Op> createOpSharedPtr(base::DeviceType device_type,
398  std::shared_ptr<ir::OpDesc> op_desc);
399 
400 using SISOOpFunc =
401  std::function<base::Status(device::Tensor *input, device::Tensor *output,
402  std::shared_ptr<base::Param> op_param)>;
403 
404 using SIMOOpFunc = std::function<base::Status(
405  device::Tensor *input, std::initializer_list<device::Tensor *> outputs,
406  std::shared_ptr<base::Param> op_param)>;
407 
408 using MISOOpFunc = std::function<base::Status(
409  std::initializer_list<device::Tensor *> inputs, device::Tensor *output,
410  std::shared_ptr<base::Param> op_param)>;
411 
412 using MIMOOpFunc =
413  std::function<base::Status(std::initializer_list<device::Tensor *> inputs,
414  std::initializer_list<device::Tensor *> outputs,
415  std::shared_ptr<base::Param> op_param)>;
416 
417 using namespace base;
418 using namespace ir;
419 #define REGISTER_OP_IMPLEMENTION(device_type_code, op_type, op_class) \
420  TypeOpRegister<TypeOpCreator<op_class>> \
421  g_##device_type_code##op_class##_register(device_type_code, op_type);
422 
423 } // namespace op
424 } // namespace nndeploy
425 
426 #endif
参照并扩充了onnx的格式,描述算子的基本信息
Definition: ir.h:34
Op的创建类
Definition: op.h:288
virtual std::shared_ptr< Op > createOpSharedPtr(base::DeviceType device_type, const std::string &name, ir::OpType op_type, std::vector< std::string > &inputs, std::vector< std::string > &outputs)=0
virtual Op * createOp(base::DeviceType device_type, const std::string &name, ir::OpType op_type, std::vector< std::string > &inputs, std::vector< std::string > &outputs)=0
virtual ~OpCreator()
Definition: op.h:290
Op的基类
Definition: op.h:42
virtual base::Status reshape(base::ShapeMap &shape_map)
重新推理形状,通常在初始化之后、preRun前调用
base::Status setInput(device::Tensor *input, int index)
base::Status setAllInput(std::vector< device::Tensor * > inputs)
virtual base::Status deinit()
std::vector< device::Tensor * > getAllOutput()
device::Tensor * getOutput(int index=0)
base::Status replaceInputTensor(const std::string &name, device::Tensor *tensor)
std::vector< device::Tensor * > getAllInput()
base::Status replaceOutputTensor(const std::string &name, device::Tensor *tensor)
bool getTimeProfileFlag()
std::vector< std::string > getAllOutputName()
virtual base::Status inferDataFormat()
数据格式推理
std::vector< std::string > getAllInputName()
std::string getName()
virtual base::Status preRun()
device::Tensor * getInput(int index=0)
virtual void setWorkspace(void *workspace)
base::Status setAllOutput(std::vector< device::Tensor * > outputs)
virtual base::Status init()
初始化
base::Status setAllOutputName(std::vector< std::string > &)
base::Status setName(std::string name)
virtual uint64_t getWorkspaceSize()
得到op的workspace大小 note: op在运行时的workspace大小,在输入确定后调用 eg:例如Conv,当存在padding时,需要分配额外的内存,存放padding后的内存
virtual base::Status inferDataType()
类型推理
virtual base::Status allocateWorkspace()
base::Status setAllInputName(std::vector< std::string > &)
ir::OpDesc op_desc_
op的描述 包含op的类型、名称、输入名称、输出名称、参数
Definition: op.h:203
virtual base::Status inferShape()
形状推理
virtual base::Status checkOrAllocOutput()
检查输出tensor
device::Tensor * getInputTensor(const std::string &name)
virtual base::Status setInput(device::Tensor *input)
base::DeviceType device_type_
op的设备类型
Definition: op.h:208
virtual base::Status postRun()
device::Stream * getStream()
virtual uint64_t getFlops()
得到op的flops
void setInitializedFlag(bool flag)
base::Status setAllOutputName(std::initializer_list< std::string >)
base::DeviceType getDeviceType()
virtual base::Status setOutput(device::Tensor *output)
void setInnerFlag(bool flag)
void setTimeProfileFlag(bool flag)
base::PrecisionType getPrecisionType()
base::ParallelType getParallelType()
void setStream(device::Stream *stream)
void setDebugFlag(bool flag)
base::Status rmInput(device::Tensor *tensor)
virtual base::Status run()=0
base::Status setAllInputName(std::initializer_list< std::string >)
std::string getOutputName(int index=0)
std::string getInputName(int index=0)
base::Status setOpType(ir::OpType op_type)
std::vector< device::Tensor * > inputs_
op的输入tensor note: 当权重为tensor时,权重tensor也会在这里 eg:
Definition: op.h:237
void setRunningFlag(bool flag)
base::Status setDeviceType(base::DeviceType device_type)
base::Status setOutput(device::Tensor *output, int index)
device::Tensor * getOutputTensor(const std::string &name)
ir::OpType getOpType()
virtual std::shared_ptr< base::Param > getParam()
virtual base::Status setParam(std::shared_ptr< base::Param > param)
base::Status setParallelType(const base::ParallelType &paralle_type)
virtual base::Status setPrecisionType(base::PrecisionType precision_type)
设置精度类型 精度不同,计算方式不同,内存分配不同
std::vector< device::Tensor * > outputs_
op的输出tensor
Definition: op.h:256
Op的创建类模板
Definition: op.h:307
Op的创建类的注册类模板
Definition: op.h:348
TypeOpRegister(base::DeviceTypeCode device_type_code, ir::OpType op_type)
Definition: op.h:350
#define NNDEPLOY_CC_API
api
Definition: macro.h:29
@ kPrecisionTypeFp32
Definition: common.h:170
std::map< std::string, std::vector< int > > ShapeMap
Definition: common.h:381
@ kParallelTypeNone
Definition: common.h:354
OpType
算子类型 算子分类
Definition: op_param.h:65
std::function< base::Status(device::Tensor *input, std::initializer_list< device::Tensor * > outputs, std::shared_ptr< base::Param > op_param)> SIMOOpFunc
Definition: op.h:406
std::function< base::Status(device::Tensor *input, device::Tensor *output, std::shared_ptr< base::Param > op_param)> SISOOpFunc
Definition: op.h:402
std::function< base::Status(std::initializer_list< device::Tensor * > inputs, std::initializer_list< device::Tensor * > outputs, std::shared_ptr< base::Param > op_param)> MIMOOpFunc
Definition: op.h:415
Op * createOp(base::DeviceType device_type, const std::string &name, ir::OpType op_type)
std::function< base::Status(std::initializer_list< device::Tensor * > inputs, device::Tensor *output, std::shared_ptr< base::Param > op_param)> MISOOpFunc
Definition: op.h:410
std::shared_ptr< Op > createOpSharedPtr(base::DeviceType device_type, const std::string &name, ir::OpType op_type)
std::map< base::DeviceTypeCode, std::map< ir::OpType, std::shared_ptr< OpCreator > > > & getGlobalOpCreatorMap()
Get the Global Op Creator Map object.