2 #ifndef _NNDEPLOY_LLM_ABSTRACT_LLM_INFER_H_
3 #define _NNDEPLOY_LLM_ABSTRACT_LLM_INFER_H_
33 key_ =
"nndeploy::llm::LlmInfer";
35 "LLM abstract pipeline: input_tokens -> "
36 "inference -> [logits]";
38 this->setInputTypeInfo<tokenizer::TokenizerIds>(
"input_tokens");
39 this->setOutputTypeInfo<device::Tensor>(
"output_logits");
42 std::vector<dag::Edge *> outputs)
44 key_ =
"nndeploy::llm::LlmInfer";
46 "LLM abstract pipeline: input_tokens -> "
47 "inference -> [logits]";
49 this->setInputTypeInfo<tokenizer::TokenizerIds>(
"input_tokens");
50 this->setOutputTypeInfo<device::Tensor>(
"output_logits");
61 const std::vector<std::string> &config_path) {
74 virtual int getMaxSeqLen() {
return std::numeric_limits<int>::max(); }
77 const std::vector<int32_t> &kv_init_shape,
83 past_kv_desc.
shape_ = kv_init_shape;
85 past_kv =
new device::Tensor(device, past_kv_desc,
"past_key_values");
97 position_ids_desc.
shape_ = {1, seq_len};
103 auto ptr = (
int *)position_ids->
getData();
105 ptr[0] = all_seq_len;
107 for (
int i = 0; i < seq_len; i++) {
108 ptr[i] = i + all_seq_len;
119 int kv_seq_len = all_seq_len + seq_len;
120 if (seq_len == 1) kv_seq_len = seq_len;
127 attention_mask_desc.
shape_ = {1, 1, seq_len, kv_seq_len};
130 new device::Tensor(device, attention_mask_desc,
"attention_mask");
133 auto ptr = (
float *)attention_mask->
getData();
134 for (
int i = 0; i < seq_len; i++) {
135 for (
int j = 0; j < kv_seq_len; j++) {
136 int row = i + all_seq_len;
137 ptr[kv_seq_len * i + j] =
138 (j > row) * std::numeric_limits<float>::lowest();
142 return attention_mask;
147 rapidjson::Value &json,
148 rapidjson::Document::AllocatorType &allocator)
override {
156 json.AddMember(
"is_prefill",
is_prefill_, allocator);
159 rapidjson::Value config_path_array(rapidjson::kArrayType);
161 rapidjson::Value path_value;
162 path_value.SetString(path.c_str(), path.length(), allocator);
163 config_path_array.PushBack(path_value, allocator);
165 json.AddMember(
"config_path", config_path_array, allocator);
168 rapidjson::Value model_key_value;
171 json.AddMember(
"model_key", model_key_value, allocator);
174 rapidjson::Value infer_key_value;
177 json.AddMember(
"infer_key", infer_key_value, allocator);
181 rapidjson::Value model_inputs(rapidjson::kArrayType);
183 model_inputs.PushBack(rapidjson::Value(input.c_str(), allocator),
186 json.AddMember(
"model_inputs", model_inputs, allocator);
189 rapidjson::Value model_outputs(rapidjson::kArrayType);
191 model_outputs.PushBack(rapidjson::Value(output.c_str(), allocator),
194 json.AddMember(
"model_outputs", model_outputs, allocator);
207 if (json.HasMember(
"is_prefill") && json[
"is_prefill"].IsBool()) {
212 if (json.HasMember(
"config_path") && json[
"config_path"].IsArray()) {
214 const rapidjson::Value &config_path_array = json[
"config_path"];
215 for (rapidjson::SizeType i = 0; i < config_path_array.Size(); i++) {
216 if (config_path_array[i].IsString()) {
217 config_path_.push_back(config_path_array[i].GetString());
223 if (json.HasMember(
"model_key") && json[
"model_key"].IsString()) {
228 if (json.HasMember(
"infer_key") && json[
"infer_key"].IsString()) {
233 if (json.HasMember(
"model_inputs") && json[
"model_inputs"].IsArray()) {
235 const rapidjson::Value &model_inputs = json[
"model_inputs"];
236 for (rapidjson::SizeType i = 0; i < model_inputs.Size(); i++) {
237 if (model_inputs[i].IsString()) {
243 if (json.HasMember(
"model_outputs") && json[
"model_outputs"].IsArray()) {
245 const rapidjson::Value &model_outputs = json[
"model_outputs"];
246 for (rapidjson::SizeType i = 0; i < model_outputs.Size(); i++) {
247 if (model_outputs[i].IsString()) {
257 std::string key =
"";
278 "position_ids",
"past_key_values"};
284 template <
typename T>
285 class TypeLlmInferCreator;
292 const std::string &name, std::vector<dag::Edge *> inputs,
293 std::vector<dag::Edge *> outputs) = 0;
296 template <
typename T>
303 const std::string &name, std::vector<dag::Edge *> inputs,
304 std::vector<dag::Edge *> outputs)
override {
305 return new T(name, inputs, outputs);
317 const std::string &model_key,
318 std::shared_ptr<LlmInferCreator> creator) {
319 auto it = creators_.find(infer_key);
320 if (it == creators_.end()) {
321 creators_[infer_key] =
322 std::map<std::string, std::shared_ptr<LlmInferCreator>>();
325 auto model_it = creators_[infer_key].find(model_key);
326 if (model_it != creators_[infer_key].end()) {
327 NNDEPLOY_LOGW(
"LlmInfer %s@%s already exists, will be overwritten!\n",
328 infer_key.c_str(), model_key.c_str());
330 creators_[infer_key][model_key] = creator;
333 std::shared_ptr<LlmInferCreator>
getCreator(
const std::string &infer_key,
334 const std::string &model_key) {
335 auto infer_it = creators_.find(infer_key);
336 if (infer_it != creators_.end()) {
337 auto model_it = infer_it->second.find(model_key);
338 if (model_it != infer_it->second.end()) {
339 return model_it->second;
342 NNDEPLOY_LOGE(
"LlmInfer %s@%s not found!\n", infer_key.c_str(),
348 std::set<std::string> keys;
349 for (
auto &it : creators_) {
350 keys.insert(it.first);
356 std::set<std::string> keys;
357 auto it = creators_.find(infer_key);
358 if (it != creators_.end()) {
359 for (
auto &model_it : it->second) {
360 keys.insert(model_it.first);
367 std::set<std::string> keys;
368 for (
auto &it : creators_) {
369 for (
auto &model_it : it.second) {
370 keys.insert(model_it.first);
377 std::set<std::pair<std::string, std::string>> keys;
378 for (
auto &infer_it : creators_) {
379 for (
auto &model_it : infer_it.second) {
380 keys.insert(std::make_pair(infer_it.first, model_it.first));
389 std::map<std::string, std::map<std::string, std::shared_ptr<LlmInferCreator>>>
393 #define REGISTER_LLM_INFER(infer_key, model_key, node_class) \
395 struct LlmInferRegister_##node_class { \
396 LlmInferRegister_##node_class() { \
397 LlmInferFactory::getInstance()->registerLlmInfer( \
398 infer_key, model_key, \
399 std::make_shared<nndeploy::llm::TypeLlmInferCreator<node_class>>()); \
402 static LlmInferRegister_##node_class g_llm_infer_register_##node_class; \
Composite node Composite node is a special type of node in nndeploy that enhances the capabilities of...
virtual base::Status deserialize(rapidjson::Value &json)
Deserialize from JSON.
CompositeNode(const std::string &name)
virtual std::string serialize()
Serialize to JSON string.
void setDynamicInput(bool is_dynamic_input)
Set whether it's dynamic input.
std::string desc_
Node description.
std::string key_
Node key.
device::Tensor * genAttentionMask(int seq_len, int all_seq_len, base::DataType data_type, base::DataFormat data_format, base::DeviceType device_type=base::kDeviceTypeCodeCpu)
std::string getShareKey()
AbstractLlmInfer(const std::string &name, std::vector< dag::Edge * > inputs, std::vector< dag::Edge * > outputs)
virtual ~AbstractLlmInfer()
device::Tensor * genPositionIds(int seq_len, int all_seq_len, base::DataType data_type, base::DataFormat data_format, base::DeviceType device_type=base::kDeviceTypeCodeCpu)
std::vector< std::string > model_outputs_
virtual base::Status setConfigPath(const std::vector< std::string > &config_path)
virtual base::Status setModelKey(const std::string &model_key)
virtual base::Status run()=0
Run node (pure virtual function)
std::vector< std::string > model_inputs_
virtual base::Status setPrefill(bool is_prefill)
virtual base::Status serialize(rapidjson::Value &json, rapidjson::Document::AllocatorType &allocator) override
Serialize to JSON.
AbstractLlmInfer(const std::string &name)
device::Tensor * genPastKeyValue(const std::vector< int32_t > &kv_init_shape, base::DeviceType device_type=base::kDeviceTypeCodeCpu)
virtual int getMaxSeqLen()
virtual base::Status deserialize(rapidjson::Value &json) override
Deserialize from JSON.
std::vector< std::string > config_path_
virtual base::Status setInferKey(const std::string &infer_key)
virtual AbstractLlmInfer * createLlmInfer(const std::string &name, std::vector< dag::Edge * > inputs, std::vector< dag::Edge * > outputs)=0
virtual ~LlmInferCreator()=default
virtual AbstractLlmInfer * createLlmInfer(const std::string &name)=0
std::set< std::string > getModelKeys()
std::set< std::string > getModelKeys(const std::string &infer_key)
void registerLlmInfer(const std::string &infer_key, const std::string &model_key, std::shared_ptr< LlmInferCreator > creator)
std::shared_ptr< LlmInferCreator > getCreator(const std::string &infer_key, const std::string &model_key)
std::set< std::pair< std::string, std::string > > getAllKeys()
static LlmInferFactory * getInstance()
std::set< std::string > getInferKeys()
virtual AbstractLlmInfer * createLlmInfer(const std::string &name, std::vector< dag::Edge * > inputs, std::vector< dag::Edge * > outputs) override
virtual AbstractLlmInfer * createLlmInfer(const std::string &name) override
#define NNDEPLOY_LOGW(fmt,...)
#define NNDEPLOY_LOGE(fmt,...)
#define NNDEPLOY_CC_API
api
DataType dataTypeOf< float >()
Device * getDevice(base::DeviceType device_type)
获取指定类型的设备
base::DataFormat data_format_
base::DataType data_type_