SourceXtractorPlusPlus 0.19.2
SourceXtractor++, the next generation SExtractor
Loading...
Searching...
No Matches
SegmentationConfig.cpp
Go to the documentation of this file.
1
23#include <iostream>
24#include <fstream>
25
26#include <boost/regex.hpp>
27#include <boost/algorithm/string.hpp>
28
35
36using boost::regex;
37using boost::regex_match;
38using boost::smatch;
39
40using namespace Euclid::Configuration;
41namespace po = boost::program_options;
42
43namespace SourceXtractor {
44
46
47static const std::string SEGMENTATION_ALGORITHM {"segmentation-algorithm" };
48static const std::string SEGMENTATION_USE_FILTERING {"segmentation-use-filtering" };
49static const std::string SEGMENTATION_FILTER {"segmentation-filter" };
50static const std::string SEGMENTATION_LUTZ_WINDOW_SIZE {"segmentation-lutz-window-size" };
51static const std::string SEGMENTATION_BFS_MAX_DELTA {"segmentation-bfs-max-delta" };
52static const std::string SEGMENTATION_ML_MODEL {"segmentation-ml-model" };
53static const std::string SEGMENTATION_ML_THRESHOLD {"segmentation-ml-threshold" };
54
55SegmentationConfig::SegmentationConfig(long manager_id) : Configuration(manager_id), m_selected_algorithm(Algorithm::UNKNOWN)
56 , m_lutz_window_size(0)
57 , m_bfs_max_delta(1000)
58 , m_ml_threshold(0.9) {}
59
61 return { {"Detection image", {
62 {SEGMENTATION_ALGORITHM.c_str(), po::value<std::string>()->default_value("LUTZ"),
63 "Segmentation algorithm to be used (LUTZ, TILES or ML (a ONNX-format model must be provided))"},
64 {SEGMENTATION_USE_FILTERING.c_str(), po::value<bool>()->default_value(true),
65 "Is filtering used"},
66 {SEGMENTATION_FILTER.c_str(), po::value<std::string>()->default_value(""),
67 "Loads a filter"},
68 {SEGMENTATION_LUTZ_WINDOW_SIZE.c_str(), po::value<int>()->default_value(0),
69 "Lutz sliding window size (0=disable)"},
70 {SEGMENTATION_BFS_MAX_DELTA.c_str(), po::value<int>()->default_value(1000),
71 "BFS algorithm max source x/y size (default=1000)"},
72 {SEGMENTATION_ML_MODEL.c_str(), po::value<std::string>()->default_value(""),
73 "ONNX model to use with machine learning segmentation"},
74 {SEGMENTATION_ML_THRESHOLD.c_str(), po::value<double>()->default_value(0.9),
75 "Probability threshold for ML detection"},
76 }}};
77}
78
80 auto algorithm_name = boost::to_upper_copy(args.at(SEGMENTATION_ALGORITHM).as<std::string>());
81 if (algorithm_name == "LUTZ") {
83 } else if (algorithm_name == "BFS") {
85 } else if (algorithm_name == "ML") {
86#ifdef WITH_ML_SEGMENTATION
88#else
89 throw Elements::Exception() << "SourceXtractor++ has not been compiled with ONNX support";
90#endif
91 } else {
92 throw Elements::Exception() << "Unknown segmentation algorithm : " << algorithm_name;
93 }
94
95 if (args.at(SEGMENTATION_USE_FILTERING).as<bool>()) {
96 auto filter_filename = args.at(SEGMENTATION_FILTER).as<std::string>();
97 if (filter_filename != "") {
98 m_filter = loadFilter(filter_filename);
99 if (m_filter == nullptr)
100 throw Elements::Exception() << "Can not load filter: " << filter_filename;
101 } else {
103 }
104 } else {
105 m_filter = nullptr;
106 }
107
111 m_ml_threshold = args.at(SEGMENTATION_ML_THRESHOLD).as<double>();
112
114 throw Elements::Exception() << "Machine learning segmentation requested but no ONNX model was provided";
115 }
116}
117
119 segConfigLogger.info() << "Using the default segmentation (3x3) filter.";
120 auto convolution_kernel = VectorImage<SeFloat>::create(3, 3);
121 convolution_kernel->setValue(0,0, 1);
122 convolution_kernel->setValue(0,1, 2);
123 convolution_kernel->setValue(0,2, 1);
124
125 convolution_kernel->setValue(1,0, 2);
126 convolution_kernel->setValue(1,1, 4);
127 convolution_kernel->setValue(1,2, 2);
128
129 convolution_kernel->setValue(2,0, 1);
130 convolution_kernel->setValue(2,1, 2);
131 convolution_kernel->setValue(2,2, 1);
132
133 return std::make_shared<BackgroundConvolution>(convolution_kernel, true);
134}
135
137 // check for the extension ".fits"
138 std::string fits_ending(".fits");
139 if (filename.length() >= fits_ending.length()
140 && filename.compare (filename.length() - fits_ending.length(), fits_ending.length(), fits_ending)==0) {
141 // load a FITS filter
142 return loadFITSFilter(filename);
143 }
144 else{
145 // load an ASCII filter
146 return loadASCIIFilter(filename);
147 }
148}
149
151
152 // read in the FITS file
153 auto convolution_kernel = FitsReader<SeFloat>::readFile(filename);
154
155 // give some feedback on the filter
156 segConfigLogger.info() << "Loaded segmentation filter: " << filename << " height: " << convolution_kernel->getHeight() << " width: " << convolution_kernel->getWidth();
157
158 // return the correct object
159 return std::make_shared<BackgroundConvolution>(convolution_kernel, true);
160}
161
162static bool getNormalization(std::istream& line_stream) {
163 std::string conv, norm_type;
164 line_stream >> conv >> norm_type;
165 if (conv != "CONV") {
166 throw Elements::Exception() << "Unexpected start for ASCII filter: " << conv;
167 }
168 if (norm_type == "NORM") {
169 return true;
170 }
171 else if (norm_type == "NONORM") {
172 return false;
173 }
174
175 throw Elements::Exception() << "Unexpected normalization type: " << norm_type;
176}
177
178template <typename T>
179static void extractValues(std::istream& line_stream, std::vector<T>& data) {
180 T value;
181 while (line_stream.good()) {
182 line_stream >> value;
183 data.push_back(value);
184 }
185}
186
188 std::ifstream file;
189
190 // open the file and check
191 file.open(filename);
192 if (!file.good() || !file.is_open()){
193 throw Elements::Exception() << "Can not load filter: " << filename;
194 }
195
196 enum class LoadState {
197 STATE_START,
198 STATE_FIRST_LINE,
199 STATE_OTHER_LINES
200 };
201
202 LoadState state = LoadState::STATE_START;
203 bool normalize = false;
204 std::vector<SeFloat> kernel_data;
205 size_t kernel_width = 0;
206
207 while (file.good()) {
208 std::string line;
209 std::getline(file, line);
210 line = regex_replace(line, regex("\\s*#.*"), std::string(""));
211 line = regex_replace(line, regex("\\s*$"), std::string(""));
212 if (line.size() == 0) {
213 continue;
214 }
215
216 std::stringstream line_stream(line);
217
218 switch (state) {
219 case LoadState::STATE_START:
220 normalize = getNormalization(line_stream);
221 state = LoadState::STATE_FIRST_LINE;
222 break;
223 case LoadState::STATE_FIRST_LINE:
224 extractValues(line_stream, kernel_data);
225 kernel_width = kernel_data.size();
226 state = LoadState::STATE_OTHER_LINES;
227 break;
228 case LoadState::STATE_OTHER_LINES:
229 extractValues(line_stream, kernel_data);
230 break;
231 }
232 }
233
234 // compute the dimensions and create the kernel
235 if (kernel_width == 0) {
236 throw Elements::Exception() << "Malformed segmentation filter: width is 0";
237 }
238 auto kernel_height = kernel_data.size() / kernel_width;
239 auto convolution_kernel = VectorImage<SeFloat>::create(kernel_width, kernel_height, kernel_data);
240
241 // give some feedback on the filter
242 segConfigLogger.info() << "Loaded segmentation filter: " << filename << " width: " << convolution_kernel->getWidth() << " height: " << convolution_kernel->getHeight();
243
244 // return the correct object
245 return std::make_shared<BackgroundConvolution>(convolution_kernel, normalize);
246}
247
248} // SourceXtractor namespace
T at(T... args)
T c_str(T... args)
static Logging getLogger(const std::string &name="")
void info(const std::string &logMessage)
static std::shared_ptr< Image< T > > readFile(const std::string &filename)
Definition FitsReader.h:46
std::shared_ptr< DetectionImageFrame::ImageFilter > m_filter
SegmentationConfig(long manager_id)
Constructs a new SegmentationConfig object.
std::map< std::string, Configuration::OptionDescriptionList > getProgramOptions() override
void preInitialize(const UserValues &args) override
std::shared_ptr< DetectionImageFrame::ImageFilter > getDefaultFilter() const
std::shared_ptr< DetectionImageFrame::ImageFilter > loadFITSFilter(const std::string &filename) const
std::shared_ptr< DetectionImageFrame::ImageFilter > loadASCIIFilter(const std::string &filename) const
std::shared_ptr< DetectionImageFrame::ImageFilter > loadFilter(const std::string &filename) const
static std::shared_ptr< VectorImage< T > > create(Args &&... args)
T getline(T... args)
T good(T... args)
T is_open(T... args)
static void extractValues(std::istream &line_stream, std::vector< T > &data)
static const std::string SEGMENTATION_ML_THRESHOLD
static const std::string SEGMENTATION_USE_FILTERING
static const std::string SEGMENTATION_ALGORITHM
static const std::string SEGMENTATION_FILTER
static bool getNormalization(std::istream &line_stream)
static const std::string SEGMENTATION_LUTZ_WINDOW_SIZE
static const std::string SEGMENTATION_ML_MODEL
static Elements::Logging segConfigLogger
static const std::string SEGMENTATION_BFS_MAX_DELTA
T open(T... args)
T push_back(T... args)
T regex_replace(T... args)
T length(T... args)