SourceXtractorPlusPlus
0.21
SourceXtractor++, the next generation SExtractor
Loading...
Searching...
No Matches
SEImplementation
src
lib
Plugin
Onnx
OnnxTaskFactory.cpp
Go to the documentation of this file.
1
18
#include <onnxruntime_cxx_api.h>
19
20
#include <AlexandriaKernel/memory_tools.h>
21
#include <NdArray/NdArray.h>
22
23
#include "
SEImplementation/Common/OnnxCommon.h
"
24
25
#include "
SEImplementation/Plugin/Onnx/OnnxPlugin.h
"
26
#include "
SEImplementation/Plugin/Onnx/OnnxSourceTask.h
"
27
#include "
SEImplementation/Plugin/Onnx/OnnxProperty.h
"
28
#include "
SEImplementation/Plugin/Onnx/OnnxConfig.h
"
29
30
#include "
SEImplementation/Plugin/Onnx/OnnxTaskFactory.h
"
31
32
namespace
SourceXtractor
{
33
37
static
std::string
generatePropertyName
(
const
OnnxModel
& model) {
38
std::stringstream
prop_name;
39
40
std::string
domain
= model.
getDomain
();
41
if
(!
domain
.empty()) {
42
prop_name <<
domain
<<
'.'
;
43
}
44
45
std::string
graph_name
= model.
getGraphName
();
46
if
(!
graph_name
.empty()) {
47
prop_name <<
graph_name
<<
'.'
;
48
}
49
50
prop_name << model.
getOutputName
();
51
52
return
prop_name.
str
();
53
}
54
55
OnnxTaskFactory::OnnxTaskFactory
() {}
56
57
std::shared_ptr<Task>
OnnxTaskFactory::createTask
(
const
PropertyId
&
property_id
)
const
{
58
if
(
property_id
== PropertyId::create<OnnxProperty>()) {
59
return
std::make_shared<OnnxSourceTask>
(
m_model_infos
);
60
}
61
return
nullptr
;
62
}
63
64
void
OnnxTaskFactory::reportConfigDependencies
(
Euclid::Configuration::ConfigManager
&
manager
)
const
{
65
manager
.registerConfiguration<
OnnxConfig
>();
66
}
67
68
void
OnnxTaskFactory::configure
(
Euclid::Configuration::ConfigManager
&
manager
) {
69
const
auto
&
onnx_config
=
manager
.getConfiguration<
OnnxConfig
>();
70
const
auto
& models =
onnx_config
.getModels();
71
72
for
(
auto
model_path
: models) {
73
auto
model =
std::make_shared<OnnxModel>
(
model_path
);
74
75
if
(model->getInputType() !=
ONNX_TENSOR_ELEMENT_DATA_TYPE_FLOAT
) {
76
throw
Elements::Exception
() <<
"Only ONNX models with float input are supported"
;
77
}
78
79
if
(model->getInputShape().size() != 4) {
80
throw
Elements::Exception
() <<
"Expected 4 axes for the input layer, got "
<< model->getInputShape().size();
81
}
82
83
auto
prop_name =
generatePropertyName
(*model);
84
onnx_logger
.
info
() <<
"Output name will be "
<< prop_name;
85
86
m_model_infos
.emplace_back(
OnnxSourceTask::OnnxModelInfo
{model, prop_name});
87
88
}
89
}
90
91
template
<
typename
T>
92
static
void
registerColumnConverter
(
OutputRegistry
&
registry
,
const
OnnxSourceTask::OnnxModelInfo
&
model_info
) {
93
auto
key
=
model_info
.prop_name;
94
95
registry
.registerColumnConverter<
OnnxProperty
,
Euclid::NdArray::NdArray<T>
>(
96
model_info
.prop_name, [
key
](
const
OnnxProperty
&
prop
) {
97
return
prop
.getData<T>(
key
);
98
},
""
,
model_info
.model->getModelPath()
99
);
100
}
101
102
void
OnnxTaskFactory::registerPropertyInstances
(
OutputRegistry
&
registry
) {
103
for
(
const
auto
&
model_info
:
m_model_infos
) {
104
switch
(
model_info
.model->getOutputType()) {
105
case
ONNX_TENSOR_ELEMENT_DATA_TYPE_FLOAT
:
106
registerColumnConverter<float>
(
registry
,
model_info
);
107
break
;
108
case
ONNX_TENSOR_ELEMENT_DATA_TYPE_INT32
:
109
registerColumnConverter<int32_t>
(
registry
,
model_info
);
110
break
;
111
default
:
112
throw
Elements::Exception
() <<
"Unsupported output type: "
<<
model_info
.model->getOutputType();
113
}
114
}
115
}
116
117
}
// end of namespace SourceXtractor
OnnxCommon.h
OnnxConfig.h
OnnxPlugin.h
OnnxProperty.h
OnnxSourceTask.h
OnnxTaskFactory.h
std::string
std::stringstream
Elements::Exception
Elements::Logging::info
void info(const std::string &logMessage)
Euclid::Configuration::ConfigManager
Euclid::NdArray::NdArray
SourceXtractor::OnnxConfig
Definition
OnnxConfig.h:28
SourceXtractor::OnnxModel
Definition
OnnxModel.h:23
SourceXtractor::OnnxModel::getGraphName
std::string getGraphName() const
Definition
OnnxModel.h:128
SourceXtractor::OnnxModel::getDomain
std::string getDomain() const
Definition
OnnxModel.h:124
SourceXtractor::OnnxModel::getOutputName
std::string getOutputName() const
Definition
OnnxModel.h:136
SourceXtractor::OnnxProperty
Definition
OnnxProperty.h:30
SourceXtractor::OnnxTaskFactory::reportConfigDependencies
void reportConfigDependencies(Euclid::Configuration::ConfigManager &manager) const override
Registers all the Configuration dependencies.
Definition
OnnxTaskFactory.cpp:64
SourceXtractor::OnnxTaskFactory::OnnxTaskFactory
OnnxTaskFactory()
Definition
OnnxTaskFactory.cpp:55
SourceXtractor::OnnxTaskFactory::createTask
std::shared_ptr< Task > createTask(const PropertyId &property_id) const override
Returns a Task producing a Property corresponding to the given PropertyId.
Definition
OnnxTaskFactory.cpp:57
SourceXtractor::OnnxTaskFactory::registerPropertyInstances
void registerPropertyInstances(OutputRegistry ®istry) override
Definition
OnnxTaskFactory.cpp:102
SourceXtractor::OnnxTaskFactory::m_model_infos
std::vector< OnnxSourceTask::OnnxModelInfo > m_model_infos
Definition
OnnxTaskFactory.h:49
SourceXtractor::OnnxTaskFactory::configure
void configure(Euclid::Configuration::ConfigManager &manager) override
Method which should initialize the object.
Definition
OnnxTaskFactory.cpp:68
SourceXtractor::OutputRegistry
Definition
OutputRegistry.h:37
SourceXtractor::PropertyId
Identifier used to set and retrieve properties.
Definition
PropertyId.h:40
std::function
SourceXtractor
Definition
Aperture.h:30
SourceXtractor::registerColumnConverter
static void registerColumnConverter(OutputRegistry ®istry, const OnnxSourceTask::OnnxModelInfo &model_info)
Definition
OnnxTaskFactory.cpp:92
SourceXtractor::generatePropertyName
static std::string generatePropertyName(const OnnxModel &model)
Definition
OnnxTaskFactory.cpp:37
SourceXtractor::onnx_logger
Elements::Logging onnx_logger
Logger for the ONNX plugin.
Definition
OnnxPlugin.cpp:26
std::stringstream::str
T str(T... args)
SourceXtractor::OnnxSourceTask::OnnxModelInfo
Definition
OnnxSourceTask.h:31
Generated by
1.10.0