import numpy as np
from sklearn.cluster import DBSCAN
from kbmod.filters.clustering_grid import TrajectoryClusterGrid
from kbmod.results import Results
import kbmod.search as kb
logger = kb.Logging.getLogger(__name__)
[docs]class DBSCANFilter:
"""Cluster the candidates using DBSCAN and only keep a
single representative trajectory from each cluster.
Attributes
----------
cluster_eps : `float`
The clustering threshold (in pixels).
cluster_type : `str`
The type of clustering.
cluster_args : `dict`
Additional arguments to pass to the clustering algorithm.
"""
def __init__(self, cluster_eps, **kwargs):
"""Create a DBSCANFilter.
Parameters
----------
cluster_eps : `float`
The clustering threshold.
**kwargs : `dict`
Additional arguments to pass to the clustering algorithm.
"""
self.cluster_eps = cluster_eps
self.cluster_type = ""
self.cluster_args = dict(eps=self.cluster_eps, min_samples=1, n_jobs=-1)
[docs] def get_filter_name(self):
"""Get the name of the filter.
Returns
-------
str
The filter name.
"""
return f"DBSCAN_{self.cluster_type} eps={self.cluster_eps}"
def _build_clustering_data(self, result_data):
"""Build the specific data set for this clustering approach.
Parameters
----------
result_data: `Results`
The set of results to filter.
Returns
-------
data : `numpy.ndarray`
The N x D matrix to cluster where N is the number of results
and D is the number of attributes.
"""
raise NotImplementedError()
[docs] def keep_indices(self, result_data):
"""Determine which of the results's indices to keep.
Parameters
----------
result_data: `Results`
The set of results to filter.
Returns
-------
`list`
A list of indices (int) indicating which rows to keep.
"""
# Build a numpy array of the trajectories to cluster with one row for each trajectory.
data = self._build_clustering_data(result_data)
# Set up the clustering algorithm
cluster = DBSCAN(**self.cluster_args)
cluster.fit(data)
# Get the best index per cluster. If the data is sorted by LH, this should always
# be the first point in the cluster. But we do an argmax in case the user has
# manually sorted the data by something else.
top_vals = []
for cluster_num in np.unique(cluster.labels_):
cluster_vals = np.where(cluster.labels_ == cluster_num)[0]
top_ind = np.argmax(result_data["likelihood"][cluster_vals])
top_vals.append(cluster_vals[top_ind])
return top_vals
[docs]class ClusterPredictionFilter(DBSCANFilter):
"""Cluster the candidates using their positions at specific times.
Attributes
----------
times : list-like
The times at which to evaluate the trajectories (in days).
"""
def __init__(self, cluster_eps, pred_times=[0.0], **kwargs):
"""Create a DBSCANFilter.
Parameters
----------
cluster_eps : `float`
The clustering threshold.
pred_times : `list`
The times a which to prediction the positions (in days).
Default = [0.0] (starting position only)
"""
super().__init__(cluster_eps, **kwargs)
# Confirm we have at least one prediction time.
if len(pred_times) == 0:
raise ValueError("No prediction times given.")
self.times = np.array(pred_times, dtype=np.float32)
# Set up the clustering algorithm's name.
self.cluster_type = f"position t={self.times}"
def _build_clustering_data(self, result_data):
"""Build the specific data set for this clustering approach.
Parameters
----------
result_data: `Results`
The set of results to filter.
Returns
-------
data : `numpy.ndarray`
The N x D matrix to cluster where N is the number of results
and D is the number of attributes.
"""
x0_arr = result_data["x"][:, np.newaxis].astype(np.float32)
xv_arr = result_data["vx"][:, np.newaxis].astype(np.float32)
pred_x = x0_arr + xv_arr * self.times[np.newaxis, :]
y0_arr = result_data["y"][:, np.newaxis].astype(np.float32)
yv_arr = result_data["vy"][:, np.newaxis].astype(np.float32)
pred_y = y0_arr + yv_arr * self.times[np.newaxis, :]
return np.hstack([pred_x, pred_y])
[docs]class ClusterPosVelFilter(DBSCANFilter):
"""Cluster the candidates using their starting position and velocities."""
def __init__(self, cluster_eps, cluster_v_scale=1.0, **kwargs):
"""Create a DBSCANFilter.
Parameters
----------
cluster_eps : `float`
The clustering threshold (in pixels).
cluster_v_scale : `float`
The relative scaling of velocity differences compared to position
differences. Default: 1.0 (no difference).
"""
super().__init__(cluster_eps, **kwargs)
if cluster_v_scale < 0.0:
raise ValueError("cluster_v_scale cannot be negative.")
self.cluster_v_scale = cluster_v_scale
self.cluster_type = "all"
def _build_clustering_data(self, result_data):
"""Build the specific data set for this clustering approach.
Parameters
----------
result_data: `Results`
The set of results to filter.
Returns
-------
data : `numpy.ndarray`
The N x D matrix to cluster where N is the number of results
and D is the number of attributes.
"""
data = np.empty((len(result_data), 4), dtype=np.float32)
data[:, 0] = result_data["x"].astype(np.float32)
data[:, 1] = result_data["y"].astype(np.float32)
data[:, 2] = result_data["vx"] * self.cluster_v_scale
data[:, 3] = result_data["vy"] * self.cluster_v_scale
return data
[docs]class NNSweepFilter:
"""Filter any points that have neighboring trajectory with
a higher likleihood within the threshold.
Attributes
----------
thresh : `float`
The filtering threshold to use (in pixels).
times : list-like
The times at which to evaluate the trajectories (in days).
batch_size : `int`
The size of batching to use for kd-tree lookups. A batch size of 1
turns off multi-threading and runs everything in series.
Default: 1000
"""
def __init__(self, cluster_eps, pred_times, batch_size=1_000):
"""Create a NNFilter.
Parameters
----------
cluster_eps : `float`
The filtering threshold to use.
pred_times : list-like
The times at which to evaluate the trajectories.
batch_size : `int`
The size of batching to use for kd-tree lookups. A batch size of 1
turns off multi-threading and runs everything in series.
Default: 1000
"""
if cluster_eps <= 0.0:
raise ValueError(f"Threshold must be > 0.0.")
self.thresh = cluster_eps
self.times = np.asarray(pred_times, dtype=np.float32)
if len(self.times) == 0:
raise ValueError(f"Empty time array provided.")
if batch_size <= 0:
raise ValueError(f"batch_size must be > 0.")
self.batch_size = batch_size
[docs] def get_filter_name(self):
"""Get the name of the filter.
Returns
-------
str
The filter name.
"""
return f"NNFilter times={self.times} eps={self.thresh}"
def _build_clustering_data(self, result_data):
"""Build the specific data set for this clustering approach.
Parameters
----------
result_data: `Results`
The set of results to filter.
Returns
-------
data : `numpy.ndarray`
The N x D matrix to cluster where N is the number of results
and D is the number of attributes.
"""
x0_arr = result_data["x"][:, np.newaxis].astype(np.float32)
xv_arr = result_data["vx"][:, np.newaxis].astype(np.float32)
pred_x = x0_arr + xv_arr * self.times[np.newaxis, :]
y0_arr = result_data["y"][:, np.newaxis].astype(np.float32)
yv_arr = result_data["vy"][:, np.newaxis].astype(np.float32)
pred_y = y0_arr + yv_arr * self.times[np.newaxis, :]
return np.hstack([pred_x, pred_y])
[docs] def keep_indices(self, result_data):
"""Determine which of the results's indices to keep.
Parameters
----------
result_data: `Results`
The set of results to filter.
Returns
-------
`list`
A list of indices (int) indicating which rows to keep.
"""
from scipy.spatial import KDTree
# Predict the Trajectory's locations at the given times and put the
# resulting points in a KDTree.
build_data_timer = kb.DebugTimer("NNSweepFilter building data", logger)
cart_data = self._build_clustering_data(result_data)
kd_tree = KDTree(cart_data)
build_data_timer.stop()
num_pts = len(result_data)
lh_data = result_data["likelihood"]
# For each point, search for all neighbors within the threshold and
# only keep the point if it has the highest likelihood in that range.
# We do this in batches to benefit from multi-threaded KDTree queries
# while avoiding too much memory for the match data.
num_workers = -1 if self.batch_size > 1 else 1
can_skip = np.full(num_pts, False)
keep_vals = []
batch_start = 0
while batch_start < num_pts:
# Get the next batch of indices to search. Each batch only includes those
# results that have not already been eliminated to avoid unnecessary searches.
batch_end = min(num_pts, batch_start + self.batch_size)
batch_inds = np.asanyarray([i for i in range(batch_start, batch_end) if not (can_skip[i])])
# Skip all the work if there is nothing to query in this batch.
if len(batch_inds) == 0:
batch_start = batch_end
continue
# Do the (multi-threaded) KD-tree search for his batch of indices.
batch_matches = kd_tree.query_ball_point(
cart_data[batch_inds, :],
self.thresh,
workers=num_workers,
)
# Check if each index is the best in its neighborhood.
for batch_idx, total_idx in enumerate(batch_inds):
if not can_skip[total_idx]:
matches = np.asanyarray(batch_matches[batch_idx])
if lh_data[total_idx] >= np.max(lh_data[matches]):
keep_vals.append(total_idx)
# Everything found in this run (including the current point)
# doesn't need to be searched in the future, because we have
# found the maximum value in this area.
can_skip[matches] = True
batch_start = batch_end
return keep_vals
[docs]class ClusterGridFilter:
"""Use a discrete grid to cluster the points. Each trajectory
is fit into a bin and only the best trajectory per bin is retained.
Attributes
----------
bin_width : `int`
The width of the grid bins (in pixels).
cluster_grid : `TrajectoryClusterGrid`
The grid of best result trajectories seen.
max_dt : `float`
The maximum different between times in pred_times.
"""
def __init__(self, cluster_eps, pred_times):
"""Create a ClusterGridFilter.
Parameters
----------
cluster_eps : `float`
The bin width to use (in pixels).
pred_times : list-like
The times at which to evaluate the trajectories (in days).
"""
self.bin_width = np.ceil(cluster_eps)
if self.bin_width <= 0:
raise ValueError(f"Bin width must be > 0.0.")
self.times = np.asarray(pred_times)
if len(self.times) == 0:
self.times = np.array([0.0])
self.max_dt = np.max(self.times) - np.min(self.times)
# Create the actual grid to store the results.
self.cluster_grid = TrajectoryClusterGrid(
bin_width=self.bin_width,
max_time=self.max_dt,
)
[docs] def get_filter_name(self):
"""Get the name of the filter.
Returns
-------
str
The filter name.
"""
return f"ClusterGridFilter bin_width{self.bin_width}, max_dt={self.max_dt}"
[docs] def keep_indices(self, result_data):
"""Determine which of the results's indices to keep.
Parameters
----------
result_data: `Results`
The set of results to filter.
Returns
-------
`list`
A list of indices (int) indicating which rows to keep.
"""
trj_list = result_data.make_trajectory_list()
for idx, trj in enumerate(trj_list):
self.cluster_grid.add_trajectory(trj, idx)
keep_vals = np.sort(self.cluster_grid.get_indices())
return list(keep_vals)
[docs]def apply_clustering(result_data, cluster_params):
"""This function clusters results that have similar trajectories.
Parameters
----------
result_data: `Results`
The set of results to filter. This data gets modified directly by
the filtering.
cluster_params : dict
Contains values concerning the image and search settings including:
cluster_type, cluster_eps, times, and cluster_v_scale (optional).
Raises
------
Raises a ValueError if the parameters are not valid.
Raises a TypeError if ``result_data`` is of an unsupported type.
"""
if "cluster_type" not in cluster_params:
raise KeyError("Missing cluster_type parameter")
cluster_type = cluster_params["cluster_type"]
# Skip clustering if there is nothing to cluster.
if len(result_data) == 0:
logger.info("Clustering : skipping, no results.")
return
# Get the times used for prediction clustering.
if not "times" in cluster_params:
raise KeyError("Missing times parameter in the clustering parameters.")
all_times = np.sort(cluster_params["times"])
zeroed_times = np.array(all_times) - all_times[0]
# Do the clustering and the filtering.
if cluster_type == "all" or cluster_type == "pos_vel":
filt = ClusterPosVelFilter(**cluster_params)
elif cluster_type == "position" or cluster_type == "start_position":
cluster_params["pred_times"] = [0.0]
filt = ClusterPredictionFilter(**cluster_params)
elif cluster_type == "mid_position":
cluster_params["pred_times"] = [np.median(zeroed_times)]
filt = ClusterPredictionFilter(**cluster_params)
elif cluster_type == "start_end_position":
cluster_params["pred_times"] = [0.0, zeroed_times[-1]]
filt = ClusterPredictionFilter(**cluster_params)
elif cluster_type == "nn_start_end":
filt = NNSweepFilter(cluster_params["cluster_eps"], [0.0, zeroed_times[-1]])
elif cluster_type == "nn_start":
filt = NNSweepFilter(cluster_params["cluster_eps"], [0.0])
elif cluster_type == "grid_start_end":
filt = ClusterGridFilter(cluster_params["cluster_eps"], [0.0, zeroed_times[-1]])
elif cluster_type == "grid_start":
filt = ClusterGridFilter(cluster_params["cluster_eps"], [0.0])
else:
raise ValueError(f"Unknown clustering type: {cluster_type}")
logger.info(f"Clustering {len(result_data)} results using {filt.get_filter_name()}")
# Do the actual filtering.
indices_to_keep = filt.keep_indices(result_data)
result_data.filter_rows(indices_to_keep, filt.get_filter_name())