nndeploy C++ API  0.2.0
nndeploy C++ API
scheduler.h
Go to the documentation of this file.
1 
2 #ifndef _NNDEPLOY_MODEL_STABLE_DIFFUSION_SCHEDULER_H_
3 #define _NNDEPLOY_MODEL_STABLE_DIFFUSION_SCHEDULER_H_
4 
5 #include <random>
6 
7 #include "nndeploy/base/any.h"
8 #include "nndeploy/base/common.h"
10 #include "nndeploy/base/log.h"
11 #include "nndeploy/base/macro.h"
12 #include "nndeploy/base/object.h"
14 #include "nndeploy/base/param.h"
15 #include "nndeploy/base/status.h"
16 #include "nndeploy/base/string.h"
17 #include "nndeploy/dag/edge.h"
18 #include "nndeploy/dag/graph.h"
19 #include "nndeploy/dag/loop.h"
20 #include "nndeploy/dag/node.h"
21 #include "nndeploy/device/buffer.h"
22 #include "nndeploy/device/device.h"
24 #include "nndeploy/device/tensor.h"
26 
27 namespace nndeploy {
28 namespace stable_diffusion {
29 
31  public:
32  SchedulerParam() : base::Param() {}
33  virtual ~SchedulerParam() {}
34 
37 
39  if (this == &param) return *this;
40  version_ = param.version_;
41  num_train_timesteps_ = param.num_train_timesteps_;
42  clip_sample_ = param.clip_sample_;
43  num_inference_steps_ = param.num_inference_steps_;
44  unet_channels_ = param.unet_channels_;
45  image_height_ = param.image_height_;
46  image_width_ = param.image_width_;
47  guidance_scale_ = param.guidance_scale_;
48  vae_scale_factor_ = param.vae_scale_factor_;
49  init_noise_sigma_ = param.init_noise_sigma_;
50  return *this;
51  }
52 
53  public:
54  std::string version_ = "v1.5";
55  int num_train_timesteps_ = 1000; // 训练时间步数
56  bool clip_sample_ = false; // 是否裁剪样本
57  int num_inference_steps_ = 50; // 推断步数
58  int unet_channels_ = 4; // channel
59  int image_height_ = 512; // height
60  int image_width_ = 512; // width
61  float guidance_scale_ = 7.5; // 指导比例
62  float vae_scale_factor_ = 0.18215;
63  float init_noise_sigma_ = 1.0f;
64 
65  public:
66  virtual base::Status serialize(rapidjson::Value &json,
67  rapidjson::Document::AllocatorType &allocator);
68  virtual base::Status deserialize(rapidjson::Value &json);
69 };
70 
72  public:
73  Scheduler(SchedulerType type) : scheduler_type_(type) {}
74  virtual ~Scheduler() {}
75 
76  virtual base::Status init(SchedulerParam *param) = 0;
77  virtual base::Status deinit() = 0;
78 
85  virtual base::Status setTimesteps() = 0;
86 
94  virtual base::Status scaleModelInput(device::Tensor *sample, int index) = 0;
104  virtual base::Status step(device::Tensor *sample, device::Tensor *timestep,
105  device::Tensor *latents,
106  device::Tensor *pre_sample) = 0;
107 
113  virtual std::vector<int> &getTimesteps() = 0;
114 
115  protected:
118 };
119 
125  public:
126  virtual ~SchedulerCreator() {};
128 };
129 
135 template <typename T>
137  virtual Scheduler *createScheduler(SchedulerType type) { return new T(type); }
138 };
139 
145 std::map<SchedulerType, std::shared_ptr<SchedulerCreator>> &
147 
153 template <typename T>
155  public:
157  getGlobalSchedulerCreatorMap()[type] = std::shared_ptr<T>(new T());
158  }
159 };
160 
168 
177 base::Status initializeLatents(std::mt19937 &generator, float init_noise_sigma,
178  device ::Tensor *latents);
179 
180 } // namespace stable_diffusion
181 } // namespace nndeploy
182 
183 #endif
virtual Scheduler * createScheduler(SchedulerType type)=0
SchedulerParam & operator=(const SchedulerParam &param)
Definition: scheduler.h:38
virtual base::Status deserialize(rapidjson::Value &json)
virtual base::Status serialize(rapidjson::Value &json, rapidjson::Document::AllocatorType &allocator)
virtual base::Status deinit()=0
virtual base::Status scaleModelInput(device::Tensor *sample, int index)=0
virtual base::Status step(device::Tensor *sample, device::Tensor *timestep, device::Tensor *latents, device::Tensor *pre_sample)=0
virtual std::vector< int > & getTimesteps()=0
Get the Timestep object.
virtual base::Status setTimesteps()=0
Set the Timesteps object.
virtual base::Status init(SchedulerParam *param)=0
推理框架的创建类模板
Definition: scheduler.h:136
推理框架的创建类的注册类模板
Definition: scheduler.h:154
#define NNDEPLOY_CC_API
api
Definition: macro.h:29
base::Status initializeLatents(std::mt19937 &generator, float init_noise_sigma, device ::Tensor *latents)
std::map< SchedulerType, std::shared_ptr< SchedulerCreator > > & getGlobalSchedulerCreatorMap()
Get the Global Scheduler Creator Map object.
Scheduler * createScheduler(SchedulerType type)
Create a Scheduler object.