2 #ifndef _NNDEPLOY_DAG_NODE_H_
3 #define _NNDEPLOY_DAG_NODE_H_
20 #include "rapidjson/document.h"
21 #include "rapidjson/stringbuffer.h"
22 #include "rapidjson/writer.h"
46 std::initializer_list<std::string> inputs,
47 std::initializer_list<std::string> outputs)
48 : node_name_(node_name), inputs_(inputs), outputs_(outputs) {}
56 NodeDesc(
const std::string &node_name, std::vector<std::string> inputs,
57 std::vector<std::string> outputs)
58 : node_name_(node_name), inputs_(inputs), outputs_(outputs) {}
67 NodeDesc(
const std::string &node_key,
const std::string &node_name,
68 std::initializer_list<std::string> inputs,
69 std::initializer_list<std::string> outputs)
70 : node_key_(node_key),
71 node_name_(node_name),
82 NodeDesc(
const std::string &node_key,
const std::string &node_name,
83 std::vector<std::string> inputs, std::vector<std::string> outputs)
84 : node_key_(node_key),
85 node_name_(node_name),
95 std::string
getKey()
const {
return node_key_; }
101 std::string
getName()
const {
return node_name_; }
107 std::vector<std::string>
getInputs()
const {
return inputs_; }
113 std::vector<std::string>
getOutputs()
const {
return outputs_; }
123 rapidjson::Document::AllocatorType &allocator);
160 std::string node_key_;
161 std::string node_name_;
162 std::vector<std::string> inputs_;
163 std::vector<std::string> outputs_;
185 Node(
const std::string &name, std::vector<Edge *> inputs,
186 std::vector<Edge *> outputs);
434 const std::string &key, std::shared_ptr<base::Param> external_param);
466 const std::string &value);
476 const base::Any &value);
491 template <
typename T>
493 base::Any &any = this->getResourceWithoutState(key);
498 return base::get<T>(any);
532 template <
typename T>
534 Edge* edge = this->getResourceWithState(key);
535 if (edge ==
nullptr) {
536 NNDEPLOY_LOGE(
"edge is nullptr in setResourceWithState, key: %s.\n", key.c_str());
539 edge->
set<T>(value, is_external);
549 template <
typename T>
551 Edge* edge = this->getResourceWithState(key);
552 if (edge ==
nullptr) {
553 NNDEPLOY_LOGE(
"edge is nullptr in getResourceWithState, key: %s.\n", key.c_str());
556 return edge->
get<T>(
this);
580 const std::vector<std::string> &required_params);
683 const std::map<std::string, std::vector<std::string>> &dropdown_params);
692 const std::string &dropdown_param,
693 const std::vector<std::string> &dropdown_values);
777 std::vector<std::shared_ptr<Edge>> inputs);
785 std::vector<std::shared_ptr<Edge>> outputs);
808 template <
typename T>
810 Edge *edge = getInput(index);
811 if (edge ==
nullptr) {
814 return edge->
get<T>(
this);
825 template <
typename T>
827 Edge *edge = getOutput(index);
828 if (edge ==
nullptr) {
831 return edge->
set<T>(obj, is_external);
1029 template <
typename T>
1031 std::shared_ptr<EdgeTypeInfo> edge_type_info =
1032 std::make_shared<EdgeTypeInfo>();
1033 edge_type_info->setType<T>();
1034 edge_type_info->setEdgeName(desc);
1035 input_type_info_.push_back(edge_type_info);
1046 std::string desc =
"");
1060 template <
typename T>
1062 std::shared_ptr<EdgeTypeInfo> edge_type_info =
1063 std::make_shared<EdgeTypeInfo>();
1064 edge_type_info->setType<T>();
1065 edge_type_info->setEdgeName(desc);
1066 output_type_info_.push_back(edge_type_info);
1077 std::string desc =
"");
1165 virtual std::vector<Edge *>
forward(std::vector<Edge *> inputs);
1172 virtual std::vector<Edge *>
operator()(std::vector<Edge *> inputs);
1249 rapidjson::Document::AllocatorType &allocator);
1301 bool is_external_stream_ =
false;
1312 bool is_dynamic_input_ =
false;
1313 bool is_dynamic_output_ =
false;
1325 bool constructed_ =
false;
1326 bool is_inner_ =
false;
1327 bool parallel_type_set_ =
false;
1329 bool initialized_ =
false;
1330 bool is_running_ =
false;
1331 size_t run_size_ = 0;
1332 size_t completed_size_ = 0;
1333 bool is_time_profile_ =
false;
1334 bool is_debug_ =
false;
1335 bool is_trace_ =
false;
1336 bool traced_ =
false;
1337 bool is_graph_ =
false;
1338 bool is_loop_ =
false;
1339 bool is_condition_ =
false;
1340 bool is_composite_node_ =
false;
1345 int loop_count_ = -1;
1346 std::atomic<bool> stop_{
false};
1348 std::string version_ =
"1.0.0";
1369 std::vector<Edge *> inputs,
1370 std::vector<Edge *> outputs) = 0;
1380 const std::string &node_name, std::vector<Edge *> inputs,
1381 std::vector<Edge *> outputs) = 0;
1391 template <
typename T>
1402 std::vector<Edge *> inputs,
1403 std::vector<Edge *> outputs)
override {
1404 return new T(node_name, inputs, outputs);
1415 const std::string &node_name, std::vector<Edge *> inputs,
1416 std::vector<Edge *> outputs)
override {
1417 return std::make_shared<T>(node_name, inputs, outputs);
1442 std::shared_ptr<NodeCreator> creator) {
1443 auto it = creators_.find(node_key);
1445 if (it != creators_.end()) {
1448 NNDEPLOY_LOGW(
"Node name %s already exists, will be overwritten!\n",
1451 creators_[node_key] = creator;
1460 std::shared_ptr<NodeCreator>
getCreator(
const std::string &node_key) {
1464 auto it = creators_.find(node_key);
1465 if (it != creators_.end()) {
1476 std::set<std::string> keys;
1477 for (
auto &it : creators_) {
1478 keys.insert(it.first);
1486 std::map<std::string, std::shared_ptr<NodeCreator>> creators_;
1501 #define REGISTER_NODE(node_key, node_class) \
1503 struct NodeRegister_##node_class { \
1504 NodeRegister_##node_class() { \
1505 nndeploy::dag::getGlobalNodeFactory()->registerNode( \
1507 std::make_shared<nndeploy::dag::TypeNodeCreator<node_class>>()); \
1510 static NodeRegister_##node_class g_node_register_##node_class; \
1527 const std::string &node_name);
1538 const std::string &node_name,
1539 std::initializer_list<Edge *> inputs,
1540 std::initializer_list<Edge *> outputs);
1551 const std::string &node_name,
1552 std::vector<Edge *> inputs,
1553 std::vector<Edge *> outputs);
1562 const std::string &node_key,
const std::string &node_name);
1573 const std::string &node_key,
const std::string &node_name,
1574 std::initializer_list<Edge *> inputs,
1575 std::initializer_list<Edge *> outputs);
1586 const std::string &node_key,
const std::string &node_name,
1587 std::vector<Edge *> inputs, std::vector<Edge *> outputs);
1594 std::vector<Edge *> outputs,
Composite node Composite node is a special type of node in nndeploy that enhances the capabilities of...
Edge class in DAG graph for connecting nodes and transferring data.
base::Status set(device::Buffer *buffer, bool is_external=true)
Set Buffer data to Edge.
T * get(const Node *node)
Get arbitrary type data for specified node (template version)
Directed Acyclic Graph Node.
virtual std::shared_ptr< Node > createNodeSharedPtr(const std::string &node_name, std::vector< Edge * > inputs, std::vector< Edge * > outputs)=0
Create node (shared pointer)
virtual ~NodeCreator()=default
virtual Node * createNode(const std::string &node_name, std::vector< Edge * > inputs, std::vector< Edge * > outputs)=0
Create node.
base::Status saveFile(const std::string &path)
Save to file.
NodeDesc(const std::string &node_key, const std::string &node_name, std::initializer_list< std::string > inputs, std::initializer_list< std::string > outputs)
Constructor.
std::string getKey() const
Get node key.
std::vector< std::string > getInputs() const
Get input edge name list.
std::string serialize()
Serialize to JSON string.
base::Status deserialize(rapidjson::Value &json)
Deserialize from JSON.
std::string getName() const
Get node name.
base::Status deserialize(const std::string &json_str)
Deserialize from JSON string.
virtual ~NodeDesc()=default
base::Status serialize(rapidjson::Value &json, rapidjson::Document::AllocatorType &allocator)
Serialize to JSON.
base::Status loadFile(const std::string &path)
Load from file.
std::vector< std::string > getOutputs() const
Get output edge name list.
NodeDesc(const std::string &node_name, std::vector< std::string > inputs, std::vector< std::string > outputs)
Constructor.
NodeDesc(const std::string &node_name, std::initializer_list< std::string > inputs, std::initializer_list< std::string > outputs)
Constructor.
NodeDesc(const std::string &node_key, const std::string &node_name, std::vector< std::string > inputs, std::vector< std::string > outputs)
Constructor.
std::shared_ptr< NodeCreator > getCreator(const std::string &node_key)
Get node creator.
std::set< std::string > getNodeKeys()
Get all node keys.
void registerNode(const std::string &node_key, std::shared_ptr< NodeCreator > creator)
Register node.
static NodeFactory * getInstance()
Get singleton instance.
virtual base::Status setOutput(Edge *output, int index=-1)
Set output edge.
bool checkOutputs(std::vector< std::string > &outputs_name)
Check output edge names.
Node(const std::string &name, std::vector< Edge * > inputs, std::vector< Edge * > outputs)
Constructor.
std::string getInputName(int index=0)
Get input edge name at specified index.
base::Status clearUiParams()
Clear UI parameters.
bool isInputsChanged(std::vector< Edge * > inputs)
Check if inputs changed.
base::Status addUiParam(const std::string &ui_param)
Add UI parameter.
virtual std::shared_ptr< RunStatus > getRunStatus()
Get run status.
void setDynamicInput(bool is_dynamic_input)
Set whether it's dynamic input.
base::Status setOutputTypeInfo(std::string desc="")
Set output type information (template method)
base::Status setGraph(Graph *graph)
Set parent graph.
void setKey(const std::string &key)
Set node key.
virtual Edge * createInternalOutputEdge(const std::string &name)
Create internal output edge.
int getOutputCount()
Get output edge count.
virtual base::Status setParam(base::Param *param)
Set parameter.
std::map< std::string, std::vector< std::string > > dropdown_params_
Dropdown parameter mapping.
virtual base::ParallelType getParallelType()
Get parallel type.
std::map< std::string, std::shared_ptr< base::Param > > external_param_
External parameter mapping.
std::string getKey()
Get node key.
std::string name_
Node name.
std::vector< std::shared_ptr< EdgeTypeInfo > > input_type_info_
Input type information.
void setStream(device::Stream *stream)
Set compute stream.
bool getTimeProfileFlag()
Get time profile flag.
virtual base::Status setInputName(const std::string &name, int index=0)
Set input edge name.
std::string developer_
Developer information.
void setDesc(const std::string &desc)
Set node description.
std::string getSource()
Get source information.
virtual base::Status setExternalParam(const std::string &key, std::shared_ptr< base::Param > external_param)
Set external parameter.
bool checkOutputs(std::vector< Edge * > &outputs)
Check output edges.
std::string desc_
Node description.
std::string getOutputName(int index=0)
Get output edge name at specified index.
void setSource(const std::string &source)
Set source information.
std::string getDesc()
Get node description.
std::vector< std::string > getIoParams()
Get IO parameter list.
base::Status setResourceWithState(const std::string &key, T *value, bool is_external=true)
Set stateful resource (template method)
virtual std::vector< Edge * > operator()(Edge *input)
Single input invocation operator overload.
virtual base::Status saveFile(const std::string &path)
Save to file.
void setGraphFlag(bool flag)
Set graph flag.
T * getResourceWithState(const std::string &key)
Get stateful resource (template method)
virtual base::Status getParam(const std::string &key, base::Any &any)
Get parameter (Any type)
virtual base::Status setOutputName(const std::string &name, int index=0)
Set output edge name.
int getInputCount()
Get input edge count.
void setDynamicOutput(bool is_dynamic_output)
Set whether it's dynamic output.
bool isDynamicOutput()
Check if it's dynamic output.
virtual std::vector< Edge * > forward(std::vector< Edge * > inputs)
Node invocation interface.
std::string source_
Source information.
std::vector< Edge * > getAllOutput()
Get all output edges.
virtual std::vector< Edge * > operator()(std::vector< Edge * > inputs)
Node invocation operator overload.
virtual int getLoopCount()
Get loop count.
std::vector< std::string > required_params_
Required parameter list.
std::vector< std::shared_ptr< EdgeTypeInfo > > getInputTypeInfo()
Get input type information.
virtual std::string serialize()
Serialize to JSON string.
base::Status removeRequiredParam(const std::string &required_param)
Remove required parameter.
Graph * getGraph()
Get parent graph.
virtual void setLoopCount(int loop_count)
Set loop count.
virtual std::vector< std::string > getRealOutputsName()
Get real output names.
virtual std::vector< Edge * > operator()()
Parameter-free invocation operator overload.
virtual bool synchronize()
Synchronize execution.
virtual base::Status toStaticGraph()
Convert to static graph.
bool getConstructed()
Get whether it's constructed.
size_t getRunSize()
Get run count.
virtual base::Status addResourceWithState(const std::string &key, Edge *edge)
Add stateful resource.
virtual base::Status setParamSharedPtr(std::shared_ptr< base::Param > param)
Set parameter (shared pointer)
void setTimeProfileFlag(bool flag)
Set time profile flag.
virtual std::vector< Edge * > forward()
Parameter-free forward propagation.
bool getDebugFlag()
Get debug flag.
NodeType getNodeType()
Get node type.
size_t getCompletedSize()
Get completed count.
std::string getDeveloper()
Get developer information.
device::Stream * getStream()
Get compute stream.
virtual base::DeviceType getDeviceType()
Get device type.
base::Status clearRequiredParams()
Clear required parameters.
virtual bool interrupt()
Interrupt execution.
T * getInputData(int index=0)
Get input data (template method)
void setDeveloper(const std::string &developer)
Set developer information.
base::Status clearDropdownParams()
Clear dropdown parameters.
virtual base::Status setOutputNames(const std::vector< std::string > &names)
Set all output edge names.
void setName(const std::string &name)
Set node name.
virtual base::Status run()=0
Run node (pure virtual function)
virtual bool checkInterruptStatus()
Check interrupt status.
std::vector< std::string > getUiParams()
Get UI parameter list.
std::vector< std::string > getInputNames()
Get all input edge names.
virtual base::Param * getParam()
Get parameter.
std::vector< std::string > io_params_
IO parameter list.
base::Status setUiParams(const std::vector< std::string > &ui_params)
Set UI parameter list.
base::Status setRequiredParams(const std::vector< std::string > &required_params)
Set required parameter list.
void setNodeType(NodeType node_type)
Set node type.
virtual int64_t getMemorySize()
Get memory size.
std::string getVersion()
Get version number.
virtual base::Status addResourceWithoutState(const std::string &key, const base::Any &value)
Add stateless resource.
std::vector< std::string > getRequiredParams()
Get required parameter list.
virtual base::Status setIterInput(Edge *input, int index=-1)
Set iteration input edge.
virtual base::Status setParam(const std::string &key, const std::string &value)
Set parameter (string type)
bool getInitialized()
Get whether it's initialized.
bool getTraceFlag()
Get trace flag.
base::Status removeDropdownParam(const std::string &dropdown_param)
Remove dropdown parameter.
virtual base::Status setMemory(device::Buffer *buffer)
Set memory buffer.
virtual base::Status deserialize(rapidjson::Value &json)
Deserialize from JSON.
virtual base::Status init()
Initialize node.
base::Status setOutputTypeInfo(std::shared_ptr< EdgeTypeInfo > output_type_info, std::string desc="")
Set output type information.
virtual base::Status setDeviceType(base::DeviceType device_type)
Set device type.
virtual Edge * createResourceWithState(const std::string &key)
Create stateful resource.
virtual base::EdgeUpdateFlag updateInput()
Update input.
virtual base::Any & getResourceWithoutState(const std::string &key)
Get stateless resource.
virtual std::vector< Edge * > forward(Edge *input)
Single input forward propagation.
std::vector< Edge * > outputs_
Output edge list.
T getResourceWithoutState(const std::string &key)
Get stateless resource (template method)
Edge * getOutput(int index=0)
Get output edge.
Edge * getInput(int index=0)
Get input edge.
bool isDynamicInput()
Check if it's dynamic input.
Node(const std::string &name)
Constructor.
virtual base::Status deinit()
Deinitialize node.
virtual std::shared_ptr< base::Param > getParamSharedPtr()
Get parameter (shared pointer)
std::vector< Edge * > getAllInput()
Get all input edges.
base::Status clearIoParams()
Clear IO parameters.
IOType getIoType()
Get IO type.
void setDebugFlag(bool flag)
Set debug flag.
int getOutputIndex(const std::string &name)
Get output edge index by name.
base::Status setInputTypeInfo(std::string desc="")
Set input type information (template method)
virtual base::Status setParallelType(const base::ParallelType ¶lle_type)
Set parallel type.
std::vector< std::string > getOutputNames()
Get all output edge names.
virtual base::Status setInputs(std::vector< Edge * > inputs)
Set all input edges.
void setIoType(IOType io_type)
Set IO type.
void setRunningFlag(bool flag)
Set running flag.
int getInputIndex(const std::string &name)
Get input edge index by name.
virtual base::Status deserialize(const std::string &json_str)
Deserialize from JSON string.
virtual base::Status setParam(const std::string &key, base::Any &any)
Set parameter (Any type)
std::map< std::string, Edge * > internal_outputs_
Internal output edge mapping.
virtual base::Status defaultParam()
Configure default parameters.
void setInitializedFlag(bool flag)
Set initialized flag.
virtual base::Status serialize(rapidjson::Value &json, rapidjson::Document::AllocatorType &allocator)
Serialize to JSON.
std::string key_
Node key.
virtual base::Status setInputSharedPtr(std::shared_ptr< Edge > input, int index=-1)
Set input edge (shared pointer)
virtual base::Status setOutputs(std::vector< Edge * > outputs)
Set all output edges.
std::vector< Edge * > inputs_
Input edge list.
virtual base::Status setInputNames(const std::vector< std::string > &names)
Set all input edge names.
std::vector< std::string > ui_params_
UI parameter list.
bool checkInputs(std::vector< Edge * > &inputs)
Check input edges.
base::Status setInputTypeInfo(std::shared_ptr< EdgeTypeInfo > input_type_info, std::string desc="")
Set input type information.
base::Status setOutputData(T *obj, int index=0, bool is_external=true)
Set output data (template method)
virtual base::Status setInputsSharedPtr(std::vector< std::shared_ptr< Edge >> inputs)
Set all input edges (shared pointer)
virtual void clearInterrupt()
Clear interrupt status.
virtual base::Status setOutputsSharedPtr(std::vector< std::shared_ptr< Edge >> outputs)
Set all output edges (shared pointer)
std::map< std::string, std::vector< std::string > > getDropdownParams()
Get dropdown parameters.
virtual base::Status setOutputSharedPtr(std::shared_ptr< Edge > output, int index=-1)
Set output edge (shared pointer)
base::Status removeIoParam(const std::string &io_param)
Remove IO parameter.
base::Status setCompositeNode(CompositeNode *composite_node)
Set parent composite node.
std::vector< std::shared_ptr< EdgeTypeInfo > > getOutputTypeInfo()
Get output type information.
base::Status addRequiredParam(const std::string &required_param)
Add required parameter.
base::Status setVersion(const std::string &version)
Set version number.
base::Status addIoParam(const std::string &io_param)
Add IO parameter.
base::DeviceType device_type_
Device type.
virtual base::Status setInput(Edge *input, int index=-1)
Set input edge.
virtual void setTraceFlag(bool flag)
Set trace flag.
std::string getName()
Get node name.
virtual Edge * getResourceWithState(const std::string &key)
Get stateful resource.
base::Status removeUiParam(const std::string &ui_param)
Remove UI parameter.
base::Status setIoParams(const std::vector< std::string > &io_params)
Set IO parameter list.
CompositeNode * getCompositeNode()
Get parent composite node.
virtual std::shared_ptr< base::Param > getExternalParam(const std::string &key)
Get external parameter.
base::Status setDropdownParams(const std::map< std::string, std::vector< std::string >> &dropdown_params)
Set dropdown parameters.
virtual base::Status loadFile(const std::string &path)
Load from file.
std::vector< std::shared_ptr< EdgeTypeInfo > > output_type_info_
Output type information.
void setInnerFlag(bool flag)
Set inner flag.
bool isRunning()
Check if it's running.
bool getGraphFlag()
Get graph flag.
std::shared_ptr< base::Param > param_
Node parameters.
base::Status addDropdownParam(const std::string &dropdown_param, const std::vector< std::string > &dropdown_values)
Add dropdown parameter.
virtual std::shared_ptr< Node > createNodeSharedPtr(const std::string &node_name, std::vector< Edge * > inputs, std::vector< Edge * > outputs) override
Create node (shared pointer)
virtual Node * createNode(const std::string &node_name, std::vector< Edge * > inputs, std::vector< Edge * > outputs) override
Create node.
#define NNDEPLOY_LOGW(fmt,...)
#define NNDEPLOY_LOGE(fmt,...)
#define NNDEPLOY_CC_API
api
@ kStatusCodeErrorNullParam
std::function< base::Status(std::vector< Edge * > inputs, std::vector< Edge * > outputs, base::Param *param)> NodeFunc
Node function type definition.
std::shared_ptr< Node > createNodeSharedPtr(const std::string &node_key, const std::string &node_name)
Create node (shared pointer)
NodeFactory * getGlobalNodeFactory()
Get global node factory.
Node * createNode(const std::string &node_key, const std::string &node_name)
Create node.
std::set< std::string > getNodeKeys()
Get all node keys.