"""Functions to help with the SigmaG clipping.
For more details see:
Sifting Through the Static: Moving Object Detection in Difference Images
by Smotherman et. al. 2021
"""
import logging
import numpy as np
import torch
from scipy.special import erfinv
from kbmod.search import DebugTimer
logger = logging.getLogger(__name__)
[docs]class SigmaGClipping:
"""This class contains the basic information for performing SigmaG clipping.
Attributes
----------
low_bnd : `float`
The lower bound of the interval to use to estimate the standard deviation.
high_bnd : `float`
The upper bound of the interval to use to estimate the standard deviation.
n_sigma : `float`
The number of standard deviations to use for the bound.
clip_negative : `bool`
A Boolean indicating whether to use negative values when computing
standard deviation.
coeff : `float`
The precomputed coefficient based on the given bounds.
"""
def __init__(self, low_bnd=25, high_bnd=75, n_sigma=2, clip_negative=False):
if low_bnd > high_bnd or low_bnd <= 0.0 or high_bnd >= 100.0:
raise ValueError(f"Invalid bounds [{low_bnd}, {high_bnd}]")
if n_sigma <= 0.0:
raise ValueError(f"Invalid n_sigma {n_sigma}")
self.low_bnd = low_bnd
self.high_bnd = high_bnd
self.n_sigma = n_sigma
self.coeff = SigmaGClipping.find_sigma_g_coeff(low_bnd, high_bnd)
self.clip_negative = clip_negative
[docs] @staticmethod
def find_sigma_g_coeff(low_bnd, high_bnd):
"""Compute the sigma G coefficient from the upper and lower bounds
of the percentiles.
Parameters
----------
low_bnd : `float`
The lower bound of the percentile on a scale of [0, 100].
high_bnd : `float`
The lower bound of the percentile on a scale of [0, 100].
Returns
-------
result : `float`
The corresponding sigma G coefficient.
Raises
------
Raises a ``ValueError`` is the bounds are invalid.
"""
if (high_bnd <= low_bnd) or (low_bnd < 0) or (high_bnd > 100):
raise ValueError(f"Invalid percentiles for sigma G coefficient [{low_bnd}, {high_bnd}]")
x1 = SigmaGClipping.invert_gauss_cdf(low_bnd / 100.0)
x2 = SigmaGClipping.invert_gauss_cdf(high_bnd / 100.0)
return 1 / (x2 - x1)
@staticmethod
def invert_gauss_cdf(z):
if z < 0.5:
sign = -1
else:
sign = 1
x = sign * np.sqrt(2) * erfinv(sign * (2 * z - 1))
return float(x)
[docs] def compute_clipped_sigma_g(self, lh):
"""Compute the SigmaG clipping on the given likelihood curve.
Points are eliminated if they are more than n_sigma*sigmaG away from the median.
Parameters
----------
lh : numpy array
A single likelihood curve.
Returns
-------
good_index: numpy array
The indices that pass the filtering for a given set of curves.
"""
if self.clip_negative:
# Skip entries where we will clip everything (all lh < 0).
if np.count_nonzero(lh > 0) == 0:
return np.array([])
lower_per, median, upper_per = np.percentile(lh[lh > 0], [self.low_bnd, 50, self.high_bnd])
else:
lower_per, median, upper_per = np.percentile(lh, [self.low_bnd, 50, self.high_bnd])
delta = max(upper_per - lower_per, 1e-8)
sigmaG = self.coeff * delta
nSigmaG = self.n_sigma * sigmaG
good_index = np.where(np.logical_and(lh > median - nSigmaG, lh < median + nSigmaG))[0]
return good_index
[docs] def compute_clipped_sigma_g_matrix(self, lh):
"""Compute the SigmaG clipping on a matrix containing curves where
each row is a single curve at different time points.
Points are eliminated if they are more than n_sigma*sigmaG away from the median.
Parameters
----------
lh : `numpy.ndarray`
A N x T matrix with N curves, each with T time steps.
Returns
-------
index_valid : `numpy.ndarray`
A N x T matrix of Booleans indicating if each point is valid (True)
or has been filtered (False).
"""
# Use a GPU if one is available.
if torch.cuda.is_available():
device = torch.device("cuda")
else:
device = torch.device("cpu")
# Move the likelihood data to the device.
torch_lh = torch.tensor(lh, device=device, dtype=torch.float32)
# Mask out the negative values (if clip negative is true).
if self.clip_negative:
masked_lh = torch.where(torch_lh > 0.0, torch_lh, np.nan)
else:
masked_lh = torch_lh
# Compute the quantiles for each row.
q_values = torch.tensor(
[self.low_bnd / 100.0, 0.5, self.high_bnd / 100.0],
device=device,
dtype=torch.float32,
)
lower_per, median, upper_per = torch.nanquantile(masked_lh, q_values, dim=1)
# Compute the bounds for each row, enforcing a minimum gap in case all the
# points are identical (upper_per == lower_per).
delta = upper_per - lower_per
delta[delta < 1e-5] = 1e-5
nSigmaG = self.n_sigma * self.coeff * delta
num_rows = lh.shape[0]
num_cols = lh.shape[1]
lower_bnd = torch.reshape((median - nSigmaG), (num_rows, 1)).expand(-1, num_cols)
upper_bnd = torch.reshape((median + nSigmaG), (num_rows, 1)).expand(-1, num_cols)
# Check whether the values fall within the bounds.
index_valid = torch.isfinite(torch_lh) & (torch_lh < upper_bnd) & (torch_lh > lower_bnd)
# Return as a numpy array on the CPU.
return index_valid.cpu().numpy().astype(bool)
[docs]def apply_clipped_sigma_g(clipper, result_data):
"""This function applies a clipped median filter to the results of a KBMOD
search using sigmaG as a robust estimater of standard deviation.
Parameters
----------
clipper : `SigmaGClipping`
The object to apply the SigmaG clipping.
result_data : `Results`
The values from trajectories. This data gets modified directly by the filtering.
"""
if len(result_data) == 0:
logger.info("SigmaG Clipping : skipping, nothing to filter.")
return
filter_timer = DebugTimer("sigma-g filtering", logger)
lh = result_data.compute_likelihood_curves(filter_obs=True, mask_value=np.nan)
obs_valid = clipper.compute_clipped_sigma_g_matrix(lh)
result_data.update_obs_valid(obs_valid)
filter_timer.stop()