nndeploy C++ API  0.2.0
nndeploy C++ API
op_param.h
Go to the documentation of this file.
1 
2 #ifndef _NNDEPLOY_IR_OP_PARAM_H_
3 #define _NNDEPLOY_IR_OP_PARAM_H_
4 
5 #include "nndeploy/base/any.h"
6 #include "nndeploy/base/common.h"
8 #include "nndeploy/base/log.h"
9 #include "nndeploy/base/macro.h"
10 #include "nndeploy/base/object.h"
11 #include "nndeploy/base/param.h"
13 #include "nndeploy/base/status.h"
14 #include "nndeploy/base/string.h"
15 #include "nndeploy/device/tensor.h"
16 #include "rapidjson/document.h"
17 #include "rapidjson/stringbuffer.h"
18 #include "rapidjson/writer.h"
19 
20 #ifdef _MSC_VER
21 #pragma warning(push) // 保存当前警告状态
22 #pragma warning(disable : 4267) // 禁用 C4267 警告
23 #endif
24 
25 namespace nndeploy {
26 namespace ir {
65 enum OpType : int {
66  kOpTypeNet = 0x0000,
67 
83  kOpTypeConstant, // 该算子转换为权重
214 
218 
220 };
221 
222 NNDEPLOY_CC_API std::string opTypeToString(OpType op_type);
223 
224 NNDEPLOY_CC_API OpType stringToOpType(const std::string &op_type_name);
225 
231  public:
232  virtual ~OpParamCreator() {};
233  virtual std::shared_ptr<base::Param> createOpParam(OpType type) = 0;
234 };
235 
241 template <typename T>
243  virtual std::shared_ptr<base::Param> createOpParam(OpType type) {
244  return std::make_shared<T>();
245  }
246 };
247 
253 extern NNDEPLOY_CC_API std::map<OpType, std::shared_ptr<OpParamCreator>> &
255 
261 template <typename T>
263  public:
264  explicit TypeOpParamRegister(OpType type) {
265  getGlobalOpParamCreatorMap()[type] = std::shared_ptr<T>(new T());
266  }
267 };
268 
275 extern NNDEPLOY_CC_API std::shared_ptr<base::Param> createOpParam(
276  OpType op_type);
277 
278 #define REGISTER_OP_PARAM_IMPLEMENTION(op_type, op_param_class) \
279  TypeOpParamRegister<TypeOpParamCreator<op_param_class>> \
280  g_##op_type##_##op_param_class##_register(op_type);
281 
290  public:
291  OpParam() : base::Param(), reserved_(0) {};
292  virtual ~OpParam() {};
293 
296 
297  public:
298  // 保留字段,也可以充void *使用
299  size_t reserved_;
300 };
301 
303  public:
306 
309 
312  rapidjson::Value &json, rapidjson::Document::AllocatorType &allocator) {
313  json.AddMember("epsilon_", epsilon_, allocator);
314  json.AddMember("momentum_", momentum_, allocator);
315  json.AddMember("training_mode_", training_mode_, allocator);
316  return base::kStatusCodeOk;
317  }
319  virtual base::Status deserialize(rapidjson::Value &json) {
320  if (json.HasMember("epsilon_")) {
321  epsilon_ = json["epsilon_"].GetFloat();
322  } else {
323  epsilon_ = 1e-05f; // 默认值
324  }
325 
326  if (json.HasMember("momentum_")) {
327  momentum_ = json["momentum_"].GetFloat();
328  } else {
329  momentum_ = 0.9f; // 默认值
330  }
331 
332  if (json.HasMember("training_mode_")) {
333  training_mode_ = json["training_mode_"].GetInt();
334  } else {
335  training_mode_ = 0; // 默认值
336  }
337 
338  return base::kStatusCodeOk;
339  }
340 
341  public:
342  // The epsilon value to use to avoid division by zero.
343  float epsilon_ = 1e-05f;
344  // Factor used in computing the running mean and variance.e.g., running_mean =
345  // running_mean * momentum + mean * (1 - momentum).
346  float momentum_ = 0.9f;
347  int training_mode_ = 0;
348 };
349 
350 class ConcatParam : public OpParam {
351  public:
353  virtual ~ConcatParam() {};
354 
357 
360  rapidjson::Value &json, rapidjson::Document::AllocatorType &allocator) {
361  json.AddMember("axis_", axis_, allocator);
362  return base::kStatusCodeOk;
363  }
365  virtual base::Status deserialize(rapidjson::Value &json) {
366  if (json.HasMember("axis_")) {
367  axis_ = json["axis_"].GetInt();
368  } else {
369  axis_ = 1;
370  }
371 
372  return base::kStatusCodeOk;
373  }
374 
375  public:
376  int axis_ = 1; // 拼接的维度
377 };
378 
379 class MatMulParam : public OpParam {
380  public:
382  virtual ~MatMulParam() {};
383 
386 
389  rapidjson::Value &json, rapidjson::Document::AllocatorType &allocator) {
390  json.AddMember("transposeA_", transposeA_, allocator);
391  json.AddMember("transposeB_", transposeB_, allocator);
392  return base::kStatusCodeOk;
393  }
395  virtual base::Status deserialize(rapidjson::Value &json) {
396  if (json.HasMember("transposeA_")) {
397  transposeA_ = json["transposeA_"].GetBool();
398  } else {
399  transposeA_ = false;
400  }
401  if (json.HasMember("transposeB_")) {
402  transposeB_ = json["transposeB_"].GetBool();
403  } else {
404  transposeB_ = false;
405  }
406 
407  return base::kStatusCodeOk;
408  }
409 
410  public:
411  bool transposeA_ = false; // 是否转置A矩阵
412  bool transposeB_ = false; // 是否转置A矩阵
413 };
414 
416  public:
417  // 构造函数
419  virtual ~ConvParam() {}
420 
423 
426  rapidjson::Value &json, rapidjson::Document::AllocatorType &allocator) {
427  json.AddMember("auto_pad_", rapidjson::Value(auto_pad_.c_str(), allocator),
428  allocator);
429  json.AddMember("dilations_", rapidjson::Value(rapidjson::kArrayType),
430  allocator);
431  for (size_t i = 0; i < dilations_.size(); ++i) {
432  json["dilations_"].PushBack(dilations_[i], allocator);
433  }
434  json.AddMember("group_", group_, allocator);
435  json.AddMember("kernel_shape_", rapidjson::Value(rapidjson::kArrayType),
436  allocator);
437  for (size_t i = 0; i < kernel_shape_.size(); ++i) {
438  json["kernel_shape_"].PushBack(kernel_shape_[i], allocator);
439  }
440  json.AddMember("pads_", rapidjson::Value(rapidjson::kArrayType), allocator);
441  for (size_t i = 0; i < pads_.size(); ++i) {
442  json["pads_"].PushBack(pads_[i], allocator);
443  }
444  json.AddMember("strides_", rapidjson::Value(rapidjson::kArrayType),
445  allocator);
446  for (size_t i = 0; i < strides_.size(); ++i) {
447  json["strides_"].PushBack(strides_[i], allocator);
448  }
449  json.AddMember(
450  "activate_op_",
451  rapidjson::Value(opTypeToString(activate_op_).c_str(), allocator),
452  allocator);
453  if (activate_op_ != kOpTypeNone && fused_op_param_ != nullptr) {
454  rapidjson::Value op_desc_json(rapidjson::kObjectType);
455  fused_op_param_->serialize(op_desc_json, allocator);
456  json.AddMember("fused_op_param_", op_desc_json, allocator);
457  }
458  return base::kStatusCodeOk;
459  }
461  virtual base::Status deserialize(rapidjson::Value &json) {
462  if (json.HasMember("auto_pad_")) {
463  auto_pad_ = json["auto_pad_"].GetString();
464  } else {
465  auto_pad_ = "NOTSET";
466  }
467 
468  if (json.HasMember("dilations_")) {
469  dilations_.clear();
470  for (size_t i = 0; i < json["dilations_"].Size(); ++i) {
471  dilations_.push_back(json["dilations_"][i].GetInt());
472  }
473  } else {
474  dilations_ = {1, 1};
475  }
476 
477  if (json.HasMember("group_")) {
478  group_ = json["group_"].GetInt();
479  } else {
480  group_ = 1;
481  }
482 
483  if (json.HasMember("kernel_shape_")) {
484  kernel_shape_.clear();
485  for (size_t i = 0; i < json["kernel_shape_"].Size(); ++i) {
486  kernel_shape_.push_back(json["kernel_shape_"][i].GetInt());
487  }
488  } else {
489  kernel_shape_.clear();
490  }
491 
492  if (json.HasMember("pads_")) {
493  pads_.clear();
494  for (size_t i = 0; i < json["pads_"].Size(); ++i) {
495  pads_.push_back(json["pads_"][i].GetInt());
496  }
497  } else {
498  pads_ = {0, 0, 0, 0};
499  }
500 
501  if (json.HasMember("strides_")) {
502  strides_.clear();
503  for (size_t i = 0; i < json["strides_"].Size(); ++i) {
504  strides_.push_back(json["strides_"][i].GetInt());
505  }
506  } else {
507  strides_ = {1, 1};
508  }
509 
510  if (json.HasMember("activate_op_")) {
511  activate_op_ = stringToOpType(json["activate_op_"].GetString());
512  fused_op_param_ = createOpParam(activate_op_);
513  if (json.HasMember("fused_op_param_")) {
514  fused_op_param_->deserialize(json["fused_op_param_"]);
515  }
516  } else {
517  activate_op_ = kOpTypeNone;
518  }
519 
520  return base::kStatusCodeOk;
521  }
522 
523  public:
524  // 自动填充方式
525  std::string auto_pad_ = "NOTSET";
526  // 扩张系数
527  std::vector<int> dilations_ = {1, 1};
528  // 组数
529  int group_ = 1;
530  // 卷积核大小
531  std::vector<int> kernel_shape_;
532  // 填充大小
533  std::vector<int> pads_ = {0, 0, 0, 0};
534  // 卷积步长
535  std::vector<int> strides_ = {1, 1};
536 
537  // 服务与算子融合
538  OpType activate_op_ = kOpTypeNone;
539  // OpParam* fused_op_param_ = nullptr;
540  std::shared_ptr<base::Param> fused_op_param_ = nullptr;
541 };
542 // MaxPool 参数类
544  public:
546  virtual ~MaxPoolParam() {}
547 
550 
553  rapidjson::Value &json, rapidjson::Document::AllocatorType &allocator) {
554  json.AddMember("auto_pad_", rapidjson::Value(auto_pad_.c_str(), allocator),
555  allocator);
556  json.AddMember("ceil_mode_", ceil_mode_, allocator);
557  rapidjson::Value dilations_array(rapidjson::kArrayType);
558  for (size_t i = 0; i < dilations_.size(); ++i) {
559  dilations_array.PushBack(dilations_[i], allocator);
560  }
561  json.AddMember("dilations_", dilations_array, allocator);
562 
563  rapidjson::Value kernel_shape_array(rapidjson::kArrayType);
564  for (size_t i = 0; i < kernel_shape_.size(); ++i) {
565  kernel_shape_array.PushBack(kernel_shape_[i], allocator);
566  }
567  json.AddMember("kernel_shape_", kernel_shape_array, allocator);
568 
569  rapidjson::Value pads_array(rapidjson::kArrayType);
570  for (size_t i = 0; i < pads_.size(); ++i) {
571  pads_array.PushBack(pads_[i], allocator);
572  }
573  json.AddMember("pads_", pads_array, allocator);
574 
575  json.AddMember("storage_order_", storage_order_, allocator);
576 
577  rapidjson::Value strides_array(rapidjson::kArrayType);
578  for (size_t i = 0; i < strides_.size(); ++i) {
579  strides_array.PushBack(strides_[i], allocator);
580  }
581  json.AddMember("strides_", strides_array, allocator);
582 
583  return base::kStatusCodeOk;
584  }
586  virtual base::Status deserialize(rapidjson::Value &json) {
587  if (json.HasMember("auto_pad_")) {
588  auto_pad_ = json["auto_pad_"].GetString();
589  } else {
590  auto_pad_ = "NOTSET";
591  }
592 
593  if (json.HasMember("ceil_mode_")) {
594  ceil_mode_ = json["ceil_mode_"].GetInt();
595  } else {
596  ceil_mode_ = 0;
597  }
598 
599  if (json.HasMember("dilations_")) {
600  dilations_.clear();
601  for (size_t i = 0; i < json["dilations_"].Size(); ++i) {
602  dilations_.push_back(json["dilations_"][i].GetInt());
603  }
604  } else {
605  dilations_ = {1, 1};
606  }
607 
608  if (json.HasMember("kernel_shape_")) {
609  kernel_shape_.clear();
610  for (size_t i = 0; i < json["kernel_shape_"].Size(); ++i) {
611  kernel_shape_.push_back(json["kernel_shape_"][i].GetInt());
612  }
613  } else {
614  kernel_shape_.clear();
615  }
616 
617  if (json.HasMember("pads_")) {
618  pads_.clear();
619  for (size_t i = 0; i < json["pads_"].Size(); ++i) {
620  pads_.push_back(json["pads_"][i].GetInt());
621  }
622  } else {
623  pads_ = {0, 0, 0, 0};
624  }
625 
626  if (json.HasMember("storage_order_")) {
627  storage_order_ = json["storage_order_"].GetInt();
628  } else {
629  storage_order_ = 0;
630  }
631 
632  if (json.HasMember("strides_")) {
633  strides_.clear();
634  for (size_t i = 0; i < json["strides_"].Size(); ++i) {
635  strides_.push_back(json["strides_"][i].GetInt());
636  }
637  } else {
638  strides_ = {1, 1};
639  }
640 
641  return base::kStatusCodeOk;
642  }
643 
644  public:
645  std::string auto_pad_ = "NOTSET"; // 自动填充方式
646  int ceil_mode_ = 0; // 是否向上取整
647  std::vector<int> dilations_ = {1, 1}; // 扩张系数
648  std::vector<int> kernel_shape_; // 池化核大小
649  std::vector<int> pads_ = {0, 0, 0, 0}; // 填充大小
650  int storage_order_ = 0; // 存储顺序
651  std::vector<int> strides_ = {1, 1}; // 步长
652 };
653 
654 // Reshape 参数类
656  public:
658  virtual ~ReshapeParam() {}
659 
662 
665  rapidjson::Value &json, rapidjson::Document::AllocatorType &allocator) {
666  json.AddMember("allowzero_", allowzero_, allocator);
667  return base::kStatusCodeOk;
668  }
670  virtual base::Status deserialize(rapidjson::Value &json) {
671  if (json.HasMember("allowzero_")) {
672  allowzero_ = json["allowzero_"].GetInt();
673  } else {
674  allowzero_ = 0; // 默认值
675  }
676 
677  return base::kStatusCodeOk;
678  }
679 
680  public:
681  int allowzero_ = 0; // 是否允许0
682 };
683 
684 // Resize 参数类 - opset 18~19
686  public:
688  virtual ~ResizeParam() {}
689 
692 
695  rapidjson::Value &json, rapidjson::Document::AllocatorType &allocator) {
696  json.AddMember("antialias_", antialias_, allocator);
697  json.AddMember("axes_", axes_, allocator);
698  json.AddMember(
699  "coordinate_transformation_mode_",
700  rapidjson::Value(coordinate_transformation_mode_.c_str(), allocator),
701  allocator);
702  json.AddMember("cubic_coeff_a_", cubic_coeff_a_, allocator);
703  json.AddMember("exclude_outside_", exclude_outside_, allocator);
704  json.AddMember("extrapolation_value_", extrapolation_value_, allocator);
705  json.AddMember(
706  "keep_aspect_ratio_policy_",
707  rapidjson::Value(keep_aspect_ratio_policy_.c_str(), allocator),
708  allocator);
709  json.AddMember("mode_", rapidjson::Value(mode_.c_str(), allocator),
710  allocator);
711  json.AddMember("nearest_mode_",
712  rapidjson::Value(nearest_mode_.c_str(), allocator),
713  allocator);
714  return base::kStatusCodeOk;
715  }
717  virtual base::Status deserialize(rapidjson::Value &json) {
718  if (json.HasMember("antialias_")) {
719  antialias_ = json["antialias_"].GetInt();
720  } else {
721  antialias_ = 0; // 默认值
722  }
723 
724  if (json.HasMember("axes_")) {
725  axes_ = json["axes_"].GetInt();
726  } else {
727  axes_ = INT_MAX; // 默认值
728  }
729 
730  if (json.HasMember("coordinate_transformation_mode_")) {
731  coordinate_transformation_mode_ =
732  json["coordinate_transformation_mode_"].GetString();
733  } else {
734  coordinate_transformation_mode_ = "half_pixel"; // 默认值
735  }
736 
737  if (json.HasMember("cubic_coeff_a_")) {
738  cubic_coeff_a_ = json["cubic_coeff_a_"].GetFloat();
739  } else {
740  cubic_coeff_a_ = -0.75; // 默认值
741  }
742 
743  if (json.HasMember("exclude_outside_")) {
744  exclude_outside_ = json["exclude_outside_"].GetInt();
745  } else {
746  exclude_outside_ = 0; // 默认值
747  }
748 
749  if (json.HasMember("extrapolation_value_")) {
750  extrapolation_value_ = json["extrapolation_value_"].GetFloat();
751  } else {
752  extrapolation_value_ = -0.0; // 默认值
753  }
754 
755  if (json.HasMember("keep_aspect_ratio_policy_")) {
756  keep_aspect_ratio_policy_ = json["keep_aspect_ratio_policy_"].GetString();
757  } else {
758  keep_aspect_ratio_policy_ = "stretch"; // 默认值
759  }
760 
761  if (json.HasMember("mode_")) {
762  mode_ = json["mode_"].GetString();
763  } else {
764  mode_ = "nearest"; // 默认值
765  }
766 
767  if (json.HasMember("nearest_mode_")) {
768  nearest_mode_ = json["nearest_mode_"].GetString();
769  } else {
770  nearest_mode_ = "round_prefer_floor"; // 默认值
771  }
772 
773  return base::kStatusCodeOk;
774  }
775 
776  public:
777  int antialias_ = 0;
778  int axes_ = INT_MAX; // 轴,当为INT_MAX时,表示未设置
779  std::string coordinate_transformation_mode_ = "half_pixel";
780  float cubic_coeff_a_ = -0.75;
781  int exclude_outside_ = 0;
782  float extrapolation_value_ = -0.0;
783  std::string keep_aspect_ratio_policy_ = "stretch";
784  std::string mode_ = "nearest";
785  std::string nearest_mode_ = "round_prefer_floor";
786 };
787 
788 // Softmax 参数类
790  public:
792  virtual ~SoftmaxParam() {}
793 
796 
799  rapidjson::Value &json, rapidjson::Document::AllocatorType &allocator) {
800  json.AddMember("axis_", axis_, allocator);
801  return base::kStatusCodeOk;
802  }
804  virtual base::Status deserialize(rapidjson::Value &json) {
805  if (json.HasMember("axis_")) {
806  axis_ = json["axis_"].GetInt();
807  } else {
808  axis_ = -1;
809  }
810 
811  return base::kStatusCodeOk;
812  }
813 
814  public:
815  int axis_ = -1; // 应用 Softmax 的轴
816 };
817 
818 // Split 参数类
820  public:
821  SplitParam() : OpParam() {} // 默认轴为0,分割数为1
822  virtual ~SplitParam() {}
823 
826 
829  rapidjson::Value &json, rapidjson::Document::AllocatorType &allocator) {
830  json.AddMember("axis_", axis_, allocator);
831  json.AddMember("num_outputs_", num_outputs_, allocator);
832  return base::kStatusCodeOk;
833  }
835  virtual base::Status deserialize(rapidjson::Value &json) {
836  if (json.HasMember("axis_")) {
837  axis_ = json["axis_"].GetInt();
838  } else {
839  axis_ = 0; // 默认值
840  }
841 
842  if (json.HasMember("num_outputs_")) {
843  num_outputs_ = json["num_outputs_"].GetInt();
844  } else {
845  num_outputs_ = INT_MAX; // 默认值
846  }
847 
848  return base::kStatusCodeOk;
849  }
850 
851  public:
852  int axis_ = 0; // 分割轴
853  int num_outputs_ = INT_MAX; // 分割数
854 };
855 
856 // Transpose 参数类
858  public:
859  TransposeParam() : OpParam() {} // 默认轴为0,分割数为1
860  virtual ~TransposeParam() {}
861 
864 
867  rapidjson::Value &json, rapidjson::Document::AllocatorType &allocator) {
868  rapidjson::Value permArray(rapidjson::kArrayType);
869  for (size_t i = 0; i < perm_.size(); ++i) {
870  permArray.PushBack(perm_[i], allocator);
871  }
872  json.AddMember("perm_", permArray, allocator);
873  return base::kStatusCodeOk;
874  }
876  virtual base::Status deserialize(rapidjson::Value &json) {
877  if (json.HasMember("perm_")) {
878  perm_.clear();
879  for (size_t i = 0; i < json["perm_"].Size(); ++i) {
880  perm_.push_back(json["perm_"][i].GetInt());
881  }
882  } else {
883  perm_.clear(); // 默认值
884  }
885 
886  return base::kStatusCodeOk;
887  }
888 
889  public:
890  std::vector<int> perm_;
891 };
892 
894  public:
895  RMSNormParam() : OpParam() {} // 默认轴为0,分割数为1
896  virtual ~RMSNormParam() {}
897 
900 
903  rapidjson::Value &json, rapidjson::Document::AllocatorType &allocator) {
904  json.AddMember("eps_", eps_, allocator);
905  json.AddMember("is_last_", is_last_, allocator);
906  return base::kStatusCodeOk;
907  }
909  virtual base::Status deserialize(rapidjson::Value &json) {
910  if (json.HasMember("eps_")) {
911  eps_ = json["eps_"].GetFloat();
912  } else {
913  eps_ = 1e-6f; // 默认值
914  }
915 
916  if (json.HasMember("is_last_")) {
917  is_last_ = json["is_last_"].GetBool();
918  } else {
919  is_last_ = false; // 默认值
920  }
921 
922  return base::kStatusCodeOk;
923  }
924 
925  public:
926  float eps_ = 1e-6f;
927  bool is_last_ = false;
928 };
929 
931  public:
934 
937 
940  rapidjson::Value &json, rapidjson::Document::AllocatorType &allocator) {
941  json.AddMember("epsilon_", epsilon_, allocator);
942  return base::kStatusCodeOk;
943  }
945  virtual base::Status deserialize(rapidjson::Value &json) {
946  if (json.HasMember("epsilon_")) {
947  epsilon_ = json["epsilon_"].GetFloat();
948  } else {
949  epsilon_ = 1e-5f;
950  }
951  return base::kStatusCodeOk;
952  }
953 
954  public:
955  float epsilon_ = 1e-5f;
956 };
957 
959  public:
961  virtual ~FlattenParam() {};
962 
965 
968  rapidjson::Value &json, rapidjson::Document::AllocatorType &allocator) {
969  json.AddMember("axis_", axis_, allocator);
970  return base::kStatusCodeOk;
971  }
973  virtual base::Status deserialize(rapidjson::Value &json) {
974  if (json.HasMember("axis_")) {
975  axis_ = json["axis_"].GetInt();
976  } else {
977  axis_ = 1;
978  }
979 
980  return base::kStatusCodeOk;
981  }
982 
983  public:
984  int axis_ = 1;
985 };
986 
988  public:
990  virtual ~EmbeddingParam() {};
991 
994 
997  rapidjson::Value &json, rapidjson::Document::AllocatorType &allocator) {
998  return base::kStatusCodeOk;
999  }
1001  virtual base::Status deserialize(rapidjson::Value &json) {
1002  return base::kStatusCodeOk;
1003  }
1004 };
1005 
1007  public:
1008  GemmParam() : OpParam() {};
1009  virtual ~GemmParam() {};
1010 
1013 
1014  using base::Param::serialize;
1016  rapidjson::Value &json, rapidjson::Document::AllocatorType &allocator) {
1017  json.AddMember("alpha_", alpha_, allocator);
1018  json.AddMember("beta_", beta_, allocator);
1019  json.AddMember("trans_a_", trans_a_, allocator);
1020  json.AddMember("trans_b_", trans_b_, allocator);
1021  return base::kStatusCodeOk;
1022  }
1024  virtual base::Status deserialize(rapidjson::Value &json) {
1025  if (json.HasMember("alpha_")) {
1026  alpha_ = json["alpha_"].GetFloat();
1027  } else {
1028  alpha_ = 1.0; // 默认值
1029  }
1030 
1031  if (json.HasMember("beta_")) {
1032  beta_ = json["beta_"].GetFloat();
1033  } else {
1034  beta_ = 1.0; // 默认值
1035  }
1036 
1037  if (json.HasMember("trans_a_")) {
1038  trans_a_ = json["trans_a_"].GetInt();
1039  } else {
1040  trans_a_ = 0; // 默认值
1041  }
1042 
1043  if (json.HasMember("trans_b_")) {
1044  trans_b_ = json["trans_b_"].GetInt();
1045  } else {
1046  trans_b_ = 0; // 默认值
1047  }
1048 
1049  return base::kStatusCodeOk;
1050  }
1051 
1052  public:
1053  float alpha_ = 1.0; // 默认值为1.0
1054  float beta_ = 1.0; // 默认值为1.0
1055  int trans_a_ = 0; // 默认值为0
1056  int trans_b_ = 0; // 默认值为0
1057 };
1058 
1060  public:
1063 
1066 
1067  using base::Param::serialize;
1069  rapidjson::Value &json, rapidjson::Document::AllocatorType &allocator) {
1070  json.AddMember("axis_", axis_, allocator);
1071  json.AddMember("saturate_", saturate_, allocator);
1072  return base::kStatusCodeOk;
1073  }
1075  virtual base::Status deserialize(rapidjson::Value &json) {
1076  if (json.HasMember("axis_")) {
1077  axis_ = json["axis_"].GetInt();
1078  } else {
1079  axis_ = 1; // 默认值
1080  }
1081 
1082  if (json.HasMember("saturate_")) {
1083  saturate_ = json["saturate_"].GetInt();
1084  } else {
1085  saturate_ = 1; // 默认值
1086  }
1087 
1088  return base::kStatusCodeOk;
1089  }
1090 
1091  public:
1092  int axis_ = 1; // 量化维度,默认为1
1093  int saturate_ = 1; // 是否饱和,默认为1
1094 };
1095 
1097  public:
1100 
1103 
1104  using base::Param::serialize;
1106  rapidjson::Value &json, rapidjson::Document::AllocatorType &allocator) {
1107  json.AddMember("axis_", axis_, allocator);
1108  return base::kStatusCodeOk;
1109  }
1111  virtual base::Status deserialize(rapidjson::Value &json) {
1112  if (json.HasMember("axis_")) {
1113  axis_ = json["axis_"].GetInt();
1114  } else {
1115  axis_ = 1; // 默认值
1116  }
1117 
1118  return base::kStatusCodeOk;
1119  }
1120 
1121  public:
1122  int axis_ = 1; // 反量化维度,默认为1
1123 };
1124 
1126  public:
1127  // 构造函数
1129  virtual ~QLinearConvParam() {}
1130 
1133 
1134  using base::Param::serialize;
1136  rapidjson::Value &json, rapidjson::Document::AllocatorType &allocator) {
1137  json.AddMember("auto_pad_", rapidjson::Value(auto_pad_.c_str(), allocator),
1138  allocator);
1139  json.AddMember("dilations_", rapidjson::Value(rapidjson::kArrayType),
1140  allocator);
1141  for (size_t i = 0; i < dilations_.size(); ++i) {
1142  json["dilations_"].PushBack(dilations_[i], allocator);
1143  }
1144  json.AddMember("group_", group_, allocator);
1145  json.AddMember("kernel_shape_", rapidjson::Value(rapidjson::kArrayType),
1146  allocator);
1147  for (size_t i = 0; i < kernel_shape_.size(); ++i) {
1148  json["kernel_shape_"].PushBack(kernel_shape_[i], allocator);
1149  }
1150  json.AddMember("pads_", rapidjson::Value(rapidjson::kArrayType), allocator);
1151  for (size_t i = 0; i < pads_.size(); ++i) {
1152  json["pads_"].PushBack(pads_[i], allocator);
1153  }
1154  json.AddMember("strides_", rapidjson::Value(rapidjson::kArrayType),
1155  allocator);
1156  for (size_t i = 0; i < strides_.size(); ++i) {
1157  json["strides_"].PushBack(strides_[i], allocator);
1158  }
1159 
1160  return base::kStatusCodeOk;
1161  }
1163  virtual base::Status deserialize(rapidjson::Value &json) {
1164  if (json.HasMember("auto_pad_")) {
1165  auto_pad_ = json["auto_pad_"].GetString();
1166  } else {
1167  auto_pad_ = "NOTSET";
1168  }
1169 
1170  if (json.HasMember("dilations_")) {
1171  dilations_.clear();
1172  for (size_t i = 0; i < json["dilations_"].Size(); ++i) {
1173  dilations_.push_back(json["dilations_"][i].GetInt());
1174  }
1175  } else {
1176  dilations_ = {1, 1};
1177  }
1178 
1179  if (json.HasMember("group_")) {
1180  group_ = json["group_"].GetInt();
1181  } else {
1182  group_ = 1;
1183  }
1184 
1185  if (json.HasMember("kernel_shape_")) {
1186  kernel_shape_.clear();
1187  for (size_t i = 0; i < json["kernel_shape_"].Size(); ++i) {
1188  kernel_shape_.push_back(json["kernel_shape_"][i].GetInt());
1189  }
1190  } else {
1191  kernel_shape_.clear();
1192  }
1193 
1194  if (json.HasMember("pads_")) {
1195  pads_.clear();
1196  for (size_t i = 0; i < json["pads_"].Size(); ++i) {
1197  pads_.push_back(json["pads_"][i].GetInt());
1198  }
1199  } else {
1200  pads_ = {0, 0, 0, 0};
1201  }
1202 
1203  if (json.HasMember("strides_")) {
1204  strides_.clear();
1205  for (size_t i = 0; i < json["strides_"].Size(); ++i) {
1206  strides_.push_back(json["strides_"][i].GetInt());
1207  }
1208  } else {
1209  strides_ = {1, 1};
1210  }
1211 
1212  return base::kStatusCodeOk;
1213  }
1214 
1215  public:
1216  // 自动填充方式
1217  std::string auto_pad_ = "NOTSET";
1218  // 扩张系数
1219  std::vector<int> dilations_ = {1, 1};
1220  // 组数
1221  int group_ = 1;
1222  // 卷积核大小
1223  std::vector<int> kernel_shape_;
1224  // 填充大小
1225  std::vector<int> pads_ = {0, 0, 0, 0};
1226  // 卷积步长
1227  std::vector<int> strides_ = {1, 1};
1228 };
1229 
1231  public:
1232  // 构造函数
1234  virtual ~AveragePoolParam() {}
1235 
1238 
1239  using base::Param::serialize;
1241  rapidjson::Value &json, rapidjson::Document::AllocatorType &allocator) {
1242  json.AddMember("auto_pad_", rapidjson::Value(auto_pad_.c_str(), allocator),
1243  allocator);
1244  json.AddMember("ceil_mode_", ceil_mode_, allocator);
1245  json.AddMember("count_include_pad_",
1246  rapidjson::Value(count_include_pad_.c_str(), allocator),
1247  allocator);
1248  json.AddMember("dilations_", rapidjson::Value(rapidjson::kArrayType),
1249  allocator);
1250  for (size_t i = 0; i < dilations_.size(); ++i) {
1251  json["dilations_"].PushBack(dilations_[i], allocator);
1252  }
1253  json.AddMember("kernel_shape_", rapidjson::Value(rapidjson::kArrayType),
1254  allocator);
1255  for (size_t i = 0; i < kernel_shape_.size(); ++i) {
1256  json["kernel_shape_"].PushBack(kernel_shape_[i], allocator);
1257  }
1258  json.AddMember("pads_", rapidjson::Value(rapidjson::kArrayType), allocator);
1259  for (size_t i = 0; i < pads_.size(); ++i) {
1260  json["pads_"].PushBack(pads_[i], allocator);
1261  }
1262  json.AddMember("strides_", rapidjson::Value(rapidjson::kArrayType),
1263  allocator);
1264  for (size_t i = 0; i < strides_.size(); ++i) {
1265  json["strides_"].PushBack(strides_[i], allocator);
1266  }
1267  return base::kStatusCodeOk;
1268  }
1269 
1271  virtual base::Status deserialize(rapidjson::Value &json) {
1272  if (json.HasMember("auto_pad_")) {
1273  auto_pad_ = json["auto_pad_"].GetString();
1274  } else {
1275  auto_pad_ = "NOTSET";
1276  }
1277 
1278  if (json.HasMember("ceil_mode_")) {
1279  ceil_mode_ = json["ceil_mode_"].GetInt();
1280  } else {
1281  ceil_mode_ = 0;
1282  }
1283 
1284  if (json.HasMember("count_include_pad_")) {
1285  count_include_pad_ = json["count_include_pad_"].GetString();
1286  } else {
1287  count_include_pad_ = "EXCLUDE";
1288  }
1289 
1290  if (json.HasMember("dilations_")) {
1291  dilations_.clear();
1292  for (size_t i = 0; i < json["dilations_"].Size(); ++i) {
1293  dilations_.push_back(json["dilations_"][i].GetInt());
1294  }
1295  } else {
1296  dilations_ = {1, 1};
1297  }
1298 
1299  if (json.HasMember("kernel_shape_")) {
1300  kernel_shape_.clear();
1301  for (size_t i = 0; i < json["kernel_shape_"].Size(); ++i) {
1302  kernel_shape_.push_back(json["kernel_shape_"][i].GetInt());
1303  }
1304  } else {
1305  kernel_shape_.clear();
1306  }
1307 
1308  if (json.HasMember("pads_")) {
1309  pads_.clear();
1310  for (size_t i = 0; i < json["pads_"].Size(); ++i) {
1311  pads_.push_back(json["pads_"][i].GetInt());
1312  }
1313  } else {
1314  pads_ = {0, 0, 0, 0};
1315  }
1316 
1317  if (json.HasMember("strides_")) {
1318  strides_.clear();
1319  for (size_t i = 0; i < json["strides_"].Size(); ++i) {
1320  strides_.push_back(json["strides_"][i].GetInt());
1321  }
1322  } else {
1323  strides_ = {1, 1};
1324  }
1325 
1326  return base::kStatusCodeOk;
1327  }
1328 
1329  public:
1330  // 自动填充方式
1331  std::string auto_pad_ = "NOTSET";
1332  // 是否向上取整
1333  int ceil_mode_ = 0;
1334  // 计算方式
1335  std::string count_include_pad_ = "EXCLUDE";
1336  // 扩张系数
1337  std::vector<int> dilations_ = {1, 1};
1338  // 平均池化的核大小
1339  std::vector<int> kernel_shape_;
1340  // 填充大小
1341  std::vector<int> pads_ = {0, 0, 0, 0};
1342  // 平均池化的步长
1343  std::vector<int> strides_ = {1, 1};
1344 };
1345 
1347  public:
1349  virtual ~CastParam() {}
1350 
1353 
1354  using base::Param::serialize;
1356  rapidjson::Value &json, rapidjson::Document::AllocatorType &allocator) {
1357  json.AddMember("saturate_", saturate_, allocator);
1358  json.AddMember("to_", rapidjson::Value(rapidjson::kArrayType), allocator);
1359  json["to_"].PushBack(static_cast<int32_t>(to_.code_), allocator);
1360  json["to_"].PushBack(static_cast<int32_t>(to_.bits_), allocator);
1361  json["to_"].PushBack(static_cast<int32_t>(to_.lanes_), allocator);
1362  return base::kStatusCodeOk;
1363  }
1365  virtual base::Status deserialize(rapidjson::Value &json) {
1366  if (json.HasMember("saturate_")) {
1367  saturate_ = json["saturate_"].GetInt();
1368  } else {
1369  saturate_ = 1;
1370  }
1371 
1372  if (json.HasMember("to_")) {
1373  to_.code_ = json["to_"][0].GetInt();
1374  to_.bits_ = json["to_"][1].GetInt();
1375  to_.lanes_ = json["to_"][2].GetInt();
1376  } else {
1377  to_ = base::dataTypeOf<float>();
1378  }
1379 
1380  return base::kStatusCodeOk;
1381  };
1382 
1383  public:
1384  int saturate_ =
1385  1; // https://onnx.org.cn/onnx/operators/onnx__Cast.html#cast-19
1386  base::DataType to_; // 输入张量元素将被转换为的数据类型。
1387 };
1388 
1390  public:
1392  virtual ~UnsqueezeParam() {}
1393 
1396 
1397  using base::Param::serialize;
1399  rapidjson::Value &json, rapidjson::Document::AllocatorType &allocator) {
1400  json.AddMember("axes_", axes_, allocator);
1401  return base::kStatusCodeOk;
1402  }
1404  virtual base::Status deserialize(rapidjson::Value &json) {
1405  if (json.HasMember("axes_")) {
1406  axes_ = json["axes_"].GetInt();
1407  } else {
1408  axes_ = 0; // 默认值
1409  }
1410  return base::kStatusCodeOk;
1411  }
1412 
1413  public:
1414  int axes_ = 0; // 指定在哪些维度上增加维度,默认值为0
1415 };
1416 
1418  public:
1419  GatherParam() : OpParam() {} // 默认构造函数
1420  virtual ~GatherParam() {}
1421 
1424 
1425  using base::Param::serialize;
1427  rapidjson::Value &json, rapidjson::Document::AllocatorType &allocator) {
1428  json.AddMember("axis_", axis_, allocator);
1429  return base::kStatusCodeOk;
1430  }
1432  virtual base::Status deserialize(rapidjson::Value &json) {
1433  if (json.HasMember("axis_")) {
1434  axis_ = json["axis_"].GetInt();
1435  } else {
1436  axis_ = 0; // 默认值
1437  }
1438  return base::kStatusCodeOk;
1439  }
1440 
1441  public:
1442  int axis_ = 0; // 用于收集的轴,默认值为0
1443 };
1444 
1446  public:
1448  virtual ~ReduceMeanParam() {}
1449 
1452 
1453  using base::Param::serialize;
1455  rapidjson::Value &json, rapidjson::Document::AllocatorType &allocator) {
1456  json.AddMember("keepdims_", keepdims_, allocator);
1457  json.AddMember("noop_with_empty_axes_", noop_with_empty_axes_, allocator);
1458  return base::kStatusCodeOk;
1459  }
1461  virtual base::Status deserialize(rapidjson::Value &json) {
1462  if (json.HasMember("keepdims_")) {
1463  keepdims_ = json["keepdims_"].GetInt();
1464  } else {
1465  keepdims_ = 1; // 默认值
1466  }
1467 
1468  if (json.HasMember("noop_with_empty_axes_")) {
1469  noop_with_empty_axes_ = json["noop_with_empty_axes_"].GetInt();
1470  } else {
1471  noop_with_empty_axes_ = 0; // 默认值
1472  }
1473 
1474  return base::kStatusCodeOk;
1475  }
1476 
1477  public:
1478  int keepdims_ = 1; // 是否保留被约简的维度,默认为1
1479  int noop_with_empty_axes_ = 0; // 当axes为空时的行为,默认为0
1480 };
1481 
1483  public:
1485  virtual ~ReduceMaxParam() {}
1486 
1489 
1490  using base::Param::serialize;
1492  rapidjson::Value &json, rapidjson::Document::AllocatorType &allocator) {
1493  json.AddMember("keepdims_", keepdims_, allocator);
1494  json.AddMember("noop_with_empty_axes_", noop_with_empty_axes_, allocator);
1495  return base::kStatusCodeOk;
1496  }
1498  virtual base::Status deserialize(rapidjson::Value &json) {
1499  if (json.HasMember("keepdims_")) {
1500  keepdims_ = json["keepdims_"].GetInt();
1501  } else {
1502  keepdims_ = 1; // 默认值
1503  }
1504 
1505  if (json.HasMember("noop_with_empty_axes_")) {
1506  noop_with_empty_axes_ = json["noop_with_empty_axes_"].GetInt();
1507  } else {
1508  noop_with_empty_axes_ = 0; // 默认值
1509  }
1510 
1511  return base::kStatusCodeOk;
1512  }
1513 
1514  public:
1515  int keepdims_ = 1; // 是否保留被约简的维度,默认为1
1516  int noop_with_empty_axes_ = 0; // 当axes为空时的行为,默认为0
1517 };
1518 
1520  public:
1522  virtual ~ReduceMinParam() {}
1523 
1526 
1527  using base::Param::serialize;
1529  rapidjson::Value &json, rapidjson::Document::AllocatorType &allocator) {
1530  json.AddMember("keepdims_", keepdims_, allocator);
1531  json.AddMember("noop_with_empty_axes_", noop_with_empty_axes_, allocator);
1532  return base::kStatusCodeOk;
1533  }
1535  virtual base::Status deserialize(rapidjson::Value &json) {
1536  if (json.HasMember("keepdims_")) {
1537  keepdims_ = json["keepdims_"].GetInt();
1538  } else {
1539  keepdims_ = 1; // 默认值
1540  }
1541 
1542  if (json.HasMember("noop_with_empty_axes_")) {
1543  noop_with_empty_axes_ = json["noop_with_empty_axes_"].GetInt();
1544  } else {
1545  noop_with_empty_axes_ = 0; // 默认值
1546  }
1547 
1548  return base::kStatusCodeOk;
1549  }
1550 
1551  public:
1552  int keepdims_ = 1; // 是否保留被约简的维度,默认为1
1553  int noop_with_empty_axes_ = 0; // 当axes为空时的行为,默认为0
1554 };
1555 
1557  public:
1559  virtual ~ReduceSumParam() {}
1560 
1563 
1564  using base::Param::serialize;
1566  rapidjson::Value &json, rapidjson::Document::AllocatorType &allocator) {
1567  json.AddMember("keepdims_", keepdims_, allocator);
1568  json.AddMember("noop_with_empty_axes_", noop_with_empty_axes_, allocator);
1569  return base::kStatusCodeOk;
1570  }
1572  virtual base::Status deserialize(rapidjson::Value &json) {
1573  if (json.HasMember("keepdims_")) {
1574  keepdims_ = json["keepdims_"].GetInt();
1575  } else {
1576  keepdims_ = 1; // 默认值
1577  }
1578 
1579  if (json.HasMember("noop_with_empty_axes_")) {
1580  noop_with_empty_axes_ = json["noop_with_empty_axes_"].GetInt();
1581  } else {
1582  noop_with_empty_axes_ = 0; // 默认值
1583  }
1584 
1585  return base::kStatusCodeOk;
1586  }
1587 
1588  public:
1589  int keepdims_ = 1; // 是否保留被约简的维度,默认为1
1590  int noop_with_empty_axes_ = 0; // 当axes为空时的行为,默认为0
1591 };
1592 
1594  public:
1595  ShapeParam() : OpParam() {} // 默认构造函数
1596  virtual ~ShapeParam() {}
1597 
1600 
1601  using base::Param::serialize;
1603  rapidjson::Value &json, rapidjson::Document::AllocatorType &allocator) {
1604  json.AddMember("start_", start_, allocator);
1605  json.AddMember("end_", end_, allocator);
1606  return base::kStatusCodeOk;
1607  }
1609  virtual base::Status deserialize(rapidjson::Value &json) {
1610  if (json.HasMember("start_")) {
1611  start_ = json["start_"].GetInt();
1612  } else {
1613  start_ = 0; // 默认值
1614  }
1615 
1616  if (json.HasMember("end_")) {
1617  end_ = json["end_"].GetInt();
1618  } else {
1619  end_ = -1; // 默认值
1620  }
1621 
1622  return base::kStatusCodeOk;
1623  }
1624 
1625  public:
1626  int start_ = 0; // 用于切片形状的起始轴,默认值为0
1627  int end_ =
1628  -1; // 负值表示从后向前计数维度。如果省略,将包含直到(包括)最后一个轴的所有轴的大小。
1629 };
1630 
1632  public:
1635 
1638 
1639  base::Status serialize(rapidjson::Value &json,
1640  rapidjson::Document::AllocatorType &allocator) {
1641  json.AddMember("value_", value_, allocator);
1642  json.AddMember("datatype_", rapidjson::Value(rapidjson::kArrayType),
1643  allocator);
1644  json["datatype_"].PushBack(static_cast<int32_t>(datatype_.code_),
1645  allocator);
1646  json["datatype_"].PushBack(static_cast<int32_t>(datatype_.bits_),
1647  allocator);
1648  json["datatype_"].PushBack(static_cast<int32_t>(datatype_.lanes_),
1649  allocator);
1650  return base::kStatusCodeOk;
1651  }
1652 
1653  base::Status deserialize(rapidjson::Value &json) {
1654  if (json.HasMember("value_")) {
1655  value_ = json["value_"].GetFloat();
1656  } else {
1657  value_ = 0.0f; // 默认值
1658  }
1659 
1660  if (json.HasMember("datatype_")) {
1661  datatype_.code_ = json["datatype_"][0].GetInt();
1662  datatype_.bits_ = json["datatype_"][1].GetInt();
1663  datatype_.lanes_ = json["datatype_"][2].GetInt();
1664  } else {
1665  datatype_ = base::dataTypeOf<float>(); // 默认为 float 类型
1666  }
1667 
1668  return base::kStatusCodeOk;
1669  }
1670 
1671  public:
1672  float value_ = 0.0f; // 默认值为 0.0
1673  base::DataType datatype_ =
1674  base::dataTypeOf<float>(); // 数据类型,默认为 float 类型
1675 };
1676 
1677 } // namespace ir
1678 } // namespace nndeploy
1679 
1680 #ifdef _MSC_VER
1681 #pragma warning(pop) // 恢复之前的警告状态
1682 #endif
1683 
1684 #endif /* _NNDEPLOY_IR_OP_PARAM_H_ */
virtual base::Status deserialize(rapidjson::Value &json)
virtual std::string serialize()
virtual base::Status deserialize(rapidjson::Value &json)
Definition: op_param.h:1271
std::vector< int > kernel_shape_
Definition: op_param.h:1339
virtual base::Status serialize(rapidjson::Value &json, rapidjson::Document::AllocatorType &allocator)
Definition: op_param.h:1240
virtual base::Status deserialize(rapidjson::Value &json)
Definition: op_param.h:319
virtual base::Status serialize(rapidjson::Value &json, rapidjson::Document::AllocatorType &allocator)
Definition: op_param.h:311
virtual base::Status serialize(rapidjson::Value &json, rapidjson::Document::AllocatorType &allocator)
Definition: op_param.h:1355
virtual base::Status deserialize(rapidjson::Value &json)
Definition: op_param.h:1365
base::DataType to_
Definition: op_param.h:1386
virtual base::Status serialize(rapidjson::Value &json, rapidjson::Document::AllocatorType &allocator)
Definition: op_param.h:359
virtual base::Status deserialize(rapidjson::Value &json)
Definition: op_param.h:365
base::Status deserialize(rapidjson::Value &json)
Definition: op_param.h:1653
virtual base::Status serialize(rapidjson::Value &json, rapidjson::Document::AllocatorType &allocator)
Definition: op_param.h:425
std::vector< int > kernel_shape_
Definition: op_param.h:531
virtual base::Status deserialize(rapidjson::Value &json)
Definition: op_param.h:461
virtual base::Status serialize(rapidjson::Value &json, rapidjson::Document::AllocatorType &allocator)
Definition: op_param.h:1105
virtual base::Status deserialize(rapidjson::Value &json)
Definition: op_param.h:1111
virtual base::Status serialize(rapidjson::Value &json, rapidjson::Document::AllocatorType &allocator)
Definition: op_param.h:996
virtual base::Status deserialize(rapidjson::Value &json)
Definition: op_param.h:1001
virtual base::Status deserialize(rapidjson::Value &json)
Definition: op_param.h:973
virtual base::Status serialize(rapidjson::Value &json, rapidjson::Document::AllocatorType &allocator)
Definition: op_param.h:967
virtual base::Status serialize(rapidjson::Value &json, rapidjson::Document::AllocatorType &allocator)
Definition: op_param.h:1426
virtual base::Status deserialize(rapidjson::Value &json)
Definition: op_param.h:1432
virtual base::Status serialize(rapidjson::Value &json, rapidjson::Document::AllocatorType &allocator)
Definition: op_param.h:1015
virtual base::Status deserialize(rapidjson::Value &json)
Definition: op_param.h:1024
virtual base::Status serialize(rapidjson::Value &json, rapidjson::Document::AllocatorType &allocator)
Definition: op_param.h:939
virtual base::Status deserialize(rapidjson::Value &json)
Definition: op_param.h:945
virtual base::Status serialize(rapidjson::Value &json, rapidjson::Document::AllocatorType &allocator)
Definition: op_param.h:388
virtual base::Status deserialize(rapidjson::Value &json)
Definition: op_param.h:395
virtual base::Status deserialize(rapidjson::Value &json)
Definition: op_param.h:586
virtual base::Status serialize(rapidjson::Value &json, rapidjson::Document::AllocatorType &allocator)
Definition: op_param.h:552
std::vector< int > kernel_shape_
Definition: op_param.h:648
算子参数的创建类
Definition: op_param.h:230
virtual std::shared_ptr< base::Param > createOpParam(OpType type)=0
virtual ~OpParam()
Definition: op_param.h:292
virtual base::Status deserialize(rapidjson::Value &json)
Definition: op_param.h:1163
std::vector< int > kernel_shape_
Definition: op_param.h:1223
virtual base::Status serialize(rapidjson::Value &json, rapidjson::Document::AllocatorType &allocator)
Definition: op_param.h:1135
virtual base::Status serialize(rapidjson::Value &json, rapidjson::Document::AllocatorType &allocator)
Definition: op_param.h:1068
virtual base::Status deserialize(rapidjson::Value &json)
Definition: op_param.h:1075
virtual base::Status deserialize(rapidjson::Value &json)
Definition: op_param.h:909
virtual base::Status serialize(rapidjson::Value &json, rapidjson::Document::AllocatorType &allocator)
Definition: op_param.h:902
virtual base::Status deserialize(rapidjson::Value &json)
Definition: op_param.h:1498
virtual base::Status serialize(rapidjson::Value &json, rapidjson::Document::AllocatorType &allocator)
Definition: op_param.h:1491
virtual base::Status deserialize(rapidjson::Value &json)
Definition: op_param.h:1461
virtual base::Status serialize(rapidjson::Value &json, rapidjson::Document::AllocatorType &allocator)
Definition: op_param.h:1454
virtual base::Status serialize(rapidjson::Value &json, rapidjson::Document::AllocatorType &allocator)
Definition: op_param.h:1528
virtual base::Status deserialize(rapidjson::Value &json)
Definition: op_param.h:1535
virtual base::Status deserialize(rapidjson::Value &json)
Definition: op_param.h:1572
virtual base::Status serialize(rapidjson::Value &json, rapidjson::Document::AllocatorType &allocator)
Definition: op_param.h:1565
virtual base::Status deserialize(rapidjson::Value &json)
Definition: op_param.h:670
virtual base::Status serialize(rapidjson::Value &json, rapidjson::Document::AllocatorType &allocator)
Definition: op_param.h:664
virtual base::Status serialize(rapidjson::Value &json, rapidjson::Document::AllocatorType &allocator)
Definition: op_param.h:694
virtual base::Status deserialize(rapidjson::Value &json)
Definition: op_param.h:717
virtual base::Status deserialize(rapidjson::Value &json)
Definition: op_param.h:1609
virtual base::Status serialize(rapidjson::Value &json, rapidjson::Document::AllocatorType &allocator)
Definition: op_param.h:1602
virtual base::Status serialize(rapidjson::Value &json, rapidjson::Document::AllocatorType &allocator)
Definition: op_param.h:798
virtual base::Status deserialize(rapidjson::Value &json)
Definition: op_param.h:804
virtual base::Status deserialize(rapidjson::Value &json)
Definition: op_param.h:835
virtual base::Status serialize(rapidjson::Value &json, rapidjson::Document::AllocatorType &allocator)
Definition: op_param.h:828
virtual base::Status deserialize(rapidjson::Value &json)
Definition: op_param.h:876
std::vector< int > perm_
Definition: op_param.h:890
virtual base::Status serialize(rapidjson::Value &json, rapidjson::Document::AllocatorType &allocator)
Definition: op_param.h:866
算子参数的创建类模板
Definition: op_param.h:242
算子参数的创建类的注册类模板
Definition: op_param.h:262
virtual base::Status deserialize(rapidjson::Value &json)
Definition: op_param.h:1404
virtual base::Status serialize(rapidjson::Value &json, rapidjson::Document::AllocatorType &allocator)
Definition: op_param.h:1398
#define NNDEPLOY_CC_API
api
Definition: macro.h:29
@ kStatusCodeOk
Definition: status.h:13
DataType dataTypeOf< float >()
base::Status serialize(Graph *graph, rapidjson::Value &json, rapidjson::Document::AllocatorType &allocator)
std::map< OpType, std::shared_ptr< OpParamCreator > > & getGlobalOpParamCreatorMap()
Get the Global base::Param Creator Map object.
OpType stringToOpType(const std::string &op_type_name)
std::string opTypeToString(OpType op_type)
OpType
算子类型 算子分类
Definition: op_param.h:65
@ kOpTypeSequenceEmpty
Definition: op_param.h:181
@ kOpTypeConstant
Definition: op_param.h:83
@ kOpTypeExpand
Definition: op_param.h:98
@ kOpTypeNonMaxSuppression
Definition: op_param.h:139
@ kOpTypeQuantizeLinear
Definition: op_param.h:150
@ kOpTypeTranspose
Definition: op_param.h:209
@ kOpTypeGlobalAveragePool
Definition: op_param.h:104
@ kOpTypeRMSNorm
Definition: op_param.h:215
@ kOpTypeConstantOfShape
Definition: op_param.h:87
@ kOpTypeSequenceConstruct
Definition: op_param.h:180
@ kOpTypeLeakyRelu
Definition: op_param.h:117
@ kOpTypeDropout
Definition: op_param.h:92
@ kOpTypeMomentum
Definition: op_param.h:133
@ kOpTypeIdentity
Definition: op_param.h:109
@ kOpTypeSoftsign
Definition: op_param.h:195
@ kOpTypeOnesLike
Definition: op_param.h:143
@ kOpTypeMultinomial
Definition: op_param.h:136
@ kOpTypeSoftplus
Definition: op_param.h:194
@ kOpTypeSequenceLength
Definition: op_param.h:184
@ kOpTypeMaxUnpool
Definition: op_param.h:129
@ kOpTypeNegLogSoftmax
Definition: op_param.h:138
@ kOpTypeReduceLogSumExp
Definition: op_param.h:161
@ kOpTypeUnsqueeze
Definition: op_param.h:210
@ kOpTypeGlobalLpPool
Definition: op_param.h:105
@ kOpTypeArgMin
Definition: op_param.h:74
@ kOpTypeBatchNormalization
Definition: op_param.h:78
@ kOpTypeFlatten
Definition: op_param.h:99
@ kOpTypeReduceProd
Definition: op_param.h:165
@ kOpTypeMatMulInteger
Definition: op_param.h:125
@ kOpTypeGreater
Definition: op_param.h:107
@ kOpTypeSqueeze
Definition: op_param.h:199
@ kOpTypeArgMax
Definition: op_param.h:73
@ kOpTypeRandomNormalLike
Definition: op_param.h:153
@ kOpTypeDequantizeLinear
Definition: op_param.h:89
@ kOpTypeUpsample
Definition: op_param.h:211
@ kOpTypeAveragePool
Definition: op_param.h:77
@ kOpTypeLogSoftmax
Definition: op_param.h:120
@ kOpTypeEinsum
Definition: op_param.h:93
@ kOpTypeThresholdedRelu
Definition: op_param.h:206
@ kOpTypeReduceSum
Definition: op_param.h:166
@ kOpTypeScatter
Definition: op_param.h:176
@ kOpTypeImageScaler
Definition: op_param.h:111
@ kOpTypeReciprocal
Definition: op_param.h:157
@ kOpTypeRandomNormal
Definition: op_param.h:152
@ kOpTypeRoiAlign
Definition: op_param.h:172
@ kOpTypeGlobalMaxPool
Definition: op_param.h:106
@ kOpTypeReduceL1
Definition: op_param.h:158
@ kOpTypeSoftmax
Definition: op_param.h:193
@ kOpTypeReduceMean
Definition: op_param.h:163
@ kOpTypeSigmoid
Definition: op_param.h:187
@ kOpTypeEmbedding
Definition: op_param.h:216
@ kOpTypeRandomUniformLike
Definition: op_param.h:155
@ kOpTypeReduceL2
Definition: op_param.h:159
@ kOpTypeDepthToSpace
Definition: op_param.h:88
@ kOpTypeLayerNormalization
Definition: op_param.h:217
@ kOpTypeSequenceInsert
Definition: op_param.h:183
@ kOpTypeReduceLogSum
Definition: op_param.h:160
@ kOpTypeQLinearMatMul
Definition: op_param.h:149
@ kOpTypeReduceSumSquare
Definition: op_param.h:167
@ kOpTypeReduceMin
Definition: op_param.h:164
@ kOpTypeSequenceErase
Definition: op_param.h:182
@ kOpTypeRandomUniform
Definition: op_param.h:154
@ kOpTypeQLinearConv
Definition: op_param.h:148
@ kOpTypeInstanceNormalization
Definition: op_param.h:112
@ kOpTypeEqual
Definition: op_param.h:95
@ kOpTypeMaxRoiPool
Definition: op_param.h:128
@ kOpTypeReduceMax
Definition: op_param.h:162
@ kOpTypeSpaceToDepth
Definition: op_param.h:196
@ kOpTypeHardSigmoid
Definition: op_param.h:108
@ kOpTypeSequenceAt
Definition: op_param.h:179
@ kOpTypeConcat
Definition: op_param.h:82
@ kOpTypeMaxPool
Definition: op_param.h:127
@ kOpTypeReverseSequence
Definition: op_param.h:171
@ kOpTypeLpNormalization
Definition: op_param.h:122
@ kOpTypeReshape
Definition: op_param.h:169
@ kOpTypeNonZero
Definition: op_param.h:140
std::shared_ptr< base::Param > createOpParam(OpType op_type)
Create a base::Param object.
#define PARAM_COPY_TO(param_type)
Definition: param.h:25
#define PARAM_COPY(param_type)
Definition: param.h:16