Source code for kbmod.filters.clustering_filters

import numpy as np

from sklearn.cluster import DBSCAN

from kbmod.filters.clustering_grid import TrajectoryClusterGrid
from kbmod.results import Results
import kbmod.search as kb

logger = kb.Logging.getLogger(__name__)


[docs]class DBSCANFilter: """Cluster the candidates using DBSCAN and only keep a single representative trajectory from each cluster. Attributes ---------- cluster_eps : `float` The clustering threshold (in pixels). cluster_type : `str` The type of clustering. cluster_args : `dict` Additional arguments to pass to the clustering algorithm. """ def __init__(self, cluster_eps, **kwargs): """Create a DBSCANFilter. Parameters ---------- cluster_eps : `float` The clustering threshold. **kwargs : `dict` Additional arguments to pass to the clustering algorithm. """ self.cluster_eps = cluster_eps self.cluster_type = "" self.cluster_args = dict(eps=self.cluster_eps, min_samples=1, n_jobs=-1)
[docs] def get_filter_name(self): """Get the name of the filter. Returns ------- str The filter name. """ return f"DBSCAN_{self.cluster_type} eps={self.cluster_eps}"
def _build_clustering_data(self, result_data): """Build the specific data set for this clustering approach. Parameters ---------- result_data: `Results` The set of results to filter. Returns ------- data : `numpy.ndarray` The N x D matrix to cluster where N is the number of results and D is the number of attributes. """ raise NotImplementedError()
[docs] def keep_indices(self, result_data): """Determine which of the results's indices to keep. Parameters ---------- result_data: `Results` The set of results to filter. Returns ------- `list` A list of indices (int) indicating which rows to keep. """ # Build a numpy array of the trajectories to cluster with one row for each trajectory. data = self._build_clustering_data(result_data) # Set up the clustering algorithm cluster = DBSCAN(**self.cluster_args) cluster.fit(data) # Get the best index per cluster. If the data is sorted by LH, this should always # be the first point in the cluster. But we do an argmax in case the user has # manually sorted the data by something else. top_vals = [] for cluster_num in np.unique(cluster.labels_): cluster_vals = np.where(cluster.labels_ == cluster_num)[0] top_ind = np.argmax(result_data["likelihood"][cluster_vals]) top_vals.append(cluster_vals[top_ind]) return top_vals
[docs]class ClusterPredictionFilter(DBSCANFilter): """Cluster the candidates using their positions at specific times. Attributes ---------- times : list-like The times at which to evaluate the trajectories (in days). """ def __init__(self, cluster_eps, pred_times=[0.0], **kwargs): """Create a DBSCANFilter. Parameters ---------- cluster_eps : `float` The clustering threshold. pred_times : `list` The times a which to prediction the positions (in days). Default = [0.0] (starting position only) """ super().__init__(cluster_eps, **kwargs) # Confirm we have at least one prediction time. if len(pred_times) == 0: raise ValueError("No prediction times given.") self.times = np.array(pred_times, dtype=np.float32) # Set up the clustering algorithm's name. self.cluster_type = f"position t={self.times}" def _build_clustering_data(self, result_data): """Build the specific data set for this clustering approach. Parameters ---------- result_data: `Results` The set of results to filter. Returns ------- data : `numpy.ndarray` The N x D matrix to cluster where N is the number of results and D is the number of attributes. """ x0_arr = result_data["x"][:, np.newaxis].astype(np.float32) xv_arr = result_data["vx"][:, np.newaxis].astype(np.float32) pred_x = x0_arr + xv_arr * self.times[np.newaxis, :] y0_arr = result_data["y"][:, np.newaxis].astype(np.float32) yv_arr = result_data["vy"][:, np.newaxis].astype(np.float32) pred_y = y0_arr + yv_arr * self.times[np.newaxis, :] return np.hstack([pred_x, pred_y])
[docs]class ClusterPosVelFilter(DBSCANFilter): """Cluster the candidates using their starting position and velocities.""" def __init__(self, cluster_eps, cluster_v_scale=1.0, **kwargs): """Create a DBSCANFilter. Parameters ---------- cluster_eps : `float` The clustering threshold (in pixels). cluster_v_scale : `float` The relative scaling of velocity differences compared to position differences. Default: 1.0 (no difference). """ super().__init__(cluster_eps, **kwargs) if cluster_v_scale < 0.0: raise ValueError("cluster_v_scale cannot be negative.") self.cluster_v_scale = cluster_v_scale self.cluster_type = "all" def _build_clustering_data(self, result_data): """Build the specific data set for this clustering approach. Parameters ---------- result_data: `Results` The set of results to filter. Returns ------- data : `numpy.ndarray` The N x D matrix to cluster where N is the number of results and D is the number of attributes. """ data = np.empty((len(result_data), 4), dtype=np.float32) data[:, 0] = result_data["x"].astype(np.float32) data[:, 1] = result_data["y"].astype(np.float32) data[:, 2] = result_data["vx"] * self.cluster_v_scale data[:, 3] = result_data["vy"] * self.cluster_v_scale return data
[docs]class NNSweepFilter: """Filter any points that have neighboring trajectory with a higher likleihood within the threshold. Attributes ---------- thresh : `float` The filtering threshold to use (in pixels). times : list-like The times at which to evaluate the trajectories (in days). batch_size : `int` The size of batching to use for kd-tree lookups. A batch size of 1 turns off multi-threading and runs everything in series. Default: 1000 """ def __init__(self, cluster_eps, pred_times, batch_size=1_000): """Create a NNFilter. Parameters ---------- cluster_eps : `float` The filtering threshold to use. pred_times : list-like The times at which to evaluate the trajectories. batch_size : `int` The size of batching to use for kd-tree lookups. A batch size of 1 turns off multi-threading and runs everything in series. Default: 1000 """ if cluster_eps <= 0.0: raise ValueError(f"Threshold must be > 0.0.") self.thresh = cluster_eps self.times = np.asarray(pred_times, dtype=np.float32) if len(self.times) == 0: raise ValueError(f"Empty time array provided.") if batch_size <= 0: raise ValueError(f"batch_size must be > 0.") self.batch_size = batch_size
[docs] def get_filter_name(self): """Get the name of the filter. Returns ------- str The filter name. """ return f"NNFilter times={self.times} eps={self.thresh}"
def _build_clustering_data(self, result_data): """Build the specific data set for this clustering approach. Parameters ---------- result_data: `Results` The set of results to filter. Returns ------- data : `numpy.ndarray` The N x D matrix to cluster where N is the number of results and D is the number of attributes. """ x0_arr = result_data["x"][:, np.newaxis].astype(np.float32) xv_arr = result_data["vx"][:, np.newaxis].astype(np.float32) pred_x = x0_arr + xv_arr * self.times[np.newaxis, :] y0_arr = result_data["y"][:, np.newaxis].astype(np.float32) yv_arr = result_data["vy"][:, np.newaxis].astype(np.float32) pred_y = y0_arr + yv_arr * self.times[np.newaxis, :] return np.hstack([pred_x, pred_y])
[docs] def keep_indices(self, result_data): """Determine which of the results's indices to keep. Parameters ---------- result_data: `Results` The set of results to filter. Returns ------- `list` A list of indices (int) indicating which rows to keep. """ from scipy.spatial import KDTree # Predict the Trajectory's locations at the given times and put the # resulting points in a KDTree. build_data_timer = kb.DebugTimer("NNSweepFilter building data", logger) cart_data = self._build_clustering_data(result_data) kd_tree = KDTree(cart_data) build_data_timer.stop() num_pts = len(result_data) lh_data = result_data["likelihood"] # For each point, search for all neighbors within the threshold and # only keep the point if it has the highest likelihood in that range. # We do this in batches to benefit from multi-threaded KDTree queries # while avoiding too much memory for the match data. num_workers = -1 if self.batch_size > 1 else 1 can_skip = np.full(num_pts, False) keep_vals = [] batch_start = 0 while batch_start < num_pts: # Get the next batch of indices to search. Each batch only includes those # results that have not already been eliminated to avoid unnecessary searches. batch_end = min(num_pts, batch_start + self.batch_size) batch_inds = np.asanyarray([i for i in range(batch_start, batch_end) if not (can_skip[i])]) # Skip all the work if there is nothing to query in this batch. if len(batch_inds) == 0: batch_start = batch_end continue # Do the (multi-threaded) KD-tree search for his batch of indices. batch_matches = kd_tree.query_ball_point( cart_data[batch_inds, :], self.thresh, workers=num_workers, ) # Check if each index is the best in its neighborhood. for batch_idx, total_idx in enumerate(batch_inds): if not can_skip[total_idx]: matches = np.asanyarray(batch_matches[batch_idx]) if lh_data[total_idx] >= np.max(lh_data[matches]): keep_vals.append(total_idx) # Everything found in this run (including the current point) # doesn't need to be searched in the future, because we have # found the maximum value in this area. can_skip[matches] = True batch_start = batch_end return keep_vals
[docs]class ClusterGridFilter: """Use a discrete grid to cluster the points. Each trajectory is fit into a bin and only the best trajectory per bin is retained. Attributes ---------- bin_width : `int` The width of the grid bins (in pixels). cluster_grid : `TrajectoryClusterGrid` The grid of best result trajectories seen. max_dt : `float` The maximum different between times in pred_times. """ def __init__(self, cluster_eps, pred_times): """Create a ClusterGridFilter. Parameters ---------- cluster_eps : `float` The bin width to use (in pixels). pred_times : list-like The times at which to evaluate the trajectories (in days). """ self.bin_width = np.ceil(cluster_eps) if self.bin_width <= 0: raise ValueError(f"Bin width must be > 0.0.") self.times = np.asarray(pred_times) if len(self.times) == 0: self.times = np.array([0.0]) self.max_dt = np.max(self.times) - np.min(self.times) # Create the actual grid to store the results. self.cluster_grid = TrajectoryClusterGrid( bin_width=self.bin_width, max_time=self.max_dt, )
[docs] def get_filter_name(self): """Get the name of the filter. Returns ------- str The filter name. """ return f"ClusterGridFilter bin_width{self.bin_width}, max_dt={self.max_dt}"
[docs] def keep_indices(self, result_data): """Determine which of the results's indices to keep. Parameters ---------- result_data: `Results` The set of results to filter. Returns ------- `list` A list of indices (int) indicating which rows to keep. """ trj_list = result_data.make_trajectory_list() for idx, trj in enumerate(trj_list): self.cluster_grid.add_trajectory(trj, idx) keep_vals = np.sort(self.cluster_grid.get_indices()) return list(keep_vals)
[docs]def apply_clustering(result_data, cluster_params): """This function clusters results that have similar trajectories. Parameters ---------- result_data: `Results` The set of results to filter. This data gets modified directly by the filtering. cluster_params : dict Contains values concerning the image and search settings including: cluster_type, cluster_eps, times, and cluster_v_scale (optional). Raises ------ Raises a ValueError if the parameters are not valid. Raises a TypeError if ``result_data`` is of an unsupported type. """ if "cluster_type" not in cluster_params: raise KeyError("Missing cluster_type parameter") cluster_type = cluster_params["cluster_type"] # Skip clustering if there is nothing to cluster. if len(result_data) == 0: logger.info("Clustering : skipping, no results.") return # Get the times used for prediction clustering. if not "times" in cluster_params: raise KeyError("Missing times parameter in the clustering parameters.") all_times = np.sort(cluster_params["times"]) zeroed_times = np.array(all_times) - all_times[0] # Do the clustering and the filtering. if cluster_type == "all" or cluster_type == "pos_vel": filt = ClusterPosVelFilter(**cluster_params) elif cluster_type == "position" or cluster_type == "start_position": cluster_params["pred_times"] = [0.0] filt = ClusterPredictionFilter(**cluster_params) elif cluster_type == "mid_position": cluster_params["pred_times"] = [np.median(zeroed_times)] filt = ClusterPredictionFilter(**cluster_params) elif cluster_type == "start_end_position": cluster_params["pred_times"] = [0.0, zeroed_times[-1]] filt = ClusterPredictionFilter(**cluster_params) elif cluster_type == "nn_start_end": filt = NNSweepFilter(cluster_params["cluster_eps"], [0.0, zeroed_times[-1]]) elif cluster_type == "nn_start": filt = NNSweepFilter(cluster_params["cluster_eps"], [0.0]) elif cluster_type == "grid_start_end": filt = ClusterGridFilter(cluster_params["cluster_eps"], [0.0, zeroed_times[-1]]) elif cluster_type == "grid_start": filt = ClusterGridFilter(cluster_params["cluster_eps"], [0.0]) else: raise ValueError(f"Unknown clustering type: {cluster_type}") logger.info(f"Clustering {len(result_data)} results using {filt.get_filter_name()}") # Do the actual filtering. indices_to_keep = filt.keep_indices(result_data) result_data.filter_rows(indices_to_keep, filt.get_filter_name())