nndeploy C++ API  0.2.0
nndeploy C++ API
runtime.h
Go to the documentation of this file.
1 
2 #ifndef _NNDEPLOY_NET_RUNTIME_H_
3 #define _NNDEPLOY_NET_RUNTIME_H_
4 
5 #include "nndeploy/base/any.h"
6 #include "nndeploy/base/common.h"
8 #include "nndeploy/base/log.h"
9 #include "nndeploy/base/macro.h"
10 #include "nndeploy/base/object.h"
11 #include "nndeploy/base/status.h"
12 #include "nndeploy/base/string.h"
14 #include "nndeploy/net/util.h"
15 
16 namespace nndeploy {
17 namespace net {
18 
19 class Runtime;
20 
22  public:
24  virtual ~PipelineTensor() {};
25  std::vector<device::Tensor *> tensors_;
26  std::vector<Runtime *> producers_;
27  std::vector<Runtime *> consumers_;
28 
29  // 添加互斥锁和条件变量,用于同步不同阶段之间的数据传递
30  std::mutex mutex_;
31  std::condition_variable cv_;
32  std::map<Runtime *, int> current_index_;
33  bool is_finish_ = false;
34 
35  void push(device::Tensor *tensor) {
36  // NNDEPLOY_LOGI("tensor name %s\n", tensor->getName().c_str());
37  std::lock_guard<std::mutex> lock(mutex_);
38  tensors_.push_back(tensor);
39  cv_.notify_all();
40  }
41 
42  void setFinish() {
43  std::lock_guard<std::mutex> lock(mutex_);
44  is_finish_ = true;
45  cv_.notify_all();
46  }
47 
48  device::Tensor *pop(Runtime *runtime) {
49  std::unique_lock<std::mutex> lock(mutex_);
50  cv_.wait(lock, [this, runtime]() {
51  bool flag = current_index_[runtime] < tensors_.size();
52  return flag || is_finish_;
53  });
54  if (is_finish_) {
55  return nullptr;
56  }
57  device::Tensor *tensor = tensors_[current_index_[runtime]];
58  current_index_[runtime]++;
59  return tensor;
60  }
61 };
62 
64  public:
65  Runtime(const base::DeviceType &device_type) : device_type_(device_type) {};
66  virtual ~Runtime() {
67  if (!is_external_stream_ && stream_ != nullptr) {
68  device::destroyStream(stream_);
69  stream_ = nullptr;
70  }
71  };
72 
73  void setStream(device::Stream *stream);
76 
77  base::Status setWorkers(int worker_num,
78  std::vector<base::DeviceType> device_types =
79  std::vector<base::DeviceType>());
80 
81  virtual base::Status init(
82  std::vector<TensorWrapper *> &tensor_repository,
83  std::vector<OpWrapper *> &op_repository,
84  std::vector<device::Tensor *> &input_tensors,
85  std::vector<device::Tensor *> &output_tensors, bool is_dynamic_shape,
86  base::ShapeMap max_shape,
87  TensorPoolType tensor_pool_type =
89  bool is_external_tensor_pool_memory = false) = 0;
90  virtual base::Status deinit() = 0;
91 
92  virtual base::Status reshape(base::ShapeMap &shape_map) = 0;
93 
99  virtual int64_t getMemorySize();
107 
108  virtual base::Status preRun() = 0;
109  virtual base::Status run() = 0;
110  virtual base::Status postRun() = 0;
111 
119 
130  const std::string &name, base::DeviceType device_type, bool is_copy,
131  base::DataFormat data_format) = 0;
132 
133  protected:
139  bool is_external_stream_ = false;
140  device::Stream *stream_ = nullptr;
141  TensorPoolType tensor_pool_type_ =
143  bool is_external_tensor_pool_memory_ = false;
145  bool is_dynamic_shape_ = false; // 是否是动态shape
146  base::ShapeMap max_shape_ = base::ShapeMap(); // 当为动态输入时最大shape
147  bool is_pure_dynamic_shape_ = false;
148  std::vector<TensorWrapper *> tensor_repository_;
149  std::vector<OpWrapper *> op_repository_;
150  std::vector<device::Tensor *> input_tensors_;
151  std::vector<device::Tensor *> output_tensors_;
152  int worker_num_ = 1;
153  std::vector<base::DeviceType> device_types_;
154 };
155 
161  public:
162  virtual ~RuntimeCreator() {};
163 
164  virtual Runtime *createRuntime(const base::DeviceType &device_type,
165  base::ParallelType parallel_type) = 0;
166 };
167 
173 template <typename T>
175  virtual Runtime *createRuntime(const base::DeviceType &device_type,
176  base::ParallelType parallel_type) {
177  auto Runtime = new T(device_type);
178  return Runtime;
179  }
180 };
181 
188 std::map<base::ParallelType, std::shared_ptr<RuntimeCreator>> &
190 
196 template <typename T>
198  public:
199  explicit TypeRuntimeRegister(base::ParallelType parallel_type) {
200  getGlobalRuntimeCreatorMap()[parallel_type] = std::shared_ptr<T>(new T());
201  }
202 };
203 
205  base::ParallelType parallel_type);
206 
207 } // namespace net
208 } // namespace nndeploy
209 
210 #endif
device::Tensor * pop(Runtime *runtime)
Definition: runtime.h:48
std::condition_variable cv_
Definition: runtime.h:31
std::vector< Runtime * > consumers_
Definition: runtime.h:27
std::vector< Runtime * > producers_
Definition: runtime.h:26
void push(device::Tensor *tensor)
Definition: runtime.h:35
std::map< Runtime *, int > current_index_
Definition: runtime.h:32
std::vector< device::Tensor * > tensors_
Definition: runtime.h:24
Runtime的创建类
Definition: runtime.h:160
virtual Runtime * createRuntime(const base::DeviceType &device_type, base::ParallelType parallel_type)=0
std::vector< base::DeviceType > device_types_
Definition: runtime.h:153
virtual base::Status setMemory(device::Buffer *buffer)
设置推理所需的内存(推理内存由外部分配)
virtual base::Status deinit()=0
std::vector< device::Tensor * > input_tensors_
Definition: runtime.h:150
std::vector< TensorWrapper * > tensor_repository_
Definition: runtime.h:148
TensorPool * tensor_pool_
Definition: runtime.h:144
base::Status setWorkers(int worker_num, std::vector< base::DeviceType > device_types=std::vector< base::DeviceType >())
std::vector< device::Tensor * > output_tensors_
Definition: runtime.h:151
device::Stream * getStream()
void setStream(device::Stream *stream)
base::DeviceType device_type_
Definition: runtime.h:134
base::Status synchronize()
virtual base::Status postRun()=0
virtual ~Runtime()
Definition: runtime.h:66
virtual base::Status copyToInputTensor(device::Tensor *tensor)=0
将输入tensor复制到输入tensor
virtual int64_t getMemorySize()
获取推理所需的内存大小
virtual base::Status preRun()=0
virtual base::Status init(std::vector< TensorWrapper * > &tensor_repository, std::vector< OpWrapper * > &op_repository, std::vector< device::Tensor * > &input_tensors, std::vector< device::Tensor * > &output_tensors, bool is_dynamic_shape, base::ShapeMap max_shape, TensorPoolType tensor_pool_type=kTensorPool1DSharedObjectTypeGreedyBySizeImprove, bool is_external_tensor_pool_memory=false)=0
Runtime(const base::DeviceType &device_type)
Definition: runtime.h:65
virtual device::Tensor * getOutputTensorAfterRun(const std::string &name, base::DeviceType device_type, bool is_copy, base::DataFormat data_format)=0
获取推理后的输出tensor
virtual base::Status reshape(base::ShapeMap &shape_map)=0
virtual base::Status run()=0
std::vector< OpWrapper * > op_repository_
Definition: runtime.h:149
Runtime的创建类模板
Definition: runtime.h:174
Runtime的创建类的注册类模板
Definition: runtime.h:197
TypeRuntimeRegister(base::ParallelType parallel_type)
Definition: runtime.h:199
#define NNDEPLOY_CC_API
api
Definition: macro.h:29
std::map< std::string, std::vector< int > > ShapeMap
Definition: common.h:381
base::Status destroyStream(Stream *stream)
销毁流
std::map< base::ParallelType, std::shared_ptr< RuntimeCreator > > & getGlobalRuntimeCreatorMap()
Get the Global Runtime Creator Map object.
@ kTensorPool1DOffsetCalculateTypeGreedyByBreadth
Definition: tensor_pool.h:33
@ kTensorPool1DSharedObjectTypeGreedyBySizeImprove
Definition: tensor_pool.h:31
Runtime * createRuntime(const base::DeviceType &device_type, base::ParallelType parallel_type)