nndeploy C++ API  0.2.0
nndeploy C++ API
optimizer.h
Go to the documentation of this file.
1 
2 #ifndef _NNDEPLOY_NET_OPTIMIZER_H_
3 #define _NNDEPLOY_NET_OPTIMIZER_H_
4 
5 #include <unordered_map>
6 #include <unordered_set>
7 
8 #include "nndeploy/ir/ir.h"
9 #include "nndeploy/net/util.h"
10 #include "nndeploy/op/op.h"
11 
12 namespace nndeploy {
13 namespace net {
14 // help types
15 using OpSet = std::unordered_set<ir::OpType>;
16 
17 enum OptPassType : int {
18  // Operator fusion
23 
24  // Eliminate useless op
27 
28  // Constant Folding
30 
31  // QDQ fuse
33 };
34 
35 
36 class Net;
37 
38 class OptPass {
39  public:
40  OptPass(std::string name);
41  virtual ~OptPass();
42 
56  virtual int seqPatternMatch(std::vector<TensorWrapper*>& tensor_repository,
57  std::vector<OpWrapper*>& op_repository,
58  const std::vector<ir::OpType>& types,
59  int begin_op_index);
60 
70  virtual int seqPatternMatch(std::vector<TensorWrapper*>& tensor_repository,
71  std::vector<OpWrapper*>& op_repository,
72  const std::vector<OpSet>& types,
73  std::vector<ir::OpType>& matched_types,
74  int begin_op_index);
75 
90  std::vector<TensorWrapper*>& tensor_repository,
91  std::vector<OpWrapper*>& op_repository,
92  const std::vector<ir::OpType>& types, int begin_op_index);
93 
95  std::vector<TensorWrapper*>& tensor_repository,
96  std::vector<OpWrapper*>& op_repository,
97  const std::vector<ir::OpType>& types, int begin_op_index);
98 
103 
108 
114  OpWrapper* op_wrapper, std::vector<TensorWrapper*>& tensor_repository);
115 
121  OpWrapper* op_wrapper, std::vector<TensorWrapper*>& tensor_repository);
122 
123  std::string getName();
125 
126  virtual base::Status optimize(std::vector<TensorWrapper*>& tensor_repository,
127  std::vector<OpWrapper*>& op_repository,
128  int begin_op_index) = 0;
129 
130  protected:
131  std::string name_; // pass名称
132 
134  nullptr; // 该pass所属的Net,可能要修改Net内部的数据,例如释放某些tensor
135 };
136 
142  public:
143  virtual ~OptPassCreator() {};
144 
145  virtual std::shared_ptr<OptPass> createOptPass() = 0;
146 };
147 
153 template <typename T>
155  virtual std::shared_ptr<OptPass> createOptPass() {
156  return std::shared_ptr<T>(new T());
157  }
158 };
159 
165 std::map<base::DeviceTypeCode,
166  std::map<int, std::map<OptPassType, std::shared_ptr<OptPassCreator>>>>&
168 
174 template <typename T>
176  public:
181  explicit TypeOptPassRegister(base::DeviceTypeCode device_type_code,
182  OptPassType type, int level) {
183  auto& creator_map = getGlobalOptPassCreatorMap();
184  auto device_map = creator_map.find(device_type_code);
185  if (device_map == creator_map.end()) {
186  creator_map[device_type_code] =
187  std::map<int,
188  std::map<OptPassType, std::shared_ptr<OptPassCreator>>>();
189  }
190  auto level_map = creator_map[device_type_code].find(level);
191  if (level_map == creator_map[device_type_code].end()) {
192  creator_map[device_type_code][level] =
193  std::map<OptPassType, std::shared_ptr<OptPassCreator>>();
194  }
195 
196  auto creator = creator_map[device_type_code][level].find(type);
197  if (creator == creator_map[device_type_code][level].end()) {
198  creator_map[device_type_code][level][type] = std::shared_ptr<T>(new T());
199  }
200  }
201 };
202 
203 std::shared_ptr<OptPass> createOptPass(base::DeviceType device_type, int level,
204  OptPassType type);
205 
207  public:
210 
212  std::set<OptPassType> enable_pass,
213  std::set<OptPassType> disable_pass);
215 
216  base::Status addPass(OptPassType type, int level);
218 
219  base::Status optimize(std::vector<TensorWrapper*>& tensor_repository,
220  std::vector<OpWrapper*>& op_repository, Net* net);
221 
222  protected:
224  std::map<int, std::map<OptPassType, std::shared_ptr<OptPass>>>
225  opt_passes_; // 第一个key是优先级,数字越小, 优先级越高,
226  // 在图优化时首先执行这个pass
227 };
228 
230 extern NNDEPLOY_CC_API OptPassType stringToOptPassType(const std::string &src);
231 
232 
233 } // namespace net
234 } // namespace nndeploy
235 
236 #endif /* _NNDEPLOY_NET_OPTIMIZER_H_ */
OptPass的创建类
Definition: optimizer.h:141
virtual std::shared_ptr< OptPass > createOptPass()=0
virtual base::Status rmOutputTensorAndMaybeDelete(OpWrapper *op_wrapper, std::vector< TensorWrapper * > &tensor_repository)
处理一个Op的输出Tensor 将该Op从Tensor的生产者中删除,如果该Tensor的生产者仅有这一个Op作为生产者,则释放该Tensor
std::string getName()
virtual base::Status optimize(std::vector< TensorWrapper * > &tensor_repository, std::vector< OpWrapper * > &op_repository, int begin_op_index)=0
virtual base::Status rmOpFromSuccessors(OpWrapper *op_wrapper)
将一个Op从它后继的前驱中删除
virtual base::Status rmOpFromPredecessor(OpWrapper *op_wrapper)
将一个Op从它前驱的后继中删除
virtual base::Status seqPatternMatchUpateTensorRepository(std::vector< TensorWrapper * > &tensor_repository, std::vector< OpWrapper * > &op_repository, const std::vector< ir::OpType > &types, int begin_op_index)
模式匹配并更新tensor_repository
virtual base::Status rmInputTensorAndMaybeDelete(OpWrapper *op_wrapper, std::vector< TensorWrapper * > &tensor_repository)
处理一个Op的输入Tensor 将该Op从Tensor的消费者中删除,如果该Tensor的消费者仅有这一个Op作为消费者,则释放该Tensor
virtual int seqPatternMatch(std::vector< TensorWrapper * > &tensor_repository, std::vector< OpWrapper * > &op_repository, const std::vector< OpSet > &types, std::vector< ir::OpType > &matched_types, int begin_op_index)
virtual base::Status seqPatternMatchUpateOpRepository(std::vector< TensorWrapper * > &tensor_repository, std::vector< OpWrapper * > &op_repository, const std::vector< ir::OpType > &types, int begin_op_index)
virtual int seqPatternMatch(std::vector< TensorWrapper * > &tensor_repository, std::vector< OpWrapper * > &op_repository, const std::vector< ir::OpType > &types, int begin_op_index)
模式匹配
base::Status setNet(Net *net)
OptPass(std::string name)
std::map< int, std::map< OptPassType, std::shared_ptr< OptPass > > > opt_passes_
Definition: optimizer.h:225
base::Status removePass(OptPassType type)
base::Status optimize(std::vector< TensorWrapper * > &tensor_repository, std::vector< OpWrapper * > &op_repository, Net *net)
base::Status addPass(OptPassType type, int level)
base::Status init(base::DeviceType device_type, std::set< OptPassType > enable_pass, std::set< OptPassType > disable_pass)
base::DeviceType device_type_
Definition: optimizer.h:223
OptPass的创建类模板
Definition: optimizer.h:154
OptPass的创建类的注册类模板
Definition: optimizer.h:175
TypeOptPassRegister(base::DeviceTypeCode device_type_code, OptPassType type, int level)
Definition: optimizer.h:181
#define NNDEPLOY_CC_API
api
Definition: macro.h:29
std::unordered_set< ir::OpType > OpSet
Definition: optimizer.h:15
OptPassType stringToOptPassType(const std::string &src)
@ kOptPassTypeEliminateCommonSubexpression
Definition: optimizer.h:25
@ kOptPassTypeFoldConstant
Definition: optimizer.h:29
@ kOptPassTypeFuseConvBias
Definition: optimizer.h:19
@ kOptPassTypeFuseConvBatchNorm
Definition: optimizer.h:20
@ kOptPassTypeEliminateDeadOp
Definition: optimizer.h:26
@ kOptPassTypeFuseConvAct
Definition: optimizer.h:22
@ kOptPassTypeFuseConvRelu
Definition: optimizer.h:21
@ kOptPassTypeFuseQdq
Definition: optimizer.h:32
std::string optPassTypeToString(OptPassType type)
std::shared_ptr< OptPass > createOptPass(base::DeviceType device_type, int level, OptPassType type)
std::map< base::DeviceTypeCode, std::map< int, std::map< OptPassType, std::shared_ptr< OptPassCreator > > > > & getGlobalOptPassCreatorMap()
Get the Global OptPass Creator Map object.