nndeploy C++ API  0.2.0
nndeploy C++ API
abstract_llm_infer.h
Go to the documentation of this file.
1 
2 #ifndef _NNDEPLOY_LLM_ABSTRACT_LLM_INFER_H_
3 #define _NNDEPLOY_LLM_ABSTRACT_LLM_INFER_H_
4 
5 #include "nndeploy/base/any.h"
6 #include "nndeploy/base/common.h"
8 #include "nndeploy/base/log.h"
9 #include "nndeploy/base/macro.h"
10 #include "nndeploy/base/object.h"
12 #include "nndeploy/base/param.h"
13 #include "nndeploy/base/status.h"
14 #include "nndeploy/base/string.h"
16 #include "nndeploy/dag/edge.h"
17 #include "nndeploy/dag/graph.h"
18 #include "nndeploy/dag/loop.h"
19 #include "nndeploy/dag/node.h"
20 #include "nndeploy/device/buffer.h"
21 #include "nndeploy/device/device.h"
23 #include "nndeploy/device/tensor.h"
24 #include "nndeploy/infer/infer.h"
26 
27 namespace nndeploy {
28 namespace llm {
29 
31  public:
32  AbstractLlmInfer(const std::string &name) : dag::CompositeNode(name) {
33  key_ = "nndeploy::llm::LlmInfer";
34  desc_ =
35  "LLM abstract pipeline: input_tokens -> "
36  "inference -> [logits]";
37  this->setDynamicInput(true);
38  this->setInputTypeInfo<tokenizer::TokenizerIds>("input_tokens");
39  this->setOutputTypeInfo<device::Tensor>("output_logits");
40  }
41  AbstractLlmInfer(const std::string &name, std::vector<dag::Edge *> inputs,
42  std::vector<dag::Edge *> outputs)
43  : dag::CompositeNode(name, inputs, outputs) {
44  key_ = "nndeploy::llm::LlmInfer";
45  desc_ =
46  "LLM abstract pipeline: input_tokens -> "
47  "inference -> [logits]";
48  this->setDynamicInput(true);
49  this->setInputTypeInfo<tokenizer::TokenizerIds>("input_tokens");
50  this->setOutputTypeInfo<device::Tensor>("output_logits");
51  }
52  virtual ~AbstractLlmInfer() {}
53 
54  virtual base::Status run() = 0;
55 
56  virtual base::Status setPrefill(bool is_prefill) {
57  is_prefill_ = is_prefill;
58  return base::kStatusCodeOk;
59  }
61  const std::vector<std::string> &config_path) {
62  config_path_ = config_path;
63  return base::kStatusCodeOk;
64  }
65  virtual base::Status setModelKey(const std::string &model_key) {
66  model_key_ = model_key;
67  return base::kStatusCodeOk;
68  }
69  virtual base::Status setInferKey(const std::string &infer_key) {
70  infer_key_ = infer_key;
71  return base::kStatusCodeOk;
72  }
73 
74  virtual int getMaxSeqLen() { return std::numeric_limits<int>::max(); }
75 
77  const std::vector<int32_t> &kv_init_shape,
79  device::Device *device = device::getDevice(device_type);
80  device::TensorDesc past_kv_desc;
81  past_kv_desc.data_type_ = base::dataTypeOf<float>();
83  past_kv_desc.shape_ = kv_init_shape;
84  device::Tensor *past_kv;
85  past_kv = new device::Tensor(device, past_kv_desc, "past_key_values");
86  return past_kv;
87  }
88 
90  int seq_len, int all_seq_len, base::DataType data_type,
91  base::DataFormat data_format,
93  device::Device *device = device::getDevice(device_type);
94  device::TensorDesc position_ids_desc;
95  position_ids_desc.data_type_ = data_type;
96  position_ids_desc.data_format_ = data_format;
97  position_ids_desc.shape_ = {1, seq_len};
98  device::Tensor *position_ids;
99  position_ids =
100  new device::Tensor(device, position_ids_desc, "position_ids");
101 
102  // only host
103  auto ptr = (int *)position_ids->getData();
104  if (seq_len == 1) {
105  ptr[0] = all_seq_len;
106  } else {
107  for (int i = 0; i < seq_len; i++) {
108  ptr[i] = i + all_seq_len;
109  }
110  }
111 
112  return position_ids;
113  }
114 
116  int seq_len, int all_seq_len, base::DataType data_type,
117  base::DataFormat data_format,
119  int kv_seq_len = all_seq_len + seq_len;
120  if (seq_len == 1) kv_seq_len = seq_len;
121 
122  /* create attetion_mask tensor */
123  device::Device *device = device::getDevice(device_type);
124  device::TensorDesc attention_mask_desc;
125  attention_mask_desc.data_type_ = data_type;
126  attention_mask_desc.data_format_ = data_format;
127  attention_mask_desc.shape_ = {1, 1, seq_len, kv_seq_len};
128  device::Tensor *attention_mask;
129  attention_mask =
130  new device::Tensor(device, attention_mask_desc, "attention_mask");
131 
132  // only host
133  auto ptr = (float *)attention_mask->getData();
134  for (int i = 0; i < seq_len; i++) {
135  for (int j = 0; j < kv_seq_len; j++) {
136  int row = i + all_seq_len;
137  ptr[kv_seq_len * i + j] =
138  (j > row) * std::numeric_limits<float>::lowest();
139  }
140  }
141 
142  return attention_mask;
143  }
144 
147  rapidjson::Value &json,
148  rapidjson::Document::AllocatorType &allocator) override {
149  // 调用父类的序列化方法
150  base::Status status = dag::CompositeNode::serialize(json, allocator);
151  if (status != base::kStatusCodeOk) {
152  return status;
153  }
154 
155  // 序列化 is_prefill_
156  json.AddMember("is_prefill", is_prefill_, allocator);
157 
158  // 序列化 config_path_
159  rapidjson::Value config_path_array(rapidjson::kArrayType);
160  for (const auto &path : config_path_) {
161  rapidjson::Value path_value;
162  path_value.SetString(path.c_str(), path.length(), allocator);
163  config_path_array.PushBack(path_value, allocator);
164  }
165  json.AddMember("config_path", config_path_array, allocator);
166 
167  // 序列化 model_key_
168  rapidjson::Value model_key_value;
169  model_key_value.SetString(model_key_.c_str(), model_key_.length(),
170  allocator);
171  json.AddMember("model_key", model_key_value, allocator);
172 
173  // 序列化 infer_key_
174  rapidjson::Value infer_key_value;
175  infer_key_value.SetString(infer_key_.c_str(), infer_key_.length(),
176  allocator);
177  json.AddMember("infer_key", infer_key_value, allocator);
178 
179  // 序列化输入输出名称
180  // 序列化模型输入
181  rapidjson::Value model_inputs(rapidjson::kArrayType);
182  for (const auto &input : model_inputs_) {
183  model_inputs.PushBack(rapidjson::Value(input.c_str(), allocator),
184  allocator);
185  }
186  json.AddMember("model_inputs", model_inputs, allocator);
187 
188  // 序列化模型输出
189  rapidjson::Value model_outputs(rapidjson::kArrayType);
190  for (const auto &output : model_outputs_) {
191  model_outputs.PushBack(rapidjson::Value(output.c_str(), allocator),
192  allocator);
193  }
194  json.AddMember("model_outputs", model_outputs, allocator);
195 
196  return base::kStatusCodeOk;
197  }
199  virtual base::Status deserialize(rapidjson::Value &json) override {
200  // 调用父类的反序列化方法
202  if (status != base::kStatusCodeOk) {
203  return status;
204  }
205 
206  // 反序列化 is_prefill_
207  if (json.HasMember("is_prefill") && json["is_prefill"].IsBool()) {
208  is_prefill_ = json["is_prefill"].GetBool();
209  }
210 
211  // 反序列化 config_path_
212  if (json.HasMember("config_path") && json["config_path"].IsArray()) {
213  config_path_.clear();
214  const rapidjson::Value &config_path_array = json["config_path"];
215  for (rapidjson::SizeType i = 0; i < config_path_array.Size(); i++) {
216  if (config_path_array[i].IsString()) {
217  config_path_.push_back(config_path_array[i].GetString());
218  }
219  }
220  }
221 
222  // 反序列化 model_key_
223  if (json.HasMember("model_key") && json["model_key"].IsString()) {
224  model_key_ = json["model_key"].GetString();
225  }
226 
227  // 反序列化 infer_key_
228  if (json.HasMember("infer_key") && json["infer_key"].IsString()) {
229  infer_key_ = json["infer_key"].GetString();
230  }
231 
232  // 反序列化模型输入
233  if (json.HasMember("model_inputs") && json["model_inputs"].IsArray()) {
234  model_inputs_.clear();
235  const rapidjson::Value &model_inputs = json["model_inputs"];
236  for (rapidjson::SizeType i = 0; i < model_inputs.Size(); i++) {
237  if (model_inputs[i].IsString()) {
238  model_inputs_.push_back(model_inputs[i].GetString());
239  }
240  }
241  }
242  // 反序列化模型输出
243  if (json.HasMember("model_outputs") && json["model_outputs"].IsArray()) {
244  model_outputs_.clear();
245  const rapidjson::Value &model_outputs = json["model_outputs"];
246  for (rapidjson::SizeType i = 0; i < model_outputs.Size(); i++) {
247  if (model_outputs[i].IsString()) {
248  model_outputs_.push_back(model_outputs[i].GetString());
249  }
250  }
251  }
252 
253  return base::kStatusCodeOk;
254  }
255 
256  std::string getShareKey() {
257  std::string key = "";
258  for (const auto &path : config_path_) {
259  key += path;
260  }
261  key += model_key_;
262  key += infer_key_;
263  return key;
264  }
265 
266  protected:
267  // prefill or decode
268  bool is_prefill_ = true;
269  // config_path
270  std::vector<std::string> config_path_;
271  // qwen or llama...
272  std::string model_key_;
273  // llm::DefaultLlmInfer or llm::MnnLlmInfer
274  std::string infer_key_;
275 
276  // model inputs
277  std::vector<std::string> model_inputs_ = {"input_ids", "attention_mask",
278  "position_ids", "past_key_values"};
279  // model outputs
280  std::vector<std::string> model_outputs_ = {"logits", "presents"};
281 };
282 
283 // 前向声明
284 template <typename T>
285 class TypeLlmInferCreator;
286 
288  public:
289  virtual ~LlmInferCreator() = default;
290  virtual AbstractLlmInfer *createLlmInfer(const std::string &name) = 0;
292  const std::string &name, std::vector<dag::Edge *> inputs,
293  std::vector<dag::Edge *> outputs) = 0;
294 };
295 
296 template <typename T>
298  public:
299  virtual AbstractLlmInfer *createLlmInfer(const std::string &name) override {
300  return new T(name);
301  }
303  const std::string &name, std::vector<dag::Edge *> inputs,
304  std::vector<dag::Edge *> outputs) override {
305  return new T(name, inputs, outputs);
306  }
307 };
308 
310  public:
312  static LlmInferFactory instance;
313  return &instance;
314  }
315 
316  void registerLlmInfer(const std::string &infer_key,
317  const std::string &model_key,
318  std::shared_ptr<LlmInferCreator> creator) {
319  auto it = creators_.find(infer_key);
320  if (it == creators_.end()) {
321  creators_[infer_key] =
322  std::map<std::string, std::shared_ptr<LlmInferCreator>>();
323  }
324 
325  auto model_it = creators_[infer_key].find(model_key);
326  if (model_it != creators_[infer_key].end()) {
327  NNDEPLOY_LOGW("LlmInfer %s@%s already exists, will be overwritten!\n",
328  infer_key.c_str(), model_key.c_str());
329  }
330  creators_[infer_key][model_key] = creator;
331  }
332 
333  std::shared_ptr<LlmInferCreator> getCreator(const std::string &infer_key,
334  const std::string &model_key) {
335  auto infer_it = creators_.find(infer_key);
336  if (infer_it != creators_.end()) {
337  auto model_it = infer_it->second.find(model_key);
338  if (model_it != infer_it->second.end()) {
339  return model_it->second;
340  }
341  }
342  NNDEPLOY_LOGE("LlmInfer %s@%s not found!\n", infer_key.c_str(),
343  model_key.c_str());
344  return nullptr;
345  }
346 
347  std::set<std::string> getInferKeys() {
348  std::set<std::string> keys;
349  for (auto &it : creators_) {
350  keys.insert(it.first);
351  }
352  return keys;
353  }
354 
355  std::set<std::string> getModelKeys(const std::string &infer_key) {
356  std::set<std::string> keys;
357  auto it = creators_.find(infer_key);
358  if (it != creators_.end()) {
359  for (auto &model_it : it->second) {
360  keys.insert(model_it.first);
361  }
362  }
363  return keys;
364  }
365 
366  std::set<std::string> getModelKeys() {
367  std::set<std::string> keys;
368  for (auto &it : creators_) {
369  for (auto &model_it : it.second) {
370  keys.insert(model_it.first);
371  }
372  }
373  return keys;
374  }
375 
376  std::set<std::pair<std::string, std::string>> getAllKeys() {
377  std::set<std::pair<std::string, std::string>> keys;
378  for (auto &infer_it : creators_) {
379  for (auto &model_it : infer_it.second) {
380  keys.insert(std::make_pair(infer_it.first, model_it.first));
381  }
382  }
383  return keys;
384  }
385 
386  private:
387  LlmInferFactory() = default;
388  ~LlmInferFactory() = default;
389  std::map<std::string, std::map<std::string, std::shared_ptr<LlmInferCreator>>>
390  creators_;
391 };
392 
393 #define REGISTER_LLM_INFER(infer_key, model_key, node_class) \
394  namespace { \
395  struct LlmInferRegister_##node_class { \
396  LlmInferRegister_##node_class() { \
397  LlmInferFactory::getInstance()->registerLlmInfer( \
398  infer_key, model_key, \
399  std::make_shared<nndeploy::llm::TypeLlmInferCreator<node_class>>()); \
400  } \
401  }; \
402  static LlmInferRegister_##node_class g_llm_infer_register_##node_class; \
403  }
404 
405 } // namespace llm
406 } // namespace nndeploy
407 
408 #endif
Composite node Composite node is a special type of node in nndeploy that enhances the capabilities of...
virtual base::Status deserialize(rapidjson::Value &json)
Deserialize from JSON.
CompositeNode(const std::string &name)
virtual std::string serialize()
Serialize to JSON string.
void setDynamicInput(bool is_dynamic_input)
Set whether it's dynamic input.
std::string desc_
Node description.
Definition: node.h:1294
std::string key_
Node key.
Definition: node.h:1290
设备抽象基类
Definition: device.h:155
void * getData() const
device::Tensor * genAttentionMask(int seq_len, int all_seq_len, base::DataType data_type, base::DataFormat data_format, base::DeviceType device_type=base::kDeviceTypeCodeCpu)
AbstractLlmInfer(const std::string &name, std::vector< dag::Edge * > inputs, std::vector< dag::Edge * > outputs)
device::Tensor * genPositionIds(int seq_len, int all_seq_len, base::DataType data_type, base::DataFormat data_format, base::DeviceType device_type=base::kDeviceTypeCodeCpu)
std::vector< std::string > model_outputs_
virtual base::Status setConfigPath(const std::vector< std::string > &config_path)
virtual base::Status setModelKey(const std::string &model_key)
virtual base::Status run()=0
Run node (pure virtual function)
std::vector< std::string > model_inputs_
virtual base::Status setPrefill(bool is_prefill)
virtual base::Status serialize(rapidjson::Value &json, rapidjson::Document::AllocatorType &allocator) override
Serialize to JSON.
AbstractLlmInfer(const std::string &name)
device::Tensor * genPastKeyValue(const std::vector< int32_t > &kv_init_shape, base::DeviceType device_type=base::kDeviceTypeCodeCpu)
virtual base::Status deserialize(rapidjson::Value &json) override
Deserialize from JSON.
std::vector< std::string > config_path_
virtual base::Status setInferKey(const std::string &infer_key)
virtual AbstractLlmInfer * createLlmInfer(const std::string &name, std::vector< dag::Edge * > inputs, std::vector< dag::Edge * > outputs)=0
virtual ~LlmInferCreator()=default
virtual AbstractLlmInfer * createLlmInfer(const std::string &name)=0
std::set< std::string > getModelKeys()
std::set< std::string > getModelKeys(const std::string &infer_key)
void registerLlmInfer(const std::string &infer_key, const std::string &model_key, std::shared_ptr< LlmInferCreator > creator)
std::shared_ptr< LlmInferCreator > getCreator(const std::string &infer_key, const std::string &model_key)
std::set< std::pair< std::string, std::string > > getAllKeys()
static LlmInferFactory * getInstance()
std::set< std::string > getInferKeys()
virtual AbstractLlmInfer * createLlmInfer(const std::string &name, std::vector< dag::Edge * > inputs, std::vector< dag::Edge * > outputs) override
virtual AbstractLlmInfer * createLlmInfer(const std::string &name) override
#define NNDEPLOY_LOGW(fmt,...)
Definition: log.h:61
#define NNDEPLOY_LOGE(fmt,...)
Definition: log.h:59
#define NNDEPLOY_CC_API
api
Definition: macro.h:29
@ kStatusCodeOk
Definition: status.h:13
DataType dataTypeOf< float >()
@ kDeviceTypeCodeCpu
Definition: common.h:82
@ kDataFormatS1D
Definition: common.h:143
Device * getDevice(base::DeviceType device_type)
获取指定类型的设备
base::IntVector shape_
Definition: type.h:113
base::DataFormat data_format_
Definition: type.h:112
base::DataType data_type_
Definition: type.h:111