Source code for kbmod.run_search

import logging
import os
import time

import astropy.units as u
import numpy as np
import psutil
from astropy.coordinates import EarthLocation, SkyCoord

import kbmod.search as kb

from .filters.clustering_filters import apply_clustering
from .filters.clustering_grid import apply_trajectory_grid_filter
from .filters.sigma_g_filter import SigmaGClipping, apply_clipped_sigma_g
from .filters.sns_filters import peak_offset_filter, predictive_line_cluster
from .filters.stamp_filters import append_all_stamps, append_coadds, filter_stamps_by_cnn
from .reprojection_utils import invert_correct_parallax_vectorized
from .results import Results, write_results_to_files_destructive
from .trajectory_generator import create_trajectory_generator
from .trajectory_utils import predict_pixel_locations

logger = kb.Logging.getLogger(__name__)


[docs]def configure_kb_search_stack(search, config): """Configure the kbmod SearchStack object from a search configuration. Parameters ---------- search : `kb.StackSearch` The SearchStack object. config : `SearchConfiguration` The configuration parameters """ width = search.get_image_width() height = search.get_image_height() # Set the filtering parameters. search.set_min_obs(int(config["num_obs"])) search.set_min_lh(config["lh_level"]) # Set the search bounds. if config["x_pixel_bounds"] and len(config["x_pixel_bounds"]) == 2: search.set_start_bounds_x(config["x_pixel_bounds"][0], config["x_pixel_bounds"][1]) elif config["x_pixel_buffer"] and config["x_pixel_buffer"] > 0: search.set_start_bounds_x(-config["x_pixel_buffer"], width + config["x_pixel_buffer"]) if config["y_pixel_bounds"] and len(config["y_pixel_bounds"]) == 2: search.set_start_bounds_y(config["y_pixel_bounds"][0], config["y_pixel_bounds"][1]) elif config["y_pixel_buffer"] and config["y_pixel_buffer"] > 0: search.set_start_bounds_y(-config["y_pixel_buffer"], height + config["y_pixel_buffer"]) # Set the results per pixel. search.set_results_per_pixel(config["results_per_pixel"]) # If we are using gpu_filtering, enable it and set the parameters. if config["sigmaG_filter"] and config["gpu_filter"]: logger.debug("Using in-line GPU sigmaG filtering methods") coeff = SigmaGClipping.find_sigma_g_coeff( config["sigmaG_lims"][0], config["sigmaG_lims"][1], ) search.enable_gpu_sigmag_filter( np.array(config["sigmaG_lims"]) / 100.0, coeff, config["lh_level"], ) else: search.disable_gpu_sigmag_filter() # Clear the cached results. search.clear_results()
[docs]def check_gpu_memory(config, stack, trj_generator=None): """Check whether we can run this search on the GPU. Parameters ---------- config : `SearchConfiguration` The configuration parameters stack : `ImageStackPy` The stack of image data. trj_generator : `TrajectoryGenerator`, optional The object to generate the candidate trajectories for each pixel. Returns ------- valid : `bool` Returns True if the search will fit on GPU and False otherwise. """ bytes_free = kb.get_gpu_free_memory() logger.debug(f"Checking GPU memory needs (Free memory = {bytes_free} bytes):") # Compute the size of the PSI/PHI images using the encoded size (-1 means 4 bytes). gpu_float_size = config["encode_num_bytes"] if config["encode_num_bytes"] > 0 else 4 img_stack_size = stack.get_total_pixels() * gpu_float_size logger.debug( f" PSI/PHI encoding at {gpu_float_size} bytes per pixel.\n" f" PSI = {img_stack_size} bytes\n PHI = {img_stack_size} bytes" ) # Compute the size of the candidates num_candidates = 0 if trj_generator is None else len(trj_generator) candidate_memory = kb.TrajectoryList.estimate_memory(num_candidates) logger.debug(f" Candidates ({num_candidates}) = {candidate_memory} bytes.") # Compute the size of the results. We use the bounds from the search dimensions # (not the raw image dimensions). search_width = stack.width if config["x_pixel_bounds"] and len(config["x_pixel_bounds"]) == 2: search_width = config["x_pixel_bounds"][1] - config["x_pixel_bounds"][0] elif config["x_pixel_buffer"] and config["x_pixel_buffer"] > 0: search_width += 2 * config["x_pixel_buffer"] search_height = stack.height if config["y_pixel_bounds"] and len(config["y_pixel_bounds"]) == 2: search_height = config["y_pixel_bounds"][1] - config["y_pixel_bounds"][0] elif config["y_pixel_buffer"] and config["y_pixel_buffer"] > 0: search_height += 2 * config["y_pixel_buffer"] num_results = search_width * search_height * config["results_per_pixel"] result_memory = kb.TrajectoryList.estimate_memory(num_results) logger.debug(f" Results ({num_results}) = {result_memory} bytes.") return bytes_free > (2 * img_stack_size + result_memory + candidate_memory)
[docs]class SearchRunner: """A class to run the KBMOD grid search. Attributes ---------- config : `SearchConfiguration` The configuration parameters. debug : `bool` If True, enable debug logging (and additional computation). phase_times : `dict` A dictionary mapping the search phase to the timing information, a list of [starting time, ending time] in seconds. phase_memory : `dict` A dictionary mapping the search phase the memory information, a list of [starting memory, ending memory] in bytes. timeout : `float` or `None` The time at which the search should timeout, in seconds since the epoch. This is a soft timeout that will not interrupt during a processing stage. None means no timeout is set. """ def __init__(self, config=None): self.phase_times = {} self.phase_memory = {} self.timeout = None self.debug = False self.apply_config(config)
[docs] def apply_config(self, config): """Apply the configuration parameters to the search runner. This function is designed to be called at multiple points allow it to be used regardless of which level of the search is being run. Parameters ---------- config : `SearchConfiguration` The configuration parameters """ if config is None: return # Nothing to apply if not config.validate(): raise ValueError("Invalid configuration") self.config = config if config["debug"]: logging.basicConfig(level=logging.DEBUG) self.debug = True if self.timeout is None and config["timeout_hours"] is not None: logger.debug(f"Setting search timeout to {config['timeout_hours']} hours.") self.timeout = time.time() + config["timeout_hours"] * 3600.0 logger.debug(f"Search will timeout at {time.ctime(self.timeout)}.")
def _check_timeout(self): """Check if the search has exceeded the timeout. This is a soft timeout that will only quit between phases, so that each phase can correctly free its resources (for C++ GPU functions). """ if self.timeout is not None and time.time() > self.timeout: self.display_phase_stats() # Display which phases have been run so far. raise TimeoutError("Search has exceeded the maximum allowed time.") def _start_phase(self, phase_name): """Start recording stats for the current phase. Parameters ---------- phase_name : `str` The current phase. """ self._check_timeout() logger.debug(f"Starting {phase_name}.") # Record the start time. self.phase_times[phase_name] = [time.time(), None] # Record the starting memory. memory_info = psutil.Process().memory_info() self.phase_memory[phase_name] = [memory_info.rss, None] def _end_phase(self, phase_name): """Finish recording stats for the current phase. Parameters ---------- phase_name : `str` The current phase. """ self._check_timeout() if phase_name not in self.phase_times: raise KeyError(f"Phase {phase_name} has not been started.") # Record the end time. self.phase_times[phase_name][1] = time.time() delta_t = self.phase_times[phase_name][1] - self.phase_times[phase_name][0] logger.debug(f"Finished {phase_name} in {delta_t} seconds.") # Record the starting memory. memory_info = psutil.Process().memory_info() self.phase_memory[phase_name][1] = memory_info.rss
[docs] def display_phase_stats(self): """Output the statistics for each phase.""" for phase in self.phase_times: print(f"{phase}:") if self.phase_times[phase][1] is not None: delta_t = self.phase_times[phase][1] - self.phase_times[phase][0] print(f" Time (sec) = {delta_t}") else: print(f" Time (sec) = Unfinished") print(f" Memory Start (mb) = {self.phase_memory[phase][0] / (1024.0 * 1024.0)}") if self.phase_memory[phase][1] is not None: print(f" Memory End (mb) = {self.phase_memory[phase][1] / (1024.0 * 1024.0)}") else: print(f" Memory End (mb) = Unfinished")
[docs] def load_and_filter_results(self, search, config, batch_size=100_000): """This function loads results that are output by the grid search. It can then generate psi + phi curves and perform sigma-G filtering (depending on the parameter settings). Parameters ---------- search : `kbmod.search` The search function object. config : `SearchConfiguration` The configuration parameters batch_size : `int` The number of results to load at once. This is used to limit the memory usage when loading results. Default is 100000. Returns ------- keep : `Results` A Results object containing values from trajectories. """ self._start_phase("load_and_filter_results") num_times = search.get_num_images() # Set up the clipped sigmaG filter. if config["sigmaG_lims"] is not None: bnds = config["sigmaG_lims"] else: bnds = [25, 75] clipper = SigmaGClipping(bnds[0], bnds[1], 2, config["clip_negative"]) keep = Results(track_filtered=config["track_filtered"]) # Retrieve a reference to all the results and compile the results table. result_trjs = search.get_all_results() logger.info(f"Retrieving Results (total={len(result_trjs)})") if len(result_trjs) < 1: logger.info(f"No results found.") return keep logger.info(f"Max Likelihood = {result_trjs[0].lh}") logger.info(f"Min. Likelihood = {result_trjs[-1].lh}") # Perform near duplicate filtering. if config["near_dup_thresh"] is not None and config["near_dup_thresh"] > 0: self._start_phase("near duplicate removal") bin_width = config["near_dup_thresh"] max_dt = np.max(search.zeroed_times) - np.min(search.zeroed_times) logger.info(f"Prefiltering Near Duplicates (bin_width={bin_width}, max_dt={max_dt})") result_trjs, _ = apply_trajectory_grid_filter(result_trjs, bin_width, max_dt) logger.info(f"After prefiltering {len(result_trjs)} remaining.") self._end_phase("near duplicate removal") # Transform the results into a Result table in batches while doing sigma-G filtering. batch_start = 0 while batch_start < len(result_trjs): self._check_timeout() batch_end = min(batch_start + batch_size, len(result_trjs)) batch = result_trjs[batch_start:batch_end] batch_results = Results.from_trajectories(batch, track_filtered=config["track_filtered"]) if config["generate_psi_phi"]: psi_phi_batch = search.get_all_psi_phi_curves(batch) batch_results.add_psi_phi_data(psi_phi_batch[:, :num_times], psi_phi_batch[:, num_times:]) # Do the sigma-G filtering and subsequent stats filtering. if config["sigmaG_filter"]: if not config["generate_psi_phi"]: raise ValueError("Unable to do sigma-G filtering without psi and phi curves.") apply_clipped_sigma_g(clipper, batch_results) # Re-test the obs_count and likelihood after sigma-G has removed points. row_mask = batch_results["obs_count"] >= config["num_obs"] if config["lh_level"] > 0.0: row_mask = row_mask & (batch_results["likelihood"] >= config["lh_level"]) batch_results.filter_rows(row_mask, "sigma-g") logger.debug(f"After sigma-G filtering, batch size = {len(batch_results)}") # Append the unfiltered results to the final table. logger.debug(f"Added {len(batch_results)} results from batch [{batch_start}, {batch_end}).") keep.extend(batch_results) batch_start += batch_size # Save the timing information. self._end_phase("load_and_filter_results") # Return the extracted and unfiltered results. return keep
[docs] def run_search_from_work_unit(self, work): """Run a KBMOD search from a WorkUnit object. Parameters ---------- work : `WorkUnit` The input data and configuration. Returns ------- keep : `Results` The results. """ trj_generator = create_trajectory_generator(work.config, work_unit=work) if work.config["color_scale"] is not None: work.im_stack.apply_color_scaling(work.config["color_scale"]) return self.run_search( work.config, work.im_stack, trj_generator=trj_generator, workunit=work, )
[docs]def append_positions_to_results(workunit, results): """Appends predicted RA, Dec positions to the results table. Parameters ---------- workunit : `WorkUnit` The WorkUnit with all the WCS information. results : `Results` The current table of results including the per-pixel trajectories. This is modified in-place. """ num_results = len(results) if num_results == 0: return # Nothing to do num_times = workunit.im_stack.num_times times = workunit.im_stack.zeroed_times # Predict pixel locations (same as original) xp = predict_pixel_locations(times, results["x"], results["vx"], as_int=False, centered=False) yp = predict_pixel_locations(times, results["y"], results["vy"], as_int=False, centered=False) results.table["pred_x"] = xp results.table["pred_y"] = yp all_ra = np.zeros((num_results, num_times)) all_dec = np.zeros((num_results, num_times)) if workunit.wcs is not None: logger.info("Found common WCS. Adding global_ra and global_dec columns (vectorized).") # Compute the global (RA, dec) for all results and all times in one call skypos = workunit.wcs.pixel_to_world(xp, yp) results.table["global_ra"] = skypos.ra.degree results.table["global_dec"] = skypos.dec.degree # Now compute img_ra, img_dec by iterating over time (not per-result) # This allows us to batch all results for a given time step if workunit.reprojected and workunit.reprojection_frame != "ebd": logger.warning("No EBD reprojection found. Skipping img_ra and img_dec columns.") else: # For EBD reprojection, use the vectorized invert function obstimes = workunit.get_all_obstimes() for time_idx in range(num_times): # Get all results' sky positions at this time time_skypos = SkyCoord( ra=skypos.ra[:, time_idx], dec=skypos.dec[:, time_idx], distance=workunit.barycentric_distance * u.AU, ) # Invert parallax for all results at this time step original_icrs = invert_correct_parallax_vectorized( time_skypos, obstimes=obstimes[time_idx], point_on_earth=workunit.observatory, ) all_ra[:, time_idx] = original_icrs.ra.degree all_dec[:, time_idx] = original_icrs.dec.degree else: logger.info("No common WCS found. Skipping global_ra and global_dec columns (vectorized).") # No global WCS: iterate over time, batch over results for time_idx in range(num_times): wcs = workunit.get_wcs(time_idx) if wcs is not None: skypos = wcs.pixel_to_world(xp[:, time_idx], yp[:, time_idx]) all_ra[:, time_idx] = skypos.ra.degree all_dec[:, time_idx] = skypos.dec.degree results.table["img_ra"] = all_ra results.table["img_dec"] = all_dec