# Copyright (C) 2022 ISIS Rutherford Appleton Laboratory UKRI
# SPDX - License - Identifier: GPL-3.0-or-later
import datetime
import os
from logging import getLogger
from typing import List, Union, Optional, Dict, Callable
import h5py
import numpy as np
from skimage import io as skio
import astropy.io.fits as fits
from .utility import DEFAULT_IO_FILE_FORMAT
from ..data.dataset import StrictDataset
from ..data.imagestack import ImageStack
from ..operations.rescale import RescaleFilter
from ..utility.data_containers import Indices
from ..utility.progress_reporting import Progress
from ..utility.version_check import CheckVersion
LOG = getLogger(__name__)
DEFAULT_ZFILL_LENGTH = 6
DEFAULT_NAME_PREFIX = 'image'
DEFAULT_NAME_POSTFIX = ''
INT16_SIZE = 65536
[docs]
def write_fits(data: np.ndarray, filename: str, overwrite: bool = False, description: Optional[str] = ""):
    hdu = fits.PrimaryHDU(data)
    hdulist = fits.HDUList([hdu])
    hdulist.writeto(filename, overwrite=overwrite) 
[docs]
def write_img(data: np.ndarray, filename: str, overwrite: bool = False, description: Optional[str] = ""):
    skio.imsave(filename, data, description=description, metadata=None, software="Mantid Imaging") 
[docs]
def write_nxs(data: np.ndarray, filename: str, projection_angles: Optional[np.ndarray] = None, overwrite: bool = False):
    import h5py
    nxs = h5py.File(filename, 'w')
    # appending flat and dark images is disabled for now
    # new shape to account for appending flat and dark images
    # correct_shape = (data.shape[0] + 2, data.shape[1], data.shape[2])
    dset = nxs.create_dataset("tomography/sample_data", data.shape)
    dset[:data.shape[0]] = data[:]
    # left here if we decide to start appending the flat and dark images again
    # dset[-2] = flat[:]
    # dset[-1] = dark[:]
    if projection_angles is not None:
        rangle = nxs.create_dataset("tomography/rotation_angle", data=projection_angles)
        rangle[...] = projection_angles 
[docs]
def image_save(images: ImageStack,
               output_dir: str,
               name_prefix: str = DEFAULT_NAME_PREFIX,
               swap_axes: bool = False,
               out_format: str = DEFAULT_IO_FILE_FORMAT,
               overwrite_all: bool = False,
               custom_idx: Optional[int] = None,
               zfill_len: int = DEFAULT_ZFILL_LENGTH,
               name_postfix: str = DEFAULT_NAME_POSTFIX,
               indices: Union[List[int], Indices, None] = None,
               pixel_depth: Optional[str] = None,
               progress: Optional[Progress] = None) -> Union[str, List[str]]:
    """
    Save image volume (3d) into a series of slices along the Z axis.
    The Z axis in the script is the ndarray.shape[0].
    :param images: Data as images/slices stores in numpy array
    :param output_dir: Output directory for the files
    :param name_prefix: Prefix for the names of the images,
           appended before the image number
    :param swap_axes: Swap the 0 and 1 axis of the images
           (convert from radiograms to sinograms on saving)
    :param out_format: File format of the saved out images
    :param overwrite_all: Overwrite existing images with conflicting names
    :param custom_idx: Single index to be used for the file name,
           instead of incremental numbers
    :param zfill_len: This option is ignored if custom_idx is specified!
           Prepend zeros to the output file names to have a
           constant file name length. Example:
           - saving out an image with zfill_len = 6:
           saved_image000001,...saved_image000201 and so on
           - saving out an image with zfill_len = 3:
           saved_image001,...saved_image201 and so on
    :param name_postfix: Postfix for the name after the index
    :param indices: Only works if custom_idx is not specified.
           Specify the start and end range of the indices
           which will be used for the file names.
    :param progress: Passed to ensure progress during saving is tracked properly
    :param pixel_depth: Defines the target pixel depth of the save operation so
           np.float32 or np.int16 will ensure the values are scaled
           correctly to these values.
    :returns: The filename/filenames of the saved data.
    """
    progress = Progress.ensure_instance(progress, task_name='Save')
    # expand the path for plugins that don't do it themselves
    output_dir = os.path.abspath(os.path.expanduser(output_dir))
    make_dirs_if_needed(output_dir, overwrite_all)
    # Define current parameters
    min_value: float = np.nanmin(images.data)
    max_value: float = np.nanmax(images.data)
    int_16_slope = max_value / INT16_SIZE
    # Do rescale if needed.
    if pixel_depth is None or pixel_depth == "float32":
        rescale_params: Optional[Dict[str, Union[str, float]]] = None
        rescale_info = ""
    elif pixel_depth == "int16":
        # turn the offset to string otherwise json throws a TypeError when trying to save float32
        rescale_params = {"offset": str(min_value), "slope": int_16_slope}
        rescale_info = "offset = {offset} \n slope = {slope}".format(**rescale_params)
    else:
        raise ValueError("The pixel depth given is not handled: " + pixel_depth)
    # Save metadata
    metadata_filename = os.path.join(output_dir, name_prefix + '.json')
    LOG.debug('Metadata filename: {}'.format(metadata_filename))
    with open(metadata_filename, 'w+') as f:
        images.save_metadata(f, rescale_params)
    data = images.data
    if swap_axes:
        data = np.swapaxes(data, 0, 1)
    if out_format in ['nxs']:
        filename = os.path.join(output_dir, name_prefix + name_postfix)
        write_nxs(data, filename + '.nxs', overwrite=overwrite_all)
        return filename
    else:
        if out_format in ['fit', 'fits']:
            write_func: Callable[[np.ndarray, str, bool, Optional[str]], None] = write_fits
        else:
            # pass all other formats to skimage
            write_func = write_img
        num_images = data.shape[0]
        progress.set_estimated_steps(num_images)
        names = generate_names(name_prefix, indices, num_images, custom_idx, zfill_len, name_postfix, out_format)
        for i in range(len(names)):
            names[i] = os.path.join(output_dir, names[i])
        with progress:
            for idx in range(num_images):
                # Overwrite images with the copy that has been rescaled.
                if pixel_depth == "int16":
                    output_data = RescaleFilter.filter_array(np.copy(images.data[idx]),
                                                             min_input=min_value,
                                                             max_input=max_value,
                                                             max_output=INT16_SIZE - 1).astype(np.uint16)
                    write_func(output_data, names[idx], overwrite_all, rescale_info)
                else:
                    write_func(data[idx, :, :], names[idx], overwrite_all, rescale_info)
                progress.update(msg='Image')
        return names 
[docs]
def nexus_save(dataset: StrictDataset, path: str, sample_name: str):
    """
    Uses information from a StrictDataset to create a NeXus file.
    :param dataset: The dataset to save as a NeXus file.
    :param path: The NeXus file path.
    :param sample_name: The sample name.
    """
    try:
        nexus_file = h5py.File(path, "w", driver="core")
    except OSError as e:
        raise RuntimeError("Unable to save NeXus file. " + str(e))
    try:
        _nexus_save(nexus_file, dataset, sample_name)
    except OSError as e:
        nexus_file.close()
        os.remove(path)
        raise RuntimeError("Unable to save NeXus file. " + str(e))
    nexus_file.close() 
def _nexus_save(nexus_file: h5py.File, dataset: StrictDataset, sample_name: str):
    """
    Takes a NeXus file and writes the StrictDataset information to it.
    :param nexus_file: The NeXus file.
    :param dataset: The StrictDataset.
    :param sample_name: The sample name.
    """
    # Top-level group
    entry = nexus_file.create_group("entry1")
    _set_nx_class(entry, "NXentry")
    # Tomo entry
    tomo_entry = entry.create_group("tomo_entry")
    _set_nx_class(tomo_entry, "NXsubentry")
    # definition field
    tomo_entry.create_dataset("definition", data=np.string_("NXtomo"))
    # instrument field
    instrument_group = tomo_entry.create_group("instrument")
    _set_nx_class(instrument_group, "NXinstrument")
    # instrument/detector field
    detector = instrument_group.create_group("detector")
    _set_nx_class(detector, "NXdetector")
    # instrument data
    combined_data_shape = (sum([len(arr) for arr in dataset.nexus_arrays]), ) + dataset.nexus_arrays[0].shape[1:]
    detector.create_dataset("data", shape=combined_data_shape, dtype="uint16")
    index = 0
    for arr in dataset.nexus_arrays:
        detector["data"][index:index + arr.shape[0]] = arr
        index += arr.shape[0]
    detector.create_dataset("image_key", data=dataset.image_keys)
    # sample field
    sample_group = tomo_entry.create_group("sample")
    _set_nx_class(sample_group, "NXsample")
    sample_group.create_dataset("name", data=np.string_(sample_name))
    # rotation angle
    rotation_angle = sample_group.create_dataset("rotation_angle", data=np.concatenate(dataset.nexus_rotation_angles))
    rotation_angle.attrs["units"] = "rad"
    # data field
    data = tomo_entry.create_group("data")
    _set_nx_class(data, "NXdata")
    data["data"] = detector["data"]
    data["rotation_angle"] = rotation_angle
    data["image_key"] = detector["image_key"]
    for recon in dataset.recons:
        _save_recon_to_nexus(nexus_file, recon)
def _save_recon_to_nexus(nexus_file: h5py.File, recon: ImageStack):
    """
    Saves a recon to a NeXus file.
    :param nexus_file: The NeXus file.
    :param recon: The recon data.
    """
    recon_entry = nexus_file.create_group(recon.name)
    _set_nx_class(recon_entry, "NXentry")
    recon_entry.create_dataset("title", data=np.string_(recon.name))
    recon_entry.create_dataset("definition", data=np.string_("NXtomoproc"))
    instrument = recon_entry.create_group("INSTRUMENT")
    _set_nx_class(instrument, "NXinstrument")
    source = instrument.create_group("SOURCE")
    _set_nx_class(source, "NXsource")
    source.create_dataset("type", data=np.string_("Neutron source"))
    source.create_dataset("name", data=np.string_("ISIS"))
    source.create_dataset("probe", data=np.string_("neutron"))
    sample = recon_entry.create_group("SAMPLE")
    _set_nx_class(sample, "NXsample")
    sample.create_dataset("name", data=np.string_("sample description"))
    reconstruction = recon_entry.create_group("reconstruction")
    _set_nx_class(reconstruction, "NXprocess")
    reconstruction.create_dataset("program", data=np.string_("Mantid Imaging"))
    reconstruction.create_dataset("version", data=np.string_(CheckVersion().get_version()))
    reconstruction.create_dataset("date", data=np.string_("T".join(str(datetime.datetime.now()).split())))
    reconstruction.create_group("parameters")
    data = recon_entry.create_group("data")
    _set_nx_class(data, "NXdata")
    data.create_dataset("data", shape=recon.data.shape, dtype="uint16")
    data["data"][:] = _rescale_recon_data(recon.data)
def _set_nx_class(group: h5py.Group, class_name: str):
    """
    Sets the NX_class attribute of data in a NeXus file.
    :param group: The h5py group.
    :param class_name: The class name.
    """
    group.attrs["NX_class"] = np.string_(class_name)
def _rescale_recon_data(data: np.ndarray) -> np.ndarray:
    """
    Rescales recon data so that it can be converted to uint.
    :param data: The recon data.
    :return: The rescaled recon data.
    """
    min_value = np.min(data)
    if min_value < 0:
        data -= min_value
    data *= (np.iinfo("uint16").max / np.max(data))
    return data
[docs]
def generate_names(name_prefix: str,
                   indices: Union[List[int], Indices, None],
                   num_images: int,
                   custom_idx: Optional[int] = None,
                   zfill_len: int = DEFAULT_ZFILL_LENGTH,
                   name_postfix: str = DEFAULT_NAME_POSTFIX,
                   out_format: str = DEFAULT_IO_FILE_FORMAT) -> List[str]:
    start_index = indices[0] if indices else 0
    if custom_idx:
        index = custom_idx
        increment = 0
    else:
        index = int(start_index)
        increment = indices[2] if indices else 1
    names = []
    for _ in range(num_images):
        # create the file name, and use the format as extension
        names.append(name_prefix + '_' + str(index).zfill(zfill_len) + name_postfix + "." + out_format)
        index += increment
    return names 
[docs]
def make_dirs_if_needed(dirname: Optional[str] = None, overwrite_all: bool = False):
    """
    Makes sure that the directory needed (for example to save a file)
    exists, otherwise creates it.
    :param dirname :: (output) directory to check
    """
    if dirname is None:
        return
    path = os.path.abspath(os.path.expanduser(dirname))
    if not os.path.exists(path):
        os.makedirs(path)
    elif os.listdir(path) and not overwrite_all:
        raise RuntimeError("The output directory is NOT empty:{0}\nThis can be "
                           "overridden by specifying 'Overwrite on name conflict'.".format(path))