2 #ifndef _NNDEPLOY_OP_OP_H_
3 #define _NNDEPLOY_OP_OP_H_
55 virtual std::shared_ptr<base::Param>
getParam();
213 bool is_external_stream_ =
false;
263 bool workspace_is_external_ =
false;
264 uint64_t workspace_size_ = 0;
265 void *workspace_ =
nullptr;
269 bool is_inner_ =
false;
273 bool is_inplace_ =
false;
275 bool is_changed_ =
false;
277 bool constructed_ =
false;
278 bool initialized_ =
false;
279 bool is_running_ =
false;
280 bool is_time_profile_ =
false;
281 bool is_debug_ =
false;
293 ir::OpType op_type, std::vector<std::string> &inputs,
294 std::vector<std::string> &outputs) = 0;
298 std::vector<std::string> &inputs, std::vector<std::string> &outputs) = 0;
306 template <
typename T>
309 ir::OpType op_type, std::vector<std::string> &inputs,
310 std::vector<std::string> &outputs) {
312 op->setDeviceType(device_type);
314 op->setOpType(op_type);
315 op->setAllInputName(inputs);
316 op->setAllOutputName(outputs);
320 virtual std::shared_ptr<Op> createOpSharedPtr(
322 std::vector<std::string> &inputs, std::vector<std::string> &outputs) {
323 auto op = std::make_shared<T>();
324 op->setDeviceType(device_type);
326 op->setOpType(op_type);
327 op->setAllInputName(inputs);
328 op->setAllOutputName(outputs);
339 extern NNDEPLOY_CC_API std::map<base::DeviceTypeCode, std::map<ir::OpType, std::shared_ptr<OpCreator>>>
347 template <
typename T>
353 std::shared_ptr<T>(
new T());
361 ir::OpType op_type, std::initializer_list<std::string> inputs,
362 std::initializer_list<std::string> outputs);
365 ir::OpType op_type, std::vector<std::string> &inputs,
366 std::vector<std::string> &outputs);
369 ir::OpType op_type, std::vector<std::string> &inputs,
370 std::vector<std::string> &outputs,
371 std::shared_ptr<base::Param> param);
376 const std::string &name,
381 std::initializer_list<std::string> inputs,
382 std::initializer_list<std::string> outputs);
385 const std::string &name,
387 std::vector<std::string> &inputs,
388 std::vector<std::string> &outputs);
391 const std::string &name,
393 std::vector<std::string> &inputs,
394 std::vector<std::string> &outputs,
395 std::shared_ptr<base::Param> param);
398 std::shared_ptr<ir::OpDesc> op_desc);
402 std::shared_ptr<base::Param> op_param)>;
405 device::Tensor *input, std::initializer_list<device::Tensor *> outputs,
406 std::shared_ptr<base::Param> op_param)>;
409 std::initializer_list<device::Tensor *> inputs,
device::Tensor *output,
410 std::shared_ptr<base::Param> op_param)>;
413 std::function<
base::Status(std::initializer_list<device::Tensor *> inputs,
414 std::initializer_list<device::Tensor *> outputs,
415 std::shared_ptr<base::Param> op_param)>;
417 using namespace base;
419 #define REGISTER_OP_IMPLEMENTION(device_type_code, op_type, op_class) \
420 TypeOpRegister<TypeOpCreator<op_class>> \
421 g_##device_type_code##op_class##_register(device_type_code, op_type);
virtual std::shared_ptr< Op > createOpSharedPtr(base::DeviceType device_type, const std::string &name, ir::OpType op_type, std::vector< std::string > &inputs, std::vector< std::string > &outputs)=0
virtual Op * createOp(base::DeviceType device_type, const std::string &name, ir::OpType op_type, std::vector< std::string > &inputs, std::vector< std::string > &outputs)=0
virtual base::Status reshape(base::ShapeMap &shape_map)
重新推理形状,通常在初始化之后、preRun前调用
base::Status setInput(device::Tensor *input, int index)
base::Status setAllInput(std::vector< device::Tensor * > inputs)
virtual base::Status deinit()
std::vector< device::Tensor * > getAllOutput()
device::Tensor * getOutput(int index=0)
base::Status replaceInputTensor(const std::string &name, device::Tensor *tensor)
std::vector< device::Tensor * > getAllInput()
base::Status replaceOutputTensor(const std::string &name, device::Tensor *tensor)
bool getTimeProfileFlag()
std::vector< std::string > getAllOutputName()
virtual base::Status inferDataFormat()
数据格式推理
std::vector< std::string > getAllInputName()
virtual base::Status preRun()
device::Tensor * getInput(int index=0)
virtual void setWorkspace(void *workspace)
base::Status setAllOutput(std::vector< device::Tensor * > outputs)
virtual base::Status init()
初始化
base::Status setAllOutputName(std::vector< std::string > &)
base::Status setName(std::string name)
virtual uint64_t getWorkspaceSize()
得到op的workspace大小 note: op在运行时的workspace大小,在输入确定后调用 eg:例如Conv,当存在padding时,需要分配额外的内存,存放padding后的内存
virtual base::Status inferDataType()
类型推理
virtual base::Status allocateWorkspace()
base::Status setAllInputName(std::vector< std::string > &)
ir::OpDesc op_desc_
op的描述 包含op的类型、名称、输入名称、输出名称、参数
virtual base::Status inferShape()
形状推理
virtual base::Status checkOrAllocOutput()
检查输出tensor
device::Tensor * getInputTensor(const std::string &name)
virtual base::Status setInput(device::Tensor *input)
base::DeviceType device_type_
op的设备类型
virtual base::Status postRun()
device::Stream * getStream()
virtual uint64_t getFlops()
得到op的flops
void setInitializedFlag(bool flag)
base::Status setAllOutputName(std::initializer_list< std::string >)
base::DeviceType getDeviceType()
virtual base::Status setOutput(device::Tensor *output)
void setInnerFlag(bool flag)
void setTimeProfileFlag(bool flag)
base::PrecisionType getPrecisionType()
base::ParallelType getParallelType()
void setStream(device::Stream *stream)
void setDebugFlag(bool flag)
base::Status rmInput(device::Tensor *tensor)
virtual base::Status run()=0
base::Status setAllInputName(std::initializer_list< std::string >)
std::string getOutputName(int index=0)
std::string getInputName(int index=0)
base::Status setOpType(ir::OpType op_type)
std::vector< device::Tensor * > inputs_
op的输入tensor note: 当权重为tensor时,权重tensor也会在这里 eg:
void setRunningFlag(bool flag)
base::Status setDeviceType(base::DeviceType device_type)
base::Status setOutput(device::Tensor *output, int index)
device::Tensor * getOutputTensor(const std::string &name)
virtual std::shared_ptr< base::Param > getParam()
virtual base::Status setParam(std::shared_ptr< base::Param > param)
base::Status setParallelType(const base::ParallelType ¶lle_type)
virtual base::Status setPrecisionType(base::PrecisionType precision_type)
设置精度类型 精度不同,计算方式不同,内存分配不同
std::vector< device::Tensor * > outputs_
op的输出tensor
TypeOpRegister(base::DeviceTypeCode device_type_code, ir::OpType op_type)
#define NNDEPLOY_CC_API
api
std::map< std::string, std::vector< int > > ShapeMap
std::function< base::Status(device::Tensor *input, std::initializer_list< device::Tensor * > outputs, std::shared_ptr< base::Param > op_param)> SIMOOpFunc
std::function< base::Status(device::Tensor *input, device::Tensor *output, std::shared_ptr< base::Param > op_param)> SISOOpFunc
std::function< base::Status(std::initializer_list< device::Tensor * > inputs, std::initializer_list< device::Tensor * > outputs, std::shared_ptr< base::Param > op_param)> MIMOOpFunc
Op * createOp(base::DeviceType device_type, const std::string &name, ir::OpType op_type)
std::function< base::Status(std::initializer_list< device::Tensor * > inputs, device::Tensor *output, std::shared_ptr< base::Param > op_param)> MISOOpFunc
std::shared_ptr< Op > createOpSharedPtr(base::DeviceType device_type, const std::string &name, ir::OpType op_type)
std::map< base::DeviceTypeCode, std::map< ir::OpType, std::shared_ptr< OpCreator > > > & getGlobalOpCreatorMap()
Get the Global Op Creator Map object.