SourceXtractorPlusPlus
0.21
SourceXtractor++, the next generation SExtractor
Loading...
Searching...
No Matches
SEImplementation
src
lib
Common
OnnxModel.cpp
Go to the documentation of this file.
1
/*
2
* OnnxModel.cpp
3
*
4
* Created on: Feb 16, 2021
5
* Author: mschefer
6
*/
7
8
#include "ElementsKernel/Exception.h"
9
#include "ElementsKernel/Logging.h"
10
#include "AlexandriaKernel/memory_tools.h"
11
12
#include "
SEImplementation/Common/OnnxModel.h
"
13
#include "
SEImplementation/Common/OnnxCommon.h
"
14
15
namespace
SourceXtractor
{
16
17
OnnxModel::OnnxModel
(
const
std::string
&
model_path
) {
18
m_model_path
=
model_path
;
19
20
Elements::Logging
onnx_logger
=
Elements::Logging::getLogger
(
"Onnx"
);
21
auto
allocator
= Ort::AllocatorWithDefaultOptions();
22
23
onnx_logger
.
info
() <<
"Loading ONNX model "
<<
model_path
;
24
m_session
= Euclid::make_unique<Ort::Session>(
ORT_ENV
,
model_path
.c_str(), Ort::SessionOptions{nullptr});
25
26
if
(
m_session
->GetOutputCount() != 1) {
27
throw
Elements::Exception
() <<
"Only ONNX models with a single output tensor are supported"
;
28
}
29
30
for
(
size_t
i
=0;
i
<
m_session
->GetInputCount();
i
++) {
31
auto
input_type
=
m_session
->GetInputTypeInfo(
i
);
32
33
m_input_names
.
emplace_back
(
m_session
->GetInputNameAllocated(
i
,
allocator
).
get
());
34
m_input_shapes
.
emplace_back
(
input_type
.GetTensorTypeAndShapeInfo().GetShape());
35
m_input_types
.
emplace_back
(
input_type
.GetTensorTypeAndShapeInfo().GetElementType());
36
}
37
38
m_output_name
=
std::string
(
m_session
->GetOutputNameAllocated(0,
allocator
).
get
());
39
m_domain_name
=
std::string
(
m_session
->GetModelMetadata().GetDomainAllocated(
allocator
).
get
());
40
m_graph_name
=
std::string
(
m_session
->GetModelMetadata().GetGraphNameAllocated(
allocator
).
get
());
41
42
auto
output_type
=
m_session
->GetOutputTypeInfo(0);
43
44
m_output_shape
=
output_type
.GetTensorTypeAndShapeInfo().GetShape();
45
m_output_type
=
output_type
.GetTensorTypeAndShapeInfo().GetElementType();
46
47
// onnx_logger.info() << "ONNX model with input of " << formatShape(m_input_shapes[0]);
48
// onnx_logger.info() << "ONNX model with output of " << formatShape(m_output_shape);
49
}
50
51
}
OnnxCommon.h
OnnxModel.h
std::allocator
std::string
Elements::Exception
Elements::Logging
Elements::Logging::getLogger
static Logging getLogger(const std::string &name="")
Elements::Logging::info
void info(const std::string &logMessage)
SourceXtractor::OnnxModel::m_input_types
std::vector< ONNXTensorElementDataType > m_input_types
Input type.
Definition
OnnxModel.h:157
SourceXtractor::OnnxModel::m_session
std::unique_ptr< Ort::Session > m_session
Session, one per model. In theory, it is thread-safe.
Definition
OnnxModel.h:162
SourceXtractor::OnnxModel::m_output_name
std::string m_output_name
Output tensor name.
Definition
OnnxModel.h:156
SourceXtractor::OnnxModel::m_output_type
ONNXTensorElementDataType m_output_type
Output type.
Definition
OnnxModel.h:158
SourceXtractor::OnnxModel::m_input_names
std::vector< std::string > m_input_names
Input tensor name.
Definition
OnnxModel.h:155
SourceXtractor::OnnxModel::OnnxModel
OnnxModel(const std::string &model_path)
Definition
OnnxModel.cpp:17
SourceXtractor::OnnxModel::m_output_shape
std::vector< std::int64_t > m_output_shape
Output tensor shape.
Definition
OnnxModel.h:160
SourceXtractor::OnnxModel::m_graph_name
std::string m_graph_name
graph name
Definition
OnnxModel.h:154
SourceXtractor::OnnxModel::m_domain_name
std::string m_domain_name
domain name
Definition
OnnxModel.h:153
SourceXtractor::OnnxModel::m_model_path
std::string m_model_path
Path to the ONNX model.
Definition
OnnxModel.h:161
SourceXtractor::OnnxModel::m_input_shapes
std::vector< std::vector< std::int64_t > > m_input_shapes
Input tensor shape.
Definition
OnnxModel.h:159
std::vector::emplace_back
T emplace_back(T... args)
std::function
std::unique_ptr::get
T get(T... args)
SourceXtractor
Definition
Aperture.h:30
SourceXtractor::onnx_logger
Elements::Logging onnx_logger
Logger for the ONNX plugin.
Definition
OnnxPlugin.cpp:26
SourceXtractor::ORT_ENV
Ort::Env ORT_ENV
Definition
OnnxCommon.cpp:25
Generated by
1.10.0