import os
import csv
import time
import heapq
import multiprocessing as mp
from collections import OrderedDict
import numpy as np
import astropy.units as u
from astropy.io import fits
from astropy.wcs import WCS
import astropy.coordinates as astroCoords
from scipy.special import erfinv # import mpmath
from sklearn.cluster import DBSCAN, OPTICS
from .file_utils import *
from .image_info import *
from .filters import *
from .result_list import *
import kbmod.search as kb
[docs]class Interface(SharedTools):
"""This class manages the KBMOD interface with the local filesystem, the cpp
KBMOD code, and the PostProcess python filtering functions. It is
responsible for loading in data from .fits files, initializing the kbmod
object, loading results from the kbmod object into python, and saving
results to file.
"""
def __init__(self):
return
[docs] def load_images(
self,
im_filepath,
time_file,
psf_file,
mjd_lims,
default_psf,
verbose=False,
):
"""This function loads images and ingests them into a search object.
Parameters
----------
im_filepath : string
Image file path from which to load images.
time_file : string
File name containing image times.
psf_file : string
File name containing the image-specific PSFs.
If set to None the code will use the provided default psf for
all images.
mjd_lims : list of ints
Optional MJD limits on the images to search.
default_psf : `psf`
The default PSF in case no image-specific PSF is provided.
verbose : bool
Use verbose output (mainly for debugging).
Returns
-------
stack : `kbmod.image_stack`
The stack of images loaded.
img_info : `ImageInfo`
The information for the images loaded.
"""
print("---------------------------------------")
print("Loading Images")
print("---------------------------------------")
# Load a mapping from visit numbers to the visit times. This dictionary stays
# empty if no time file is specified.
image_time_dict = FileUtils.load_time_dictionary(time_file)
if verbose:
print(f"Loaded {len(image_time_dict)} time stamps.")
# Load a mapping from visit numbers to PSFs. This dictionary stays
# empty if no time file is specified.
image_psf_dict = FileUtils.load_psf_dictionary(psf_file)
if verbose:
print(f"Loaded {len(image_psf_dict)} image PSFs stamps.")
# Retrieve the list of visits (file names) in the data directory.
patch_visits = sorted(os.listdir(im_filepath))
# Load the images themselves.
img_info = ImageInfoSet()
images = []
visit_times = []
for visit_file in np.sort(patch_visits):
# Skip non-fits files.
if not ".fits" in visit_file:
if verbose:
print(f"Skipping non-FITS file {visit_file}")
continue
# Compute the full file path for loading.
full_file_path = os.path.join(im_filepath, visit_file)
# Load the image info from the FITS header.
header_info = ImageInfo()
header_info.populate_from_fits_file(full_file_path)
# Skip files without a valid visit ID.
if header_info.visit_id is None:
if verbose:
print(f"WARNING: Unable to extract visit ID for {visit_file}.")
continue
# Compute the time stamp as a MJD float. If there is an entry in the
# timestamp file, defer to that. Otherwise use the value from the header.
time_stamp = -1.0
if header_info.visit_id in image_time_dict:
time_stamp = image_time_dict[header_info.visit_id]
else:
time_obj = header_info.get_epoch(none_if_unset=True)
if time_obj is not None:
time_stamp = time_obj.mjd
if time_stamp <= 0.0:
if verbose:
print(f"WARNING: No valid timestamp provided for {visit_file}.")
continue
# Check if we should filter the record based on the time bounds.
if mjd_lims is not None and (time_stamp < mjd_lims[0] or time_stamp > mjd_lims[1]):
if verbose:
print(f"Pruning file {visit_file} by timestamp={time_stamp}.")
continue
# Check if the image has a specific PSF.
psf = default_psf
if header_info.visit_id in image_psf_dict:
psf = kb.psf(image_psf_dict[header_info.visit_id])
# Load the image file and set its time.
if verbose:
print(f"Loading file: {full_file_path}")
img = kb.layered_image(full_file_path, psf)
img.set_time(time_stamp)
# Save the file, time, and image information.
img_info.append(header_info)
visit_times.append(time_stamp)
images.append(img)
print(f"Loaded {len(images)} images")
stack = kb.image_stack(images)
# Create a list of visit times and visit times shifted to 0.0.
img_info.set_times_mjd(np.array(visit_times))
times = img_info.get_zero_shifted_times()
stack.set_times(times)
print("Times set", flush=True)
return (stack, img_info)
[docs] def save_results(self, res_filepath, out_suffix, keep, all_times):
"""This function saves results from a given search method.
Parameters
----------
res_filepath : string
The filepath for the results.
out_suffix : string
Suffix to append to the output file name
keep : `ResultList`
ResultList object containing the values to keep and print to file.
all_times : list
A list of times.
"""
print("---------------------------------------")
print("Saving Results")
print("---------------------------------------", flush=True)
keep.save_to_files(res_filepath, out_suffix)
[docs]class PostProcess(SharedTools):
"""This class manages the post-processing utilities used to filter out and
otherwise remove false positives from the KBMOD search. This includes,
for example, kalman filtering to remove outliers, stamp filtering to remove
results with non-Gaussian postage stamps, and clustering to remove similar
results.
"""
def __init__(self, config, mjds):
self.coeff = None
self.num_cores = config["num_cores"]
self.sigmaG_lims = config["sigmaG_lims"]
self.eps = config["eps"]
self.cluster_type = config["cluster_type"]
self.cluster_function = config["cluster_function"]
self.clip_negative = config["clip_negative"]
self.mask_bits_dict = config["mask_bits_dict"]
self.flag_keys = config["flag_keys"]
self.repeated_flag_keys = config["repeated_flag_keys"]
self._mjds = mjds
[docs] def apply_mask(self, stack, mask_num_images=2, mask_threshold=None, mask_grow=10):
"""This function applys a mask to the images in a KBMOD stack. This mask
sets a high variance for masked pixels
Parameters
----------
stack : `kbmod.image_stack`
The stack before the masks have been applied.
mask_num_images : int
The minimum number of images in which a masked pixel must appear in
order for it to be masked out. E.g. if masked_num_images = 2, then an
object must appear in the same place in at least two images in order
for the variance at that location to be increased.
mask_threshold : float
Any pixel with a flux greater than mask_threshold is masked out.
mask_grow : int
The number of pixels by which to grow the mask.
Returns
-------
stack : `kbmod.image_stack`
The stack after the masks have been applied.
"""
mask_bits_dict = self.mask_bits_dict
flag_keys = self.flag_keys
global_flag_keys = self.repeated_flag_keys
flags = 0
for bit in flag_keys:
flags += 2 ** mask_bits_dict[bit]
flag_exceptions = [0]
# mask any pixels which have any of these flags
global_flags = 0
for bit in global_flag_keys:
global_flags += 2 ** mask_bits_dict[bit]
# Apply masks if needed.
if len(flag_keys) > 0:
stack.apply_mask_flags(flags, flag_exceptions)
if mask_threshold:
stack.apply_mask_threshold(mask_threshold)
if len(global_flag_keys) > 0:
stack.apply_global_mask(global_flags, mask_num_images)
# Grow the masks by 'mask_grow' pixels.
stack.grow_mask(mask_grow, True)
return stack
[docs] def load_and_filter_results(
self,
search,
lh_level,
chunk_size=500000,
max_lh=1e9,
):
"""This function loads results that are output by the gpu grid search.
Results are loaded in chunks and evaluated to see if the minimum
likelihood level has been reached. If not, another chunk of results is
fetched. The results are filtered using a clipped-sigmaG filter as they
are loaded and only the passing results are kept.
Parameters
----------
search : `kbmod.search`
The search function object.
lh_level : float
The minimum likelihood theshold for an acceptable result. Results below
this likelihood level will be discarded.
chunk_size : int
The number of results to load at a given time from search.
max_lh : float
The maximum likelihood threshold for an acceptable results.
Results ABOVE this likelihood level will be discarded.
Returns
-------
keep : `ResultList`
A ResultList object containing values from trajectories.
"""
keep = ResultList(self._mjds)
likelihood_limit = False
res_num = 0
total_count = 0
print("---------------------------------------")
print("Retrieving Results")
print("---------------------------------------")
while likelihood_limit is False:
print("Getting results...")
results = search.get_results(res_num, chunk_size)
print("---------------------------------------")
print("Chunk Start = %i" % res_num)
print("Chunk Max Likelihood = %.2f" % results[0].lh)
print("Chunk Min. Likelihood = %.2f" % results[-1].lh)
print("---------------------------------------")
result_batch = ResultList(self._mjds)
for i, trj in enumerate(results):
# Stop as soon as we hit a result below our limit, because anything after
# that is not guarrenteed to be valid due to potential on-GPU filtering.
if trj.lh < lh_level:
likelihood_limit = True
break
if trj.lh < max_lh:
row = ResultRow(trj, len(self._mjds))
psi_curve = np.array(search.psi_curves(trj))
phi_curve = np.array(search.phi_curves(trj))
row.set_psi_phi(psi_curve, phi_curve)
result_batch.append_result(row)
total_count += 1
batch_size = result_batch.num_results()
print("Extracted batch of %i results for total of %i" % (batch_size, total_count))
if batch_size > 0:
self.apply_clipped_sigmaG(result_batch)
if lh_level > 0.0:
result_batch.apply_filter(LHFilter(lh_level, None))
result_batch.apply_filter(NumObsFilter(3))
# Add the results to the final set.
keep.extend(result_batch)
res_num += chunk_size
return keep
[docs] def get_all_stamps(self, result_list, search, stamp_radius):
"""Get the stamps for the final results from a kbmod search.
Parameters
----------
result_list : `ResultList`
The values from trajectories. The stamps are inserted into this data structure.
search : `kbmod.stack_search`
The search object
stamp_radius : int
The radius of the stamps to create.
"""
stamp_edge = stamp_radius * 2 + 1
for row in result_list.results:
stamps = search.science_viz_stamps(row.trajectory, stamp_radius)
row.all_stamps = np.array([np.array(stamp).reshape(stamp_edge, stamp_edge) for stamp in stamps])
[docs] def apply_clipped_sigmaG(self, result_list):
"""This function applies a clipped median filter to the results of a KBMOD
search using sigmaG as a robust estimater of standard deviation.
Parameters
----------
result_list : `ResultList`
The values from trajectories. This data gets modified directly
by the filtering.
"""
print("Applying Clipped-sigmaG Filtering")
start_time = time.time()
# Compute the coefficients for the filtering.
if self.coeff is None:
if self.sigmaG_lims is not None:
self.percentiles = self.sigmaG_lims
else:
self.percentiles = [25, 75]
self.coeff = self._find_sigmaG_coeff(self.percentiles)
if self.num_cores > 1:
zipped_curves = result_list.zip_phi_psi_idx()
keep_idx_results = []
print("Starting pooling...")
pool = mp.Pool(processes=self.num_cores)
keep_idx_results = pool.starmap_async(self._clipped_sigmaG, zipped_curves)
pool.close()
pool.join()
keep_idx_results = keep_idx_results.get()
for i, res in enumerate(keep_idx_results):
result_list.results[i].filter_indices(res[1])
else:
for i, row in enumerate(result_list.results):
single_res = self._clipped_sigmaG(row.psi_curve, row.phi_curve, i)
row.filter_indices(single_res[1])
end_time = time.time()
time_elapsed = end_time - start_time
print("{:.2f}s elapsed".format(time_elapsed))
print("Completed filtering.", flush=True)
print("---------------------------------------")
def _find_sigmaG_coeff(self, percentiles):
z1 = percentiles[0] / 100
z2 = percentiles[1] / 100
x1 = self._invert_Gaussian_CDF(z1)
x2 = self._invert_Gaussian_CDF(z2)
coeff = 1 / (x2 - x1)
print("sigmaG limits: [{},{}]".format(percentiles[0], percentiles[1]))
print("sigmaG coeff: {:.4f}".format(coeff), flush=True)
return coeff
def _invert_Gaussian_CDF(self, z):
if z < 0.5:
sign = -1
else:
sign = 1
x = sign * np.sqrt(2) * erfinv(sign * (2 * z - 1)) # mpmath.erfinv(sign * (2 * z - 1))
return float(x)
def _clipped_sigmaG(self, psi_curve, phi_curve, index, n_sigma=2):
"""This function applies a clipped median filter to a set of likelihood
values. Points are eliminated if they are more than n_sigma*sigmaG away
from the median.
Parameters
----------
psi_curve : numpy array
A single Psi curve, likely from a `ResultRow`.
phi_curve : numpy array
A single Phi curve, likely from a `ResultRow`.
index : int
The index of the ResultRow being processed. Used track
multiprocessing.
n_sigma : int
The number of standard deviations away from the median that
the largest likelihood values (N=num_clipped) must be in order
to be eliminated.
Returns
-------
index : int
The index of the ResultRow being processed. Used track multiprocessing.
good_index: numpy array
The indices that pass the filtering for a given set of curves.
new_lh : float
The new maximum likelihood of the set of curves, after max_lh_index has
been applied.
"""
masked_phi = np.copy(phi_curve)
masked_phi[masked_phi == 0] = 1e9
lh = psi_curve / np.sqrt(masked_phi)
good_index = self._exclude_outliers(lh, n_sigma)
if len(good_index) == 0:
new_lh = 0
good_index = []
else:
new_lh = kb.calculate_likelihood_psi_phi(psi_curve[good_index], phi_curve[good_index])
return (index, good_index, new_lh)
def _exclude_outliers(self, lh, n_sigma):
if self.clip_negative:
lower_per, median, upper_per = np.percentile(
lh[lh > 0], [self.percentiles[0], 50, self.percentiles[1]]
)
sigmaG = self.coeff * (upper_per - lower_per)
nSigmaG = n_sigma * sigmaG
good_index = np.where(
np.logical_and(lh != 0, np.logical_and(lh > median - nSigmaG, lh < median + nSigmaG))
)[0]
else:
lower_per, median, upper_per = np.percentile(lh, [self.percentiles[0], 50, self.percentiles[1]])
sigmaG = self.coeff * (upper_per - lower_per)
nSigmaG = n_sigma * sigmaG
good_index = np.where(np.logical_and(lh > median - nSigmaG, lh < median + nSigmaG))[0]
return good_index
[docs] def apply_stamp_filter(
self,
result_list,
search,
center_thresh=0.03,
peak_offset=[2.0, 2.0],
mom_lims=[35.5, 35.5, 1.0, 0.25, 0.25],
chunk_size=1000000,
stamp_type="sum",
stamp_radius=10,
):
"""This function filters result postage stamps based on their Gaussian
Moments. Results with stamps that are similar to a Gaussian are kept.
Parameters
----------
result_list : `ResultList`
The values from trajectories. This data gets modified directly by
the filtering.
search : `kbmod.stack_search`
The search object.
center_thresh : float
The fraction of the total flux that must be contained in a single
central pixel.
peak_offset : list of floats
How far the brightest pixel in the stamp can be from the central
pixel.
mom_lims : list of floats
The maximum limit of the xx, yy, xy, x, and y central moments of
the stamp.
chunk_size : int
How many stamps to load and filter at a time.
stamp_type : string
Which method to use to generate stamps.
One of 'median', 'cpp_median', 'mean', 'cpp_mean', or 'sum'.
stamp_radius : int
The radius of the stamp.
"""
# Set the stamp creation and filtering parameters.
params = kb.stamp_parameters()
params.radius = stamp_radius
params.do_filtering = True
params.center_thresh = center_thresh
params.peak_offset_x = peak_offset[0]
params.peak_offset_y = peak_offset[1]
params.m20 = mom_lims[0]
params.m02 = mom_lims[1]
params.m11 = mom_lims[2]
params.m10 = mom_lims[3]
params.m01 = mom_lims[4]
if stamp_type == "cpp_median" or stamp_type == "median":
params.stamp_type = kb.StampType.STAMP_MEDIAN
elif stamp_type == "cpp_mean" or stamp_type == "mean":
params.stamp_type = kb.StampType.STAMP_MEAN
else:
params.stamp_type = kb.StampType.STAMP_SUM
# Save some useful helper data.
num_times = search.get_num_images()
all_valid_inds = []
# Run the stamp creation and filtering in batches of chunk_size.
print("---------------------------------------")
print("Applying Stamp Filtering")
print("---------------------------------------", flush=True)
start_time = time.time()
start_idx = 0
if result_list.num_results() <= 0:
print("Skipping. Nothing to filter.")
return
print("Stamp filtering %i results" % result_list.num_results())
while start_idx < result_list.num_results():
end_idx = min([start_idx + chunk_size, result_list.num_results()])
# Create a subslice of the results and the Boolean indices.
# Note that the sum stamp type does not filter out lc_index.
inds_to_use = [i for i in range(start_idx, end_idx)]
trj_slice = [result_list.results[i].trajectory for i in inds_to_use]
if params.stamp_type != kb.StampType.STAMP_SUM:
bool_slice = [result_list.results[i].valid_indices_as_booleans() for i in inds_to_use]
else:
# For the sum stamp, use all the indices for each trajectory.
all_true = [True] * num_times
bool_slice = [all_true for _ in inds_to_use]
# Create and filter the results.
stamps_slice = search.gpu_coadded_stamps(trj_slice, bool_slice, params)
for ind, stamp in enumerate(stamps_slice):
if stamp.get_width() > 1:
result_list.results[ind + start_idx].stamp = np.array(stamp)
all_valid_inds.append(ind + start_idx)
# Move to the next chunk.
start_idx += chunk_size
# Do the actual filtering of results
result_list.filter_results(all_valid_inds)
print("Keeping %i results" % result_list.num_results(), flush=True)
end_time = time.time()
time_elapsed = end_time - start_time
print("{:.2f}s elapsed".format(time_elapsed))
[docs] def apply_clustering(self, result_list, cluster_params):
"""This function clusters results that have similar trajectories.
Parameters
----------
result_list : `ResultList`
The values from trajectories. This data gets modified directly by
the filtering.
cluster_params : dict
Contains values concerning the image and search settings including:
x_size, y_size, vel_lims, ang_lims, and mjd.
"""
# Skip clustering if there is nothing to cluster.
if result_list.num_results() == 0:
return
print("Clustering %i results" % result_list.num_results(), flush=True)
# Do the clustering and the filtering.
cluster_idx = self._cluster_results(
np.array([row.trajectory for row in result_list.results]),
cluster_params["x_size"],
cluster_params["y_size"],
cluster_params["vel_lims"],
cluster_params["ang_lims"],
cluster_params["mjd"],
)
result_list.filter_results(cluster_idx)
def _cluster_results(self, results, x_size, y_size, v_lim, ang_lim, mjd_times, cluster_args=None):
"""This function clusters results and selects the highest-likelihood
trajectory from a given cluster.
Parameters
----------
results : list
A list of kbmod trajectory results.
x_size : int
The width of the images (in pixels) used in the kbmod stack, such
as are stored in image_params['x_size'].
y_size : int
The height of the images (in pixels) used in the kbmod stack such
as are stored in image_params['y_size'].
v_lim : list
The velocity limits of the search, such as are stored in
image_params['v_lim']. The first two elements are used and represent
the minimum (v_lim[0]) and maximum (v_lim[1]) velocities used in the
search.
ang_lim : list
The angle limits of the search, such as are stored in image_params['ang_lim'].
The first two elements are used and represent the minimum (ang_lim[0]) and
maximum (ang_lim[1]) angles used in the search.
cluster_args : dict
Arguments to pass to dbscan or OPTICS.
Returns
-------
top_vals : numpy array
An array of the indices for the best trajectories of each individual cluster.
"""
if self.cluster_function == "DBSCAN":
default_cluster_args = dict(eps=self.eps, min_samples=1, n_jobs=-1)
elif self.cluster_function == "OPTICS":
default_cluster_args = dict(max_eps=self.eps, min_samples=2, n_jobs=-1)
if cluster_args is not None:
default_cluster_args.update(cluster_args)
cluster_args = default_cluster_args
x_arr = []
y_arr = []
vx_arr = []
vy_arr = []
vel_arr = []
ang_arr = []
times = mjd_times - mjd_times[0]
for line in results:
x_arr.append(line.x)
y_arr.append(line.y)
vx_arr.append(line.x_v)
vy_arr.append(line.y_v)
vel_arr.append(np.sqrt(line.x_v**2.0 + line.y_v**2.0))
ang_arr.append(np.arctan2(line.y_v, line.x_v))
x_arr = np.array(x_arr)
y_arr = np.array(y_arr)
vx_arr = np.array(vx_arr)
vy_arr = np.array(vy_arr)
vel_arr = np.array(vel_arr)
ang_arr = np.array(ang_arr)
scaled_x = x_arr / x_size
scaled_y = y_arr / y_size
v_scale = (v_lim[1] - v_lim[0]) if v_lim[1] != v_lim[0] else 1.0
scaled_vel = (vel_arr - v_lim[0]) / v_scale
a_scale = (ang_lim[1] - ang_lim[0]) if ang_lim[1] != ang_lim[0] else 1.0
scaled_ang = (ang_arr - ang_lim[0]) / a_scale
if self.cluster_function == "DBSCAN":
cluster = DBSCAN(**cluster_args)
elif self.cluster_function == "OPTICS":
cluster = OPTICS(**cluster_args)
if self.cluster_type == "all":
cluster.fit(np.array([scaled_x, scaled_y, scaled_vel, scaled_ang], dtype=float).T)
elif self.cluster_type == "position":
cluster.fit(np.array([scaled_x, scaled_y], dtype=float).T)
elif self.cluster_type == "mid_position":
median_time = np.median(times)
mid_x_arr = x_arr + median_time * vx_arr
mid_y_arr = y_arr + median_time * vy_arr
scaled_mid_x = mid_x_arr / x_size
scaled_mid_y = mid_y_arr / y_size
cluster.fit(np.array([scaled_mid_x, scaled_mid_y], dtype=float).T)
top_vals = []
for cluster_num in np.unique(cluster.labels_):
cluster_vals = np.where(cluster.labels_ == cluster_num)[0]
top_vals.append(cluster_vals[0])
return top_vals