SourceXtractorPlusPlus 0.21
SourceXtractor++, the next generation SExtractor
Loading...
Searching...
No Matches
OnnxTaskFactory.cpp
Go to the documentation of this file.
1
18#include <onnxruntime_cxx_api.h>
19
20#include <AlexandriaKernel/memory_tools.h>
21#include <NdArray/NdArray.h>
22
24
29
31
32namespace SourceXtractor {
33
38 std::stringstream prop_name;
39
41 if (!domain.empty()) {
42 prop_name << domain << '.';
43 }
44
46 if (!graph_name.empty()) {
47 prop_name << graph_name << '.';
48 }
49
50 prop_name << model.getOutputName();
51
52 return prop_name.str();
53}
54
56
58 if (property_id == PropertyId::create<OnnxProperty>()) {
60 }
61 return nullptr;
62}
63
67
69 const auto& onnx_config = manager.getConfiguration<OnnxConfig>();
70 const auto& models = onnx_config.getModels();
71
72 for (auto model_path : models) {
74
75 if (model->getInputType() != ONNX_TENSOR_ELEMENT_DATA_TYPE_FLOAT) {
76 throw Elements::Exception() << "Only ONNX models with float input are supported";
77 }
78
79 if (model->getInputShape().size() != 4) {
80 throw Elements::Exception() << "Expected 4 axes for the input layer, got " << model->getInputShape().size();
81 }
82
83 auto prop_name = generatePropertyName(*model);
84 onnx_logger.info() << "Output name will be " << prop_name;
85
86 m_model_infos.emplace_back(OnnxSourceTask::OnnxModelInfo {model, prop_name});
87
88 }
89}
90
91template<typename T>
93 auto key = model_info.prop_name;
94
95 registry.registerColumnConverter<OnnxProperty, Euclid::NdArray::NdArray<T>>(
96 model_info.prop_name, [key](const OnnxProperty& prop) {
97 return prop.getData<T>(key);
98 }, "", model_info.model->getModelPath()
99 );
100}
101
103 for (const auto& model_info : m_model_infos) {
104 switch (model_info.model->getOutputType()) {
107 break;
110 break;
111 default:
112 throw Elements::Exception() << "Unsupported output type: " << model_info.model->getOutputType();
113 }
114 }
115}
116
117} // end of namespace SourceXtractor
void info(const std::string &logMessage)
std::string getGraphName() const
Definition OnnxModel.h:128
std::string getDomain() const
Definition OnnxModel.h:124
std::string getOutputName() const
Definition OnnxModel.h:136
void reportConfigDependencies(Euclid::Configuration::ConfigManager &manager) const override
Registers all the Configuration dependencies.
std::shared_ptr< Task > createTask(const PropertyId &property_id) const override
Returns a Task producing a Property corresponding to the given PropertyId.
void registerPropertyInstances(OutputRegistry &registry) override
std::vector< OnnxSourceTask::OnnxModelInfo > m_model_infos
void configure(Euclid::Configuration::ConfigManager &manager) override
Method which should initialize the object.
Identifier used to set and retrieve properties.
Definition PropertyId.h:40
static void registerColumnConverter(OutputRegistry &registry, const OnnxSourceTask::OnnxModelInfo &model_info)
static std::string generatePropertyName(const OnnxModel &model)
Elements::Logging onnx_logger
Logger for the ONNX plugin.
T str(T... args)