nndeploy C++ API  0.2.0
nndeploy C++ API
sample.h
Go to the documentation of this file.
1 //
2 // sample.hpp
3 //
4 // Created by MNN on 2023/09/25.
5 // ZhaodeWang
6 //
7 
20 #ifndef _NNDEPLOY_LLM_SAMPLE_H_
21 #define _NNDEPLOY_LLM_SAMPLE_H_
22 
23 #include "nndeploy/base/any.h"
24 #include "nndeploy/base/common.h"
26 #include "nndeploy/base/log.h"
27 #include "nndeploy/base/macro.h"
28 #include "nndeploy/base/object.h"
30 #include "nndeploy/base/param.h"
31 #include "nndeploy/base/status.h"
32 #include "nndeploy/base/string.h"
34 #include "nndeploy/dag/edge.h"
35 #include "nndeploy/dag/graph.h"
36 #include "nndeploy/dag/loop.h"
37 #include "nndeploy/dag/node.h"
38 #include "nndeploy/device/buffer.h"
39 #include "nndeploy/device/device.h"
41 #include "nndeploy/device/tensor.h"
43 
44 namespace nndeploy {
45 namespace llm {
46 
52  public:
53  SampleParam() = default;
54  virtual ~SampleParam() = default;
55 
56  // * 1. greedy - 贪婪采样,选择概率最高的token
57  // * 2. temperature - 温度采样,通过温度参数控制随机性
58  // * 3. topK - Top-K采样,从概率最高的K个token中采样
59  // * 4. topP - Top-P采样(核采样),从累积概率达到P的token集合中采样
60  // * 5. minP - Min-P采样,过滤掉概率低于阈值的token
61  // * 6. tfs - Tail Free Sampling,基于二阶导数的采样方法
62  // * 7. typical - Typical采样,基于信息论的采样方法
63  // * 8. penalty - 重复惩罚采样,对重复token进行惩罚
64  // * 9. ngram - N-gram重复惩罚,对重复的n-gram序列进行惩罚
65  std::string sampler =
66  "temperature"; // "greedy", "temperature", "topK", "topP", "minP", "tfs",
67  // "typical", "penalty", "ngram".
68 
69  float temperature = 0.8;
70  int topK = 40;
71  float topP = 0.9;
72  float minP = 0.05;
73  float tfsZ = 1.0;
74  float typical = 0.95;
75  // penalty
76  float penalty = 1.05;
77  int ngram = 8;
78  float ngram_factor =
79  1.02; // panalize repeated ngram with a multiplied ngram_factor.
80  float max_penalty = 10.0f;
81  std::vector<std::string> mixed_samplers = {"topK", "tfs", "typical",
82  "topP", "minP", "temperature"};
83 
86  rapidjson::Value& json,
87  rapidjson::Document::AllocatorType& allocator) override;
89  virtual base::Status deserialize(rapidjson::Value& json) override;
90 };
91 
114  public:
115  Sampler(const std::string& name, std::vector<dag::Edge*> inputs,
116  std::vector<dag::Edge*> outputs);
117  virtual ~Sampler();
118 
119  virtual base::Status run();
120 
121  int sample(device::Tensor* logits);
122 
123  struct SubsetLogits penalty(struct SubsetLogits superset);
124  struct SubsetLogits topK(struct SubsetLogits superset);
125  struct SubsetLogits topP(struct SubsetLogits superset);
126  struct SubsetLogits minP(struct SubsetLogits superset);
127  struct SubsetLogits tfs(struct SubsetLogits superset);
128  struct SubsetLogits typical(struct SubsetLogits superset);
129  struct SubsetLogits mixed(struct SubsetLogits subset);
130  struct SubsetLogits subsetSampler(std::string sampler_type,
131  struct SubsetLogits subset);
132  int handleSelect(struct SubsetLogits subset);
133 
134  void setIsPrefill(bool is_prefill);
135 
136  protected:
137  bool is_prefill_ = true;
138  bool is_first_ = true;
139 };
140 
141 } // namespace llm
142 } // namespace nndeploy
143 
144 #endif
virtual base::Status deserialize(rapidjson::Value &json)
virtual std::string serialize()
Node base class.
Definition: node.h:171
SampleParam - Sample节点的参数配置 @wangzhaode.
Definition: sample.h:51
virtual base::Status serialize(rapidjson::Value &json, rapidjson::Document::AllocatorType &allocator) override
virtual ~SampleParam()=default
virtual base::Status deserialize(rapidjson::Value &json) override
Sample - 文本生成采样节点
Definition: sample.h:113
virtual base::Status run()
Run node (pure virtual function)
Sampler(const std::string &name, std::vector< dag::Edge * > inputs, std::vector< dag::Edge * > outputs)
#define NNDEPLOY_CC_API
api
Definition: macro.h:29