nndeploy C++ API  0.2.0
nndeploy C++ API
decode.h
Go to the documentation of this file.
1 
2 
3 // 将 tokenizer + embedding + infer + sample 封装成一个循环图
4 #ifndef _NNDEPLOY_LLM_DECODE_H_
5 #define _NNDEPLOY_LLM_DECODE_H_
6 
7 #include "nndeploy/dag/graph.h"
8 #include "nndeploy/dag/loop.h"
10 #include "nndeploy/llm/llm_infer.h"
11 #include "nndeploy/llm/prompt.h"
12 #include "nndeploy/llm/sample.h"
15 
16 namespace nndeploy {
17 namespace llm {
18 
19 class Decode : public dag::Loop {
20  public:
21  Decode(const std::string& name, std::vector<dag::Edge*> inputs,
22  std::vector<dag::Edge*> outputs);
23  virtual ~Decode();
24 
25  virtual base::Status make(const dag::NodeDesc& infer,
26  const dag::NodeDesc& sample,
27  const dag::NodeDesc& tokenizer,
28  const dag::NodeDesc& stream_out);
29 
30  virtual base::Status initEnd() override;
31  virtual base::Status iterAfter() override;
32 
33  virtual std::vector<dag::Edge*> forward(dag::Edge* input) override;
34 
35  void getStopTokens(std::string& token_file);
36  virtual int loops() override;
37  virtual bool isStop();
38  virtual bool isStopTokens();
39  virtual bool isStopTexts();
40 
42  rapidjson::Value& json,
43  rapidjson::Document::AllocatorType& allocator) override;
44  virtual base::Status deserialize(rapidjson::Value& json) override;
45 
46  private:
47  LlmInfer* decode_infer_node_;
48  Sampler* decode_sampler_node_;
49  dag::Node* decode_token_node_;
50  StreamOut* stream_out_node_;
51 
52  bool is_first_ = true;
53  int max_seq_len_ = std::numeric_limits<int>::max();
54 
55  std::string tokenizer_txt_ = "";
56  std::vector<int> stop_tokens_;
57  std::vector<int> special_tokens_;
58 
59  std::vector<std::string> stop_texts_ = {
60  "<|endoftext|>", "<|im_end|>", "</s>", "<|end|>", "<|eot_id|>", "[DONE]"};
61  std::vector<std::string> special_texts_;
62 };
63 
64 } // namespace llm
65 } // namespace nndeploy
66 
67 #endif
Edge class in DAG graph for connecting nodes and transferring data.
Definition: edge.h:35
Node description class.
Definition: node.h:35
Node base class.
Definition: node.h:171
virtual base::Status iterAfter() override
Decode(const std::string &name, std::vector< dag::Edge * > inputs, std::vector< dag::Edge * > outputs)
virtual base::Status initEnd() override
virtual base::Status serialize(rapidjson::Value &json, rapidjson::Document::AllocatorType &allocator) override
Serialize to JSON.
virtual base::Status deserialize(rapidjson::Value &json) override
Deserialize from JSON.
virtual base::Status make(const dag::NodeDesc &infer, const dag::NodeDesc &sample, const dag::NodeDesc &tokenizer, const dag::NodeDesc &stream_out)
virtual bool isStopTexts()
virtual bool isStop()
virtual int loops() override
virtual bool isStopTokens()
virtual std::vector< dag::Edge * > forward(dag::Edge *input) override
Forward propagation (single input version)
void getStopTokens(std::string &token_file)
LlmInfer - LLM推理节点
Definition: llm_infer.h:47
Sample - 文本生成采样节点
Definition: sample.h:113
Stream - 大语言模型流式输出节点
Definition: stream_out.h:41