2 #ifndef _NNDEPLOY_NET_OPTIMIZER_H_
3 #define _NNDEPLOY_NET_OPTIMIZER_H_
5 #include <unordered_map>
6 #include <unordered_set>
15 using OpSet = std::unordered_set<ir::OpType>;
57 std::vector<OpWrapper*>& op_repository,
58 const std::vector<ir::OpType>& types,
71 std::vector<OpWrapper*>& op_repository,
72 const std::vector<OpSet>& types,
73 std::vector<ir::OpType>& matched_types,
90 std::vector<TensorWrapper*>& tensor_repository,
91 std::vector<OpWrapper*>& op_repository,
92 const std::vector<ir::OpType>& types,
int begin_op_index);
95 std::vector<TensorWrapper*>& tensor_repository,
96 std::vector<OpWrapper*>& op_repository,
97 const std::vector<ir::OpType>& types,
int begin_op_index);
114 OpWrapper* op_wrapper, std::vector<TensorWrapper*>& tensor_repository);
121 OpWrapper* op_wrapper, std::vector<TensorWrapper*>& tensor_repository);
127 std::vector<OpWrapper*>& op_repository,
128 int begin_op_index) = 0;
153 template <
typename T>
155 virtual std::shared_ptr<OptPass> createOptPass() {
156 return std::shared_ptr<T>(
new T());
166 std::map<int, std::map<OptPassType, std::shared_ptr<OptPassCreator>>>>&
174 template <
typename T>
184 auto device_map = creator_map.find(device_type_code);
185 if (device_map == creator_map.end()) {
186 creator_map[device_type_code] =
188 std::map<OptPassType, std::shared_ptr<OptPassCreator>>>();
190 auto level_map = creator_map[device_type_code].find(level);
191 if (level_map == creator_map[device_type_code].end()) {
192 creator_map[device_type_code][level] =
193 std::map<OptPassType, std::shared_ptr<OptPassCreator>>();
196 auto creator = creator_map[device_type_code][level].find(type);
197 if (creator == creator_map[device_type_code][level].end()) {
198 creator_map[device_type_code][level][type] = std::shared_ptr<T>(
new T());
212 std::set<OptPassType> enable_pass,
213 std::set<OptPassType> disable_pass);
220 std::vector<OpWrapper*>& op_repository,
Net* net);
224 std::map<int, std::map<OptPassType, std::shared_ptr<OptPass>>>
virtual ~OptPassCreator()
virtual std::shared_ptr< OptPass > createOptPass()=0
virtual base::Status rmOutputTensorAndMaybeDelete(OpWrapper *op_wrapper, std::vector< TensorWrapper * > &tensor_repository)
处理一个Op的输出Tensor 将该Op从Tensor的生产者中删除,如果该Tensor的生产者仅有这一个Op作为生产者,则释放该Tensor
virtual base::Status optimize(std::vector< TensorWrapper * > &tensor_repository, std::vector< OpWrapper * > &op_repository, int begin_op_index)=0
virtual base::Status rmOpFromSuccessors(OpWrapper *op_wrapper)
将一个Op从它后继的前驱中删除
virtual base::Status rmOpFromPredecessor(OpWrapper *op_wrapper)
将一个Op从它前驱的后继中删除
virtual base::Status seqPatternMatchUpateTensorRepository(std::vector< TensorWrapper * > &tensor_repository, std::vector< OpWrapper * > &op_repository, const std::vector< ir::OpType > &types, int begin_op_index)
模式匹配并更新tensor_repository
virtual base::Status rmInputTensorAndMaybeDelete(OpWrapper *op_wrapper, std::vector< TensorWrapper * > &tensor_repository)
处理一个Op的输入Tensor 将该Op从Tensor的消费者中删除,如果该Tensor的消费者仅有这一个Op作为消费者,则释放该Tensor
virtual int seqPatternMatch(std::vector< TensorWrapper * > &tensor_repository, std::vector< OpWrapper * > &op_repository, const std::vector< OpSet > &types, std::vector< ir::OpType > &matched_types, int begin_op_index)
virtual base::Status seqPatternMatchUpateOpRepository(std::vector< TensorWrapper * > &tensor_repository, std::vector< OpWrapper * > &op_repository, const std::vector< ir::OpType > &types, int begin_op_index)
virtual int seqPatternMatch(std::vector< TensorWrapper * > &tensor_repository, std::vector< OpWrapper * > &op_repository, const std::vector< ir::OpType > &types, int begin_op_index)
模式匹配
base::Status setNet(Net *net)
OptPass(std::string name)
std::map< int, std::map< OptPassType, std::shared_ptr< OptPass > > > opt_passes_
base::Status removePass(OptPassType type)
base::Status optimize(std::vector< TensorWrapper * > &tensor_repository, std::vector< OpWrapper * > &op_repository, Net *net)
base::Status addPass(OptPassType type, int level)
base::Status init(base::DeviceType device_type, std::set< OptPassType > enable_pass, std::set< OptPassType > disable_pass)
base::DeviceType device_type_
TypeOptPassRegister(base::DeviceTypeCode device_type_code, OptPassType type, int level)
#define NNDEPLOY_CC_API
api
std::unordered_set< ir::OpType > OpSet
OptPassType stringToOptPassType(const std::string &src)
@ kOptPassTypeEliminateCommonSubexpression
@ kOptPassTypeFoldConstant
@ kOptPassTypeFuseConvBias
@ kOptPassTypeFuseConvBatchNorm
@ kOptPassTypeEliminateDeadOp
@ kOptPassTypeFuseConvAct
@ kOptPassTypeFuseConvRelu
std::string optPassTypeToString(OptPassType type)
std::shared_ptr< OptPass > createOptPass(base::DeviceType device_type, int level, OptPassType type)
std::map< base::DeviceTypeCode, std::map< int, std::map< OptPassType, std::shared_ptr< OptPassCreator > > > > & getGlobalOptPassCreatorMap()
Get the Global OptPass Creator Map object.