# coding=utf-8
"""Classes for aggregating images from multiple files into a single image"""
from abc import ABCMeta, abstractmethod
import numpy as np
import six
from imagesplit.image.image_wrapper import SmartImage
[docs]class Source(object):
"""Base class for reading data"""
__metaclass__ = ABCMeta
[docs] @abstractmethod
def read_image(self, start, size):
"""Read image from specified starting coordinates and size"""
raise NotImplementedError
[docs] @abstractmethod
def close(self):
"""Read image from specified starting coordinates and size"""
raise NotImplementedError
[docs]class CombinedImage(object):
"""A kind of virtual file for writing where the data are distributed
across multiple real files. """
def __init__(self, descriptors, file_factory):
"""Create for the given set of descriptors"""
self.limits = None
self._subimages = []
for subimage_descriptor in descriptors:
self._subimages.append(SubImage(subimage_descriptor, file_factory))
[docs] def read_image(self, start_local, size_local, transformer):
"""Assembles an image range from subimages"""
# Create the output image wrapper
combined_image = SmartImage(start=start_local,
size=size_local,
image=None,
transformer=transformer)
# Compute global coordinates to match with subimage descriptors
start, size = transformer.to_global(start_local, size_local)
# Check each subimage for overlaps
for subimage in self._subimages:
# Fetch any part of the image which overlaps this subimage's ROI
part_image = subimage.read_image_bound_by_roi(start, size)
# If any part overlapped, copy this into the combined image
if part_image:
combined_image.set_sub_image(part_image)
return combined_image
[docs] def close(self):
"""Closes all streams and files"""
for subimage in self._subimages:
subimage.close()
[docs] def write_image(self, source, rescale, test=False):
"""Write out all the subimages with data from supplied source"""
# If rescaling is required, get the global limits
if not rescale:
limits = None
six.print_("Limits: No rescale")
elif rescale == "limits":
limits = source.get_limits()
six.print_("Limits: " + str(limits.min) + ":" + str(limits.max))
else:
limits = Limits(rescale[0], rescale[1])
six.print_("Limits: " + str(limits.min) + ":" + str(limits.max))
# Get each subimage to write itself
if not test:
for next_image in self._subimages:
next_image.write_image(source, limits)
[docs] def get_limits(self):
"""Return minimum and maximum values across all subimages"""
if not self.limits:
minv = None
maxv = None
for next_image in self._subimages:
next_min, next_max = next_image.get_limits()
if minv is None or next_min < minv:
minv = next_min
if maxv is None or next_max > maxv:
maxv = next_max
self.limits = Limits(minv, maxv)
return self.limits
[docs]class Limits(object):
"""Image range values across all subimages"""
def __init__(self, rmin, rmax):
self.min = rmin
self.max = rmax
[docs]class SubImage(Source):
"""An image which forms part of a larger image"""
def __init__(self, descriptor, file_factory):
self._file_factory = file_factory
self._descriptor = descriptor
self._read_file = None
self._roi_start = self._descriptor.ranges.roi_start
self._roi_size = self._descriptor.ranges.roi_size
self._axis = self._descriptor.axis
self._transformer = CoordinateTransformer(
self._descriptor.ranges.origin_start,
self._descriptor.ranges.image_size,
self._axis)
[docs] def read_image(self, start, size):
"""Returns a subimage containing any overlap from the image"""
# Convert to local coordinates for the data source
start_local, size_local = self._transformer.to_local(start, size)
# Get the image data from the data source
local_source = self._get_read_file()
image_local = local_source.read_image(start_local, size_local)
return SmartImage(start=start_local,
size=size_local,
image=image_local,
transformer=self._transformer)
[docs] def read_image_bound_by_roi(self, start, size):
"""Returns a subimage containing any overlap from the ROI"""
# Find the part of the requested region that fits in the ROI
sub_start, sub_size = self.bind_by_roi(start, size)
# Check if any of region is contained in this subimage
if np.all(np.greater(sub_size, 0)):
return self.read_image(sub_start, sub_size)
# Otherwise return None to indicate that the subimage is out of range
return None
[docs] def close(self):
"""Close all streams and files"""
if self._read_file:
self._read_file.close()
self._read_file = None
[docs] def write_image(self, global_source, rescale_limits):
"""Write out SubImage using data from the specified source"""
out_file = self._file_factory.create_write_file(self._descriptor)
local_source = LocalSource(global_source, self._transformer)
out_file.write_image(local_source, rescale_limits)
[docs] def bind_by_roi(self, start_global, size_global):
"""Find the part of the specified region that fits within the ROI"""
start = np.maximum(start_global, self._roi_start)
end = np.minimum(np.add(start_global, size_global),
np.add(self._roi_start, self._roi_size))
size = np.subtract(end, start)
return start, size
[docs] def get_limits(self):
"""Return minimum and maximum values across all subimages"""
image = self.read_image(self._descriptor.ranges.origin_start,
self._descriptor.ranges.image_size)
minv = np.min(image.image.get_raw())
maxv = np.max(image.image.get_raw())
return minv, maxv
def _get_read_file(self):
if not self._read_file:
self._read_file = self._file_factory.create_read_file(
self._descriptor)
return self._read_file
[docs]class LocalSource(Source):
"""Fetch and transform data using local coordinates"""
def __init__(self, source, transformer):
self._source = source
self._transformer = transformer
[docs] def read_image(self, start, size):
"""Returns a partial image using the specified local coordinates"""
return self._source.read_image(
start, size, self._transformer)
[docs] def close(self):
"""Close all streams and files"""
self._source.close()
[docs]class Axis(object):
"""Defines coordinate system used by image coordinates"""
def __init__(self, dim_order, dim_flip):
self.dim_order = dim_order
self.dim_flip = dim_flip
self.reverse_dim_order = np.argsort(dim_order).tolist()
def __eq__(self, other):
if isinstance(other, self.__class__):
return self.__dict__ == other.__dict__
return False
def __ne__(self, other):
return not self.__eq__(other)