import nndeploy._nndeploy_internal as _C
import nndeploy.base
import nndeploy.device
import nndeploy.ir
import nndeploy.op
[文档]class InferenceParam(_C.inference.InferenceParam):
[文档] def __init__(self, inference_type):
super().__init__(inference_type)
self._default_dic = {}
[文档] def set_inference_type(self, inference_type: nndeploy.base.InferenceType):
super().set_inference_type(inference_type)
[文档] def get_inference_type(self):
return super().get_inference_type()
[文档] def set_model_type(self, model_type: nndeploy.base.ModelType):
super().set_model_type(model_type)
[文档] def get_model_type(self):
return super().get_model_type()
[文档] def set_is_path(self, is_path: bool):
super().set_is_path(is_path)
[文档] def get_is_path(self):
return super().get_is_path()
[文档] def set_model_value(self, model_value, index=-1):
if isinstance(model_value, list):
super().set_model_value(model_value)
else:
super().set_model_value(model_value, index)
[文档] def get_model_value(self):
return super().get_model_value()
[文档] def set_output_num(self, output_num: int):
super().set_output_num(output_num)
[文档] def get_output_num(self):
return super().get_output_num()
[文档] def set_output_name(self, output_name, index=-1):
if isinstance(output_name, list):
super().set_output_name(output_name)
else:
super().set_output_name(output_name, index)
[文档] def get_output_name(self):
return super().get_output_name()
[文档] def set_encrypt_type(self, encrypt_type: nndeploy.base.EncryptType):
super().set_encrypt_type(encrypt_type)
[文档] def get_encrypt_type(self):
return super().get_encrypt_type()
[文档] def set_license(self, license: str):
super().set_license(license)
[文档] def get_license(self):
return super().get_license()
[文档] def set_device_type(self, device_type: nndeploy.base.DeviceType):
super().set_device_type(device_type)
[文档] def get_device_type(self):
return super().get_device_type()
[文档] def set_num_thread(self, num_thread: int):
super().set_num_thread(num_thread)
[文档] def get_num_thread(self):
return super().get_num_thread()
[文档] def set_gpu_tune_kernel(self, gpu_tune_kernel: int):
super().set_gpu_tune_kernel(gpu_tune_kernel)
[文档] def get_gpu_tune_kernel(self):
return super().get_gpu_tune_kernel()
[文档] def set_share_memory_mode(self, share_memory_mode: nndeploy.base.ShareMemoryType):
super().set_share_memory_mode(share_memory_mode)
[文档] def get_share_memory_mode(self):
return super().get_share_memory_mode()
[文档] def set_precision_type(self, precision_type: nndeploy.base.PrecisionType):
super().set_precision_type(precision_type)
[文档] def get_precision_type(self):
return super().get_precision_type()
[文档] def set_power_type(self, power_type: nndeploy.base.PowerType):
super().set_power_type(power_type)
[文档] def get_power_type(self):
return super().get_power_type()
[文档] def set_is_dynamic_shape(self, is_dynamic_shape: bool):
super().set_is_dynamic_shape(is_dynamic_shape)
[文档] def get_is_dynamic_shape(self):
return super().get_is_dynamic_shape()
[文档] def set_min_shape(self, min_shape: dict):
super().set_min_shape(min_shape)
[文档] def get_min_shape(self):
return super().get_min_shape()
[文档] def set_opt_shape(self, opt_shape: dict):
super().set_opt_shape(opt_shape)
[文档] def get_opt_shape(self):
return super().get_opt_shape()
[文档] def set_max_shape(self, max_shape: dict):
super().set_max_shape(max_shape)
[文档] def get_max_shape(self):
return super().get_max_shape()
[文档] def set_cache_path(self, cache_path: list):
super().set_cache_path(cache_path)
[文档] def get_cache_path(self):
return super().get_cache_path()
[文档] def set_library_path(self, library_path, index=-1):
if isinstance(library_path, list):
super().set_library_path(library_path)
else:
super().set_library_path(library_path, index)
[文档] def get_library_path(self):
return super().get_library_path()
def __str__(self):
return str(self._default_dic)
[文档] def set(self, dic : dict):
for k, v in dic.items():
self._default_dic[k] = v
[文档] def get(self, key: str):
if key in self._default_dic:
return self._default_dic[key]
else:
print(f"Unsupported key: {key}")
return None
[文档]class InferenceParamCreator(_C.inference.InferenceParamCreator):
[文档] def __init__(self):
super().__init__()
[文档] def create_inference_param(self, type: nndeploy.base.InferenceType):
# 必须实现
raise NotImplementedError("base class InferenceParamCreator does not implement create_inference_param method")
[文档]def register_inference_param_creator(type: nndeploy.base.InferenceType, creator: InferenceParamCreator):
_C.inference.register_inference_param_creator(type, creator)
[文档]def create_inference_param(type: nndeploy.base.InferenceType):
return _C.inference.create_inference_param(type)