2 #ifndef _NNDEPLOY_LLM_INFER_DEFAULT_LLM_INFER_H_
3 #define _NNDEPLOY_LLM_INFER_DEFAULT_LLM_INFER_H_
33 bool is_embedding_ =
false;
34 std::shared_ptr<EmbeddingParam> embedding_param_ =
nullptr;
37 std::shared_ptr<inference::InferenceParam> inference_param_ =
nullptr;
40 int max_seq_len_ = 2048;
43 std::string attention_type_ =
"full";
47 rapidjson::Value& json,
48 rapidjson::Document::AllocatorType& allocator)
override {
55 json.AddMember(
"is_embedding_", is_embedding_, allocator);
56 if (is_embedding_ && embedding_param_ !=
nullptr) {
57 rapidjson::Value embedding_param_value;
58 embedding_param_->serialize(embedding_param_value, allocator);
59 json.AddMember(
"embedding_param_", embedding_param_value, allocator);
62 std::string inference_type_str =
64 json.AddMember(
"inference_type_",
65 rapidjson::Value(inference_type_str.c_str(), allocator),
67 if (inference_param_ ==
nullptr) {
69 if (inference_param_ ==
nullptr) {
71 std::make_shared<inference::InferenceParam>(inference_type_);
74 rapidjson::Value inference_param_value;
75 inference_param_->serialize(inference_param_value, allocator);
76 json.AddMember(
"inference_param_", inference_param_value, allocator);
78 json.AddMember(
"layer_nums_", layer_nums_, allocator);
79 json.AddMember(
"max_seq_len_", max_seq_len_, allocator);
80 rapidjson::Value kv_init_shape_array(rapidjson::kArrayType);
81 for (
auto dim : kv_init_shape_) {
82 kv_init_shape_array.PushBack(dim, allocator);
84 json.AddMember(
"kv_init_shape_", kv_init_shape_array, allocator);
85 std::string attention_mask_data_type_str =
88 "attention_mask_data_type_",
89 rapidjson::Value(attention_mask_data_type_str.c_str(), allocator),
91 json.AddMember(
"attention_type_",
92 rapidjson::Value(attention_type_.c_str(), allocator),
104 if (json.HasMember(
"is_embedding_") && json[
"is_embedding_"].IsBool()) {
105 is_embedding_ = json[
"is_embedding_"].GetBool();
107 if (is_embedding_ && json.HasMember(
"embedding_param_") &&
108 json[
"embedding_param_"].IsObject()) {
109 if (embedding_param_ ==
nullptr) {
110 embedding_param_ = std::make_shared<EmbeddingParam>();
112 embedding_param_->deserialize(json[
"embedding_param_"]);
115 if (json.HasMember(
"inference_type_") &&
116 json[
"inference_type_"].IsString()) {
120 if (inference_param_ ==
nullptr) {
122 if (inference_param_ ==
nullptr) {
124 std::make_shared<inference::InferenceParam>(inference_type_);
127 rapidjson::Value inference_param_value;
128 inference_param_->deserialize(json[
"inference_param_"]);
130 if (json.HasMember(
"layer_nums_") && json[
"layer_nums_"].IsInt()) {
131 layer_nums_ = json[
"layer_nums_"].GetInt();
133 if (json.HasMember(
"max_seq_len_") && json[
"max_seq_len_"].IsInt()) {
134 max_seq_len_ = json[
"max_seq_len_"].GetInt();
136 if (json.HasMember(
"kv_init_shape_") && json[
"kv_init_shape_"].IsArray()) {
137 kv_init_shape_.clear();
138 const rapidjson::Value& kv_init_shape_array = json[
"kv_init_shape_"];
139 for (rapidjson::SizeType i = 0; i < kv_init_shape_array.Size(); i++) {
140 kv_init_shape_.push_back(kv_init_shape_array[i].GetInt());
143 if (json.HasMember(
"attention_mask_data_type_") &&
144 json[
"attention_mask_data_type_"].IsString()) {
145 attention_mask_data_type_ =
148 if (json.HasMember(
"attention_type_") &&
149 json[
"attention_type_"].IsString()) {
150 attention_type_ = json[
"attention_type_"].GetString();
159 param_ = std::make_shared<DefaultLlmInferParam>();
160 key_ =
"nndeploy::llm::DefaultLlmInfer";
162 "LLM default pipeline: input_tokens -> "
163 "inference -> [logits]";
166 std::vector<dag::Edge*> outputs)
168 param_ = std::make_shared<DefaultLlmInferParam>();
169 key_ =
"nndeploy::llm::DefaultLlmInfer";
171 "LLM default pipeline: input_tokens -> "
172 "inference -> [logits]";
184 std::vector<dag::Edge*> input_edges;
185 input_ids_edge_ = this->
createEdge(input_ids_name_);
186 input_edges.push_back(input_ids_edge_);
189 attention_mask_edge_ = this->
createEdge(attention_mask_name_);
190 input_edges.push_back(attention_mask_edge_);
194 position_ids_edge_ = this->
createEdge(position_ids_name_);
195 input_edges.push_back(position_ids_edge_);
199 past_key_values_edge_ = this->
createEdge(past_key_values_name_);
200 input_edges.push_back(past_key_values_edge_);
204 std::vector<dag::Edge*> output_edges;
207 output_edges.push_back(logits_edge_);
210 presents_edge_ = this->
createEdge(presents_name_);
211 output_edges.push_back(presents_edge_);
221 dynamic_cast<Embedding*
>(this->createNode<Embedding>(desc));
224 embedding_node_->setParamSharedPtr(embedding_param);
226 embedding_node_->setInitializedFlag(
false);
227 embedding_node_->init();
228 embedding_node_->setInitializedFlag(
true);
236 std::vector<std::string> input_names;
237 std::vector<std::string> output_names;
238 for (
auto input : input_edges) {
239 input_names.push_back(input->getName());
241 for (
auto output : output_edges) {
242 output_names.push_back(output->getName());
246 auto infer = this->getResourceWithoutState<infer::Infer*>(share_key);
247 if (infer ==
nullptr) {
248 llm_infer_ =
dynamic_cast<infer::Infer*
>(this->createInfer<infer::Infer>(
257 dynamic_cast<infer::Infer*
>(this->createNode<infer::Infer>(desc));
280 std::vector<int32_t>* history_tokens =
281 new std::vector<int32_t>(ids->
ids_[0]);
284 history_tokens_edge->
set<std::vector<int32_t>>(history_tokens,
false);
286 auto seq_len = ids->
ids_[0].size();
287 auto all_seq_len = all_seq_len_;
290 auto position_ids_data_type = base::dataTypeOf<int>();
294 if (attention_mask_edge_ !=
nullptr) {
295 auto attention_mask =
297 attention_mask_data_format);
298 attention_mask_edge_->
set(attention_mask,
false);
300 if (position_ids_edge_ !=
nullptr) {
303 position_ids_data_format);
304 position_ids_edge_->
set(position_ids,
false);
306 if (past_key_values_edge_ !=
nullptr) {
308 kv_init_shape.insert(kv_init_shape.begin(), 24);
310 past_key_values_edge_->
set(past_kv,
false);
314 if (embedding_node_ !=
nullptr) {
315 auto status = embedding_node_->
run();
317 "prefill embedding_node_ run failed!");
319 if (llm_infer_ !=
nullptr) {
320 auto status = llm_infer_->
run();
322 "prefill llm_infer_ run failed!");
326 if (presents_edge_ !=
nullptr && past_key_values_edge_ !=
nullptr) {
332 past_key_values_edge->
set(presents,
true);
346 std::vector<int32_t>* history_tokens =
nullptr;
347 if (history_tokens_edge !=
nullptr) {
348 history_tokens = history_tokens_edge->
get<std::vector<int32_t>>(
this);
349 history_tokens->push_back(ids->
ids_[0].back());
354 all_seq_len_ = history_tokens->size();
355 auto all_seq_len = all_seq_len_;
358 auto position_ids_data_type = base::dataTypeOf<int>();
363 if (attention_mask_edge_ !=
nullptr) {
364 auto attention_mask =
366 attention_mask_data_format);
367 attention_mask_edge_->
set(attention_mask,
false);
369 if (position_ids_edge_ !=
nullptr) {
372 position_ids_data_format);
373 position_ids_edge_->
set(position_ids,
false);
376 if (past_key_values_edge_ !=
nullptr) {
377 auto past_kv = this->getResourceWithState<device::Tensor>(
378 past_key_values_edge_->
getName());
379 past_key_values_edge_->
set(past_kv,
true);
382 if (embedding_node_ !=
nullptr) {
383 auto status = embedding_node_->
run();
385 "decode embedding_node_ run failed!");
387 if (llm_infer_ !=
nullptr) {
388 auto status = llm_infer_->
run();
390 "decode llm_infer_ run failed!");
394 if (presents_edge_ !=
nullptr && past_key_values_edge_ !=
nullptr) {
409 default_llm_infer_param->
loadFile(file_path);
420 if (embedding_node_ !=
nullptr) {
431 std::string input_ids_name_ =
"input_ids";
432 std::string attention_mask_name_ =
"attention_mask";
433 std::string position_ids_name_ =
"position_ids";
434 std::string past_key_values_name_ =
"past_key_values";
436 dag::Edge* attention_mask_edge_ =
nullptr;
438 dag::Edge* past_key_values_edge_ =
nullptr;
440 std::string logits_name_ =
"logits";
441 std::string presents_name_ =
"presents";
446 int all_seq_len_ = 0;
447 int gen_seq_len_ = 0;
virtual base::Status deserialize(rapidjson::Value &json)
virtual base::Status loadFile(const std::string &path)
virtual std::string serialize()
Edge * createEdge(const std::string &name)
Edge class in DAG graph for connecting nodes and transferring data.
base::Status set(device::Buffer *buffer, bool is_external=true)
Set Buffer data to Edge.
T * get(const Node *node)
Get arbitrary type data for specified node (template version)
device::Tensor * getTensor(const Node *node)
Get Tensor data for specified node.
std::string getName()
Get the name of the Edge.
std::string desc_
Node description.
base::Status setResourceWithState(const std::string &key, T *value, bool is_external=true)
Set stateful resource (template method)
virtual base::Param * getParam()
Get parameter.
virtual base::Status addResourceWithoutState(const std::string &key, const base::Any &value)
Add stateless resource.
virtual base::Status setIterInput(Edge *input, int index=-1)
Set iteration input edge.
virtual Edge * createResourceWithState(const std::string &key)
Create stateful resource.
std::vector< Edge * > outputs_
Output edge list.
void setInitializedFlag(bool flag)
Set initialized flag.
std::string key_
Node key.
std::vector< Edge * > inputs_
Input edge list.
virtual Edge * getResourceWithState(const std::string &key)
Get stateful resource.
std::shared_ptr< base::Param > param_
Node parameters.
base::Status setName(const std::string &)
virtual base::Status shareInference(Infer *infer)
virtual base::Status run()
Run node (pure virtual function)
virtual base::Status setParamSharedPtr(std::shared_ptr< base::Param > param)
Set parameter (shared pointer)
virtual base::Status init()
Initialize node.
device::Tensor * genAttentionMask(int seq_len, int all_seq_len, base::DataType data_type, base::DataFormat data_format, base::DeviceType device_type=base::kDeviceTypeCodeCpu)
std::string getShareKey()
device::Tensor * genPositionIds(int seq_len, int all_seq_len, base::DataType data_type, base::DataFormat data_format, base::DeviceType device_type=base::kDeviceTypeCodeCpu)
std::vector< std::string > model_outputs_
std::vector< std::string > model_inputs_
device::Tensor * genPastKeyValue(const std::vector< int32_t > &kv_init_shape, base::DeviceType device_type=base::kDeviceTypeCodeCpu)
std::vector< std::string > config_path_
virtual ~DefaultLlmInfer()
DefaultLlmInfer(const std::string &name)
virtual base::Status init()
Initialize node.
virtual base::Status setIterInput(dag::Edge *input, int index)
Set iteration input edge.
virtual base::Status run()
Run node (pure virtual function)
base::Status parseConfig(const std::string &file_path)
DefaultLlmInfer(const std::string &name, std::vector< dag::Edge * > inputs, std::vector< dag::Edge * > outputs)
virtual base::Status decode()
virtual base::Status prefill()
virtual base::Status run()
Run node (pure virtual function)
std::vector< std::vector< int32_t > > ids_
#define NNDEPLOY_LOGE(fmt,...)
#define NNDEPLOY_CC_API
api
@ kInferenceTypeOnnxRuntime
DataType stringToDataType(const std::string &str)
std::string dataTypeToString(DataType data_type)
InferenceType stringToInferenceType(const std::string &src)
DataType dataTypeOf< float >()
std::string inferenceTypeToString(InferenceType src)
std::shared_ptr< InferenceParam > createInferenceParam(base::InferenceType type)
Create a Inference Param object.
#define NNDEPLOY_RETURN_ON_NEQ(status, expected, str)
base::InferenceType inference_type_
virtual base::Status deserialize(rapidjson::Value &json) override
std::shared_ptr< inference::InferenceParam > inference_param_
std::vector< int32_t > kv_init_shape_
virtual base::Status serialize(rapidjson::Value &json, rapidjson::Document::AllocatorType &allocator) override
std::shared_ptr< EmbeddingParam > embedding_param_