20 #ifndef _NNDEPLOY_LLM_SAMPLE_H_
21 #define _NNDEPLOY_LLM_SAMPLE_H_
69 float temperature = 0.8;
80 float max_penalty = 10.0f;
81 std::vector<std::string> mixed_samplers = {
"topK",
"tfs",
"typical",
82 "topP",
"minP",
"temperature"};
86 rapidjson::Value& json,
87 rapidjson::Document::AllocatorType& allocator)
override;
115 Sampler(
const std::string& name, std::vector<dag::Edge*> inputs,
116 std::vector<dag::Edge*> outputs);
123 struct SubsetLogits penalty(struct SubsetLogits superset);
124 struct SubsetLogits topK(struct SubsetLogits superset);
125 struct SubsetLogits topP(struct SubsetLogits superset);
126 struct SubsetLogits minP(struct SubsetLogits superset);
127 struct SubsetLogits tfs(struct SubsetLogits superset);
128 struct SubsetLogits typical(struct SubsetLogits superset);
129 struct SubsetLogits mixed(struct SubsetLogits subset);
130 struct SubsetLogits subsetSampler(std::string sampler_type,
131 struct SubsetLogits subset);
132 int handleSelect(struct SubsetLogits subset);
134 void setIsPrefill(bool is_prefill);
137 bool is_prefill_ = true;
138 bool is_first_ = true;
virtual base::Status deserialize(rapidjson::Value &json)
virtual std::string serialize()
SampleParam - Sample节点的参数配置 @wangzhaode.
virtual base::Status serialize(rapidjson::Value &json, rapidjson::Document::AllocatorType &allocator) override
virtual ~SampleParam()=default
virtual base::Status deserialize(rapidjson::Value &json) override
virtual base::Status run()
Run node (pure virtual function)
Sampler(const std::string &name, std::vector< dag::Edge * > inputs, std::vector< dag::Edge * > outputs)
#define NNDEPLOY_CC_API
api