nndeploy C++ API  0.2.0
nndeploy C++ API
op_param.h
Go to the documentation of this file.
1 
2 #ifndef _NNDEPLOY_IR_OP_PARAM_H_
3 #define _NNDEPLOY_IR_OP_PARAM_H_
4 
5 #include "nndeploy/base/any.h"
6 #include "nndeploy/base/common.h"
8 #include "nndeploy/base/log.h"
9 #include "nndeploy/base/macro.h"
10 #include "nndeploy/base/object.h"
11 #include "nndeploy/base/param.h"
13 #include "nndeploy/base/status.h"
14 #include "nndeploy/base/string.h"
15 #include "nndeploy/device/tensor.h"
16 #include "rapidjson/document.h"
17 #include "rapidjson/stringbuffer.h"
18 #include "rapidjson/writer.h"
19 
20 #ifdef _MSC_VER
21 #pragma warning(push) // 保存当前警告状态
22 #pragma warning(disable : 4267) // 禁用 C4267 警告
23 #endif
24 
25 namespace nndeploy {
26 namespace ir {
65 enum OpType : int {
66  kOpTypeNet = 0x0000,
67 
83  kOpTypeConstant, // 该算子转换为权重
212 
215 
217 };
218 
219 NNDEPLOY_CC_API std::string opTypeToString(OpType op_type);
220 
221 NNDEPLOY_CC_API OpType stringToOpType(const std::string &op_type_name);
222 
228  public:
229  virtual ~OpParamCreator() {};
230  virtual std::shared_ptr<base::Param> createOpParam(OpType type) = 0;
231 };
232 
238 template <typename T>
240  virtual std::shared_ptr<base::Param> createOpParam(OpType type) {
241  return std::make_shared<T>();
242  }
243 };
244 
250 extern NNDEPLOY_CC_API std::map<OpType, std::shared_ptr<OpParamCreator>> &
252 
258 template <typename T>
260  public:
261  explicit TypeOpParamRegister(OpType type) {
262  getGlobalOpParamCreatorMap()[type] = std::shared_ptr<T>(new T());
263  }
264 };
265 
272 extern NNDEPLOY_CC_API std::shared_ptr<base::Param> createOpParam(
273  OpType op_type);
274 
275 #define REGISTER_OP_PARAM_IMPLEMENTION(op_type, op_param_class) \
276  TypeOpParamRegister<TypeOpParamCreator<op_param_class>> \
277  g_##op_type##_##op_param_class##_register(op_type);
278 
287  public:
288  OpParam() : base::Param(), reserved_(0) {};
289  virtual ~OpParam() {};
290 
293 
294  public:
295  // 保留字段,也可以充void *使用
296  size_t reserved_;
297 };
298 
300  public:
303 
306 
309  rapidjson::Value &json, rapidjson::Document::AllocatorType &allocator) {
310  json.AddMember("epsilon_", epsilon_, allocator);
311  json.AddMember("momentum_", momentum_, allocator);
312  json.AddMember("training_mode_", training_mode_, allocator);
313  return base::kStatusCodeOk;
314  }
316  virtual base::Status deserialize(rapidjson::Value &json) {
317  if (json.HasMember("epsilon_")) {
318  epsilon_ = json["epsilon_"].GetFloat();
319  } else {
320  epsilon_ = 1e-05f; // 默认值
321  }
322 
323  if (json.HasMember("momentum_")) {
324  momentum_ = json["momentum_"].GetFloat();
325  } else {
326  momentum_ = 0.9f; // 默认值
327  }
328 
329  if (json.HasMember("training_mode_")) {
330  training_mode_ = json["training_mode_"].GetInt();
331  } else {
332  training_mode_ = 0; // 默认值
333  }
334 
335  return base::kStatusCodeOk;
336  }
337 
338  public:
339  // The epsilon value to use to avoid division by zero.
340  float epsilon_ = 1e-05f;
341  // Factor used in computing the running mean and variance.e.g., running_mean =
342  // running_mean * momentum + mean * (1 - momentum).
343  float momentum_ = 0.9f;
344  int training_mode_ = 0;
345 };
346 
347 class ConcatParam : public OpParam {
348  public:
350  virtual ~ConcatParam() {};
351 
354 
357  rapidjson::Value &json, rapidjson::Document::AllocatorType &allocator) {
358  json.AddMember("axis_", axis_, allocator);
359  return base::kStatusCodeOk;
360  }
362  virtual base::Status deserialize(rapidjson::Value &json) {
363  if (json.HasMember("axis_")) {
364  axis_ = json["axis_"].GetInt();
365  } else {
366  axis_ = 1;
367  }
368 
369  return base::kStatusCodeOk;
370  }
371 
372  public:
373  int axis_ = 1; // 拼接的维度
374 };
375 
376 class MatMulParam : public OpParam {
377  public:
379  virtual ~MatMulParam() {};
380 
383 
386  rapidjson::Value &json, rapidjson::Document::AllocatorType &allocator) {
387  json.AddMember("transposeA_", transposeA_, allocator);
388  json.AddMember("transposeB_", transposeB_, allocator);
389  return base::kStatusCodeOk;
390  }
392  virtual base::Status deserialize(rapidjson::Value &json) {
393  if (json.HasMember("transposeA_")) {
394  transposeA_ = json["transposeA_"].GetBool();
395  } else {
396  transposeA_ = false;
397  }
398  if (json.HasMember("transposeB_")) {
399  transposeB_ = json["transposeB_"].GetBool();
400  } else {
401  transposeB_ = false;
402  }
403 
404  return base::kStatusCodeOk;
405  }
406 
407  public:
408  bool transposeA_ = false; // 是否转置A矩阵
409  bool transposeB_ = false; // 是否转置A矩阵
410 };
411 
413  public:
414  // 构造函数
416  virtual ~ConvParam() {}
417 
420 
423  rapidjson::Value &json, rapidjson::Document::AllocatorType &allocator) {
424  json.AddMember("auto_pad_", rapidjson::Value(auto_pad_.c_str(), allocator),
425  allocator);
426  json.AddMember("dilations_", rapidjson::Value(rapidjson::kArrayType),
427  allocator);
428  for (size_t i = 0; i < dilations_.size(); ++i) {
429  json["dilations_"].PushBack(dilations_[i], allocator);
430  }
431  json.AddMember("group_", group_, allocator);
432  json.AddMember("kernel_shape_", rapidjson::Value(rapidjson::kArrayType),
433  allocator);
434  for (size_t i = 0; i < kernel_shape_.size(); ++i) {
435  json["kernel_shape_"].PushBack(kernel_shape_[i], allocator);
436  }
437  json.AddMember("pads_", rapidjson::Value(rapidjson::kArrayType), allocator);
438  for (size_t i = 0; i < pads_.size(); ++i) {
439  json["pads_"].PushBack(pads_[i], allocator);
440  }
441  json.AddMember("strides_", rapidjson::Value(rapidjson::kArrayType),
442  allocator);
443  for (size_t i = 0; i < strides_.size(); ++i) {
444  json["strides_"].PushBack(strides_[i], allocator);
445  }
446  json.AddMember(
447  "activate_op_",
448  rapidjson::Value(opTypeToString(activate_op_).c_str(), allocator),
449  allocator);
450  if (activate_op_ != kOpTypeNone && fused_op_param_ != nullptr) {
451  rapidjson::Value op_desc_json(rapidjson::kObjectType);
452  fused_op_param_->serialize(op_desc_json, allocator);
453  json.AddMember("fused_op_param_", op_desc_json, allocator);
454  }
455  return base::kStatusCodeOk;
456  }
458  virtual base::Status deserialize(rapidjson::Value &json) {
459  if (json.HasMember("auto_pad_")) {
460  auto_pad_ = json["auto_pad_"].GetString();
461  } else {
462  auto_pad_ = "NOTSET";
463  }
464 
465  if (json.HasMember("dilations_")) {
466  dilations_.clear();
467  for (size_t i = 0; i < json["dilations_"].Size(); ++i) {
468  dilations_.push_back(json["dilations_"][i].GetInt());
469  }
470  } else {
471  dilations_ = {1, 1};
472  }
473 
474  if (json.HasMember("group_")) {
475  group_ = json["group_"].GetInt();
476  } else {
477  group_ = 1;
478  }
479 
480  if (json.HasMember("kernel_shape_")) {
481  kernel_shape_.clear();
482  for (size_t i = 0; i < json["kernel_shape_"].Size(); ++i) {
483  kernel_shape_.push_back(json["kernel_shape_"][i].GetInt());
484  }
485  } else {
486  kernel_shape_.clear();
487  }
488 
489  if (json.HasMember("pads_")) {
490  pads_.clear();
491  for (size_t i = 0; i < json["pads_"].Size(); ++i) {
492  pads_.push_back(json["pads_"][i].GetInt());
493  }
494  } else {
495  pads_ = {0, 0, 0, 0};
496  }
497 
498  if (json.HasMember("strides_")) {
499  strides_.clear();
500  for (size_t i = 0; i < json["strides_"].Size(); ++i) {
501  strides_.push_back(json["strides_"][i].GetInt());
502  }
503  } else {
504  strides_ = {1, 1};
505  }
506 
507  if (json.HasMember("activate_op_")) {
508  activate_op_ = stringToOpType(json["activate_op_"].GetString());
509  fused_op_param_ = createOpParam(activate_op_);
510  if (json.HasMember("fused_op_param_")) {
511  fused_op_param_->deserialize(json["fused_op_param_"]);
512  }
513  } else {
514  activate_op_ = kOpTypeNone;
515  }
516 
517  return base::kStatusCodeOk;
518  }
519 
520  public:
521  // 自动填充方式
522  std::string auto_pad_ = "NOTSET";
523  // 扩张系数
524  std::vector<int> dilations_ = {1, 1};
525  // 组数
526  int group_ = 1;
527  // 卷积核大小
528  std::vector<int> kernel_shape_;
529  // 填充大小
530  std::vector<int> pads_ = {0, 0, 0, 0};
531  // 卷积步长
532  std::vector<int> strides_ = {1, 1};
533 
534  // 服务与算子融合
535  OpType activate_op_ = kOpTypeNone;
536  // OpParam* fused_op_param_ = nullptr;
537  std::shared_ptr<base::Param> fused_op_param_ = nullptr;
538 };
539 // MaxPool 参数类
541  public:
543  virtual ~MaxPoolParam() {}
544 
547 
550  rapidjson::Value &json, rapidjson::Document::AllocatorType &allocator) {
551  json.AddMember("auto_pad_", rapidjson::Value(auto_pad_.c_str(), allocator),
552  allocator);
553  json.AddMember("ceil_mode_", ceil_mode_, allocator);
554  rapidjson::Value dilations_array(rapidjson::kArrayType);
555  for (size_t i = 0; i < dilations_.size(); ++i) {
556  dilations_array.PushBack(dilations_[i], allocator);
557  }
558  json.AddMember("dilations_", dilations_array, allocator);
559 
560  rapidjson::Value kernel_shape_array(rapidjson::kArrayType);
561  for (size_t i = 0; i < kernel_shape_.size(); ++i) {
562  kernel_shape_array.PushBack(kernel_shape_[i], allocator);
563  }
564  json.AddMember("kernel_shape_", kernel_shape_array, allocator);
565 
566  rapidjson::Value pads_array(rapidjson::kArrayType);
567  for (size_t i = 0; i < pads_.size(); ++i) {
568  pads_array.PushBack(pads_[i], allocator);
569  }
570  json.AddMember("pads_", pads_array, allocator);
571 
572  json.AddMember("storage_order_", storage_order_, allocator);
573 
574  rapidjson::Value strides_array(rapidjson::kArrayType);
575  for (size_t i = 0; i < strides_.size(); ++i) {
576  strides_array.PushBack(strides_[i], allocator);
577  }
578  json.AddMember("strides_", strides_array, allocator);
579 
580  return base::kStatusCodeOk;
581  }
583  virtual base::Status deserialize(rapidjson::Value &json) {
584  if (json.HasMember("auto_pad_")) {
585  auto_pad_ = json["auto_pad_"].GetString();
586  } else {
587  auto_pad_ = "NOTSET";
588  }
589 
590  if (json.HasMember("ceil_mode_")) {
591  ceil_mode_ = json["ceil_mode_"].GetInt();
592  } else {
593  ceil_mode_ = 0;
594  }
595 
596  if (json.HasMember("dilations_")) {
597  dilations_.clear();
598  for (size_t i = 0; i < json["dilations_"].Size(); ++i) {
599  dilations_.push_back(json["dilations_"][i].GetInt());
600  }
601  } else {
602  dilations_ = {1, 1};
603  }
604 
605  if (json.HasMember("kernel_shape_")) {
606  kernel_shape_.clear();
607  for (size_t i = 0; i < json["kernel_shape_"].Size(); ++i) {
608  kernel_shape_.push_back(json["kernel_shape_"][i].GetInt());
609  }
610  } else {
611  kernel_shape_.clear();
612  }
613 
614  if (json.HasMember("pads_")) {
615  pads_.clear();
616  for (size_t i = 0; i < json["pads_"].Size(); ++i) {
617  pads_.push_back(json["pads_"][i].GetInt());
618  }
619  } else {
620  pads_ = {0, 0, 0, 0};
621  }
622 
623  if (json.HasMember("storage_order_")) {
624  storage_order_ = json["storage_order_"].GetInt();
625  } else {
626  storage_order_ = 0;
627  }
628 
629  if (json.HasMember("strides_")) {
630  strides_.clear();
631  for (size_t i = 0; i < json["strides_"].Size(); ++i) {
632  strides_.push_back(json["strides_"][i].GetInt());
633  }
634  } else {
635  strides_ = {1, 1};
636  }
637 
638  return base::kStatusCodeOk;
639  }
640 
641  public:
642  std::string auto_pad_ = "NOTSET"; // 自动填充方式
643  int ceil_mode_ = 0; // 是否向上取整
644  std::vector<int> dilations_ = {1, 1}; // 扩张系数
645  std::vector<int> kernel_shape_; // 池化核大小
646  std::vector<int> pads_ = {0, 0, 0, 0}; // 填充大小
647  int storage_order_ = 0; // 存储顺序
648  std::vector<int> strides_ = {1, 1}; // 步长
649 };
650 
651 // Reshape 参数类
653  public:
655  virtual ~ReshapeParam() {}
656 
659 
662  rapidjson::Value &json, rapidjson::Document::AllocatorType &allocator) {
663  json.AddMember("allowzero_", allowzero_, allocator);
664  return base::kStatusCodeOk;
665  }
667  virtual base::Status deserialize(rapidjson::Value &json) {
668  if (json.HasMember("allowzero_")) {
669  allowzero_ = json["allowzero_"].GetInt();
670  } else {
671  allowzero_ = 0; // 默认值
672  }
673 
674  return base::kStatusCodeOk;
675  }
676 
677  public:
678  int allowzero_ = 0; // 是否允许0
679 };
680 
681 // Resize 参数类 - opset 18~19
683  public:
685  virtual ~ResizeParam() {}
686 
689 
692  rapidjson::Value &json, rapidjson::Document::AllocatorType &allocator) {
693  json.AddMember("antialias_", antialias_, allocator);
694  json.AddMember("axes_", axes_, allocator);
695  json.AddMember(
696  "coordinate_transformation_mode_",
697  rapidjson::Value(coordinate_transformation_mode_.c_str(), allocator),
698  allocator);
699  json.AddMember("cubic_coeff_a_", cubic_coeff_a_, allocator);
700  json.AddMember("exclude_outside_", exclude_outside_, allocator);
701  json.AddMember("extrapolation_value_", extrapolation_value_, allocator);
702  json.AddMember(
703  "keep_aspect_ratio_policy_",
704  rapidjson::Value(keep_aspect_ratio_policy_.c_str(), allocator),
705  allocator);
706  json.AddMember("mode_", rapidjson::Value(mode_.c_str(), allocator),
707  allocator);
708  json.AddMember("nearest_mode_",
709  rapidjson::Value(nearest_mode_.c_str(), allocator),
710  allocator);
711  return base::kStatusCodeOk;
712  }
714  virtual base::Status deserialize(rapidjson::Value &json) {
715  if (json.HasMember("antialias_")) {
716  antialias_ = json["antialias_"].GetInt();
717  } else {
718  antialias_ = 0; // 默认值
719  }
720 
721  if (json.HasMember("axes_")) {
722  axes_ = json["axes_"].GetInt();
723  } else {
724  axes_ = INT_MAX; // 默认值
725  }
726 
727  if (json.HasMember("coordinate_transformation_mode_")) {
728  coordinate_transformation_mode_ =
729  json["coordinate_transformation_mode_"].GetString();
730  } else {
731  coordinate_transformation_mode_ = "half_pixel"; // 默认值
732  }
733 
734  if (json.HasMember("cubic_coeff_a_")) {
735  cubic_coeff_a_ = json["cubic_coeff_a_"].GetFloat();
736  } else {
737  cubic_coeff_a_ = -0.75; // 默认值
738  }
739 
740  if (json.HasMember("exclude_outside_")) {
741  exclude_outside_ = json["exclude_outside_"].GetInt();
742  } else {
743  exclude_outside_ = 0; // 默认值
744  }
745 
746  if (json.HasMember("extrapolation_value_")) {
747  extrapolation_value_ = json["extrapolation_value_"].GetFloat();
748  } else {
749  extrapolation_value_ = -0.0; // 默认值
750  }
751 
752  if (json.HasMember("keep_aspect_ratio_policy_")) {
753  keep_aspect_ratio_policy_ = json["keep_aspect_ratio_policy_"].GetString();
754  } else {
755  keep_aspect_ratio_policy_ = "stretch"; // 默认值
756  }
757 
758  if (json.HasMember("mode_")) {
759  mode_ = json["mode_"].GetString();
760  } else {
761  mode_ = "nearest"; // 默认值
762  }
763 
764  if (json.HasMember("nearest_mode_")) {
765  nearest_mode_ = json["nearest_mode_"].GetString();
766  } else {
767  nearest_mode_ = "round_prefer_floor"; // 默认值
768  }
769 
770  return base::kStatusCodeOk;
771  }
772 
773  public:
774  int antialias_ = 0;
775  int axes_ = INT_MAX; // 轴,当为INT_MAX时,表示未设置
776  std::string coordinate_transformation_mode_ = "half_pixel";
777  float cubic_coeff_a_ = -0.75;
778  int exclude_outside_ = 0;
779  float extrapolation_value_ = -0.0;
780  std::string keep_aspect_ratio_policy_ = "stretch";
781  std::string mode_ = "nearest";
782  std::string nearest_mode_ = "round_prefer_floor";
783 };
784 
785 // Softmax 参数类
787  public:
789  virtual ~SoftmaxParam() {}
790 
793 
796  rapidjson::Value &json, rapidjson::Document::AllocatorType &allocator) {
797  json.AddMember("axis_", axis_, allocator);
798  return base::kStatusCodeOk;
799  }
801  virtual base::Status deserialize(rapidjson::Value &json) {
802  if (json.HasMember("axis_")) {
803  axis_ = json["axis_"].GetInt();
804  } else {
805  axis_ = -1;
806  }
807 
808  return base::kStatusCodeOk;
809  }
810 
811  public:
812  int axis_ = -1; // 应用 Softmax 的轴
813 };
814 
815 // Split 参数类
817  public:
818  SplitParam() : OpParam() {} // 默认轴为0,分割数为1
819  virtual ~SplitParam() {}
820 
823 
826  rapidjson::Value &json, rapidjson::Document::AllocatorType &allocator) {
827  json.AddMember("axis_", axis_, allocator);
828  json.AddMember("num_outputs_", num_outputs_, allocator);
829  return base::kStatusCodeOk;
830  }
832  virtual base::Status deserialize(rapidjson::Value &json) {
833  if (json.HasMember("axis_")) {
834  axis_ = json["axis_"].GetInt();
835  } else {
836  axis_ = 0; // 默认值
837  }
838 
839  if (json.HasMember("num_outputs_")) {
840  num_outputs_ = json["num_outputs_"].GetInt();
841  } else {
842  num_outputs_ = INT_MAX; // 默认值
843  }
844 
845  return base::kStatusCodeOk;
846  }
847 
848  public:
849  int axis_ = 0; // 分割轴
850  int num_outputs_ = INT_MAX; // 分割数
851 };
852 
853 // Transpose 参数类
855  public:
856  TransposeParam() : OpParam() {} // 默认轴为0,分割数为1
857  virtual ~TransposeParam() {}
858 
861 
864  rapidjson::Value &json, rapidjson::Document::AllocatorType &allocator) {
865  rapidjson::Value permArray(rapidjson::kArrayType);
866  for (size_t i = 0; i < perm_.size(); ++i) {
867  permArray.PushBack(perm_[i], allocator);
868  }
869  json.AddMember("perm_", permArray, allocator);
870  return base::kStatusCodeOk;
871  }
873  virtual base::Status deserialize(rapidjson::Value &json) {
874  if (json.HasMember("perm_")) {
875  perm_.clear();
876  for (size_t i = 0; i < json["perm_"].Size(); ++i) {
877  perm_.push_back(json["perm_"][i].GetInt());
878  }
879  } else {
880  perm_.clear(); // 默认值
881  }
882 
883  return base::kStatusCodeOk;
884  }
885 
886  public:
887  std::vector<int> perm_;
888 };
889 
891  public:
892  RMSNormParam() : OpParam() {} // 默认轴为0,分割数为1
893  virtual ~RMSNormParam() {}
894 
897 
900  rapidjson::Value &json, rapidjson::Document::AllocatorType &allocator) {
901  json.AddMember("eps_", eps_, allocator);
902  json.AddMember("is_last_", is_last_, allocator);
903  return base::kStatusCodeOk;
904  }
906  virtual base::Status deserialize(rapidjson::Value &json) {
907  if (json.HasMember("eps_")) {
908  eps_ = json["eps_"].GetFloat();
909  } else {
910  eps_ = 1e-6f; // 默认值
911  }
912 
913  if (json.HasMember("is_last_")) {
914  is_last_ = json["is_last_"].GetBool();
915  } else {
916  is_last_ = false; // 默认值
917  }
918 
919  return base::kStatusCodeOk;
920  }
921 
922  public:
923  float eps_ = 1e-6f;
924  bool is_last_ = false;
925 };
926 
928  public:
930  virtual ~FlattenParam() {};
931 
934 
937  rapidjson::Value &json, rapidjson::Document::AllocatorType &allocator) {
938  json.AddMember("axis_", axis_, allocator);
939  return base::kStatusCodeOk;
940  }
942  virtual base::Status deserialize(rapidjson::Value &json) {
943  if (json.HasMember("axis_")) {
944  axis_ = json["axis_"].GetInt();
945  } else {
946  axis_ = 1;
947  }
948 
949  return base::kStatusCodeOk;
950  }
951 
952  public:
953  int axis_ = 1;
954 };
955 
957  public:
959  virtual ~EmbeddingParam() {};
960 
963 
966  rapidjson::Value &json, rapidjson::Document::AllocatorType &allocator) {
967  return base::kStatusCodeOk;
968  }
970  virtual base::Status deserialize(rapidjson::Value &json) {
971  return base::kStatusCodeOk;
972  }
973 };
974 
976  public:
977  GemmParam() : OpParam() {};
978  virtual ~GemmParam() {};
979 
982 
985  rapidjson::Value &json, rapidjson::Document::AllocatorType &allocator) {
986  json.AddMember("alpha_", alpha_, allocator);
987  json.AddMember("beta_", beta_, allocator);
988  json.AddMember("trans_a_", trans_a_, allocator);
989  json.AddMember("trans_b_", trans_b_, allocator);
990  return base::kStatusCodeOk;
991  }
993  virtual base::Status deserialize(rapidjson::Value &json) {
994  if (json.HasMember("alpha_")) {
995  alpha_ = json["alpha_"].GetFloat();
996  } else {
997  alpha_ = 1.0; // 默认值
998  }
999 
1000  if (json.HasMember("beta_")) {
1001  beta_ = json["beta_"].GetFloat();
1002  } else {
1003  beta_ = 1.0; // 默认值
1004  }
1005 
1006  if (json.HasMember("trans_a_")) {
1007  trans_a_ = json["trans_a_"].GetInt();
1008  } else {
1009  trans_a_ = 0; // 默认值
1010  }
1011 
1012  if (json.HasMember("trans_b_")) {
1013  trans_b_ = json["trans_b_"].GetInt();
1014  } else {
1015  trans_b_ = 0; // 默认值
1016  }
1017 
1018  return base::kStatusCodeOk;
1019  }
1020 
1021  public:
1022  float alpha_ = 1.0; // 默认值为1.0
1023  float beta_ = 1.0; // 默认值为1.0
1024  int trans_a_ = 0; // 默认值为0
1025  int trans_b_ = 0; // 默认值为0
1026 };
1027 
1029  public:
1032 
1035 
1036  using base::Param::serialize;
1038  rapidjson::Value &json, rapidjson::Document::AllocatorType &allocator) {
1039  json.AddMember("axis_", axis_, allocator);
1040  json.AddMember("saturate_", saturate_, allocator);
1041  return base::kStatusCodeOk;
1042  }
1044  virtual base::Status deserialize(rapidjson::Value &json) {
1045  if (json.HasMember("axis_")) {
1046  axis_ = json["axis_"].GetInt();
1047  } else {
1048  axis_ = 1; // 默认值
1049  }
1050 
1051  if (json.HasMember("saturate_")) {
1052  saturate_ = json["saturate_"].GetInt();
1053  } else {
1054  saturate_ = 1; // 默认值
1055  }
1056 
1057  return base::kStatusCodeOk;
1058  }
1059 
1060  public:
1061  int axis_ = 1; // 量化维度,默认为1
1062  int saturate_ = 1; // 是否饱和,默认为1
1063 };
1064 
1066  public:
1069 
1072 
1073  using base::Param::serialize;
1075  rapidjson::Value &json, rapidjson::Document::AllocatorType &allocator) {
1076  json.AddMember("axis_", axis_, allocator);
1077  return base::kStatusCodeOk;
1078  }
1080  virtual base::Status deserialize(rapidjson::Value &json) {
1081  if (json.HasMember("axis_")) {
1082  axis_ = json["axis_"].GetInt();
1083  } else {
1084  axis_ = 1; // 默认值
1085  }
1086 
1087  return base::kStatusCodeOk;
1088  }
1089 
1090  public:
1091  int axis_ = 1; // 反量化维度,默认为1
1092 };
1093 
1095  public:
1096  // 构造函数
1098  virtual ~QLinearConvParam() {}
1099 
1102 
1103  using base::Param::serialize;
1105  rapidjson::Value &json, rapidjson::Document::AllocatorType &allocator) {
1106  json.AddMember("auto_pad_", rapidjson::Value(auto_pad_.c_str(), allocator),
1107  allocator);
1108  json.AddMember("dilations_", rapidjson::Value(rapidjson::kArrayType),
1109  allocator);
1110  for (size_t i = 0; i < dilations_.size(); ++i) {
1111  json["dilations_"].PushBack(dilations_[i], allocator);
1112  }
1113  json.AddMember("group_", group_, allocator);
1114  json.AddMember("kernel_shape_", rapidjson::Value(rapidjson::kArrayType),
1115  allocator);
1116  for (size_t i = 0; i < kernel_shape_.size(); ++i) {
1117  json["kernel_shape_"].PushBack(kernel_shape_[i], allocator);
1118  }
1119  json.AddMember("pads_", rapidjson::Value(rapidjson::kArrayType), allocator);
1120  for (size_t i = 0; i < pads_.size(); ++i) {
1121  json["pads_"].PushBack(pads_[i], allocator);
1122  }
1123  json.AddMember("strides_", rapidjson::Value(rapidjson::kArrayType),
1124  allocator);
1125  for (size_t i = 0; i < strides_.size(); ++i) {
1126  json["strides_"].PushBack(strides_[i], allocator);
1127  }
1128 
1129  return base::kStatusCodeOk;
1130  }
1132  virtual base::Status deserialize(rapidjson::Value &json) {
1133  if (json.HasMember("auto_pad_")) {
1134  auto_pad_ = json["auto_pad_"].GetString();
1135  } else {
1136  auto_pad_ = "NOTSET";
1137  }
1138 
1139  if (json.HasMember("dilations_")) {
1140  dilations_.clear();
1141  for (size_t i = 0; i < json["dilations_"].Size(); ++i) {
1142  dilations_.push_back(json["dilations_"][i].GetInt());
1143  }
1144  } else {
1145  dilations_ = {1, 1};
1146  }
1147 
1148  if (json.HasMember("group_")) {
1149  group_ = json["group_"].GetInt();
1150  } else {
1151  group_ = 1;
1152  }
1153 
1154  if (json.HasMember("kernel_shape_")) {
1155  kernel_shape_.clear();
1156  for (size_t i = 0; i < json["kernel_shape_"].Size(); ++i) {
1157  kernel_shape_.push_back(json["kernel_shape_"][i].GetInt());
1158  }
1159  } else {
1160  kernel_shape_.clear();
1161  }
1162 
1163  if (json.HasMember("pads_")) {
1164  pads_.clear();
1165  for (size_t i = 0; i < json["pads_"].Size(); ++i) {
1166  pads_.push_back(json["pads_"][i].GetInt());
1167  }
1168  } else {
1169  pads_ = {0, 0, 0, 0};
1170  }
1171 
1172  if (json.HasMember("strides_")) {
1173  strides_.clear();
1174  for (size_t i = 0; i < json["strides_"].Size(); ++i) {
1175  strides_.push_back(json["strides_"][i].GetInt());
1176  }
1177  } else {
1178  strides_ = {1, 1};
1179  }
1180 
1181  return base::kStatusCodeOk;
1182  }
1183 
1184  public:
1185  // 自动填充方式
1186  std::string auto_pad_ = "NOTSET";
1187  // 扩张系数
1188  std::vector<int> dilations_ = {1, 1};
1189  // 组数
1190  int group_ = 1;
1191  // 卷积核大小
1192  std::vector<int> kernel_shape_;
1193  // 填充大小
1194  std::vector<int> pads_ = {0, 0, 0, 0};
1195  // 卷积步长
1196  std::vector<int> strides_ = {1, 1};
1197 };
1198 
1200  public:
1201  // 构造函数
1203  virtual ~AveragePoolParam() {}
1204 
1207 
1208  using base::Param::serialize;
1210  rapidjson::Value &json, rapidjson::Document::AllocatorType &allocator) {
1211  json.AddMember("auto_pad_", rapidjson::Value(auto_pad_.c_str(), allocator),
1212  allocator);
1213  json.AddMember("ceil_mode_", ceil_mode_, allocator);
1214  json.AddMember("count_include_pad_",
1215  rapidjson::Value(count_include_pad_.c_str(), allocator),
1216  allocator);
1217  json.AddMember("dilations_", rapidjson::Value(rapidjson::kArrayType),
1218  allocator);
1219  for (size_t i = 0; i < dilations_.size(); ++i) {
1220  json["dilations_"].PushBack(dilations_[i], allocator);
1221  }
1222  json.AddMember("kernel_shape_", rapidjson::Value(rapidjson::kArrayType),
1223  allocator);
1224  for (size_t i = 0; i < kernel_shape_.size(); ++i) {
1225  json["kernel_shape_"].PushBack(kernel_shape_[i], allocator);
1226  }
1227  json.AddMember("pads_", rapidjson::Value(rapidjson::kArrayType), allocator);
1228  for (size_t i = 0; i < pads_.size(); ++i) {
1229  json["pads_"].PushBack(pads_[i], allocator);
1230  }
1231  json.AddMember("strides_", rapidjson::Value(rapidjson::kArrayType),
1232  allocator);
1233  for (size_t i = 0; i < strides_.size(); ++i) {
1234  json["strides_"].PushBack(strides_[i], allocator);
1235  }
1236  return base::kStatusCodeOk;
1237  }
1238 
1240  virtual base::Status deserialize(rapidjson::Value &json) {
1241  if (json.HasMember("auto_pad_")) {
1242  auto_pad_ = json["auto_pad_"].GetString();
1243  } else {
1244  auto_pad_ = "NOTSET";
1245  }
1246 
1247  if (json.HasMember("ceil_mode_")) {
1248  ceil_mode_ = json["ceil_mode_"].GetInt();
1249  } else {
1250  ceil_mode_ = 0;
1251  }
1252 
1253  if (json.HasMember("count_include_pad_")) {
1254  count_include_pad_ = json["count_include_pad_"].GetString();
1255  } else {
1256  count_include_pad_ = "EXCLUDE";
1257  }
1258 
1259  if (json.HasMember("dilations_")) {
1260  dilations_.clear();
1261  for (size_t i = 0; i < json["dilations_"].Size(); ++i) {
1262  dilations_.push_back(json["dilations_"][i].GetInt());
1263  }
1264  } else {
1265  dilations_ = {1, 1};
1266  }
1267 
1268  if (json.HasMember("kernel_shape_")) {
1269  kernel_shape_.clear();
1270  for (size_t i = 0; i < json["kernel_shape_"].Size(); ++i) {
1271  kernel_shape_.push_back(json["kernel_shape_"][i].GetInt());
1272  }
1273  } else {
1274  kernel_shape_.clear();
1275  }
1276 
1277  if (json.HasMember("pads_")) {
1278  pads_.clear();
1279  for (size_t i = 0; i < json["pads_"].Size(); ++i) {
1280  pads_.push_back(json["pads_"][i].GetInt());
1281  }
1282  } else {
1283  pads_ = {0, 0, 0, 0};
1284  }
1285 
1286  if (json.HasMember("strides_")) {
1287  strides_.clear();
1288  for (size_t i = 0; i < json["strides_"].Size(); ++i) {
1289  strides_.push_back(json["strides_"][i].GetInt());
1290  }
1291  } else {
1292  strides_ = {1, 1};
1293  }
1294 
1295  return base::kStatusCodeOk;
1296  }
1297 
1298  public:
1299  // 自动填充方式
1300  std::string auto_pad_ = "NOTSET";
1301  // 是否向上取整
1302  int ceil_mode_ = 0;
1303  // 计算方式
1304  std::string count_include_pad_ = "EXCLUDE";
1305  // 扩张系数
1306  std::vector<int> dilations_ = {1, 1};
1307  // 平均池化的核大小
1308  std::vector<int> kernel_shape_;
1309  // 填充大小
1310  std::vector<int> pads_ = {0, 0, 0, 0};
1311  // 平均池化的步长
1312  std::vector<int> strides_ = {1, 1};
1313 };
1314 
1316  public:
1318  virtual ~CastParam() {}
1319 
1322 
1323  using base::Param::serialize;
1325  rapidjson::Value &json, rapidjson::Document::AllocatorType &allocator) {
1326  json.AddMember("saturate_", saturate_, allocator);
1327  json.AddMember("to_", rapidjson::Value(rapidjson::kArrayType), allocator);
1328  json["to_"].PushBack(static_cast<int32_t>(to_.code_), allocator);
1329  json["to_"].PushBack(static_cast<int32_t>(to_.bits_), allocator);
1330  json["to_"].PushBack(static_cast<int32_t>(to_.lanes_), allocator);
1331  return base::kStatusCodeOk;
1332  }
1334  virtual base::Status deserialize(rapidjson::Value &json) {
1335  if (json.HasMember("saturate_")) {
1336  saturate_ = json["saturate_"].GetInt();
1337  } else {
1338  saturate_ = 1;
1339  }
1340 
1341  if (json.HasMember("to_")) {
1342  to_.code_ = json["to_"][0].GetInt();
1343  to_.bits_ = json["to_"][1].GetInt();
1344  to_.lanes_ = json["to_"][2].GetInt();
1345  } else {
1346  to_ = base::dataTypeOf<float>();
1347  }
1348 
1349  return base::kStatusCodeOk;
1350  };
1351 
1352  public:
1353  int saturate_ =
1354  1; // https://onnx.org.cn/onnx/operators/onnx__Cast.html#cast-19
1355  base::DataType to_; // 输入张量元素将被转换为的数据类型。
1356 };
1357 
1359  public:
1361  virtual ~UnsqueezeParam() {}
1362 
1365 
1366  using base::Param::serialize;
1368  rapidjson::Value &json, rapidjson::Document::AllocatorType &allocator) {
1369  json.AddMember("axes_", axes_, allocator);
1370  return base::kStatusCodeOk;
1371  }
1373  virtual base::Status deserialize(rapidjson::Value &json) {
1374  if (json.HasMember("axes_")) {
1375  axes_ = json["axes_"].GetInt();
1376  } else {
1377  axes_ = 0; // 默认值
1378  }
1379  return base::kStatusCodeOk;
1380  }
1381 
1382  public:
1383  int axes_ = 0; // 指定在哪些维度上增加维度,默认值为0
1384 };
1385 
1387  public:
1388  GatherParam() : OpParam() {} // 默认构造函数
1389  virtual ~GatherParam() {}
1390 
1393 
1394  using base::Param::serialize;
1396  rapidjson::Value &json, rapidjson::Document::AllocatorType &allocator) {
1397  json.AddMember("axis_", axis_, allocator);
1398  return base::kStatusCodeOk;
1399  }
1401  virtual base::Status deserialize(rapidjson::Value &json) {
1402  if (json.HasMember("axis_")) {
1403  axis_ = json["axis_"].GetInt();
1404  } else {
1405  axis_ = 0; // 默认值
1406  }
1407  return base::kStatusCodeOk;
1408  }
1409 
1410  public:
1411  int axis_ = 0; // 用于收集的轴,默认值为0
1412 };
1413 
1415  public:
1417  virtual ~ReduceMeanParam() {}
1418 
1421 
1422  using base::Param::serialize;
1424  rapidjson::Value &json, rapidjson::Document::AllocatorType &allocator) {
1425  json.AddMember("keepdims_", keepdims_, allocator);
1426  json.AddMember("noop_with_empty_axes_", noop_with_empty_axes_, allocator);
1427  return base::kStatusCodeOk;
1428  }
1430  virtual base::Status deserialize(rapidjson::Value &json) {
1431  if (json.HasMember("keepdims_")) {
1432  keepdims_ = json["keepdims_"].GetInt();
1433  } else {
1434  keepdims_ = 1; // 默认值
1435  }
1436 
1437  if (json.HasMember("noop_with_empty_axes_")) {
1438  noop_with_empty_axes_ = json["noop_with_empty_axes_"].GetInt();
1439  } else {
1440  noop_with_empty_axes_ = 0; // 默认值
1441  }
1442 
1443  return base::kStatusCodeOk;
1444  }
1445 
1446  public:
1447  int keepdims_ = 1; // 是否保留被约简的维度,默认为1
1448  int noop_with_empty_axes_ = 0; // 当axes为空时的行为,默认为0
1449 };
1450 
1452  public:
1454  virtual ~ReduceMaxParam() {}
1455 
1458 
1459  using base::Param::serialize;
1461  rapidjson::Value &json, rapidjson::Document::AllocatorType &allocator) {
1462  json.AddMember("keepdims_", keepdims_, allocator);
1463  json.AddMember("noop_with_empty_axes_", noop_with_empty_axes_, allocator);
1464  return base::kStatusCodeOk;
1465  }
1467  virtual base::Status deserialize(rapidjson::Value &json) {
1468  if (json.HasMember("keepdims_")) {
1469  keepdims_ = json["keepdims_"].GetInt();
1470  } else {
1471  keepdims_ = 1; // 默认值
1472  }
1473 
1474  if (json.HasMember("noop_with_empty_axes_")) {
1475  noop_with_empty_axes_ = json["noop_with_empty_axes_"].GetInt();
1476  } else {
1477  noop_with_empty_axes_ = 0; // 默认值
1478  }
1479 
1480  return base::kStatusCodeOk;
1481  }
1482 
1483  public:
1484  int keepdims_ = 1; // 是否保留被约简的维度,默认为1
1485  int noop_with_empty_axes_ = 0; // 当axes为空时的行为,默认为0
1486 };
1487 
1489  public:
1491  virtual ~ReduceMinParam() {}
1492 
1495 
1496  using base::Param::serialize;
1498  rapidjson::Value &json, rapidjson::Document::AllocatorType &allocator) {
1499  json.AddMember("keepdims_", keepdims_, allocator);
1500  json.AddMember("noop_with_empty_axes_", noop_with_empty_axes_, allocator);
1501  return base::kStatusCodeOk;
1502  }
1504  virtual base::Status deserialize(rapidjson::Value &json) {
1505  if (json.HasMember("keepdims_")) {
1506  keepdims_ = json["keepdims_"].GetInt();
1507  } else {
1508  keepdims_ = 1; // 默认值
1509  }
1510 
1511  if (json.HasMember("noop_with_empty_axes_")) {
1512  noop_with_empty_axes_ = json["noop_with_empty_axes_"].GetInt();
1513  } else {
1514  noop_with_empty_axes_ = 0; // 默认值
1515  }
1516 
1517  return base::kStatusCodeOk;
1518  }
1519 
1520  public:
1521  int keepdims_ = 1; // 是否保留被约简的维度,默认为1
1522  int noop_with_empty_axes_ = 0; // 当axes为空时的行为,默认为0
1523 };
1524 
1526  public:
1528  virtual ~ReduceSumParam() {}
1529 
1532 
1533  using base::Param::serialize;
1535  rapidjson::Value &json, rapidjson::Document::AllocatorType &allocator) {
1536  json.AddMember("keepdims_", keepdims_, allocator);
1537  json.AddMember("noop_with_empty_axes_", noop_with_empty_axes_, allocator);
1538  return base::kStatusCodeOk;
1539  }
1541  virtual base::Status deserialize(rapidjson::Value &json) {
1542  if (json.HasMember("keepdims_")) {
1543  keepdims_ = json["keepdims_"].GetInt();
1544  } else {
1545  keepdims_ = 1; // 默认值
1546  }
1547 
1548  if (json.HasMember("noop_with_empty_axes_")) {
1549  noop_with_empty_axes_ = json["noop_with_empty_axes_"].GetInt();
1550  } else {
1551  noop_with_empty_axes_ = 0; // 默认值
1552  }
1553 
1554  return base::kStatusCodeOk;
1555  }
1556 
1557  public:
1558  int keepdims_ = 1; // 是否保留被约简的维度,默认为1
1559  int noop_with_empty_axes_ = 0; // 当axes为空时的行为,默认为0
1560 };
1561 
1563  public:
1564  ShapeParam() : OpParam() {} // 默认构造函数
1565  virtual ~ShapeParam() {}
1566 
1569 
1570  using base::Param::serialize;
1572  rapidjson::Value &json, rapidjson::Document::AllocatorType &allocator) {
1573  json.AddMember("start_", start_, allocator);
1574  json.AddMember("end_", end_, allocator);
1575  return base::kStatusCodeOk;
1576  }
1578  virtual base::Status deserialize(rapidjson::Value &json) {
1579  if (json.HasMember("start_")) {
1580  start_ = json["start_"].GetInt();
1581  } else {
1582  start_ = 0; // 默认值
1583  }
1584 
1585  if (json.HasMember("end_")) {
1586  end_ = json["end_"].GetInt();
1587  } else {
1588  end_ = -1; // 默认值
1589  }
1590 
1591  return base::kStatusCodeOk;
1592  }
1593 
1594  public:
1595  int start_ = 0; // 用于切片形状的起始轴,默认值为0
1596  int end_ =
1597  -1; // 负值表示从后向前计数维度。如果省略,将包含直到(包括)最后一个轴的所有轴的大小。
1598 };
1599 
1601  public:
1604 
1607 
1608  base::Status serialize(rapidjson::Value &json,
1609  rapidjson::Document::AllocatorType &allocator) {
1610  json.AddMember("value_", value_, allocator);
1611  json.AddMember("datatype_", rapidjson::Value(rapidjson::kArrayType),
1612  allocator);
1613  json["datatype_"].PushBack(static_cast<int32_t>(datatype_.code_),
1614  allocator);
1615  json["datatype_"].PushBack(static_cast<int32_t>(datatype_.bits_),
1616  allocator);
1617  json["datatype_"].PushBack(static_cast<int32_t>(datatype_.lanes_),
1618  allocator);
1619  return base::kStatusCodeOk;
1620  }
1621 
1622  base::Status deserialize(rapidjson::Value &json) {
1623  if (json.HasMember("value_")) {
1624  value_ = json["value_"].GetFloat();
1625  } else {
1626  value_ = 0.0f; // 默认值
1627  }
1628 
1629  if (json.HasMember("datatype_")) {
1630  datatype_.code_ = json["datatype_"][0].GetInt();
1631  datatype_.bits_ = json["datatype_"][1].GetInt();
1632  datatype_.lanes_ = json["datatype_"][2].GetInt();
1633  } else {
1634  datatype_ = base::dataTypeOf<float>(); // 默认为 float 类型
1635  }
1636 
1637  return base::kStatusCodeOk;
1638  }
1639 
1640  public:
1641  float value_ = 0.0f; // 默认值为 0.0
1642  base::DataType datatype_ =
1643  base::dataTypeOf<float>(); // 数据类型,默认为 float 类型
1644 };
1645 
1646 } // namespace ir
1647 } // namespace nndeploy
1648 
1649 #ifdef _MSC_VER
1650 #pragma warning(pop) // 恢复之前的警告状态
1651 #endif
1652 
1653 #endif /* _NNDEPLOY_IR_OP_PARAM_H_ */
virtual base::Status deserialize(rapidjson::Value &json)
virtual std::string serialize()
virtual base::Status deserialize(rapidjson::Value &json)
Definition: op_param.h:1240
std::vector< int > kernel_shape_
Definition: op_param.h:1308
virtual base::Status serialize(rapidjson::Value &json, rapidjson::Document::AllocatorType &allocator)
Definition: op_param.h:1209
virtual base::Status deserialize(rapidjson::Value &json)
Definition: op_param.h:316
virtual base::Status serialize(rapidjson::Value &json, rapidjson::Document::AllocatorType &allocator)
Definition: op_param.h:308
virtual base::Status serialize(rapidjson::Value &json, rapidjson::Document::AllocatorType &allocator)
Definition: op_param.h:1324
virtual base::Status deserialize(rapidjson::Value &json)
Definition: op_param.h:1334
base::DataType to_
Definition: op_param.h:1355
virtual base::Status serialize(rapidjson::Value &json, rapidjson::Document::AllocatorType &allocator)
Definition: op_param.h:356
virtual base::Status deserialize(rapidjson::Value &json)
Definition: op_param.h:362
base::Status deserialize(rapidjson::Value &json)
Definition: op_param.h:1622
virtual base::Status serialize(rapidjson::Value &json, rapidjson::Document::AllocatorType &allocator)
Definition: op_param.h:422
std::vector< int > kernel_shape_
Definition: op_param.h:528
virtual base::Status deserialize(rapidjson::Value &json)
Definition: op_param.h:458
virtual base::Status serialize(rapidjson::Value &json, rapidjson::Document::AllocatorType &allocator)
Definition: op_param.h:1074
virtual base::Status deserialize(rapidjson::Value &json)
Definition: op_param.h:1080
virtual base::Status serialize(rapidjson::Value &json, rapidjson::Document::AllocatorType &allocator)
Definition: op_param.h:965
virtual base::Status deserialize(rapidjson::Value &json)
Definition: op_param.h:970
virtual base::Status deserialize(rapidjson::Value &json)
Definition: op_param.h:942
virtual base::Status serialize(rapidjson::Value &json, rapidjson::Document::AllocatorType &allocator)
Definition: op_param.h:936
virtual base::Status serialize(rapidjson::Value &json, rapidjson::Document::AllocatorType &allocator)
Definition: op_param.h:1395
virtual base::Status deserialize(rapidjson::Value &json)
Definition: op_param.h:1401
virtual base::Status serialize(rapidjson::Value &json, rapidjson::Document::AllocatorType &allocator)
Definition: op_param.h:984
virtual base::Status deserialize(rapidjson::Value &json)
Definition: op_param.h:993
virtual base::Status serialize(rapidjson::Value &json, rapidjson::Document::AllocatorType &allocator)
Definition: op_param.h:385
virtual base::Status deserialize(rapidjson::Value &json)
Definition: op_param.h:392
virtual base::Status deserialize(rapidjson::Value &json)
Definition: op_param.h:583
virtual base::Status serialize(rapidjson::Value &json, rapidjson::Document::AllocatorType &allocator)
Definition: op_param.h:549
std::vector< int > kernel_shape_
Definition: op_param.h:645
算子参数的创建类
Definition: op_param.h:227
virtual std::shared_ptr< base::Param > createOpParam(OpType type)=0
virtual ~OpParam()
Definition: op_param.h:289
virtual base::Status deserialize(rapidjson::Value &json)
Definition: op_param.h:1132
std::vector< int > kernel_shape_
Definition: op_param.h:1192
virtual base::Status serialize(rapidjson::Value &json, rapidjson::Document::AllocatorType &allocator)
Definition: op_param.h:1104
virtual base::Status serialize(rapidjson::Value &json, rapidjson::Document::AllocatorType &allocator)
Definition: op_param.h:1037
virtual base::Status deserialize(rapidjson::Value &json)
Definition: op_param.h:1044
virtual base::Status deserialize(rapidjson::Value &json)
Definition: op_param.h:906
virtual base::Status serialize(rapidjson::Value &json, rapidjson::Document::AllocatorType &allocator)
Definition: op_param.h:899
virtual base::Status deserialize(rapidjson::Value &json)
Definition: op_param.h:1467
virtual base::Status serialize(rapidjson::Value &json, rapidjson::Document::AllocatorType &allocator)
Definition: op_param.h:1460
virtual base::Status deserialize(rapidjson::Value &json)
Definition: op_param.h:1430
virtual base::Status serialize(rapidjson::Value &json, rapidjson::Document::AllocatorType &allocator)
Definition: op_param.h:1423
virtual base::Status serialize(rapidjson::Value &json, rapidjson::Document::AllocatorType &allocator)
Definition: op_param.h:1497
virtual base::Status deserialize(rapidjson::Value &json)
Definition: op_param.h:1504
virtual base::Status deserialize(rapidjson::Value &json)
Definition: op_param.h:1541
virtual base::Status serialize(rapidjson::Value &json, rapidjson::Document::AllocatorType &allocator)
Definition: op_param.h:1534
virtual base::Status deserialize(rapidjson::Value &json)
Definition: op_param.h:667
virtual base::Status serialize(rapidjson::Value &json, rapidjson::Document::AllocatorType &allocator)
Definition: op_param.h:661
virtual base::Status serialize(rapidjson::Value &json, rapidjson::Document::AllocatorType &allocator)
Definition: op_param.h:691
virtual base::Status deserialize(rapidjson::Value &json)
Definition: op_param.h:714
virtual base::Status deserialize(rapidjson::Value &json)
Definition: op_param.h:1578
virtual base::Status serialize(rapidjson::Value &json, rapidjson::Document::AllocatorType &allocator)
Definition: op_param.h:1571
virtual base::Status serialize(rapidjson::Value &json, rapidjson::Document::AllocatorType &allocator)
Definition: op_param.h:795
virtual base::Status deserialize(rapidjson::Value &json)
Definition: op_param.h:801
virtual base::Status deserialize(rapidjson::Value &json)
Definition: op_param.h:832
virtual base::Status serialize(rapidjson::Value &json, rapidjson::Document::AllocatorType &allocator)
Definition: op_param.h:825
virtual base::Status deserialize(rapidjson::Value &json)
Definition: op_param.h:873
std::vector< int > perm_
Definition: op_param.h:887
virtual base::Status serialize(rapidjson::Value &json, rapidjson::Document::AllocatorType &allocator)
Definition: op_param.h:863
算子参数的创建类模板
Definition: op_param.h:239
算子参数的创建类的注册类模板
Definition: op_param.h:259
virtual base::Status deserialize(rapidjson::Value &json)
Definition: op_param.h:1373
virtual base::Status serialize(rapidjson::Value &json, rapidjson::Document::AllocatorType &allocator)
Definition: op_param.h:1367
#define NNDEPLOY_CC_API
api
Definition: macro.h:29
@ kStatusCodeOk
Definition: status.h:13
DataType dataTypeOf< float >()
base::Status serialize(Graph *graph, rapidjson::Value &json, rapidjson::Document::AllocatorType &allocator)
std::map< OpType, std::shared_ptr< OpParamCreator > > & getGlobalOpParamCreatorMap()
Get the Global base::Param Creator Map object.
OpType stringToOpType(const std::string &op_type_name)
std::string opTypeToString(OpType op_type)
OpType
算子类型 算子分类
Definition: op_param.h:65
@ kOpTypeSequenceEmpty
Definition: op_param.h:180
@ kOpTypeConstant
Definition: op_param.h:83
@ kOpTypeExpand
Definition: op_param.h:98
@ kOpTypeNonMaxSuppression
Definition: op_param.h:139
@ kOpTypeQuantizeLinear
Definition: op_param.h:150
@ kOpTypeTranspose
Definition: op_param.h:207
@ kOpTypeGlobalAveragePool
Definition: op_param.h:104
@ kOpTypeRMSNorm
Definition: op_param.h:213
@ kOpTypeConstantOfShape
Definition: op_param.h:87
@ kOpTypeSequenceConstruct
Definition: op_param.h:179
@ kOpTypeLeakyRelu
Definition: op_param.h:117
@ kOpTypeDropout
Definition: op_param.h:92
@ kOpTypeMomentum
Definition: op_param.h:133
@ kOpTypeIdentity
Definition: op_param.h:109
@ kOpTypeSoftsign
Definition: op_param.h:194
@ kOpTypeOnesLike
Definition: op_param.h:143
@ kOpTypeMultinomial
Definition: op_param.h:136
@ kOpTypeSoftplus
Definition: op_param.h:193
@ kOpTypeSequenceLength
Definition: op_param.h:183
@ kOpTypeMaxUnpool
Definition: op_param.h:129
@ kOpTypeNegLogSoftmax
Definition: op_param.h:138
@ kOpTypeReduceLogSumExp
Definition: op_param.h:161
@ kOpTypeUnsqueeze
Definition: op_param.h:208
@ kOpTypeGlobalLpPool
Definition: op_param.h:105
@ kOpTypeArgMin
Definition: op_param.h:74
@ kOpTypeBatchNormalization
Definition: op_param.h:78
@ kOpTypeFlatten
Definition: op_param.h:99
@ kOpTypeReduceProd
Definition: op_param.h:165
@ kOpTypeMatMulInteger
Definition: op_param.h:125
@ kOpTypeGreater
Definition: op_param.h:107
@ kOpTypeSqueeze
Definition: op_param.h:198
@ kOpTypeArgMax
Definition: op_param.h:73
@ kOpTypeRandomNormalLike
Definition: op_param.h:153
@ kOpTypeDequantizeLinear
Definition: op_param.h:89
@ kOpTypeUpsample
Definition: op_param.h:209
@ kOpTypeAveragePool
Definition: op_param.h:77
@ kOpTypeLogSoftmax
Definition: op_param.h:120
@ kOpTypeEinsum
Definition: op_param.h:93
@ kOpTypeThresholdedRelu
Definition: op_param.h:204
@ kOpTypeReduceSum
Definition: op_param.h:166
@ kOpTypeScatter
Definition: op_param.h:176
@ kOpTypeImageScaler
Definition: op_param.h:111
@ kOpTypeReciprocal
Definition: op_param.h:157
@ kOpTypeRandomNormal
Definition: op_param.h:152
@ kOpTypeRoiAlign
Definition: op_param.h:172
@ kOpTypeGlobalMaxPool
Definition: op_param.h:106
@ kOpTypeReduceL1
Definition: op_param.h:158
@ kOpTypeSoftmax
Definition: op_param.h:192
@ kOpTypeReduceMean
Definition: op_param.h:163
@ kOpTypeSigmoid
Definition: op_param.h:186
@ kOpTypeEmbedding
Definition: op_param.h:214
@ kOpTypeRandomUniformLike
Definition: op_param.h:155
@ kOpTypeReduceL2
Definition: op_param.h:159
@ kOpTypeDepthToSpace
Definition: op_param.h:88
@ kOpTypeSequenceInsert
Definition: op_param.h:182
@ kOpTypeReduceLogSum
Definition: op_param.h:160
@ kOpTypeQLinearMatMul
Definition: op_param.h:149
@ kOpTypeReduceSumSquare
Definition: op_param.h:167
@ kOpTypeReduceMin
Definition: op_param.h:164
@ kOpTypeSequenceErase
Definition: op_param.h:181
@ kOpTypeRandomUniform
Definition: op_param.h:154
@ kOpTypeQLinearConv
Definition: op_param.h:148
@ kOpTypeInstanceNormalization
Definition: op_param.h:112
@ kOpTypeEqual
Definition: op_param.h:95
@ kOpTypeMaxRoiPool
Definition: op_param.h:128
@ kOpTypeReduceMax
Definition: op_param.h:162
@ kOpTypeSpaceToDepth
Definition: op_param.h:195
@ kOpTypeHardSigmoid
Definition: op_param.h:108
@ kOpTypeSequenceAt
Definition: op_param.h:178
@ kOpTypeConcat
Definition: op_param.h:82
@ kOpTypeMaxPool
Definition: op_param.h:127
@ kOpTypeReverseSequence
Definition: op_param.h:171
@ kOpTypeLpNormalization
Definition: op_param.h:122
@ kOpTypeReshape
Definition: op_param.h:169
@ kOpTypeNonZero
Definition: op_param.h:140
std::shared_ptr< base::Param > createOpParam(OpType op_type)
Create a base::Param object.
#define PARAM_COPY_TO(param_type)
Definition: param.h:25
#define PARAM_COPY(param_type)
Definition: param.h:16