Source code for kbmod.reprojection

import numpy as np
import concurrent.futures
import reproject
from astropy.nddata import CCDData
from astropy.wcs import WCS
from tqdm.asyncio import tqdm

from kbmod import is_interactive
from kbmod.search import KB_NO_DATA, ImageStack, LayeredImage, RawImage
from kbmod.work_unit import WorkUnit
from kbmod.wcs_utils import append_wcs_to_hdu_header
from astropy.io import fits
import os
from copy import copy


# The number of executors to use in the parallel reprojecting function.
MAX_PROCESSES = 8
_DEFAULT_TQDM_BAR = "{l_bar}{bar}| {n_fmt}/{total_fmt} [{elapsed}]"


[docs]def reproject_image(image, original_wcs, common_wcs): """Given an ndarray representing image data (either science or variance, when used with `reproject_work_unit`), as well as a common wcs, return the reprojected image and footprint as a numpy.ndarray. Attributes ---------- image : `kbmod.search.RawImage` or `numpy.ndarray` The image data to be reprojected. original_wcs : `astropy.wcs.WCS` The WCS of the original image. common_wcs : `astropy.wcs.WCS` The WCS to reproject all the images into. Returns ------- new_image : `numpy.ndarray` The image data reprojected with a common `astropy.wcs.WCS`. footprint : `numpy.ndarray` An array containing the footprint of pixels that have data. for footprint[i][j], it's 1 if there is a corresponding reprojected pixel and 0 if there is no data. """ if type(image) is RawImage: image = image.image image_data = CCDData(image, unit="adu") image_data.wcs = original_wcs footprint = np.zeros(common_wcs.array_shape, dtype=np.ubyte) # if the input image is actually a stack of images, we need to duplicate the # footprint to match the total number of images. if type(image) is list: footprint = np.repeat(footprint[np.newaxis, :, :], len(image), axis=0) new_image, _ = reproject.reproject_adaptive( image_data, common_wcs, shape_out=common_wcs.array_shape, bad_value_mode="ignore", output_footprint=footprint, roundtrip_coords=False, ) # if we passed in a stack of ndarrays (i.e. science, varianace, mask), we only # need to return the first footprint, as they should all be the same. if footprint.ndim == 3: footprint = footprint[0] return new_image.astype(np.float32), footprint
[docs]def reproject_work_unit( work_unit, common_wcs, frame="original", parallelize=True, max_parallel_processes=MAX_PROCESSES, write_output=False, directory=None, filename=None, show_progress=None, ): """Given a WorkUnit and a WCS, reproject all of the images in the ImageStack into a common WCS. Attributes ---------- work_unit : `kbmod.WorkUnit` The WorkUnit to be reprojected. common_wcs : `astropy.wcs.WCS` The WCS to reproject all the images into. frame : `str` The WCS frame of reference to use when reprojecting. Can either be 'original' or 'ebd' to specify which WCS to access from the WorkUnit. parallelize : `bool` If True, use multiprocessing to reproject the images in parallel. Default is True. max_parallel_processes : `int` The maximum number of parallel processes to use when reprojecting. Only used when parallelize is True. Default is 8. For more see `concurrent.futures.ProcessPoolExecutor` in the Python docs. write_output : `bool` Whether or not to write the reprojection results out as a sharded `WorkUnit`. directory : `str` The directory where output will be written if `write_output` is set to True. filename : `str` The base filename where output will be written if `write_output` is set to True. show_progress : `bool` or `None`, optional If `None` use default settings, when a boolean forces the progress bar to be displayed or hidden. Returns ------- A `kbmod.WorkUnit` reprojected with a common `astropy.wcs.WCS`, or `None` in the case where `write_output` is set to True. """ if work_unit.reprojected: raise ValueError("Unable to reproject a reprojected WorkUnit.") show_progress = is_interactive() if show_progress is None else show_progress if (work_unit.lazy or write_output) and (directory is None or filename is None): raise ValueError("can't write output to sharded fits without directory and filename provided.") if work_unit.lazy: return reproject_lazy_work_unit( work_unit, common_wcs, frame=frame, max_parallel_processes=max_parallel_processes, directory=directory, filename=filename, show_progress=show_progress, ) if parallelize: return _reproject_work_unit_in_parallel( work_unit, common_wcs, frame, max_parallel_processes, write_output=write_output, directory=directory, filename=filename, show_progress=show_progress, ) else: return _reproject_work_unit( work_unit, common_wcs, frame, write_output=write_output, directory=directory, filename=filename, show_progress=show_progress, )
def _reproject_work_unit( work_unit, common_wcs, frame="original", write_output=False, directory=None, filename=None, show_progress=False, ): """Given a WorkUnit and a WCS, reproject all of the images in the ImageStack into a common WCS. Attributes ---------- work_unit : `kbmod.WorkUnit` The WorkUnit to be reprojected. common_wcs : `astropy.wcs.WCS` The WCS to reproject all the images into. frame : `str` The WCS frame of reference to use when reprojecting. Can either be 'original' or 'ebd' to specify which WCS to access from the WorkUnit. write_output : `bool` Whether or not to write the reprojection results out as a sharded `WorkUnit`. directory : `str` The directory where output will be written if `write_output` is set to True. filename : `str` The base filename where output will be written if `write_output` is set to True. disable_show_progress : `bool` Whether or not to disable the `tqdm` show_progress bar. Returns ------- A `kbmod.WorkUnit` reprojected with a common `astropy.wcs.WCS`, or `None` in the case where `write_output` is set to True. """ images = work_unit.im_stack.get_images() unique_obstimes, unique_obstime_indices = work_unit.get_unique_obstimes_and_indices() # Create a list of the correct WCS. We do this extraction once and reuse for all images. if frame == "original": wcs_list = work_unit.get_constituent_meta("per_image_wcs") elif frame == "ebd": wcs_list = work_unit.get_constituent_meta("ebd_wcs") else: raise ValueError("Invalid projection frame provided.") stack = ImageStack() for obstime_index, o_i in tqdm( enumerate(zip(unique_obstimes, unique_obstime_indices)), bar_format=_DEFAULT_TQDM_BAR, desc="Reprojecting", disable=not show_progress, ): time, indices = o_i science_add = np.zeros(common_wcs.array_shape, dtype=np.float32) variance_add = np.zeros(common_wcs.array_shape, dtype=np.float32) mask_add = np.zeros(common_wcs.array_shape, dtype=np.float32) footprint_add = np.zeros(common_wcs.array_shape, dtype=np.ubyte) for index in indices: image = images[index] science = image.get_science() variance = image.get_variance() mask = image.get_mask() original_wcs = wcs_list[index] if original_wcs is None: raise ValueError(f"No WCS provided for index {index}") reprojected_science, footprint = reproject_image(science, original_wcs, common_wcs) footprint_add += footprint # we'll enforce that there be no overlapping images at the same time, # for now. We might be able to add some ability co-add in the future. if np.any(footprint_add > 1): raise ValueError("Images with the same obstime are overlapping.") reprojected_variance, _ = reproject_image(variance, original_wcs, common_wcs) reprojected_mask, _ = reproject_image(mask, original_wcs, common_wcs) # change all the NaNs to zeroes so that the matrix addition works properly. # `footprint_add` will maintain the information about what areas of the frame # don't have any data so that we can change it back after we combine. reprojected_science[np.isnan(reprojected_science)] = 0.0 reprojected_variance[np.isnan(reprojected_variance)] = 0.0 reprojected_mask[np.isnan(reprojected_mask)] = 0.0 science_add += reprojected_science variance_add += reprojected_variance mask_add += reprojected_mask # change all the values where there are is no corresponding data to `KB_NO_DATA`. gaps = footprint_add == 0 science_add[gaps] = KB_NO_DATA variance_add[gaps] = KB_NO_DATA mask_add[gaps] = 1 # transforms the mask back into a bitmask. Note that we need to be explicit # about the dtypes for 0.0 and 1.0, otherwise mask_add will be cast to float64. mask_add = np.where(np.isclose(mask_add, 0.0, atol=0.2), np.float32(0.0), np.float32(1.0)) psf = images[indices[0]].get_psf() if write_output: _write_images_to_shard( science_add=science_add, variance_add=variance_add, mask_add=mask_add, psf=psf, wcs=common_wcs, obstime=time, obstime_index=obstime_index, indices=indices, directory=directory, filename=filename, ) else: new_layered_image = LayeredImage( science_add, variance_add, mask_add, psf, time, ) stack.append_image(new_layered_image, force_move=True) if write_output: # Create a copy of the WorkUnit to write the global metadata. # We preserve the metgadata for the consituent images. new_work_unit = copy(work_unit) new_work_unit._per_image_indices = unique_obstime_indices new_work_unit.wcs = common_wcs new_work_unit.reprojected = True new_work_unit.reprojection_frame = frame hdul = new_work_unit.metadata_to_hdul() hdul.writeto(os.path.join(directory, filename)) else: # Create a new WorkUnit with the new ImageStack and global WCS. # We preserve the metgadata for the consituent images. new_wunit = WorkUnit( im_stack=stack, config=work_unit.config, wcs=common_wcs, per_image_indices=unique_obstime_indices, reprojected=True, reprojection_frame=frame, barycentric_distance=work_unit.barycentric_distance, org_image_meta=work_unit.org_img_meta, ) return new_wunit def _reproject_work_unit_in_parallel( work_unit, common_wcs, frame="original", max_parallel_processes=MAX_PROCESSES, write_output=False, directory=None, filename=None, show_progress=False, ): """Given a WorkUnit and a WCS, reproject all of the images in the ImageStack into a common WCS. This function uses multiprocessing to reproject the images in parallel. Attributes ---------- work_unit : `kbmod.WorkUnit` The WorkUnit to be reprojected. common_wcs : `astropy.wcs.WCS` The WCS to reproject all the images into. frame : `str` The WCS frame of reference to use when reprojecting. Can either be 'original' or 'ebd' to specify which WCS to access from the WorkUnit. max_parallel_processes : `int` The maximum number of parallel processes to use when reprojecting. Default is 8. For more see `concurrent.futures.ProcessPoolExecutor` in the Python docs. write_output : `bool` Whether or not to write the reprojection results out as a sharded `WorkUnit`. directory : `str` The directory where output will be written if `write_output` is set to True. filename : `str` The base filename where output will be written if `write_output` is set to True. show_progress : `bool` Whether or not to enable the `tqdm` show_progress bar. Returns ------- A `kbmod.WorkUnit` reprojected with a common `astropy.wcs.WCS`, or `None` in the case where `write_output` is set to True. """ # get all the unique obstimes unique_obstimes, unique_obstimes_indices = work_unit.get_unique_obstimes_and_indices() # get the list of images from the work_unit outside the for-loop images = work_unit.im_stack.get_images() future_reprojections = [] with concurrent.futures.ProcessPoolExecutor(max_parallel_processes) as executor: # for a given list of obstime indices, collect all the science, variance, and mask images. for obstime_index, o_i in enumerate(zip(unique_obstimes, unique_obstimes_indices)): obstime, indices = o_i original_wcs = _validate_original_wcs(work_unit, indices, frame) # get the list of images for each unique obstime images_at_obstime = [images[i] for i in indices] # convert each image into a science, variance, or mask "image", i.e. a list of numpy arrays. science_images_at_obstime = [this_image.get_science().image for this_image in images_at_obstime] variance_images_at_obstime = [this_image.get_variance().image for this_image in images_at_obstime] mask_images_at_obstime = [this_image.get_mask().image for this_image in images_at_obstime] if write_output: psf_array = _get_first_psf_at_time(work_unit, obstime) future_reprojections.append( executor.submit( _reproject_and_write, science_images=science_images_at_obstime, variance_images=variance_images_at_obstime, mask_images=mask_images_at_obstime, psf=psf_array, obstime=obstime, obstime_index=obstime_index, indices=indices, common_wcs=common_wcs, original_wcs=original_wcs, directory=directory, filename=filename, ) ) else: # call `_reproject_images` in parallel. future_reprojections.append( executor.submit( _reproject_images, science_images=science_images_at_obstime, variance_images=variance_images_at_obstime, mask_images=mask_images_at_obstime, obstime=obstime, common_wcs=common_wcs, original_wcs=original_wcs, ) ) # Need to consume the generator producted by tqdm to update the show_progress bar so we instantiate a list list( tqdm( concurrent.futures.as_completed(future_reprojections), total=len(future_reprojections), bar_format=_DEFAULT_TQDM_BAR, desc="Reprojecting", disable=not show_progress, ) ) # when all the multiprocessing has finished, convert the returned numpy arrays to RawImages. concurrent.futures.wait(future_reprojections, return_when=concurrent.futures.ALL_COMPLETED) if write_output: for result in future_reprojections: if not result.result(): raise RuntimeError("one or more jobs failed.") new_work_unit = copy(work_unit) new_work_unit._per_image_indices = unique_obstimes_indices new_work_unit.wcs = common_wcs new_work_unit.reprojected = True new_work_unit.reprojection_frame = frame hdul = new_work_unit.metadata_to_hdul() hdul.writeto(os.path.join(directory, filename)) else: stack = ImageStack([]) for result in future_reprojections: science_add, variance_add, mask_add, time = result.result() psf = _get_first_psf_at_time(work_unit, obstime) # And then stack the RawImages into a LayeredImage. new_layered_image = LayeredImage( science_add, variance_add, mask_add, psf, time, ) stack.append_image(new_layered_image, force_move=True) # sort by the time_stamp stack.sort_by_time() # Add the imageStack to a new WorkUnit and return it. We preserve the metgadata # for the consituent images. new_wunit = WorkUnit( im_stack=stack, config=work_unit.config, wcs=common_wcs, per_image_indices=unique_obstimes_indices, reprojected=True, reprojection_frame=frame, barycentric_distance=work_unit.barycentric_distance, org_image_meta=work_unit.org_img_meta, ) return new_wunit
[docs]def reproject_lazy_work_unit( work_unit, common_wcs, directory, filename, frame="original", max_parallel_processes=MAX_PROCESSES, show_progress=None, ): """Given a WorkUnit and a WCS, reproject all of the images in the ImageStack into a common WCS. This function is used with lazily evaluated `WorkUnit`s and multiprocessing to reproject the images in parallel, and only loads the individual image frames at runtime. Currently only works for sharded `WorkUnit`s loaded with the `lazy` option. Attributes ---------- work_unit : `kbmod.WorkUnit` The WorkUnit to be reprojected. common_wcs : `astropy.wcs.WCS` The WCS to reproject all the images into. directory : `str` The directory where the `WorkUnit` fits shards will be output. filename : `str` The base filename (will be the actual name of the primary/metadata fits file and included with the index number in the filename of the shards). frame : `str` The WCS frame of reference to use when reprojecting. Can either be 'original' or 'ebd' to specify which WCS to access from the WorkUnit. max_parallel_processes : `int` The maximum number of parallel processes to use when reprojecting. Default is 8. For more see `concurrent.futures.ProcessPoolExecutor` in the Python docs. show_progress : `bool` or `None`, optional If `None` use default settings, when a boolean forces the progress bar to be displayed or hidden. """ show_progress = is_interactive() if show_progress is None else show_progress if not work_unit.lazy: raise ValueError("WorkUnit must be lazily loaded.") # get all the unique obstimes unique_obstimes, unique_obstimes_indices = work_unit.get_unique_obstimes_and_indices() future_reprojections = [] with concurrent.futures.ProcessPoolExecutor(max_parallel_processes) as executor: # for a given list of obstime indices, collect all the science, variance, and mask images. for obstime_index, o_i in enumerate(zip(unique_obstimes, unique_obstimes_indices)): obstime, indices = o_i original_wcs = _validate_original_wcs(work_unit, indices, frame) # get the list of images for each unique obstime file_paths_at_obstime = [work_unit.file_paths[i] for i in indices] # call `_reproject_images` in parallel. future_reprojections.append( executor.submit( _load_images_and_reproject, file_paths=file_paths_at_obstime, indices=indices, obstime=obstime, obstime_index=obstime_index, common_wcs=common_wcs, original_wcs=original_wcs, directory=directory, filename=filename, ) ) # Need to consume the generator producted by tqdm to update the show_progress bar so we instantiate a list list( tqdm( concurrent.futures.as_completed(future_reprojections), total=len(future_reprojections), bar_format=_DEFAULT_TQDM_BAR, desc="Reprojecting", disable=not show_progress, ) ) concurrent.futures.wait(future_reprojections, return_when=concurrent.futures.ALL_COMPLETED) for result in future_reprojections: if not result.result(): raise RuntimeError("one or more jobs failed.") # We use new metadata for the new images and the same metadata for the original images. new_work_unit = copy(work_unit) new_work_unit._per_image_indices = unique_obstimes_indices new_work_unit.wcs = common_wcs new_work_unit.reprojected = True new_work_unit.reprojecton = frame hdul = new_work_unit.metadata_to_hdul() hdul.writeto(os.path.join(directory, filename))
def _validate_original_wcs(work_unit, indices, frame="original"): """Given a work unit and a set of indices, verify that the WCS is not None for any of the indices. If it is, raise a ValueError. Parameters ---------- work_unit : `kbmod.WorkUnit` The WorkUnit with WCS to be validated. indices : list[int] The indices to be validated in work_unit. frame : `str` The WCS frame of reference to use when reprojecting. Can either be 'original' or 'ebd' to specify which WCS to access from the WorkUnit. Returns ------- list[`astropy.wcs.WCS`] The list of validated WCS objects for these indices Raises ------ ValueError If any WCS objects are None, raise an error. """ if frame == "original": original_wcs = [work_unit.get_wcs(i) for i in indices] elif frame == "ebd": original_wcs = [work_unit.get_constituent_meta("ebd_wcs")[i] for i in indices] else: raise ValueError("Invalid projection frame provided.") if len(original_wcs) == 0: raise ValueError(f"No WCS found for frame {frame}") if np.any(original_wcs) is None: # find indices where the wcs is None bad_indices = np.where(original_wcs == None) # get values from `indices` where original_wcs is None work_unit_indices = [indices[i] for i in bad_indices] raise ValueError(f"No WCS provided for work_unit index(s) {work_unit_indices}") return original_wcs def _get_first_psf_at_time(work_unit, time): """Given a work_unit, find the first psf object at a given time Parameters ---------- work_unit : `kbmod.WorkUnit` The WorkUnit to be searched time : float The MJD of the observation(s) to search for in the work_unit. Returns ------- `numpy.ndarray` The kernel of the first PSF found at the given time. Raises ------ ValueError If the time is not found in list of observation times in the work_unit, raise an error. """ obstimes = np.asarray(work_unit.get_all_obstimes()) # if the time isn't in the list of times, raise an error. if time not in obstimes: raise ValueError(f"Observation time {time} not found in work unit.") images = work_unit.im_stack.get_images() index = np.where(obstimes == time)[0][0] return images[index].get_psf() def _load_images_and_reproject( file_paths, indices, obstime, obstime_index, common_wcs, original_wcs, directory, filename ): """Load image data from `WorkUnit` shards. Intermediary step for when the `WorkUnit` is loaded lazily. Parameters ---------- file_paths : `list[str]` List of strings comtaining the images to be reprojected and stitched. inidces : `list[int]` List of `WorkUnit` indices corresponding to the original positions of the images within the `ImageStack`. obstime : `float` observation times for set of images. obstime_index : `int` the index of the unique obstime. i.e. the new index of the mosaicked image in the `ImageStack`. common_wcs : `astropy.wcs.WCS` The WCS to reproject all the images into. original_wcs : `list[astropy.wcs.WCS]` The list of WCS objects for these images. directory : `str` The directory to output the new sharded and reprojected `WorkUnit`. filename : `str` The base filename for the sharded and reprojected `WorkUnit`. """ science_images = [] variance_images = [] mask_images = [] psfs = [] for file_path, index in zip(file_paths, indices): with fits.open(file_path) as hdul: science_images.append(hdul[f"SCI_{index}"].data.astype(np.single)) variance_images.append(hdul[f"VAR_{index}"].data.astype(np.single)) mask_images.append(hdul[f"MSK_{index}"].data.astype(bool)) psfs.append(hdul[f"PSF_{index}"].data.astype(np.single)) return _reproject_and_write( science_images=science_images, variance_images=variance_images, mask_images=mask_images, psf=psfs[0], obstime=obstime, obstime_index=obstime_index, common_wcs=common_wcs, original_wcs=original_wcs, indices=indices, directory=directory, filename=filename, ) def _reproject_and_write( science_images, variance_images, mask_images, psf, obstime, obstime_index, indices, common_wcs, original_wcs, directory, filename, ): """Reproject a set of images and write out the output to a sharded `WorkUnit. Parameters ---------- science_images : `list[numpy.ndarray]` List of ndarrays that represent the science images to be reprojected. variance_images : `list[numpy.ndarray]` List of ndarrays that represent the variance images to be reprojected. mask_images : `list[numpy.ndarray]` List of ndarrays that represent the mask images to be reprojected. psf : `numpy.ndarray` The PSF kernel. obstime : `float` observation times for set of images. obstime_index : `int` the index of the unique obstime. i.e. the new index of the mosaicked image in the `ImageStack`. inidces : `list[int]` List of `WorkUnit` indices corresponding to the original positions of the images within the `ImageStack`. common_wcs : `astropy.wcs.WCS` The WCS to reproject all the images into. original_wcs : `list[astropy.wcs.WCS]` The list of WCS objects for these images. directory : `str` The directory to output the new sharded and reprojected `WorkUnit`. filename : `str` The base filename for the sharded """ science_add, variance_add, mask_add, obstime = _reproject_images( science_images, variance_images, mask_images, obstime, common_wcs, original_wcs, ) _write_images_to_shard( science_add=science_add, variance_add=variance_add, mask_add=mask_add, psf=psf, wcs=common_wcs, obstime=obstime, obstime_index=obstime_index, indices=indices, directory=directory, filename=filename, ) return True def _reproject_images(science_images, variance_images, mask_images, obstime, common_wcs, original_wcs): """This is the worker function that will be parallelized across multiple processes. Given a set of science, variance, and mask images, use astropy's reproject function to reproject them into a common WCS. Parameters ---------- science_images : `list[numpy.ndarray]` List of ndarrays that represent the science images to be reprojected. variance_images : `list[numpy.ndarray]` List of ndarrays that represent the variance images to be reprojected. mask_images : `list[numpy.ndarray]` List of ndarrays that represent the mask images to be reprojected. obstime : `float` observation time for each image. common_wcs : `astropy.wcs.WCS` The WCS to reproject all the images into. original_wcs : `list[astropy.wcs.WCS]` The list of WCS objects for these images. Returns ------- science_add : `numpy.ndarray` The reprojected science image. variance_add : `numpy.ndarray` The reprojected variance image. mask_add : `numpy.ndarray` The reprojected mask image. time : `float` The observation time of the original images. Raises ------ ValueError If any images overlap, raise an error. """ science_add = np.zeros(common_wcs.array_shape, dtype=np.float32) variance_add = np.zeros(common_wcs.array_shape, dtype=np.float32) mask_add = np.zeros(common_wcs.array_shape, dtype=np.float32) footprint_add = np.zeros(common_wcs.array_shape, dtype=np.ubyte) for science, variance, mask, this_original_wcs in zip( science_images, variance_images, mask_images, original_wcs ): # reproject science, variance, and mask images simulataneously. reprojected_images, footprints = reproject_image( [science, variance, mask], this_original_wcs, common_wcs ) footprint_add += footprints # we'll enforce that there be no overlapping images at the same time, # for now. We might be able to add some ability co-add in the future. if np.any(footprint_add > 1): raise ValueError("Images with the same obstime are overlapping.") # change all the NaNs to zeroes so that the matrix addition works properly. # `footprint_add` will maintain the information about what areas of the frame # don't have any data so that we can change it back after we combine. reprojected_images[np.isnan(reprojected_images)] = 0.0 science_add += reprojected_images[0] variance_add += reprojected_images[1] mask_add += reprojected_images[2] # change all the values where there are is no corresponding data to `KB_NO_DATA`. gaps = footprint_add == 0 science_add[gaps] = KB_NO_DATA variance_add[gaps] = KB_NO_DATA mask_add[gaps] = 1 # transforms the mask back into a bitmask. mask_add = np.where(np.isclose(mask_add, 0.0, atol=0.2), np.float32(0.0), np.float32(1.0)) return science_add, variance_add, mask_add, obstime def _write_images_to_shard( science_add, variance_add, mask_add, psf, wcs, obstime, obstime_index, indices, directory, filename ): """Takes in a set of post-reprojection image adds and writes them to a fits file.. Parameters ---------- science_add : `numpy.ndarray` ndarry containing the reprojected science image add. variance_add : `numpy.ndarray` ndarry containing the reprojected variance image add. mask_add : `numpy.ndarray` ndarry containing the reprojected mask image add. psf : `numpy.ndarray` the kernel of the PSF. wcs : `astropy.wcs.WCS` the common_wcs used in reprojection. obstime : `float` observation time for each image. obstime_index : `int` the obstime index in the original `ImageStack`. indices : `list[int]` the per image indices. directory : `str` the directory to output the `WorkUnit` shard to. filename : `str` the base filename to use for the shard. """ n_indices = len(indices) sub_hdul = fits.HDUList() sci_hdu = image_add_to_hdu(science_add, f"SCI_{obstime_index}", obstime, wcs) sci_hdu.header["NIND"] = n_indices for j in range(n_indices): sci_hdu.header[f"IND_{j}"] = indices[j] sub_hdul.append(sci_hdu) var_hdu = image_add_to_hdu(variance_add, f"VAR_{obstime_index}", obstime) sub_hdul.append(var_hdu) msk_hdu = image_add_to_hdu(mask_add, f"MSK_{obstime_index}", obstime) sub_hdul.append(msk_hdu) psf_hdu = fits.hdu.image.ImageHDU(psf) psf_hdu.name = f"PSF_{obstime_index}" sub_hdul.append(psf_hdu) sub_hdul.writeto(os.path.join(directory, f"{obstime_index}_{filename}"))
[docs]def image_add_to_hdu(add, name, obstime, wcs=None): """Helper function that creates a HDU out of post reproject added image. Parameters ---------- add : `np.ndarray` The image to convert. name : `str` The name of the image (type + index). obstime : `float` The observation time. wcs : `astropy.wcs.WCS` An optional WCS to include in the header. Returns ------- hdu : `astropy.io.fits.hdu.image.ImageHDU` The image extension. """ hdu = fits.hdu.image.ImageHDU(add) # If the WCS is given, copy each entry into the header. if wcs is not None: append_wcs_to_hdu_header(wcs, hdu.header) # Set the time stamp. hdu.header["MJD"] = obstime hdu.name = name return hdu