Source code for mantidimaging.core.data.dataset
# Copyright (C) 2022 ISIS Rutherford Appleton Laboratory UKRI
# SPDX - License - Identifier: GPL-3.0-or-later
import uuid
from dataclasses import dataclass
from typing import Optional, List
import numpy as np
from mantidimaging.core.data import Images
from mantidimaging.core.data.reconlist import ReconList
def _delete_stack_error_message(images_id: uuid.UUID) -> str:
return f"Unable to delete stack: Images with ID {images_id} not present in dataset."
[docs]class BaseDataset:
def __init__(self, name: str = ""):
self._id: uuid.UUID = uuid.uuid4()
self.recons = ReconList()
self._name = name
self._sinograms: Optional[Images] = None
@property
def id(self) -> uuid.UUID:
return self._id
@property
def name(self) -> str:
return self._name
@name.setter
def name(self, arg: str):
self._name = arg
@property
def sinograms(self) -> Optional[Images]:
return self._sinograms
@sinograms.setter
def sinograms(self, sino: Optional[Images]):
self._sinograms = sino
@property
def all(self):
raise NotImplementedError()
[docs] def delete_stack(self, images_id: uuid.UUID):
raise NotImplementedError()
[docs] def replace(self, images_id: uuid.UUID, new_data: np.ndarray):
for image in self.all:
if image.id == images_id:
image.data = new_data
return
raise KeyError(f"Unable to replace: Images with ID {images_id} not present in dataset.")
def __contains__(self, images_id: uuid.UUID) -> bool:
return any([image.id == images_id for image in self.all])
@property
def all_image_ids(self) -> List[uuid.UUID]:
return [image_stack.id for image_stack in self.all if image_stack is not None]
[docs] def delete_recons(self):
self.recons.clear()
[docs]class MixedDataset(BaseDataset):
def __init__(self, stacks: List[Images] = [], name: str = ""):
super().__init__(name=name)
self._stacks = stacks
@property
def all(self) -> List[Images]:
all_images = self._stacks + self.recons.stacks
if self.sinograms is None:
return all_images
return all_images + [self.sinograms]
[docs] def delete_stack(self, images_id: uuid.UUID):
for image in self._stacks:
if image.id == images_id:
self._stacks.remove(image)
return
for recon in self.recons:
if recon.id == images_id:
self.recons.remove(recon)
return
if self.sinograms is not None and self.sinograms.id == images_id:
self.sinograms = None
return
raise KeyError(_delete_stack_error_message(images_id))
[docs]@dataclass
class StrictDataset(BaseDataset):
sample: Images
flat_before: Optional[Images] = None
flat_after: Optional[Images] = None
dark_before: Optional[Images] = None
dark_after: Optional[Images] = None
def __init__(self,
sample: Images,
flat_before: Optional[Images] = None,
flat_after: Optional[Images] = None,
dark_before: Optional[Images] = None,
dark_after: Optional[Images] = None,
name: str = ""):
super().__init__(name=name)
self.sample = sample
self.flat_before = flat_before
self.flat_after = flat_after
self.dark_before = dark_before
self.dark_after = dark_after
if self.name == "":
self._name = sample.name
@property
def all(self) -> List[Images]:
image_stacks = [
self.sample, self.proj180deg, self.flat_before, self.flat_after, self.dark_before, self.dark_after,
self.sinograms
]
return [image_stack for image_stack in image_stacks if image_stack is not None] + self.recons.stacks
@property
def proj180deg(self):
if self.sample is not None:
return self.sample.proj180deg
else:
return None
@proj180deg.setter
def proj180deg(self, _180_deg: Images):
self.sample.proj180deg = _180_deg
[docs] def delete_stack(self, images_id: uuid.UUID):
if isinstance(self.sample, Images) and self.sample.id == images_id:
self.sample = None # type: ignore
elif isinstance(self.flat_before, Images) and self.flat_before.id == images_id:
self.flat_before = None
elif isinstance(self.flat_after, Images) and self.flat_after.id == images_id:
self.flat_after = None
elif isinstance(self.dark_before, Images) and self.dark_before.id == images_id:
self.dark_before = None
elif isinstance(self.dark_after, Images) and self.dark_after.id == images_id:
self.dark_after = None
elif isinstance(self.proj180deg, Images) and self.proj180deg.id == images_id:
self.sample.clear_proj180deg()
elif isinstance(self.sinograms, Images) and self.sinograms.id == images_id:
self.sinograms = None
elif images_id in self.recons.ids:
for recon in self.recons:
if recon.id == images_id:
self.recons.remove(recon)
else:
raise KeyError(_delete_stack_error_message(images_id))