# Copyright (C) 2022 ISIS Rutherford Appleton Laboratory UKRI
# SPDX - License - Identifier: GPL-3.0-or-later
import datetime
import os
from logging import getLogger
from typing import List, Union, Optional, Dict, Callable
import h5py
import numpy as np
from skimage import io as skio
import astropy.io.fits as fits
from .utility import DEFAULT_IO_FILE_FORMAT
from ..data.dataset import StrictDataset
from ..data.imagestack import ImageStack
from ..operations.rescale import RescaleFilter
from ..utility.data_containers import Indices
from ..utility.progress_reporting import Progress
from ..utility.version_check import CheckVersion
LOG = getLogger(__name__)
DEFAULT_ZFILL_LENGTH = 6
DEFAULT_NAME_PREFIX = 'image'
DEFAULT_NAME_POSTFIX = ''
INT16_SIZE = 65536
[docs]
def write_fits(data: np.ndarray, filename: str, overwrite: bool = False, description: Optional[str] = ""):
hdu = fits.PrimaryHDU(data)
hdulist = fits.HDUList([hdu])
hdulist.writeto(filename, overwrite=overwrite)
[docs]
def write_img(data: np.ndarray, filename: str, overwrite: bool = False, description: Optional[str] = ""):
skio.imsave(filename, data, description=description, metadata=None, software="Mantid Imaging")
[docs]
def write_nxs(data: np.ndarray, filename: str, projection_angles: Optional[np.ndarray] = None, overwrite: bool = False):
import h5py
nxs = h5py.File(filename, 'w')
# appending flat and dark images is disabled for now
# new shape to account for appending flat and dark images
# correct_shape = (data.shape[0] + 2, data.shape[1], data.shape[2])
dset = nxs.create_dataset("tomography/sample_data", data.shape)
dset[:data.shape[0]] = data[:]
# left here if we decide to start appending the flat and dark images again
# dset[-2] = flat[:]
# dset[-1] = dark[:]
if projection_angles is not None:
rangle = nxs.create_dataset("tomography/rotation_angle", data=projection_angles)
rangle[...] = projection_angles
[docs]
def image_save(images: ImageStack,
output_dir: str,
name_prefix: str = DEFAULT_NAME_PREFIX,
swap_axes: bool = False,
out_format: str = DEFAULT_IO_FILE_FORMAT,
overwrite_all: bool = False,
custom_idx: Optional[int] = None,
zfill_len: int = DEFAULT_ZFILL_LENGTH,
name_postfix: str = DEFAULT_NAME_POSTFIX,
indices: Union[List[int], Indices, None] = None,
pixel_depth: Optional[str] = None,
progress: Optional[Progress] = None) -> Union[str, List[str]]:
"""
Save image volume (3d) into a series of slices along the Z axis.
The Z axis in the script is the ndarray.shape[0].
:param images: Data as images/slices stores in numpy array
:param output_dir: Output directory for the files
:param name_prefix: Prefix for the names of the images,
appended before the image number
:param swap_axes: Swap the 0 and 1 axis of the images
(convert from radiograms to sinograms on saving)
:param out_format: File format of the saved out images
:param overwrite_all: Overwrite existing images with conflicting names
:param custom_idx: Single index to be used for the file name,
instead of incremental numbers
:param zfill_len: This option is ignored if custom_idx is specified!
Prepend zeros to the output file names to have a
constant file name length. Example:
- saving out an image with zfill_len = 6:
saved_image000001,...saved_image000201 and so on
- saving out an image with zfill_len = 3:
saved_image001,...saved_image201 and so on
:param name_postfix: Postfix for the name after the index
:param indices: Only works if custom_idx is not specified.
Specify the start and end range of the indices
which will be used for the file names.
:param progress: Passed to ensure progress during saving is tracked properly
:param pixel_depth: Defines the target pixel depth of the save operation so
np.float32 or np.int16 will ensure the values are scaled
correctly to these values.
:returns: The filename/filenames of the saved data.
"""
progress = Progress.ensure_instance(progress, task_name='Save')
# expand the path for plugins that don't do it themselves
output_dir = os.path.abspath(os.path.expanduser(output_dir))
make_dirs_if_needed(output_dir, overwrite_all)
# Define current parameters
min_value: float = np.nanmin(images.data)
max_value: float = np.nanmax(images.data)
int_16_slope = max_value / INT16_SIZE
# Do rescale if needed.
if pixel_depth is None or pixel_depth == "float32":
rescale_params: Optional[Dict[str, Union[str, float]]] = None
rescale_info = ""
elif pixel_depth == "int16":
# turn the offset to string otherwise json throws a TypeError when trying to save float32
rescale_params = {"offset": str(min_value), "slope": int_16_slope}
rescale_info = "offset = {offset} \n slope = {slope}".format(**rescale_params)
else:
raise ValueError("The pixel depth given is not handled: " + pixel_depth)
# Save metadata
metadata_filename = os.path.join(output_dir, name_prefix + '.json')
LOG.debug('Metadata filename: {}'.format(metadata_filename))
with open(metadata_filename, 'w+') as f:
images.save_metadata(f, rescale_params)
data = images.data
if swap_axes:
data = np.swapaxes(data, 0, 1)
if out_format in ['nxs']:
filename = os.path.join(output_dir, name_prefix + name_postfix)
write_nxs(data, filename + '.nxs', overwrite=overwrite_all)
return filename
else:
if out_format in ['fit', 'fits']:
write_func: Callable[[np.ndarray, str, bool, Optional[str]], None] = write_fits
else:
# pass all other formats to skimage
write_func = write_img
num_images = data.shape[0]
progress.set_estimated_steps(num_images)
names = generate_names(name_prefix, indices, num_images, custom_idx, zfill_len, name_postfix, out_format)
for i in range(len(names)):
names[i] = os.path.join(output_dir, names[i])
with progress:
for idx in range(num_images):
# Overwrite images with the copy that has been rescaled.
if pixel_depth == "int16":
output_data = RescaleFilter.filter_array(np.copy(images.data[idx]),
min_input=min_value,
max_input=max_value,
max_output=INT16_SIZE - 1).astype(np.uint16)
write_func(output_data, names[idx], overwrite_all, rescale_info)
else:
write_func(data[idx, :, :], names[idx], overwrite_all, rescale_info)
progress.update(msg='Image')
return names
[docs]
def nexus_save(dataset: StrictDataset, path: str, sample_name: str):
"""
Uses information from a StrictDataset to create a NeXus file.
:param dataset: The dataset to save as a NeXus file.
:param path: The NeXus file path.
:param sample_name: The sample name.
"""
try:
nexus_file = h5py.File(path, "w", driver="core")
except OSError as e:
raise RuntimeError("Unable to save NeXus file. " + str(e))
try:
_nexus_save(nexus_file, dataset, sample_name)
except OSError as e:
nexus_file.close()
os.remove(path)
raise RuntimeError("Unable to save NeXus file. " + str(e))
nexus_file.close()
def _nexus_save(nexus_file: h5py.File, dataset: StrictDataset, sample_name: str):
"""
Takes a NeXus file and writes the StrictDataset information to it.
:param nexus_file: The NeXus file.
:param dataset: The StrictDataset.
:param sample_name: The sample name.
"""
# Top-level group
entry = nexus_file.create_group("entry1")
_set_nx_class(entry, "NXentry")
# Tomo entry
tomo_entry = entry.create_group("tomo_entry")
_set_nx_class(tomo_entry, "NXsubentry")
# definition field
tomo_entry.create_dataset("definition", data=np.string_("NXtomo"))
# instrument field
instrument_group = tomo_entry.create_group("instrument")
_set_nx_class(instrument_group, "NXinstrument")
# instrument/detector field
detector = instrument_group.create_group("detector")
_set_nx_class(detector, "NXdetector")
# instrument data
combined_data_shape = (sum([len(arr) for arr in dataset.nexus_arrays]), ) + dataset.nexus_arrays[0].shape[1:]
detector.create_dataset("data", shape=combined_data_shape, dtype="uint16")
index = 0
for arr in dataset.nexus_arrays:
detector["data"][index:index + arr.shape[0]] = arr
index += arr.shape[0]
detector.create_dataset("image_key", data=dataset.image_keys)
# sample field
sample_group = tomo_entry.create_group("sample")
_set_nx_class(sample_group, "NXsample")
sample_group.create_dataset("name", data=np.string_(sample_name))
# rotation angle
rotation_angle = sample_group.create_dataset("rotation_angle", data=np.concatenate(dataset.nexus_rotation_angles))
rotation_angle.attrs["units"] = "rad"
# data field
data = tomo_entry.create_group("data")
_set_nx_class(data, "NXdata")
data["data"] = detector["data"]
data["rotation_angle"] = rotation_angle
data["image_key"] = detector["image_key"]
for recon in dataset.recons:
_save_recon_to_nexus(nexus_file, recon)
def _save_recon_to_nexus(nexus_file: h5py.File, recon: ImageStack):
"""
Saves a recon to a NeXus file.
:param nexus_file: The NeXus file.
:param recon: The recon data.
"""
recon_entry = nexus_file.create_group(recon.name)
_set_nx_class(recon_entry, "NXentry")
recon_entry.create_dataset("title", data=np.string_(recon.name))
recon_entry.create_dataset("definition", data=np.string_("NXtomoproc"))
instrument = recon_entry.create_group("INSTRUMENT")
_set_nx_class(instrument, "NXinstrument")
source = instrument.create_group("SOURCE")
_set_nx_class(source, "NXsource")
source.create_dataset("type", data=np.string_("Neutron source"))
source.create_dataset("name", data=np.string_("ISIS"))
source.create_dataset("probe", data=np.string_("neutron"))
sample = recon_entry.create_group("SAMPLE")
_set_nx_class(sample, "NXsample")
sample.create_dataset("name", data=np.string_("sample description"))
reconstruction = recon_entry.create_group("reconstruction")
_set_nx_class(reconstruction, "NXprocess")
reconstruction.create_dataset("program", data=np.string_("Mantid Imaging"))
reconstruction.create_dataset("version", data=np.string_(CheckVersion().get_version()))
reconstruction.create_dataset("date", data=np.string_("T".join(str(datetime.datetime.now()).split())))
reconstruction.create_group("parameters")
data = recon_entry.create_group("data")
_set_nx_class(data, "NXdata")
data.create_dataset("data", shape=recon.data.shape, dtype="uint16")
data["data"][:] = _rescale_recon_data(recon.data)
def _set_nx_class(group: h5py.Group, class_name: str):
"""
Sets the NX_class attribute of data in a NeXus file.
:param group: The h5py group.
:param class_name: The class name.
"""
group.attrs["NX_class"] = np.string_(class_name)
def _rescale_recon_data(data: np.ndarray) -> np.ndarray:
"""
Rescales recon data so that it can be converted to uint.
:param data: The recon data.
:return: The rescaled recon data.
"""
min_value = np.min(data)
if min_value < 0:
data -= min_value
data *= (np.iinfo("uint16").max / np.max(data))
return data
[docs]
def generate_names(name_prefix: str,
indices: Union[List[int], Indices, None],
num_images: int,
custom_idx: Optional[int] = None,
zfill_len: int = DEFAULT_ZFILL_LENGTH,
name_postfix: str = DEFAULT_NAME_POSTFIX,
out_format: str = DEFAULT_IO_FILE_FORMAT) -> List[str]:
start_index = indices[0] if indices else 0
if custom_idx:
index = custom_idx
increment = 0
else:
index = int(start_index)
increment = indices[2] if indices else 1
names = []
for _ in range(num_images):
# create the file name, and use the format as extension
names.append(name_prefix + '_' + str(index).zfill(zfill_len) + name_postfix + "." + out_format)
index += increment
return names
[docs]
def make_dirs_if_needed(dirname: Optional[str] = None, overwrite_all: bool = False):
"""
Makes sure that the directory needed (for example to save a file)
exists, otherwise creates it.
:param dirname :: (output) directory to check
"""
if dirname is None:
return
path = os.path.abspath(os.path.expanduser(dirname))
if not os.path.exists(path):
os.makedirs(path)
elif os.listdir(path) and not overwrite_all:
raise RuntimeError("The output directory is NOT empty:{0}\nThis can be "
"overridden by specifying 'Overwrite on name conflict'.".format(path))