SourceXtractorPlusPlus
0.21
SourceXtractor++, the next generation SExtractor
Loading...
Searching...
No Matches
SEImplementation
SEImplementation
Common
OnnxModel.h
Go to the documentation of this file.
1
/*
2
* OnnxModel.h
3
*
4
* Created on: Feb 16, 2021
5
* Author: mschefer
6
*/
7
8
#ifndef _SEIMPLEMENTATION_COMMON_ONNXMODEL_H_
9
#define _SEIMPLEMENTATION_COMMON_ONNXMODEL_H_
10
11
#include <cstdint>
12
#include <map>
13
#include <memory>
14
#include <vector>
15
#include <list>
16
#include <iostream>
17
#include <numeric>
18
19
#include <onnxruntime_cxx_api.h>
20
21
namespace
SourceXtractor
{
22
23
class
OnnxModel
{
24
public
:
25
26
explicit
OnnxModel
(
const
std::string
&
model_path
);
27
28
template
<
typename
T,
typename
U>
29
void
run
(
std::vector<T>
&
input_data
,
std::vector<U>
&
output_data
)
const
{
30
Ort::RunOptions
run_options
;
31
auto
mem_info
= Ort::MemoryInfo::CreateCpu(
OrtDeviceAllocator
,
OrtMemTypeCPU
);
32
33
// Allocate memory
34
std::vector<int64_t>
input_shape
(
m_input_shapes
[0].
begin
(),
m_input_shapes
[0].
end
());
35
input_shape
[0] = 1;
36
size_t
input_size
=
std::accumulate
(
input_shape
.begin(),
input_shape
.end(), 1u,
std::multiplies<size_t>
());
37
38
std::vector<int64_t>
output_shape
(
m_output_shape
.
begin
(),
m_output_shape
.
end
());
39
output_shape
[0] = 1;
40
size_t
output_size
=
std::accumulate
(
output_shape
.begin(),
output_shape
.end(), 1u,
std::multiplies<size_t>
());
41
42
// Check input and output size are OK
43
if
(
input_data
.size() <
input_size
||
output_data
.size() <
output_size
) {
44
throw
Elements::Exception
() <<
"OnnxModel: Insufficient buffer size "
;
45
}
46
47
// Setup input/output tensors
48
auto
input_tensor
= Ort::Value::CreateTensor<T>(
49
mem_info
,
input_data
.data(),
input_data
.size(),
input_shape
.data(),
input_shape
.size());
50
auto
output_tensor
= Ort::Value::CreateTensor<U>(
51
mem_info
,
output_data
.data(),
output_data
.size(),
output_shape
.data(),
output_shape
.size());
52
53
// Run the model
54
const
char
*
input_name
=
m_input_names
[0].c_str();
55
const
char
*
output_name
=
m_output_name
.
c_str
();
56
57
m_session
->Run(
run_options
, &
input_name
, &
input_tensor
, 1, &
output_name
, &
output_tensor
, 1);
58
}
59
60
template
<
typename
T,
typename
U>
61
void
runMultiInput
(
std::map
<
std::string
,
std::vector<T>
>&
input_data
,
std::vector<U>
&
output_data
)
const
{
62
Ort::RunOptions
run_options
;
63
auto
mem_info
= Ort::MemoryInfo::CreateCpu(
OrtDeviceAllocator
,
OrtMemTypeCPU
);
64
65
std::vector<const char *>
input_names
;
66
std::vector<Ort::Value>
input_tensors
;
67
68
int
inputs_nb
=
m_input_names
.
size
();
69
for
(
int
i
=0;
i
<
inputs_nb
;
i
++) {
70
input_names
.emplace_back(
m_input_names
[
i
].c_str());
71
72
// Allocate memory
73
std::vector<int64_t>
input_shape
(
m_input_shapes
[
i
].
begin
(),
m_input_shapes
[
i
].
end
());
74
input_shape
[0] = 1;
75
size_t
input_size
=
std::accumulate
(
input_shape
.begin(),
input_shape
.end(), 1u,
std::multiplies<size_t>
());
76
77
// Check input size is OK
78
if
(
input_data
[
m_input_names
[
i
]].size() <
input_size
) {
79
throw
Elements::Exception
() <<
"OnnxModel: Insufficient buffer size "
;
80
}
81
82
input_tensors
.emplace_back(Ort::Value::CreateTensor<T>(
83
mem_info
,
input_data
[
m_input_names
[
i
]].data(),
input_data
[
m_input_names
[
i
]].size(),
84
input_shape
.data(),
input_shape
.size()));
85
}
86
87
// Output name and shape
88
const
char
*
output_name
=
m_output_name
.
c_str
();
89
std::vector<int64_t>
output_shape
(
m_output_shape
.
begin
(),
m_output_shape
.
end
());
90
output_shape
[0] = 1;
91
92
// Setup output tensor
93
size_t
output_size
=
std::accumulate
(
output_shape
.begin(),
output_shape
.end(), 1u,
std::multiplies<size_t>
());
94
95
// Check output and output size are OK
96
if
(
output_data
.size() <
output_size
) {
97
throw
Elements::Exception
() <<
"OnnxModel: Insufficient buffer size "
;
98
}
99
100
auto
output_tensor
= Ort::Value::CreateTensor<U>(
101
mem_info
,
output_data
.data(),
output_data
.size(),
output_shape
.data(),
output_shape
.size());
102
103
// Run the model
104
m_session
->Run(
run_options
, &
input_names
[0], &
input_tensors
[0],
inputs_nb
, &
output_name
, &
output_tensor
, 1);
105
}
106
107
108
ONNXTensorElementDataType
getInputType
()
const
{
109
return
m_input_types
[0];
110
}
111
112
ONNXTensorElementDataType
getOutputType
()
const
{
113
return
m_output_type
;
114
}
115
116
const
std::vector<std::int64_t>
&
getInputShape
()
const
{
117
return
m_input_shapes
[0];
118
}
119
120
const
std::vector<std::int64_t>
&
getOutputShape
()
const
{
121
return
m_output_shape
;
122
}
123
124
std::string
getDomain
()
const
{
125
return
m_domain_name
;
126
}
127
128
std::string
getGraphName
()
const
{
129
return
m_graph_name
;
130
}
131
132
std::string
getInputName
()
const
{
133
return
m_input_names
[0];
134
}
135
136
std::string
getOutputName
()
const
{
137
return
m_output_name
;
138
}
139
140
std::string
getModelPath
()
const
{
141
return
m_model_path
;
142
}
143
144
size_t
getInputNb
()
const
{
145
return
m_input_names
.
size
();
146
}
147
148
size_t
getOutputNb
()
const
{
149
return
1U;
150
}
151
152
private
:
153
std::string
m_domain_name
;
154
std::string
m_graph_name
;
155
std::vector<std::string>
m_input_names
;
156
std::string
m_output_name
;
157
std::vector<ONNXTensorElementDataType>
m_input_types
;
158
ONNXTensorElementDataType
m_output_type
;
159
std::vector<std::vector<std::int64_t>
>
m_input_shapes
;
160
std::vector<std::int64_t>
m_output_shape
;
161
std::string
m_model_path
;
162
std::unique_ptr<Ort::Session>
m_session
;
163
};
164
165
}
166
167
168
#endif
/* _SEIMPLEMENTATION_COMMON_ONNXMODEL_H_ */
std::accumulate
T accumulate(T... args)
std::string
std::begin
T begin(T... args)
std::string::c_str
T c_str(T... args)
Elements::Exception
SourceXtractor::OnnxModel
Definition
OnnxModel.h:23
SourceXtractor::OnnxModel::run
void run(std::vector< T > &input_data, std::vector< U > &output_data) const
Definition
OnnxModel.h:29
SourceXtractor::OnnxModel::getInputType
ONNXTensorElementDataType getInputType() const
Definition
OnnxModel.h:108
SourceXtractor::OnnxModel::getOutputType
ONNXTensorElementDataType getOutputType() const
Definition
OnnxModel.h:112
SourceXtractor::OnnxModel::m_input_types
std::vector< ONNXTensorElementDataType > m_input_types
Input type.
Definition
OnnxModel.h:157
SourceXtractor::OnnxModel::m_session
std::unique_ptr< Ort::Session > m_session
Session, one per model. In theory, it is thread-safe.
Definition
OnnxModel.h:162
SourceXtractor::OnnxModel::getGraphName
std::string getGraphName() const
Definition
OnnxModel.h:128
SourceXtractor::OnnxModel::getDomain
std::string getDomain() const
Definition
OnnxModel.h:124
SourceXtractor::OnnxModel::m_output_name
std::string m_output_name
Output tensor name.
Definition
OnnxModel.h:156
SourceXtractor::OnnxModel::getOutputNb
size_t getOutputNb() const
Definition
OnnxModel.h:148
SourceXtractor::OnnxModel::getOutputShape
const std::vector< std::int64_t > & getOutputShape() const
Definition
OnnxModel.h:120
SourceXtractor::OnnxModel::getOutputName
std::string getOutputName() const
Definition
OnnxModel.h:136
SourceXtractor::OnnxModel::m_output_type
ONNXTensorElementDataType m_output_type
Output type.
Definition
OnnxModel.h:158
SourceXtractor::OnnxModel::m_input_names
std::vector< std::string > m_input_names
Input tensor name.
Definition
OnnxModel.h:155
SourceXtractor::OnnxModel::OnnxModel
OnnxModel(const std::string &model_path)
Definition
OnnxModel.cpp:17
SourceXtractor::OnnxModel::getInputName
std::string getInputName() const
Definition
OnnxModel.h:132
SourceXtractor::OnnxModel::m_output_shape
std::vector< std::int64_t > m_output_shape
Output tensor shape.
Definition
OnnxModel.h:160
SourceXtractor::OnnxModel::m_graph_name
std::string m_graph_name
graph name
Definition
OnnxModel.h:154
SourceXtractor::OnnxModel::runMultiInput
void runMultiInput(std::map< std::string, std::vector< T > > &input_data, std::vector< U > &output_data) const
Definition
OnnxModel.h:61
SourceXtractor::OnnxModel::m_domain_name
std::string m_domain_name
domain name
Definition
OnnxModel.h:153
SourceXtractor::OnnxModel::getInputShape
const std::vector< std::int64_t > & getInputShape() const
Definition
OnnxModel.h:116
SourceXtractor::OnnxModel::m_model_path
std::string m_model_path
Path to the ONNX model.
Definition
OnnxModel.h:161
SourceXtractor::OnnxModel::getInputNb
size_t getInputNb() const
Definition
OnnxModel.h:144
SourceXtractor::OnnxModel::m_input_shapes
std::vector< std::vector< std::int64_t > > m_input_shapes
Input tensor shape.
Definition
OnnxModel.h:159
SourceXtractor::OnnxModel::getModelPath
std::string getModelPath() const
Definition
OnnxModel.h:140
std::end
T end(T... args)
std::function
std::map
SourceXtractor
Definition
Aperture.h:30
std::vector::size
T size(T... args)
Generated by
1.10.0