# Copyright (C) 2022 ISIS Rutherford Appleton Laboratory UKRI
# SPDX - License - Identifier: GPL-3.0-or-later
from logging import getLogger
from threading import Lock
from typing import List, Optional
import numpy as np
from cil.framework import AcquisitionData, AcquisitionGeometry, DataOrder, ImageGeometry
from cil.optimisation.algorithms import PDHG
from cil.optimisation.operators import GradientOperator, BlockOperator
from cil.optimisation.functions import MixedL21Norm, L2NormSquared, BlockFunction, ZeroFunction, IndicatorBox
from cil.plugins.astra.operators import ProjectionOperator
from mantidimaging.core.data import ImageStack
from mantidimaging.core.reconstruct.base_recon import BaseRecon
from mantidimaging.core.utility.data_containers import ProjectionAngles, ReconstructionParameters, ScalarCoR
from mantidimaging.core.utility.optional_imports import safe_import
from mantidimaging.core.utility.progress_reporting import Progress
from mantidimaging.core.utility.size_calculator import full_size_KB
from mantidimaging.core.utility.memory_usage import system_free_memory
LOG = getLogger(__name__)
tomopy = safe_import('tomopy')
cil_mutex = Lock()
[docs]
class CILRecon(BaseRecon):
[docs]
@staticmethod
def set_up_TV_regularisation(image_geometry: ImageGeometry, acquisition_data: AcquisitionData,
recon_params: ReconstructionParameters):
# Forward operator
A2d = ProjectionOperator(image_geometry, acquisition_data.geometry, 'gpu')
# Define Gradient Operator and BlockOperator
alpha = recon_params.alpha
Grad = GradientOperator(image_geometry)
K = BlockOperator(alpha * Grad, A2d)
# Define BlockFunction F using the MixedL21Norm() and the L2NormSquared()
f1 = MixedL21Norm()
f2 = L2NormSquared(b=acquisition_data)
if recon_params.non_negative:
G = IndicatorBox(lower=0)
else:
# Define Function G simply as zero
G = ZeroFunction()
return (K, f1, f2, G)
[docs]
@staticmethod
def find_cor(images: ImageStack, slice_idx: int, start_cor: float, recon_params: ReconstructionParameters) -> float:
return tomopy.find_center(images.sinograms,
images.projection_angles(recon_params.max_projection_angle).value,
ind=slice_idx,
init=start_cor,
sinogram_order=True)
[docs]
@staticmethod
def single_sino(sino: np.ndarray,
cor: ScalarCoR,
proj_angles: ProjectionAngles,
recon_params: ReconstructionParameters,
progress: Optional[Progress] = None):
"""
Reconstruct a single slice from a single sinogram. Used for the preview and the single slice button.
Should return a numpy array,
"""
if progress:
progress.add_estimated_steps(recon_params.num_iter + 1)
progress.update(steps=1, msg='CIL: Setting up reconstruction', force_continue=False)
if cil_mutex.locked():
LOG.warning("CIL recon already in progress")
with cil_mutex:
sino = BaseRecon.prepare_sinogram(sino, recon_params)
pixel_num_h = sino.shape[1]
pixel_size = 1.
rot_pos_x = (cor.value - pixel_num_h / 2) * pixel_size
ag = AcquisitionGeometry.create_Parallel2D(rotation_axis_position=[rot_pos_x, 0])
ag.set_panel(pixel_num_h, pixel_size=pixel_size)
ag.set_labels(DataOrder.ASTRA_AG_LABELS)
ag.set_angles(angles=proj_angles.value, angle_unit='radian')
data = ag.allocate(None)
data.fill(sino)
ig = ag.get_ImageGeometry()
K, f1, f2, G = CILRecon.set_up_TV_regularisation(ig, data, recon_params)
F = BlockFunction(f1, f2)
normK = K.norm()
sigma = 1
tau = 1 / (sigma * normK**2)
pdhg = PDHG(f=F, g=G, operator=K, tau=tau, sigma=sigma, max_iteration=100000, update_objective_interval=10)
try:
for iter in range(recon_params.num_iter):
if progress:
progress.update(steps=1,
msg=f'CIL: Iteration {iter + 1} of {recon_params.num_iter}'
f': Objective {pdhg.get_last_objective():.2f}',
force_continue=False)
pdhg.next()
finally:
if progress:
progress.mark_complete()
return pdhg.solution.as_array()
[docs]
@staticmethod
def full(images: ImageStack,
cors: List[ScalarCoR],
recon_params: ReconstructionParameters,
progress: Optional[Progress] = None):
"""
Performs a volume reconstruction using sample data provided as sinograms.
:param images: Array of sinogram images
:param cors: Array of centre of rotation values
:param proj_angles: Array of projection angles in radians
:param recon_params: Reconstruction Parameters
:param progress: Optional progress reporter
:return: 3D image data for reconstructed volume
"""
progress = Progress.ensure_instance(progress,
task_name='CIL reconstruction',
num_steps=recon_params.num_iter + 1)
shape = images.data.shape
if images.is_sinograms:
data_order = DataOrder.ASTRA_AG_LABELS
pixel_num_h, pixel_num_v = shape[2], shape[0]
else:
data_order = DataOrder.TIGRE_AG_LABELS
pixel_num_h, pixel_num_v = shape[2], shape[1]
projection_size = full_size_KB(images.data.shape, images.dtype)
recon_volume_shape = pixel_num_h, pixel_num_h, pixel_num_v
recon_volume_size = full_size_KB(recon_volume_shape, images.dtype)
estimated_mem_required = 5 * projection_size + 13 * recon_volume_size
free_mem = system_free_memory().kb()
if (estimated_mem_required > free_mem):
estimate_gb = estimated_mem_required / 1024 / 1024
raise RuntimeError(
"The machine does not have enough physical memory available to allocate space for this data."
f" Estimated RAM needed is {estimate_gb:.2f} GB")
if cil_mutex.locked():
LOG.warning("CIL recon already in progress")
with cil_mutex:
LOG.info(f"Starting 3D PDHG-TV reconstruction: input shape {images.data.shape}"
f"output shape {recon_volume_shape}\n"
f"Num iter {recon_params.num_iter}, alpha {recon_params.alpha}, "
f"Non-negative {recon_params.non_negative}")
progress.update(steps=1, msg='CIL: Setting up reconstruction', force_continue=False)
angles = images.projection_angles(recon_params.max_projection_angle).value
pixel_size = 1.
if recon_params.tilt is None:
raise ValueError("recon_params.tilt is not set")
rot_pos = [(cors[pixel_num_v // 2].value - pixel_num_h / 2) * pixel_size, 0, 0]
slope = -np.tan(np.deg2rad(recon_params.tilt.value))
rot_angle = [slope, 0, 1]
ag = AcquisitionGeometry.create_Parallel3D(rotation_axis_position=rot_pos,
rotation_axis_direction=rot_angle)
ag.set_panel([pixel_num_h, pixel_num_v], pixel_size=(pixel_size, pixel_size))
ag.set_angles(angles=angles, angle_unit='radian')
ag.set_labels(data_order)
data = ag.allocate(None)
data.fill(BaseRecon.prepare_sinogram(images.data, recon_params))
data.reorder('astra')
ig = ag.get_ImageGeometry()
K, f1, f2, G = CILRecon.set_up_TV_regularisation(ig, data, recon_params)
F = BlockFunction(f1, f2)
normK = K.norm()
sigma = 1
tau = 1 / (sigma * normK**2)
pdhg = PDHG(f=F, g=G, operator=K, tau=tau, sigma=sigma, max_iteration=100000, update_objective_interval=10)
with progress:
for iter in range(recon_params.num_iter):
progress.update(steps=1,
msg=f'CIL: Iteration {iter+1} of {recon_params.num_iter}:'
f'Objective {pdhg.get_last_objective():.2f}',
force_continue=False)
pdhg.next()
volume = pdhg.solution.as_array()
LOG.info('Reconstructed 3D volume with shape: {0}'.format(volume.shape))
return ImageStack(volume)
[docs]
def allowed_recon_kwargs() -> dict:
return {'CIL: PDHG-TV': ['alpha', 'num_iter', 'non_negative']}