SourceXtractorPlusPlus
0.21
SourceXtractor++, the next generation SExtractor
Loading...
Searching...
No Matches
SEImplementation
src
lib
Plugin
Onnx
OnnxSourceTask.cpp
Go to the documentation of this file.
1
18
#include "
SEImplementation/Plugin/Onnx/OnnxSourceTask.h
"
19
#include "
SEImplementation/Plugin/Onnx/OnnxProperty.h
"
20
#include "
SEImplementation/Plugin/DetectionFrameImages/DetectionFrameImages.h
"
21
#include "
SEImplementation/Plugin/PixelCentroid/PixelCentroid.h
"
22
#include <NdArray/NdArray.h>
23
#include <AlexandriaKernel/memory_tools.h>
24
#include <onnxruntime_cxx_api.h>
25
26
namespace
NdArray
=
Euclid::NdArray
;
27
28
namespace
SourceXtractor
{
29
30
31
template
<
typename
T>
32
static
void
fillCutout
(
const
Image<T>
&
image
,
int
center_x
,
int
center_y
,
int
width,
int
height,
std::vector<T>
& out) {
33
int
x_start
=
center_x
- width / 2;
34
int
y_start
=
center_y
- height / 2;
35
int
x_end
=
x_start
+ width;
36
int
y_end
=
y_start
+ height;
37
38
ImageAccessor<T>
accessor
(
image
);
39
40
int
index = 0;
41
for
(
int
iy
=
y_start
;
iy
<
y_end
;
iy
++) {
42
for
(
int
ix
=
x_start
;
ix
<
x_end
;
ix
++, index++) {
43
if
(
ix
>= 0 &&
iy
>= 0 &&
ix
<
image
.getWidth() &&
iy
<
image
.getHeight()) {
44
out[index] =
accessor
.getValue(
ix
,
iy
);
45
}
46
}
47
}
48
}
49
50
OnnxSourceTask::OnnxSourceTask
(
const
std::vector<OnnxModelInfo>
&
model_infos
) : m_model_infos(
model_infos
) {}
51
59
template
<
typename
O>
60
static
std::unique_ptr<OnnxProperty::NdWrapperBase>
61
computePropertiesSpecialized
(
const
OnnxModel
& model,
const
DetectionFrameImages
&
detection_frame_images
,
62
const
PixelCentroid
&
centroid
) {
63
Ort::RunOptions
run_options
;
64
auto
mem_info
= Ort::MemoryInfo::CreateCpu(
OrtDeviceAllocator
,
OrtMemTypeCPU
);
65
66
const
int
center_x
=
static_cast<
int
>
(
centroid
.getCentroidX() + 0.5);
67
const
int
center_y
=
static_cast<
int
>
(
centroid
.getCentroidY() + 0.5);
68
69
// Allocate memory
70
std::vector<int64_t>
input_shape
(model.
getInputShape
().begin(), model.
getInputShape
().end());
71
input_shape
[0] = 1;
72
size_t
input_size
=
std::accumulate
(
input_shape
.begin(),
input_shape
.end(), 1u,
std::multiplies<size_t>
());
73
std::vector<float>
input_data
(
input_size
);
74
75
std::vector<int64_t>
output_shape
(model.
getOutputShape
().begin(), model.
getOutputShape
().end());
76
output_shape
[0] = 1;
77
size_t
output_size
=
std::accumulate
(
output_shape
.begin(),
output_shape
.end(), 1u,
std::multiplies<size_t>
());
78
std::vector<O>
output_data
(
output_size
);
79
80
// Cut the needed area
81
{
82
const
auto
&
image
=
detection_frame_images
.getLockedImage(
LayerSubtractedImage
);
83
fillCutout
(*
image
,
center_x
,
center_y
,
input_shape
[2],
input_shape
[3],
input_data
);
84
}
85
86
model.
run
<
float
,
O
>(
input_data
,
output_data
);
87
88
// Set the output
89
std::vector<size_t>
catalog_shape
{model.
getOutputShape
().begin() + 1, model.
getOutputShape
().end()};
90
return
Euclid::make_unique<OnnxProperty::NdWrapper<O>>(
catalog_shape
,
output_data
);
91
}
92
93
void
OnnxSourceTask::computeProperties
(
SourceXtractor::SourceInterface
&
source
)
const
{
94
const
auto
&
detection_frame_images
=
source
.getProperty<
DetectionFrameImages
>();
95
const
auto
&
centroid
=
source
.getProperty<
PixelCentroid
>();
96
97
std::map<std::string, std::unique_ptr<OnnxProperty::NdWrapperBase>
>
output_dict
;
98
99
for
(
const
auto
&
model_info
:
m_model_infos
) {
100
std::unique_ptr<OnnxProperty::NdWrapperBase>
result
;
101
102
switch
(
model_info
.model->getOutputType()) {
103
case
ONNX_TENSOR_ELEMENT_DATA_TYPE_FLOAT
:
104
result
=
computePropertiesSpecialized<float>
(*
model_info
.model,
detection_frame_images
,
centroid
);
105
break
;
106
case
ONNX_TENSOR_ELEMENT_DATA_TYPE_INT32
:
107
result
=
computePropertiesSpecialized<int32_t>
(*
model_info
.model,
detection_frame_images
,
centroid
);
108
break
;
109
default
:
110
throw
Elements::Exception
() <<
"This should have not happened!"
<<
model_info
.model->getOutputType();
111
}
112
113
output_dict
.emplace(
model_info
.prop_name,
std::move
(
result
));
114
}
115
116
source
.setProperty<
OnnxProperty
>(
std::move
(
output_dict
));
117
}
118
119
}
// end of namespace SourceXtractor
DetectionFrameImages.h
OnnxProperty.h
OnnxSourceTask.h
PixelCentroid.h
std::accumulate
T accumulate(T... args)
Elements::Exception
SourceXtractor::DetectionFrameImages
Definition
DetectionFrameImages.h:30
SourceXtractor::OnnxModel
Definition
OnnxModel.h:23
SourceXtractor::OnnxModel::run
void run(std::vector< T > &input_data, std::vector< U > &output_data) const
Definition
OnnxModel.h:29
SourceXtractor::OnnxModel::getOutputShape
const std::vector< std::int64_t > & getOutputShape() const
Definition
OnnxModel.h:120
SourceXtractor::OnnxModel::getInputShape
const std::vector< std::int64_t > & getInputShape() const
Definition
OnnxModel.h:116
SourceXtractor::OnnxProperty
Definition
OnnxProperty.h:30
SourceXtractor::OnnxSourceTask::computeProperties
void computeProperties(SourceInterface &source) const override
Computes one or more properties for the Source.
Definition
OnnxSourceTask.cpp:93
SourceXtractor::OnnxSourceTask::OnnxSourceTask
OnnxSourceTask(const std::vector< OnnxModelInfo > &model_infos)
Definition
OnnxSourceTask.cpp:50
SourceXtractor::OnnxSourceTask::m_model_infos
const std::vector< OnnxModelInfo > & m_model_infos
Definition
OnnxSourceTask.h:53
SourceXtractor::PixelCentroid
The centroid of all the pixels in the source, weighted by their DetectionImage pixel values.
Definition
PixelCentroid.h:37
SourceXtractor::SourceInterface
The SourceInterface is an abstract "source" that has properties attached to it.
Definition
SourceInterface.h:46
std::function
std::move
T move(T... args)
Euclid::NdArray
SourceXtractor
Definition
Aperture.h:30
SourceXtractor::LayerSubtractedImage
@ LayerSubtractedImage
Definition
Frame.h:39
SourceXtractor::fillCutout
static void fillCutout(const Image< T > &image, int center_x, int center_y, int width, int height, std::vector< T > &out)
Definition
OnnxSourceTask.cpp:32
SourceXtractor::computePropertiesSpecialized
static std::unique_ptr< OnnxProperty::NdWrapperBase > computePropertiesSpecialized(const OnnxModel &model, const DetectionFrameImages &detection_frame_images, const PixelCentroid ¢roid)
Definition
OnnxSourceTask.cpp:61
Generated by
1.10.0