nndeploy C++ API  0.2.0
nndeploy C++ API
stream_out.h
Go to the documentation of this file.
1 
2 #ifndef _NNDEPLOY_LLM_STREAM_OUT_H_
3 #define _NNDEPLOY_LLM_STREAM_OUT_H_
4 
5 #include "nndeploy/base/any.h"
6 #include "nndeploy/base/common.h"
8 #include "nndeploy/base/log.h"
9 #include "nndeploy/base/macro.h"
10 #include "nndeploy/base/object.h"
12 #include "nndeploy/base/param.h"
13 #include "nndeploy/base/status.h"
14 #include "nndeploy/base/string.h"
16 #include "nndeploy/dag/edge.h"
17 #include "nndeploy/dag/graph.h"
18 #include "nndeploy/dag/loop.h"
19 #include "nndeploy/dag/node.h"
20 #include "nndeploy/device/buffer.h"
21 #include "nndeploy/device/device.h"
23 #include "nndeploy/device/tensor.h"
25 
26 namespace nndeploy {
27 namespace llm {
28 
42  public:
43  StreamOut(const std::string& name, std::vector<dag::Edge*> inputs,
44  std::vector<dag::Edge*> outputs)
45  : dag::Node(name, inputs, outputs) {
46  key_ = "nndeploy::llm::StreamOut";
47  desc_ = "StreamOut: Stream output node";
48  this->setDynamicOutput(true);
49  this->setInputTypeInfo<tokenizer::TokenizerText>("input_text");
50  this->setOutputTypeInfo<tokenizer::TokenizerText>("output_text");
51  // this->setOutputTypeInfo<std::string>("stream_output");
52  }
53 
54  virtual ~StreamOut() {}
55 
56  virtual base::Status init() override {
57  is_first_ = true;
58  return base::kStatusCodeOk;
59  };
60 
61  virtual base::Status deinit() override {
62  if (stream_output_) {
63  delete stream_output_;
64  }
65  return base::kStatusCodeOk;
66  };
67 
68  bool isStopTexts() {
69  // TODO,大语言模型的输出字符是什么?
70  std::string text = *stream_output_;
71  std::string last_text = "";
72  size_t last_space = text.find_last_of(' ');
73  if (last_space != std::string::npos) {
74  last_text = text.substr(last_space + 1);
75  } else {
76  last_text = text;
77  }
78  return std::find(stop_texts_.begin(), stop_texts_.end(), last_text) !=
79  stop_texts_.end();
80  }
81 
82  virtual base::Status run() override {
83  tokenizer::TokenizerText* input_text =
84  dynamic_cast<tokenizer::TokenizerText*>(inputs_[0]->getParam(this));
85  if (is_first_) {
86  stream_output_ = new std::string();
87  if (enable_stream_) {
88  NNDEPLOY_PRINTF("A: ");
89  }
90  output_text_ = new tokenizer::TokenizerText();
91  output_text_->texts_.resize(input_text->texts_.size());
92  is_first_ = false;
93  }
94 
95  *stream_output_ = input_text->texts_[0];
96 
97  if (isStopTexts()) {
98  outputs_[0]->set(output_text_, true);
99  if (enable_stream_) {
100  NNDEPLOY_PRINTF("\n");
101  }
102  } else {
103  output_text_->texts_[0] += (*stream_output_);
104  if (enable_stream_) {
105  NNDEPLOY_PRINTF("%s", input_text->texts_[0].c_str());
106  }
107  }
108 
109  if (outputs_.size() > 1) {
110  outputs_[1]->set(stream_output_, true);
111  }
112  return base::kStatusCodeOk;
113  };
114 
115  using dag::Node::serialize;
117  rapidjson::Value& json,
118  rapidjson::Document::AllocatorType& allocator) override {
119  base::Status status = dag::Node::serialize(json, allocator);
120  if (status != base::kStatusCodeOk) {
121  return status;
122  }
123  json.AddMember("enable_stream_", enable_stream_, allocator);
124  return status;
125  }
127  virtual base::Status deserialize(rapidjson::Value& json) override {
128  base::Status status = dag::Node::deserialize(json);
129  if (status != base::kStatusCodeOk) {
130  return status;
131  }
132  if (json.HasMember("enable_stream_") && json["enable_stream_"].IsBool()) {
133  enable_stream_ = json["enable_stream_"].GetBool();
134  }
135  return status;
136  }
137 
138  private:
139  // 流式输出配置
140  bool enable_stream_ = true;
141  // 是否为第一次输出
142  bool is_first_ = true;
143  // 结果
144  tokenizer::TokenizerText* output_text_ = nullptr;
145  std::string* stream_output_ = nullptr;
146  //
147  std::vector<std::string> stop_texts_ = {
148  "<|endoftext|>", "<|im_end|>", "</s>", "<|end|>", "<|eot_id|>", "[DONE]"};
149 };
150 
151 } // namespace llm
152 } // namespace nndeploy
153 
154 #endif
Node base class.
Definition: node.h:171
virtual std::string serialize()
Serialize to JSON string.
virtual base::Status deserialize(rapidjson::Value &json)
Deserialize from JSON.
Stream - 大语言模型流式输出节点
Definition: stream_out.h:41
virtual base::Status init() override
Initialize node.
Definition: stream_out.h:56
virtual base::Status run() override
Run node (pure virtual function)
Definition: stream_out.h:82
virtual base::Status serialize(rapidjson::Value &json, rapidjson::Document::AllocatorType &allocator) override
Serialize to JSON.
Definition: stream_out.h:116
virtual base::Status deserialize(rapidjson::Value &json) override
Deserialize from JSON.
Definition: stream_out.h:127
StreamOut(const std::string &name, std::vector< dag::Edge * > inputs, std::vector< dag::Edge * > outputs)
Definition: stream_out.h:43
virtual base::Status deinit() override
Deinitialize node.
Definition: stream_out.h:61
std::vector< std::string > texts_
Definition: tokenizer.h:213
#define NNDEPLOY_PRINTF(fmt,...)
Definition: log.h:63
#define NNDEPLOY_CC_API
api
Definition: macro.h:29
@ kStatusCodeOk
Definition: status.h:13