# Copyright (C) 2023 ISIS Rutherford Appleton Laboratory UKRI
# SPDX - License - Identifier: GPL-3.0-or-later
from __future__ import annotations
import traceback
from enum import Enum, auto
from functools import partial
from logging import getLogger
from typing import TYPE_CHECKING, Any, Dict, List, Optional, Callable, Set
import numpy as np
from PyQt5.QtWidgets import QWidget
from mantidimaging.core.data import ImageStack
from mantidimaging.core.utility.data_containers import ScalarCoR, Degrees
from mantidimaging.gui.dialogs.async_task import start_async_task_view, TaskWorkerThread
from mantidimaging.gui.dialogs.cor_inspection.view import CORInspectionDialogView
from mantidimaging.gui.mvp_base import BasePresenter
from mantidimaging.gui.utility.qt_helpers import BlockQtSignals
from mantidimaging.gui.windows.recon.model import ReconstructWindowModel
LOG = getLogger(__name__)
if TYPE_CHECKING:
from mantidimaging.gui.windows.recon.view import ReconstructWindowView # pragma: no cover
from mantidimaging.gui.windows.main import MainWindowView
[docs]
class AutoCorMethod(Enum):
CORRELATION = auto()
MINIMISATION_SQUARE_SUM = auto()
[docs]
class Notifications(Enum):
RECONSTRUCT_VOLUME = auto()
RECONSTRUCT_PREVIEW_SLICE = auto()
RECONSTRUCT_PREVIEW_USER_CLICK = auto()
RECONSTRUCT_STACK_SLICE = auto()
RECONSTRUCT_USER_CLICK = auto()
COR_FIT = auto()
CLEAR_ALL_CORS = auto()
REMOVE_SELECTED_COR = auto()
CALCULATE_CORS_FROM_MANUAL_TILT = auto()
ALGORITHM_CHANGED = auto()
UPDATE_PROJECTION = auto()
ADD_COR = auto()
REFINE_COR = auto()
REFINE_ITERS = auto()
AUTO_FIND_COR_CORRELATE = auto()
AUTO_FIND_COR_MINIMISE = auto()
[docs]
class ReconstructWindowPresenter(BasePresenter):
ERROR_STRING = "COR/Tilt finding failed: {}"
view: 'ReconstructWindowView'
def __init__(self, view: 'ReconstructWindowView', main_window: 'MainWindowView'):
super().__init__(view)
self.view = view
self.model = ReconstructWindowModel(self.view.cor_table_model)
self.allowed_recon_kwargs: Dict[str, List[str]] = self.model.load_allowed_recon_kwargs()
self.restricted_arg_widgets: Dict[str, List[QWidget]] = {
'filter_name': [self.view.filterName, self.view.filterNameLabel],
'num_iter': [self.view.numIter, self.view.numIterLabel],
'alpha': [self.view.alphaSpinBox, self.view.alphaLabel],
'non_negative': [self.view.nonNegativeCheckBox, self.view.nonNegativeLabel],
'stochastic': [self.view.stochasticCheckBox, self.view.stochasticLabel],
'projections_per_subset': [self.view.subsetsSpinBox, self.view.subsetsLabel],
'regularisation_percent': [self.view.regPercentLabel, self.view.regPercentSpinBox],
}
self.main_window = main_window
self.recon_is_running = False
self.async_tracker: Set[Any] = set()
self.main_window.stack_changed.connect(self.handle_stack_changed)
self.stack_changed_pending = False
self.stack_selection_change_pending = False
[docs]
def notify(self, notification, slice_idx=None):
try:
if notification == Notifications.RECONSTRUCT_VOLUME:
self.do_reconstruct_volume()
elif notification == Notifications.RECONSTRUCT_PREVIEW_SLICE:
self.do_preview_reconstruct_slice()
elif notification == Notifications.RECONSTRUCT_PREVIEW_USER_CLICK:
self.do_preview_reconstruct_slice(force_update=True)
elif notification == Notifications.RECONSTRUCT_STACK_SLICE:
self.do_stack_reconstruct_slice()
elif notification == Notifications.RECONSTRUCT_USER_CLICK:
self.do_preview_reconstruct_slice(slice_idx=slice_idx)
elif notification == Notifications.COR_FIT:
self.do_cor_fit()
elif notification == Notifications.CLEAR_ALL_CORS:
self.do_clear_all_cors()
elif notification == Notifications.REMOVE_SELECTED_COR:
self.do_remove_selected_cor()
elif notification == Notifications.CALCULATE_CORS_FROM_MANUAL_TILT:
self.do_calculate_cors_from_manual_tilt()
elif notification == Notifications.ALGORITHM_CHANGED:
self.do_algorithm_changed()
elif notification == Notifications.UPDATE_PROJECTION:
self.do_update_projection()
elif notification == Notifications.ADD_COR:
self.do_add_cor()
elif notification == Notifications.REFINE_COR:
self._do_refine_selected_cor()
elif notification == Notifications.REFINE_ITERS:
self._do_refine_iterations()
elif notification == Notifications.AUTO_FIND_COR_CORRELATE:
self._auto_find_correlation()
elif notification == Notifications.AUTO_FIND_COR_MINIMISE:
self._auto_find_minimisation_square_sum()
except Exception as err:
self.show_error(err, traceback.format_exc())
[docs]
def do_algorithm_changed(self):
alg_name = self.view.algorithm_name
allowed_args = self.allowed_recon_kwargs[alg_name]
for arg, widgets in self.restricted_arg_widgets.items():
if arg in allowed_args:
for widget in widgets:
widget.show()
else:
for widget in widgets:
widget.hide()
with BlockQtSignals([self.view.filterName, self.view.numIter]):
self.view.set_filters_for_recon_tool(self.model.get_allowed_filters(alg_name))
self.do_preview_reconstruct_slice()
self.view.change_refine_iterations()
[docs]
def set_stack_uuid(self, uuid):
if not self.view.isVisible():
self.stack_selection_change_pending = True
return
images = self.view.get_stack(uuid)
if self.model.is_current_stack(uuid):
return
self.view.reset_recon_and_sino_previews()
self.view.clear_cor_table()
self.model.initial_select_data(images)
self.view.rotation_centre = self.model.last_cor.value
self.view.pixel_size = self.get_pixel_size_from_images()
self.do_update_projection()
self.view.update_recon_hist_needed = True
if images is None:
self.view.reset_recon_line_profile()
self.view.show_status_message("")
return
self._set_max_preview_indexes()
self.do_preview_reconstruct_slice(reset_roi=True)
self._do_nan_zero_negative_check()
def _set_max_preview_indexes(self):
images = self.model.images
if images is not None:
self.view.set_max_projection_index(images.num_projections - 1)
self.view.set_max_slice_index(images.height - 1)
[docs]
def set_preview_projection_idx(self, idx):
self.model.preview_projection_idx = idx
self.do_update_projection()
[docs]
def set_preview_slice_idx(self, idx):
self.model.preview_slice_idx = idx
self.do_update_projection()
self.do_preview_reconstruct_slice()
[docs]
def set_row(self, row):
self.model.selected_row = row
[docs]
def get_pixel_size_from_images(self):
if self.model.images is not None and self.model.images.pixel_size is not None:
return self.model.images.pixel_size
else:
return 0.
[docs]
def do_update_projection(self):
images = self.model.images
if images is None:
self.view.reset_projection_preview()
return
img_data = images.projection(self.model.preview_projection_idx)
self.view.update_projection(img_data, self.model.preview_slice_idx, self.model.tilt_angle)
[docs]
def handle_stack_changed(self):
if self.view.isVisible():
self.model.reset_cor_model()
self.do_update_projection()
self._set_max_preview_indexes()
self.do_preview_reconstruct_slice(reset_roi=True)
else:
self.stack_changed_pending = True
def _find_next_free_slice_index(self) -> int:
slice_index = self.model.preview_slice_idx
max_slice = self.model.images.height
column = self.view.cor_table_model.getColumn(0)
for index in range(slice_index + 1, max_slice):
if index not in column:
return index
for index in range(0, slice_index):
if index not in column:
return index
raise RuntimeError("No free slice indexes to add to the COR Table")
[docs]
def do_add_cor(self):
row = self.model.selected_row
cor = self.model.get_me_a_cor()
slice_index = self._find_next_free_slice_index()
self.view.add_cor_table_row(row, slice_index, cor.value)
[docs]
def do_reconstruct_volume(self):
if not self.model.has_results:
raise ValueError("Fit is not performed on the data, therefore the CoR cannot be found for each slice.")
self.recon_is_running = True
self.view.set_recon_buttons_enabled(False)
start_async_task_view(self.view,
self.model.run_full_recon,
self._on_volume_recon_done, {'recon_params': self.view.recon_params()},
tracker=self.async_tracker)
def _get_reconstruct_slice(self, cor, slice_idx: int, call_back: Callable[[TaskWorkerThread], None]) -> None:
# If no COR is provided and there are regression results then calculate
# the COR for the selected preview slice
cor = self.model.get_me_a_cor(cor)
start_async_task_view(self.view,
self.model.run_preview_recon,
call_back, {
'slice_idx': slice_idx,
'cor': cor,
'recon_params': self.view.recon_params()
},
tracker=self.async_tracker)
def _get_slice_index(self, slice_idx: Optional[int]) -> int:
if slice_idx is None:
slice_idx = self.model.preview_slice_idx
else:
self.model.preview_slice_idx = slice_idx
return slice_idx
[docs]
def do_preview_reconstruct_slice(self,
cor=None,
slice_idx: Optional[int] = None,
force_update: bool = False,
reset_roi: bool = False):
if self.model.images is None:
self.view.reset_recon_and_sino_previews()
return
slice_idx = self._get_slice_index(slice_idx)
self.view.update_sinogram(self.model.images.sino(slice_idx))
if self.view.is_auto_update_preview() or force_update:
on_preview_complete = partial(self._on_preview_reconstruct_slice_done, reset_roi=reset_roi)
self._get_reconstruct_slice(cor, slice_idx, on_preview_complete)
def _on_preview_reconstruct_slice_done(self, task: TaskWorkerThread, reset_roi: bool = False):
if task.error is not None:
self.view.show_error_dialog(f"Encountered error while trying to reconstruct: {str(task.error)}")
return
images: ImageStack = task.result
if images is not None:
# We copy the preview data out of shared memory when passing it into update_recon_preview so that it
# will still be available after this function ends
self.view.update_recon_preview(np.copy(images.data[0]), reset_roi)
[docs]
def do_stack_reconstruct_slice(self, cor=None, slice_idx: Optional[int] = None):
self.view.set_recon_buttons_enabled(False)
slice_idx = self._get_slice_index(slice_idx)
self._get_reconstruct_slice(cor, slice_idx, self._on_stack_reconstruct_slice_done)
def _on_stack_reconstruct_slice_done(self, task: TaskWorkerThread):
if task.error is not None:
self.view.show_error_dialog(f"Encountered error while trying to reconstruct: {str(task.error)}")
self.view.set_recon_buttons_enabled(True)
return
try:
images: ImageStack = task.result
slice_idx = self._get_slice_index(None)
if images is not None:
assert self.model.images is not None
images.name = "Recon"
self._replace_inf_nan(images) # pyqtgraph workaround
self.view.show_recon_volume(images, self.model.stack_id)
images.record_operation('AstraRecon.single_sino',
'Slice Reconstruction',
slice_idx=slice_idx,
**self.view.recon_params().to_dict())
finally:
self.view.set_recon_buttons_enabled(True)
def _do_refine_selected_cor(self):
selected_rows = self.view.get_cor_table_selected_rows()
if len(selected_rows):
slice_idx = self.model.slices[selected_rows[0]]
else:
raise ValueError("No slice selected in COR table")
dialog = CORInspectionDialogView(self.view, self.model.images, slice_idx, self.model.last_cor,
self.view.recon_params(), False)
res = dialog.exec()
dialog.deleteLater()
LOG.debug('COR refine dialog result: {}'.format(res))
if res == CORInspectionDialogView.Accepted:
new_cor = dialog.optimal_rotation_centre
LOG.debug('New optimal rotation centre: {}'.format(new_cor))
self.model.data_model.set_cor_at_slice(slice_idx, new_cor.value)
self.model.last_cor = new_cor
# Update reconstruction preview with new COR
self.set_preview_slice_idx(slice_idx)
def _do_refine_iterations(self):
slice_idx = self.model.preview_slice_idx
dialog = CORInspectionDialogView(self.view, self.model.images, slice_idx, self.model.last_cor,
self.view.recon_params(), True)
res = dialog.exec()
LOG.debug('COR refine iteration result: {}'.format(res))
if res == CORInspectionDialogView.Accepted:
new_iters = dialog.optimal_iterations
LOG.debug('New optimal iterations: {}'.format(new_iters))
self.view.num_iter = new_iters
[docs]
def do_cor_fit(self):
self.model.do_fit()
self.view.set_results(*self.model.get_results())
self.do_update_projection()
self.do_preview_reconstruct_slice()
def _on_volume_recon_done(self, task):
self.recon_is_running = False
if task.error is not None:
self.view.show_error_dialog(f"Encountered error while trying to reconstruct: {str(task.error)}")
self.view.set_recon_buttons_enabled(True)
return
try:
self._replace_inf_nan(task.result) # pyqtgraph workaround
assert self.model.images is not None
task.result.name = "Recon"
self.view.show_recon_volume(task.result, self.model.stack_id)
finally:
self.view.set_recon_buttons_enabled(True)
[docs]
def do_clear_all_cors(self):
self.view.clear_cor_table()
self.model.reset_selected_row()
[docs]
def do_remove_selected_cor(self):
self.view.remove_selected_cor()
[docs]
def set_last_cor(self, cor):
self.model.last_cor = ScalarCoR(cor)
[docs]
def do_calculate_cors_from_manual_tilt(self):
cor = ScalarCoR(self.view.rotation_centre)
tilt = Degrees(self.view.tilt)
self._set_precalculated_cor_tilt(cor, tilt)
def _set_precalculated_cor_tilt(self, cor: ScalarCoR, tilt: Degrees):
self.model.set_precalculated(cor, tilt)
self.view.set_results(*self.model.get_results())
for idx, point in enumerate(self.model.data_model.iter_points()):
self.view.set_table_point(idx, point.slice_index, point.cor)
self.do_update_projection()
self.do_preview_reconstruct_slice()
def _auto_find_correlation(self):
if not self.model.images.has_proj180deg():
self.view.show_status_message("Unable to correlate 0 and 180 because the dataset doesn't have a 180 "
"projection set. Please load a 180 projection manually.")
return
self.recon_is_running = True
def completed(task: TaskWorkerThread):
if task.result is None and task.error is not None:
selected_stack = self.view.main_window.get_images_from_stack_uuid(self.view.stackSelector.current())
self.view.show_error_dialog(
f"Finding the COR failed, likely caused by the selected stack's 180 "
f"degree projection being a different shape. \n\n "
f"Error: {str(task.error)} "
f"\n\n Suggestion: Use crop coordinates to resize the 180 degree projection to "
f"({selected_stack.height}, {selected_stack.width})")
else:
cor, tilt = task.result
self._set_precalculated_cor_tilt(cor, tilt)
self.view.set_correlate_buttons_enabled(True)
self.recon_is_running = False
self.view.set_correlate_buttons_enabled(False)
start_async_task_view(self.view, self.model.auto_find_correlation, completed, tracker=self.async_tracker)
def _auto_find_minimisation_square_sum(self):
num_cors = self.view.get_number_of_cors()
if num_cors is None:
return
self.do_clear_all_cors()
selected_row, slice_indices = self.model.get_slice_indices(num_cors)
if self.model.has_results:
initial_cor = []
for slc in slice_indices:
initial_cor.append(self.model.data_model.get_cor_from_regression(slc))
else:
initial_cor = self.view.rotation_centre
def _completed_finding_cors(task: TaskWorkerThread):
if task.error is not None:
self.view.show_error_dialog(f"Finding the COR failed.\n\n Error: {str(task.error)}")
else:
cors = task.result
for slice_idx, cor in zip(slice_indices, cors, strict=True):
self.view.add_cor_table_row(selected_row, slice_idx, cor)
self.do_cor_fit()
self.view.set_correlate_buttons_enabled(True)
self.view.set_correlate_buttons_enabled(False)
start_async_task_view(self.view,
self.model.auto_find_minimisation_sqsum,
_completed_finding_cors, {
'slices': slice_indices,
'recon_params': self.view.recon_params(),
'initial_cor': initial_cor
},
tracker=self.async_tracker)
[docs]
def proj_180_degree_shape_matches_images(self, images):
return self.model.proj_180_degree_shape_matches_images(images)
def _do_nan_zero_negative_check(self):
"""
Checks if the data contains NaNs/zeroes and displays a message if they are found.
"""
msg_list = []
if self.model.stack_contains_nans():
msg_list.append("NaN(s) found in the stack.")
if self.model.stack_contains_zeroes():
msg_list.append("Zero(es) found in the stack.")
if self.model.stack_contains_negative_values():
msg_list.append("Negative value(s) found in the stack.")
if len(msg_list) == 0:
self.view.show_status_message("")
else:
msg_list.insert(0, "Warning:")
self.view.show_status_message(" ".join(msg_list))
@staticmethod
def _replace_inf_nan(images: ImageStack):
"""
Replaces infinity values in a data array with NaNs. Used because pyqtgraph has programs with arrays containing
inf.
:param images: The ImageStack object.
"""
images.data[np.isinf(images.data)] = np.nan