68 unsigned int max_iterations,
double modified_chi_squared_scale,
73 : m_least_squares_engine(least_squares_engine),
74 m_max_iterations(max_iterations), m_modified_chi_squared_scale(modified_chi_squared_scale),
75 m_parameters(parameters), m_frames(frames), m_priors(priors), m_scale_factor(scale_factor) {}
79 return stamp_rect.
getWidth() > 0 && stamp_rect.getHeight() > 0;
102 SeFloat saturation = frame_info.getSaturation();
107 for (
int y = 0;
y < rect.getHeight();
y++) {
108 for (
int x = 0;
x < rect.getWidth();
x++) {
109 auto back_var = variance_map->getValue(rect.getTopLeft().m_x +
x, rect.getTopLeft().m_y +
y);
110 auto pixel_val = frame_image->getValue(rect.getTopLeft().m_x +
x, rect.getTopLeft().m_y +
y);
111 if (saturation > 0 && pixel_val > saturation) {
112 weight->at(
x,
y) = 0;
114 else if (gain > 0.0 && pixel_val > 0.0) {
115 weight->at(
x,
y) =
sqrt(1.0 / (back_var + pixel_val / gain));
118 weight->at(
x,
y) =
sqrt(1.0 / back_var);
131 int frame_index = frame->getFrameNb();
133 auto frame_coordinates =
135 auto ref_coordinates =
142 if (psf_property.getPsf() ==
nullptr) {
143 throw Elements::Exception() <<
"Missing PSF. No PSF mode is not supported in legacy model fitting";
150 auto group_psf =
ImagePsf(
pixel_scale * psf_property.getPixelSampling(), psf_property.getPsf());
156 for (
auto& source : group) {
157 for (
auto model : frame->getModels()) {
158 model->addForSource(manager, source, constant_models, point_models, extended_models, jacobian, ref_coordinates, frame_coordinates,
159 stamp_rect.getTopLeft());
165 pixel_scale, (
size_t) stamp_rect.getWidth(), (
size_t) stamp_rect.getHeight(),
176 int n_free_parameters = 0;
179 for (
auto& source : group) {
181 if (std::dynamic_pointer_cast<FlexibleModelFittingFreeParameter>(parameter)) {
185 parameter->create(parameter_manager, engine_parameter_manager, source));
196 int valid_frames = 0;
197 int n_good_pixels = 0;
199 int frame_index = frame->getFrameNb();
209 for (
int y = 0;
y < weight->getHeight(); ++
y) {
210 for (
int x = 0;
x < weight->getWidth(); ++
x) {
211 n_good_pixels += (weight->at(
x,
y) != 0.);
220 res_estimator.registerBlockProvider(
std::move(data_vs_model));
226 if (valid_frames == 0) {
229 else if (n_good_pixels < n_free_parameters) {
239 for (
auto& source : group) {
241 prior->setupPrior(parameter_manager, source, res_estimator);
250 auto solution = engine->solveProblem(engine_parameter_manager, res_estimator);
251 auto iterations = solution.iteration_no;
252 auto stop_reason = solution.engine_stop_reason;
253 switch (solution.status_flag) {
264 int total_data_points = 0;
267 int nb_of_free_parameters = 0;
268 for (
auto& source : group) {
270 bool is_free_parameter = std::dynamic_pointer_cast<FlexibleModelFittingFreeParameter>(parameter).get();
271 bool accessed_by_modelfitting = parameter_manager.
isParamAccessed(source, parameter);
272 if (is_free_parameter && accessed_by_modelfitting) {
273 nb_of_free_parameters++;
277 avg_reduced_chi_squared /= (total_data_points - nb_of_free_parameters);
280 for (
auto& source : group) {
285 bool is_dependent_parameter = std::dynamic_pointer_cast<FlexibleModelFittingDependentParameter>(parameter).get();
286 bool is_constant_parameter = std::dynamic_pointer_cast<FlexibleModelFittingConstantParameter>(parameter).get();
287 bool accessed_by_modelfitting = parameter_manager.
isParamAccessed(source, parameter);
288 auto modelfitting_parameter = parameter_manager.
getParameter(source, parameter);
290 if (is_constant_parameter || is_dependent_parameter || accessed_by_modelfitting) {
291 parameter_values[parameter->getId()] = modelfitting_parameter->getValue();
292 parameter_sigmas[parameter->getId()] = parameter->getSigma(parameter_manager, source, solution.parameter_sigmas);
296 auto engine_parameter = std::dynamic_pointer_cast<EngineParameter>(modelfitting_parameter);
297 if (engine_parameter) {
306 avg_reduced_chi_squared, solution.duration,
source_flags,
307 parameter_values, parameter_sigmas,
315 logger.
error() <<
"An exception occured during model fitting: " <<
e.what();
324 for (
auto& source : group) {
327 auto modelfitting_parameter = parameter_manager.
getParameter(source, parameter);
328 auto manual_parameter = std::dynamic_pointer_cast<ManualParameter>(modelfitting_parameter);
329 if (manual_parameter) {
335 dummy_values, dummy_values,
345 int frame_index = frame->getFrameNb();
349 auto final_stamp = frame_model.getImage();
356 for (
int x = 0;
x < final_stamp->getWidth();
x++) {
357 for (
int y = 0;
y < final_stamp->getHeight();
y++) {
358 auto x_coord = stamp_rect.getTopLeft().m_x +
x;
359 auto y_coord = stamp_rect.getTopLeft().m_y +
y;
360 debug_image->setValue(x_coord, y_coord,
361 debugAccessor.
getValue(x_coord, y_coord) + final_stamp->getValue(
x,
y));
372 double reduced_chi_squared = 0.0;
378 for (
int y=0;
y < image->getHeight();
y++) {
379 for (
int x=0;
x < image->getWidth();
x++) {
380 double tmp = imageAccessor.getValue(
x,
y) - modelAccessor.
getValue(
x,
y);
381 reduced_chi_squared += tmp * tmp * weightAccessor.
getValue(
x,
y) * weightAccessor.
getValue(
x,
y);
387 return reduced_chi_squared;
394 total_data_points = 0;
395 int valid_frames = 0;
397 int frame_index = frame->getFrameNb();
402 auto final_stamp = frame_model.getImage();
410 image, final_stamp, weight, data_points);
412 total_data_points += data_points;
413 total_chi_squared += chi_squared;
417 return total_chi_squared;
std::shared_ptr< DependentParameter< std::shared_ptr< EngineParameter > > > x
std::shared_ptr< DependentParameter< std::shared_ptr< EngineParameter > > > y
void error(const std::string &logMessage)
static Logging getLogger(const std::string &name="")
Data vs model comparator which computes a modified residual, using asinh.
Class responsible for managing the parameters the least square engine minimizes.
static std::shared_ptr< LeastSquareEngine > create(const std::string &name, unsigned max_iterations=1000)
Provides to the LeastSquareEngine the residual values.
static Elements::Logging logger
std::unique_ptr< DataVsModelResiduals< typename std::remove_reference< DataType >::type, typename std::remove_reference< ModelType >::type, typename std::remove_reference< WeightType >::type, typename std::remove_reference< Comparator >::type > > createDataVsModelResiduals(DataType &&data, ModelType &&model, WeightType &&weight, Comparator &&comparator)