nndeploy C++ API  0.2.0
nndeploy C++ API
embedding.h
Go to the documentation of this file.
1 //
2 // embedding.cpp
3 //
4 // Created by MNN on 2023/09/25.
5 // ZhaodeWang
6 //
7 
27 #ifndef _NNDEPLOY_LLM_EMBEDDING_H_
28 #define _NNDEPLOY_LLM_EMBEDDING_H_
29 
30 #include "nndeploy/base/any.h"
31 #include "nndeploy/base/common.h"
33 #include "nndeploy/base/log.h"
34 #include "nndeploy/base/macro.h"
35 #include "nndeploy/base/object.h"
37 #include "nndeploy/base/param.h"
38 #include "nndeploy/base/status.h"
39 #include "nndeploy/base/string.h"
41 #include "nndeploy/dag/edge.h"
42 #include "nndeploy/dag/graph.h"
43 #include "nndeploy/dag/loop.h"
44 #include "nndeploy/dag/node.h"
45 #include "nndeploy/device/buffer.h"
46 #include "nndeploy/device/device.h"
48 #include "nndeploy/device/tensor.h"
49 #include "nndeploy/llm/embedding/diskembedding.hpp"
51 
52 namespace nndeploy {
53 namespace llm {
54 
59  public:
60  EmbeddingParam() = default;
61  virtual ~EmbeddingParam() = default;
62 
65 
66  // 隐藏层维度
67  int hidden_size_ = 4096;
68  // 嵌入权重文件路径
69  std::string embedding_weight_path_ = "";
70  // 量化相关参数
71  bool use_quantization_ = false;
72  //
73  int weight_offset_ = 0;
74  //
75  int a_offset_ = 0;
76  int alpha_size_ = 0;
77  // 量化比特数
78  int quant_bit_ = 8;
79  // 量化块大小
80  int quant_block_ = 0;
81 
82  // other
83  base::DataType data_type_ = base::dataTypeOf<float>();
84  base::DataFormat data_format_ = base::DataFormat::kDataFormatNCHW;
85  std::string share_disk_embedding_key_ = "disk_embedding";
86 
87  std::string getShareKey() {
88  std::string key = "";
89  key += embedding_weight_path_;
90  key += std::to_string(hidden_size_);
91  key += base::dataTypeToString(data_type_);
92  key += base::dataFormatToString(data_format_);
93  key += std::to_string(use_quantization_);
94  key += std::to_string(weight_offset_);
95  key += std::to_string(a_offset_);
96  key += std::to_string(alpha_size_);
97  key += std::to_string(quant_bit_);
98  return key;
99  }
102  rapidjson::Value& json,
103  rapidjson::Document::AllocatorType& allocator) override;
105  virtual base::Status deserialize(rapidjson::Value& json) override;
106 };
107 
121  public:
122  Embedding(const std::string& name, std::vector<dag::Edge*> inputs,
123  std::vector<dag::Edge*> outputs);
124  virtual ~Embedding();
125 
126  virtual base::Status init();
127  virtual base::Status deinit();
128 
129  virtual base::Status run();
130 
131  private:
132  std::shared_ptr<MNN::Transformer::DiskEmbedding> disk_embedding_ = nullptr;
133 };
134 
135 } // namespace llm
136 } // namespace nndeploy
137 
138 #endif
virtual base::Status deserialize(rapidjson::Value &json)
virtual std::string serialize()
Node base class.
Definition: node.h:171
EmbeddingParam - Embedding节点的参数配置
Definition: embedding.h:58
virtual ~EmbeddingParam()=default
virtual base::Status serialize(rapidjson::Value &json, rapidjson::Document::AllocatorType &allocator) override
virtual base::Status deserialize(rapidjson::Value &json) override
Embedding - 词嵌入节点
Definition: embedding.h:120
virtual base::Status run()
Run node (pure virtual function)
virtual base::Status deinit()
Deinitialize node.
virtual base::Status init()
Initialize node.
Embedding(const std::string &name, std::vector< dag::Edge * > inputs, std::vector< dag::Edge * > outputs)
#define NNDEPLOY_CC_API
api
Definition: macro.h:29
std::string dataFormatToString(DataFormat data_format)
DataType dataTypeOf()
Definition: common.h:53
std::string dataTypeToString(DataType data_type)
@ kDataFormatNCHW
Definition: common.h:146
#define PARAM_COPY_TO(param_type)
Definition: param.h:25
#define PARAM_COPY(param_type)
Definition: param.h:16