Source code for mantidimaging.core.reconstruct.astra_recon

# Copyright (C) 2024 ISIS Rutherford Appleton Laboratory UKRI
# SPDX - License - Identifier: GPL-3.0-or-later
from __future__ import annotations

from contextlib import contextmanager
from logging import getLogger
from threading import Lock
from collections.abc import Generator

import astra
import numpy as np
from scipy.optimize import minimize

from mantidimaging.core.data import ImageStack
from mantidimaging.core.reconstruct.base_recon import BaseRecon
from mantidimaging.core.utility.cuda_check import CudaChecker
from mantidimaging.core.utility.data_containers import ScalarCoR, ProjectionAngles, ReconstructionParameters
from mantidimaging.core.utility.progress_reporting import Progress

LOG = getLogger(__name__)
astra_mutex = Lock()


# Full credit for following code to Daniil Kazantzev
# Source:
# https://github.com/dkazanc/ToMoBAR/blob/5990aaa264e2f08bd9b0069c8847e5021fbf2ee2/src/Python/tomobar/supp/astraOP.py#L20-L70
[docs] def rotation_matrix2d(theta: float) -> np.ndarray: return np.array([[np.cos(theta), -np.sin(theta)], [np.sin(theta), np.cos(theta)]])
[docs] def vec_geom_init2d(angles_rad: ProjectionAngles, detector_spacing_x: float, center_rot_offset: float) -> np.ndarray: angles_value = angles_rad.value s0 = [0.0, -1.0] # source u0 = [detector_spacing_x, 0.0] # detector coordinates vectors = np.zeros([angles_value.size, 6]) for i, theta in enumerate(angles_value): d0 = [center_rot_offset, 0.0] # detector vectors[i, 0:2] = np.dot(rotation_matrix2d(theta), s0)[:] # ray position vectors[i, 2:4] = np.dot(rotation_matrix2d(theta), d0)[:] # center of detector position vectors[i, 4:6] = np.dot(rotation_matrix2d(theta), u0)[:] # detector pixel (0,0) to (0,1). return vectors
@contextmanager def _managed_recon(sino: np.ndarray, cfg, proj_geom, vol_geom) -> Generator[tuple[int, int], None, None]: proj_id = None sino_id = None rec_id = None alg_id = None try: proj_type = 'cuda' if CudaChecker().cuda_is_present() else 'line' LOG.debug(f"Using projection type {proj_type}") proj_id = astra.create_projector(proj_type, proj_geom, vol_geom) sino_id = astra.data2d.create('-sino', proj_geom, sino) rec_id = astra.data2d.create('-vol', vol_geom) cfg['ReconstructionDataId'] = rec_id cfg['ProjectionDataId'] = sino_id cfg['ProjectorId'] = proj_id alg_id = astra.algorithm.create(cfg) yield alg_id, rec_id finally: if alg_id: astra.algorithm.delete(alg_id) if proj_id: astra.projector.delete(proj_id) if sino_id: astra.data2d.delete(sino_id) if rec_id: astra.data2d.delete(rec_id)
[docs] class AstraRecon(BaseRecon): @staticmethod def _count_gpus() -> int: num_gpus = 0 msg = '' while "Invalid device" not in msg: num_gpus += 1 msg = astra.get_gpu_info(num_gpus) return num_gpus
[docs] @staticmethod def find_cor(images: ImageStack, slice_idx: int, start_cor: float | np.ndarray, recon_params: ReconstructionParameters) -> float: """ Find the best CoR for this slice by maximising the squared sum of the reconstructed slice. Larger squared sum -> bigger deviance from the mean, i.e. larger distance between noise and data """ proj_angles = images.projection_angles(recon_params.max_projection_angle) def get_sumsq(image: np.ndarray) -> float: return np.sum(image**2) def minimizer_function(cor: float | np.ndarray) -> float: if isinstance(cor, np.ndarray): cor = float(cor[0]) return -get_sumsq(AstraRecon.single_sino(images.sino(slice_idx), ScalarCoR(cor), proj_angles, recon_params)) return minimize(minimizer_function, start_cor, method='nelder-mead', tol=0.1).x[0]
[docs] @staticmethod def single_sino(sino: np.ndarray, cor: ScalarCoR, proj_angles: ProjectionAngles, recon_params: ReconstructionParameters, progress: Progress | None = None) -> np.ndarray: assert sino.ndim == 2, "Sinogram must be a 2D image" sino = BaseRecon.prepare_sinogram(sino, recon_params) image_width = sino.shape[1] if astra_mutex.locked(): LOG.warning("Astra recon already in progress. Waiting") with astra_mutex: vectors = vec_geom_init2d(proj_angles, 1.0, cor.to_vec(image_width).value) vol_geom = astra.create_vol_geom((image_width, image_width)) proj_geom = astra.create_proj_geom('parallel_vec', image_width, vectors) cfg = astra.astra_dict(recon_params.algorithm) cfg['FilterType'] = recon_params.filter_name with _managed_recon(sino, cfg, proj_geom, vol_geom) as (alg_id, rec_id): astra.algorithm.run(alg_id, iterations=recon_params.num_iter) return astra.data2d.get(rec_id)
[docs] @staticmethod def full(images: ImageStack, cors: list[ScalarCoR], recon_params: ReconstructionParameters, progress: Progress | None = None) -> ImageStack: progress = Progress.ensure_instance(progress, num_steps=images.height) output_shape = (images.num_sinograms, images.width, images.width) output_images: ImageStack = ImageStack.create_empty_image_stack(output_shape, images.dtype, images.metadata) output_images.record_operation('AstraRecon.full', 'Volume Reconstruction', **recon_params.to_dict()) proj_angles = images.projection_angles(recon_params.max_projection_angle) for i in range(images.height): output_images.data[i] = AstraRecon.single_sino(images.sino(i), cors[i], proj_angles, recon_params) progress.update(1, "Reconstructed slice") return output_images
[docs] @staticmethod def allowed_filters() -> list[str]: # removed from list: 'kaiser' as it hard crashes ASTRA # 'projection', 'sinogram', 'rprojection', 'rsinogram' as they error return [ 'ram-lak', 'shepp-logan', 'cosine', 'hamming', 'hann', 'none', 'tukey', 'lanczos', 'triangular', 'gaussian', 'barlett-hann', 'blackman', 'nuttall', 'blackman-harris', 'blackman-nuttall', 'flat-top', 'parzen' ]
[docs] def allowed_recon_kwargs() -> dict: return { 'FBP_CUDA': ['filter_name', 'filter_par'], 'SIRT_CUDA': ['num_iter', 'min_constraint', 'max_constraint', 'DetectorSuperSampling', 'PixelSuperSampling'], 'SIRT3D_CUDA': ['num_iter', 'min_constraint', 'max_constraint', 'DetectorSuperSampling', 'PixelSuperSampling'] }