nndeploy C++ API  0.2.0
nndeploy C++ API
default_llm_infer.h
Go to the documentation of this file.
1 
2 #ifndef _NNDEPLOY_LLM_INFER_DEFAULT_LLM_INFER_H_
3 #define _NNDEPLOY_LLM_INFER_DEFAULT_LLM_INFER_H_
4 
5 #include "nndeploy/base/any.h"
6 #include "nndeploy/base/common.h"
8 #include "nndeploy/base/log.h"
9 #include "nndeploy/base/macro.h"
10 #include "nndeploy/base/object.h"
12 #include "nndeploy/base/param.h"
13 #include "nndeploy/base/status.h"
14 #include "nndeploy/base/string.h"
16 #include "nndeploy/dag/edge.h"
17 #include "nndeploy/dag/graph.h"
18 #include "nndeploy/dag/loop.h"
19 #include "nndeploy/dag/node.h"
20 #include "nndeploy/device/buffer.h"
21 #include "nndeploy/device/device.h"
23 #include "nndeploy/device/tensor.h"
24 #include "nndeploy/infer/infer.h"
26 #include "nndeploy/llm/embedding.h"
27 
28 namespace nndeploy {
29 namespace llm {
30 
32  // embedding
33  bool is_embedding_ = false;
34  std::shared_ptr<EmbeddingParam> embedding_param_ = nullptr;
35  // infer
37  std::shared_ptr<inference::InferenceParam> inference_param_ = nullptr;
38  // model
39  int layer_nums_ = 24;
40  int max_seq_len_ = 2048; // TODO
41  std::vector<int32_t> kv_init_shape_;
42  base::DataType attention_mask_data_type_ = base::dataTypeOf<float>();
43  std::string attention_type_ = "full";
44 
47  rapidjson::Value& json,
48  rapidjson::Document::AllocatorType& allocator) override {
49  base::Status status = base::Param::serialize(json, allocator);
50  if (status != base::kStatusCodeOk) {
51  NNDEPLOY_LOGE("DefaultLlmInferParam::serialize failed\n");
52  return status;
53  }
54  //
55  json.AddMember("is_embedding_", is_embedding_, allocator);
56  if (is_embedding_ && embedding_param_ != nullptr) {
57  rapidjson::Value embedding_param_value;
58  embedding_param_->serialize(embedding_param_value, allocator);
59  json.AddMember("embedding_param_", embedding_param_value, allocator);
60  }
61  //
62  std::string inference_type_str =
63  base::inferenceTypeToString(inference_type_);
64  json.AddMember("inference_type_",
65  rapidjson::Value(inference_type_str.c_str(), allocator),
66  allocator);
67  if (inference_param_ == nullptr) {
68  inference_param_ = inference::createInferenceParam(inference_type_);
69  if (inference_param_ == nullptr) {
70  inference_param_ =
71  std::make_shared<inference::InferenceParam>(inference_type_);
72  }
73  }
74  rapidjson::Value inference_param_value;
75  inference_param_->serialize(inference_param_value, allocator);
76  json.AddMember("inference_param_", inference_param_value, allocator);
77  //
78  json.AddMember("layer_nums_", layer_nums_, allocator);
79  json.AddMember("max_seq_len_", max_seq_len_, allocator);
80  rapidjson::Value kv_init_shape_array(rapidjson::kArrayType);
81  for (auto dim : kv_init_shape_) {
82  kv_init_shape_array.PushBack(dim, allocator);
83  }
84  json.AddMember("kv_init_shape_", kv_init_shape_array, allocator);
85  std::string attention_mask_data_type_str =
86  base::dataTypeToString(attention_mask_data_type_);
87  json.AddMember(
88  "attention_mask_data_type_",
89  rapidjson::Value(attention_mask_data_type_str.c_str(), allocator),
90  allocator);
91  json.AddMember("attention_type_",
92  rapidjson::Value(attention_type_.c_str(), allocator),
93  allocator);
94  return base::kStatusCodeOk;
95  }
97  virtual base::Status deserialize(rapidjson::Value& json) override {
99  if (status != base::kStatusCodeOk) {
100  NNDEPLOY_LOGE("DefaultLlmInferParam::deserialize failed\n");
101  return status;
102  }
103  //
104  if (json.HasMember("is_embedding_") && json["is_embedding_"].IsBool()) {
105  is_embedding_ = json["is_embedding_"].GetBool();
106  }
107  if (is_embedding_ && json.HasMember("embedding_param_") &&
108  json["embedding_param_"].IsObject()) {
109  if (embedding_param_ == nullptr) {
110  embedding_param_ = std::make_shared<EmbeddingParam>();
111  }
112  embedding_param_->deserialize(json["embedding_param_"]);
113  }
114  //
115  if (json.HasMember("inference_type_") &&
116  json["inference_type_"].IsString()) {
117  inference_type_ =
118  base::stringToInferenceType(json["inference_type_"].GetString());
119  }
120  if (inference_param_ == nullptr) {
121  inference_param_ = inference::createInferenceParam(inference_type_);
122  if (inference_param_ == nullptr) {
123  inference_param_ =
124  std::make_shared<inference::InferenceParam>(inference_type_);
125  }
126  }
127  rapidjson::Value inference_param_value;
128  inference_param_->deserialize(json["inference_param_"]);
129  // model
130  if (json.HasMember("layer_nums_") && json["layer_nums_"].IsInt()) {
131  layer_nums_ = json["layer_nums_"].GetInt();
132  }
133  if (json.HasMember("max_seq_len_") && json["max_seq_len_"].IsInt()) {
134  max_seq_len_ = json["max_seq_len_"].GetInt();
135  }
136  if (json.HasMember("kv_init_shape_") && json["kv_init_shape_"].IsArray()) {
137  kv_init_shape_.clear();
138  const rapidjson::Value& kv_init_shape_array = json["kv_init_shape_"];
139  for (rapidjson::SizeType i = 0; i < kv_init_shape_array.Size(); i++) {
140  kv_init_shape_.push_back(kv_init_shape_array[i].GetInt());
141  }
142  }
143  if (json.HasMember("attention_mask_data_type_") &&
144  json["attention_mask_data_type_"].IsString()) {
145  attention_mask_data_type_ =
146  base::stringToDataType(json["attention_mask_data_type_"].GetString());
147  }
148  if (json.HasMember("attention_type_") &&
149  json["attention_type_"].IsString()) {
150  attention_type_ = json["attention_type_"].GetString();
151  }
152  return base::kStatusCodeOk;
153  }
154 };
155 
157  public:
158  DefaultLlmInfer(const std::string& name) : AbstractLlmInfer(name) {
159  param_ = std::make_shared<DefaultLlmInferParam>();
160  key_ = "nndeploy::llm::DefaultLlmInfer";
161  desc_ =
162  "LLM default pipeline: input_tokens -> "
163  "inference -> [logits]";
164  }
165  DefaultLlmInfer(const std::string& name, std::vector<dag::Edge*> inputs,
166  std::vector<dag::Edge*> outputs)
167  : AbstractLlmInfer(name, inputs, outputs) {
168  param_ = std::make_shared<DefaultLlmInferParam>();
169  key_ = "nndeploy::llm::DefaultLlmInfer";
170  desc_ =
171  "LLM default pipeline: input_tokens -> "
172  "inference -> [logits]";
173  }
174  virtual ~DefaultLlmInfer() {}
175 
176  virtual base::Status init() {
177  // 解析参数
178  if (!config_path_.empty()) {
180  }
181 
182  // 创建输入边
183  input_ids_name_ = model_inputs_[0];
184  std::vector<dag::Edge*> input_edges;
185  input_ids_edge_ = this->createEdge(input_ids_name_);
186  input_edges.push_back(input_ids_edge_);
187  if (model_inputs_.size() > 1) {
188  attention_mask_name_ = model_inputs_[1];
189  attention_mask_edge_ = this->createEdge(attention_mask_name_);
190  input_edges.push_back(attention_mask_edge_);
191  }
192  if (model_inputs_.size() > 2) {
193  position_ids_name_ = model_inputs_[2];
194  position_ids_edge_ = this->createEdge(position_ids_name_);
195  input_edges.push_back(position_ids_edge_);
196  }
197  if (model_inputs_.size() > 3) {
198  past_key_values_name_ = model_inputs_[3];
199  past_key_values_edge_ = this->createEdge(past_key_values_name_);
200  input_edges.push_back(past_key_values_edge_);
201  }
202 
203  // 创建输出边
204  std::vector<dag::Edge*> output_edges;
205  logits_name_ = model_outputs_[0];
206  logits_edge_ = outputs_[0];
207  output_edges.push_back(logits_edge_);
208  if (model_outputs_.size() > 1) {
209  presents_name_ = model_outputs_[1];
210  presents_edge_ = this->createEdge(presents_name_);
211  output_edges.push_back(presents_edge_);
212  }
213 
214  // 创建embedding节点
215  DefaultLlmInferParam* default_llm_infer_param =
216  dynamic_cast<DefaultLlmInferParam*>(param_.get());
217  if (default_llm_infer_param->is_embedding_) {
218  dag::NodeDesc desc("embedding_node", {inputs_[0]->getName()},
219  {input_ids_edge_->getName()});
220  embedding_node_ =
221  dynamic_cast<Embedding*>(this->createNode<Embedding>(desc));
222  // 参数设置开始
223  auto embedding_param = default_llm_infer_param->embedding_param_;
224  embedding_node_->setParamSharedPtr(embedding_param);
225  // 参数设置结束
226  embedding_node_->setInitializedFlag(false);
227  embedding_node_->init();
228  embedding_node_->setInitializedFlag(true);
229  } else {
230  // TODO
231  // tokenizer::TokenizerIds -> device::Tensor
232  ;
233  }
234 
235  // 创建infer节点
236  std::vector<std::string> input_names;
237  std::vector<std::string> output_names;
238  for (auto input : input_edges) {
239  input_names.push_back(input->getName());
240  }
241  for (auto output : output_edges) {
242  output_names.push_back(output->getName());
243  }
244  dag::NodeDesc desc("llm_infer", input_names, output_names);
245  std::string share_key = this->getShareKey();
246  auto infer = this->getResourceWithoutState<infer::Infer*>(share_key);
247  if (infer == nullptr) {
248  llm_infer_ = dynamic_cast<infer::Infer*>(this->createInfer<infer::Infer>(
249  desc, default_llm_infer_param->inference_type_));
250  // 参数设置开始
251  llm_infer_->setParamSharedPtr(default_llm_infer_param->inference_param_);
252  // 参数设置结束
253  llm_infer_->init();
254  this->addResourceWithoutState(share_key, llm_infer_);
255  } else {
256  llm_infer_ =
257  dynamic_cast<infer::Infer*>(this->createNode<infer::Infer>(desc));
258  infer->shareInference(llm_infer_);
259  llm_infer_->setInitializedFlag(false);
260  llm_infer_->init();
261  llm_infer_->setInitializedFlag(true);
262  }
263  return base::kStatusCodeOk;
264  }
265 
266  virtual base::Status run() {
267  if (is_prefill_) {
268  return prefill();
269  } else {
270  return decode();
271  }
272  }
273 
274  virtual base::Status prefill() {
275  DefaultLlmInferParam* default_llm_infer_param =
276  dynamic_cast<DefaultLlmInferParam*>(param_.get());
277  // 全局的history_token
280  std::vector<int32_t>* history_tokens =
281  new std::vector<int32_t>(ids->ids_[0]);
282  dag::Edge* history_tokens_edge =
283  this->createResourceWithState("history_tokens");
284  history_tokens_edge->set<std::vector<int32_t>>(history_tokens, false);
285 
286  auto seq_len = ids->ids_[0].size();
287  auto all_seq_len = all_seq_len_;
288  auto attention_mask_data_type = base::dataTypeOf<float>();
289  auto attention_mask_data_format = base::DataFormat::kDataFormatS1D;
290  auto position_ids_data_type = base::dataTypeOf<int>();
291  auto position_ids_data_format = base::DataFormat::kDataFormatNC;
292 
293  // 给输入边数据
294  if (attention_mask_edge_ != nullptr) {
295  auto attention_mask =
296  genAttentionMask(seq_len, all_seq_len, attention_mask_data_type,
297  attention_mask_data_format);
298  attention_mask_edge_->set(attention_mask, false);
299  }
300  if (position_ids_edge_ != nullptr) {
301  auto position_ids =
302  genPositionIds(seq_len, all_seq_len, position_ids_data_type,
303  position_ids_data_format);
304  position_ids_edge_->set(position_ids, false);
305  }
306  if (past_key_values_edge_ != nullptr) {
307  auto kv_init_shape = default_llm_infer_param->kv_init_shape_;
308  kv_init_shape.insert(kv_init_shape.begin(), 24);
309  auto past_kv = genPastKeyValue(kv_init_shape);
310  past_key_values_edge_->set(past_kv, false);
311  }
312 
313  // 执行embedding节点和infer节点
314  if (embedding_node_ != nullptr) {
315  auto status = embedding_node_->run();
317  "prefill embedding_node_ run failed!");
318  }
319  if (llm_infer_ != nullptr) {
320  auto status = llm_infer_->run();
322  "prefill llm_infer_ run failed!");
323  }
324 
325  // 全局tensor资源
326  if (presents_edge_ != nullptr && past_key_values_edge_ != nullptr) {
327  device::Tensor* presents =
328  (device::Tensor*)presents_edge_->getTensor(llm_infer_);
329  presents->setName(past_key_values_edge_->getName());
330  dag::Edge* past_key_values_edge =
331  this->createResourceWithState(past_key_values_edge_->getName());
332  past_key_values_edge->set(presents, true);
333  }
334 
335  return base::kStatusCodeOk;
336  }
337  virtual base::Status decode() { // 执行embedding节点和infer节点
338  tokenizer::TokenizerIds* ids = nullptr;
339  if (inputs_.size() == 1 || inputs_[1]->empty()) {
340  ids = (tokenizer::TokenizerIds*)inputs_[0]->getParam(this);
341  } else {
342  ids = (tokenizer::TokenizerIds*)inputs_[1]->getParam(this);
343  }
344  dag::Edge* history_tokens_edge =
345  this->getResourceWithState("history_tokens");
346  std::vector<int32_t>* history_tokens = nullptr;
347  if (history_tokens_edge != nullptr) {
348  history_tokens = history_tokens_edge->get<std::vector<int32_t>>(this);
349  history_tokens->push_back(ids->ids_[0].back());
350  }
351 
352  // auto seq_len = ids->ids_[0].size();
353  auto seq_len = 1;
354  all_seq_len_ = history_tokens->size();
355  auto all_seq_len = all_seq_len_;
356  auto attention_mask_data_type = base::dataTypeOf<float>();
357  auto attention_mask_data_format = base::DataFormat::kDataFormatS1D;
358  auto position_ids_data_type = base::dataTypeOf<int>();
359  auto position_ids_data_format = base::DataFormat::kDataFormatNC;
360 
361  gen_seq_len_++;
362 
363  if (attention_mask_edge_ != nullptr) {
364  auto attention_mask =
365  genAttentionMask(seq_len, all_seq_len, attention_mask_data_type,
366  attention_mask_data_format);
367  attention_mask_edge_->set(attention_mask, false);
368  }
369  if (position_ids_edge_ != nullptr) {
370  auto position_ids =
371  genPositionIds(seq_len, all_seq_len, position_ids_data_type,
372  position_ids_data_format);
373  position_ids_edge_->set(position_ids, false);
374  }
375 
376  if (past_key_values_edge_ != nullptr) {
377  auto past_kv = this->getResourceWithState<device::Tensor>(
378  past_key_values_edge_->getName());
379  past_key_values_edge_->set(past_kv, true);
380  }
381 
382  if (embedding_node_ != nullptr) {
383  auto status = embedding_node_->run();
385  "decode embedding_node_ run failed!");
386  }
387  if (llm_infer_ != nullptr) {
388  auto status = llm_infer_->run();
390  "decode llm_infer_ run failed!");
391  }
392 
393  // 全局tensor资源
394  if (presents_edge_ != nullptr && past_key_values_edge_ != nullptr) {
395  device::Tensor* presents =
396  (device::Tensor*)presents_edge_->getTensor(llm_infer_);
397  presents->setName(past_key_values_edge_->getName());
398  this->setResourceWithState(past_key_values_edge_->getName(), presents);
399  }
400 
401  return base::kStatusCodeOk;
402  }
403 
404  base::Status parseConfig(const std::string& file_path) {
406  if (param_ != nullptr) {
407  DefaultLlmInferParam* default_llm_infer_param =
408  dynamic_cast<DefaultLlmInferParam*>(param_.get());
409  default_llm_infer_param->loadFile(file_path);
410  }
411  return status;
412  }
413 
414  virtual base::Status setIterInput(dag::Edge* input, int index) {
415  base::Status status = dag::Node::setIterInput(input, index);
416  if (status != base::kStatusCodeOk) {
417  NNDEPLOY_LOGE("DefaultLlmInfer::setIterInput failed\n");
418  return status;
419  }
420  if (embedding_node_ != nullptr) {
421  embedding_node_->setIterInput(input, 1);
422  }
423  return base::kStatusCodeOk;
424  }
425 
426  private:
427  Embedding* embedding_node_;
428  infer::Infer* llm_infer_;
429 
430  // 输入边
431  std::string input_ids_name_ = "input_ids";
432  std::string attention_mask_name_ = "attention_mask";
433  std::string position_ids_name_ = "position_ids";
434  std::string past_key_values_name_ = "past_key_values";
435  dag::Edge* input_ids_edge_ = nullptr;
436  dag::Edge* attention_mask_edge_ = nullptr;
437  dag::Edge* position_ids_edge_ = nullptr;
438  dag::Edge* past_key_values_edge_ = nullptr;
439  // 输出边
440  std::string logits_name_ = "logits";
441  std::string presents_name_ = "presents";
442  dag::Edge* logits_edge_ = nullptr;
443  dag::Edge* presents_edge_ = nullptr;
444 
445  //
446  int all_seq_len_ = 0;
447  int gen_seq_len_ = 0;
448 };
449 
450 } // namespace llm
451 } // namespace nndeploy
452 
453 #endif
virtual base::Status deserialize(rapidjson::Value &json)
virtual base::Status loadFile(const std::string &path)
virtual std::string serialize()
Edge * createEdge(const std::string &name)
Edge class in DAG graph for connecting nodes and transferring data.
Definition: edge.h:35
base::Status set(device::Buffer *buffer, bool is_external=true)
Set Buffer data to Edge.
T * get(const Node *node)
Get arbitrary type data for specified node (template version)
Definition: edge.h:443
device::Tensor * getTensor(const Node *node)
Get Tensor data for specified node.
std::string getName()
Get the name of the Edge.
Node description class.
Definition: node.h:35
std::string desc_
Node description.
Definition: node.h:1294
base::Status setResourceWithState(const std::string &key, T *value, bool is_external=true)
Set stateful resource (template method)
Definition: node.h:533
virtual base::Param * getParam()
Get parameter.
virtual base::Status addResourceWithoutState(const std::string &key, const base::Any &value)
Add stateless resource.
virtual base::Status setIterInput(Edge *input, int index=-1)
Set iteration input edge.
virtual Edge * createResourceWithState(const std::string &key)
Create stateful resource.
std::vector< Edge * > outputs_
Output edge list.
Definition: node.h:1318
void setInitializedFlag(bool flag)
Set initialized flag.
std::string key_
Node key.
Definition: node.h:1290
std::vector< Edge * > inputs_
Input edge list.
Definition: node.h:1317
virtual Edge * getResourceWithState(const std::string &key)
Get stateful resource.
std::shared_ptr< base::Param > param_
Node parameters.
Definition: node.h:1304
base::Status setName(const std::string &)
virtual base::Status shareInference(Infer *infer)
virtual base::Status run()
Run node (pure virtual function)
virtual base::Status setParamSharedPtr(std::shared_ptr< base::Param > param)
Set parameter (shared pointer)
virtual base::Status init()
Initialize node.
device::Tensor * genAttentionMask(int seq_len, int all_seq_len, base::DataType data_type, base::DataFormat data_format, base::DeviceType device_type=base::kDeviceTypeCodeCpu)
device::Tensor * genPositionIds(int seq_len, int all_seq_len, base::DataType data_type, base::DataFormat data_format, base::DeviceType device_type=base::kDeviceTypeCodeCpu)
std::vector< std::string > model_outputs_
std::vector< std::string > model_inputs_
device::Tensor * genPastKeyValue(const std::vector< int32_t > &kv_init_shape, base::DeviceType device_type=base::kDeviceTypeCodeCpu)
std::vector< std::string > config_path_
DefaultLlmInfer(const std::string &name)
virtual base::Status init()
Initialize node.
virtual base::Status setIterInput(dag::Edge *input, int index)
Set iteration input edge.
virtual base::Status run()
Run node (pure virtual function)
base::Status parseConfig(const std::string &file_path)
DefaultLlmInfer(const std::string &name, std::vector< dag::Edge * > inputs, std::vector< dag::Edge * > outputs)
virtual base::Status decode()
virtual base::Status prefill()
Embedding - 词嵌入节点
Definition: embedding.h:120
virtual base::Status run()
Run node (pure virtual function)
std::vector< std::vector< int32_t > > ids_
Definition: tokenizer.h:241
#define NNDEPLOY_LOGE(fmt,...)
Definition: log.h:59
#define NNDEPLOY_CC_API
api
Definition: macro.h:29
@ kInferenceTypeOnnxRuntime
Definition: common.h:293
@ kStatusCodeOk
Definition: status.h:13
DataType stringToDataType(const std::string &str)
std::string dataTypeToString(DataType data_type)
InferenceType stringToInferenceType(const std::string &src)
DataType dataTypeOf< float >()
std::string inferenceTypeToString(InferenceType src)
@ kDataFormatS1D
Definition: common.h:143
std::shared_ptr< InferenceParam > createInferenceParam(base::InferenceType type)
Create a Inference Param object.
#define NNDEPLOY_RETURN_ON_NEQ(status, expected, str)
Definition: status.h:183
virtual base::Status deserialize(rapidjson::Value &json) override
std::shared_ptr< inference::InferenceParam > inference_param_
std::vector< int32_t > kv_init_shape_
virtual base::Status serialize(rapidjson::Value &json, rapidjson::Document::AllocatorType &allocator) override
std::shared_ptr< EmbeddingParam > embedding_param_