import copy
import math
from astropy.io import fits
from astropy.table import Table
from pathlib import Path
from yaml import dump, safe_load
from kbmod.search import Logging
logger = Logging.getLogger(__name__)
class _ParamInfo:
"""Class to store information about a configuration parameter.
Parameters
----------
name : `str`
The parameter name.
default_value : any, optional
The default value for the parameter. If not provided, defaults to None.
description : `str`, optional
A description of the parameter. If not provided, defaults to an empty string.
section : `str`, optional
The section the parameter belongs to. If not provided, defaults to "other".
validate_func : `callable`, optional
A function to validate the parameter's value. If not provided, defaults to None.
required : `bool`, optional
Whether the parameter is required. If not provided, defaults to False.
"""
def __init__(
self,
name,
default_value,
description="",
section="other",
validate_func=None,
required=False,
):
self.name = name
self.default_value = default_value
self.description = description
self.section = section
self.validate_func = validate_func
self.required = required
def __str__(self):
return f"{self.name}: {self.description} (Default: {self.default_value})"
def validate(self, value):
"""Validate the parameter's value using the provided validation function.
Parameters
----------
value : any
The value to validate.
Returns
-------
`bool`
True if the value is valid, False otherwise.
"""
if self.required and value is None:
return False
if self.validate_func is not None:
return self.validate_func(value)
return True
# List of all the supported configuration parameters (in alphabetical order).
_SUPPORTED_PARAMS = [
_ParamInfo(
name="clip_negative",
default_value=False,
description="If True remove all negative values prior to sigmaG computing the percentiles.",
section="filtering",
validate_func=lambda x: isinstance(x, bool),
),
_ParamInfo(
name="cluster_eps",
default_value=20.0,
description="The epsilon parameter for clustering (in pixels).",
section="clustering",
validate_func=lambda x: isinstance(x, (int, float)) and x >= 0,
),
_ParamInfo(
name="cluster_type",
default_value="all",
description="The type of clustering algorithm to use (if do_clustering = True).",
section="clustering",
validate_func=lambda x: isinstance(x, str),
),
_ParamInfo(
name="cluster_v_scale",
default_value=1.0,
description="The weight of differences in velocity relative to differences in distances during clustering.",
section="clustering",
validate_func=lambda x: isinstance(x, (int, float)) and x >= 0,
),
_ParamInfo(
name="color_scale",
default_value=None,
description="A dictionary mapping filter names to a color scale factor to use for those images.",
section="core",
validate_func=lambda x: x is None or isinstance(x, dict | int | float),
),
_ParamInfo(
name="cnn_filter",
default_value=False,
description="If True, applies a CNN filter to the stamps.",
section="filtering",
validate_func=lambda x: isinstance(x, bool),
),
_ParamInfo(
name="cnn_model",
default_value=None,
description="The path to the CNN model file to use for filtering.",
section="filtering",
validate_func=lambda x: isinstance(x, str) or x is None,
),
_ParamInfo(
name="cnn_coadd_type",
default_value="mean",
description="The type of coadd to use for CNN filtering ('mean', 'median', or 'sum').",
section="filtering",
validate_func=lambda x: x in ["mean", "median", "sum"],
),
_ParamInfo(
name="cnn_stamp_radius",
default_value=49,
description="The radius (in pixels) of the stamp to use for CNN filtering if cnn_filter is True.",
section="filtering",
validate_func=lambda x: isinstance(x, int) and x > 0,
),
_ParamInfo(
name="cnn_model_type",
default_value="resnet18",
description="The type of CNN model to use ('resnet18', 'resnet34', etc.) if cnn_filter is True.",
section="filtering",
validate_func=lambda x: isinstance(x, str),
),
_ParamInfo(
name="coadds",
default_value=[],
description="The list of coadd images to compute ('mean', 'median', 'sum', 'weighted').",
section="stamps",
validate_func=lambda x: isinstance(x, list)
and all(i in ["mean", "median", "sum", "weighted"] for i in x),
),
_ParamInfo(
name="compute_ra_dec",
default_value=True,
description="If True, compute RA and Dec for each result.",
section="output",
validate_func=lambda x: isinstance(x, bool),
),
_ParamInfo(
name="cpu_only",
default_value=False,
description="If True, only use the CPU for processing, even if a GPU is available.",
section="other",
validate_func=lambda x: isinstance(x, bool),
),
_ParamInfo(
name="debug",
default_value=False,
description="Run with debug logging enabled.",
section="other",
validate_func=lambda x: isinstance(x, bool),
),
_ParamInfo(
name="do_clustering",
default_value=True,
description="If true, perform clustering on the results.",
section="clustering",
validate_func=lambda x: isinstance(x, bool),
),
_ParamInfo(
name="drop_columns",
default_value=[],
description="List of result table columns to drop.",
section="output",
validate_func=lambda x: isinstance(x, list) and all(isinstance(i, str) for i in x),
),
_ParamInfo(
name="encode_num_bytes",
default_value=-1,
description="Number of bytes to use for encoding pixel values on GPU. -1 means no encoding.",
section="core",
validate_func=lambda x: x in set([-1, 1, 2, 4]),
),
_ParamInfo(
name="generator_config",
default_value={
"name": "EclipticCenteredSearch",
"velocities": [92.0, 526.0, 257],
"angles": [-math.pi / 15, math.pi / 15, 129],
"angle_units": "radian",
"velocity_units": "pix / d",
"given_ecliptic": None,
},
description="Configuration dictionary for the trajectory generator.",
section="core",
validate_func=lambda x: isinstance(x, dict) and "name" in x,
),
_ParamInfo(
name="generate_psi_phi",
default_value=True,
description="If True, computes the psi and phi curves and saves them with the results.",
section="filtering",
validate_func=lambda x: isinstance(x, bool),
),
_ParamInfo(
name="gpu_filter",
default_value=False,
description="If True, performs initial sigmaG filtering on GPU.",
section="filtering",
validate_func=lambda x: isinstance(x, bool),
),
_ParamInfo(
name="lh_level",
default_value=10.0,
description="The log-likelihood level above which results are kept.",
section="filtering",
validate_func=lambda x: isinstance(x, (int, float)),
),
_ParamInfo(
name="max_masked_pixels",
default_value=0.5,
description="The maximum fraction of masked pixels allowed before an input image is dropped.",
section="core",
validate_func=lambda x: isinstance(x, (int, float)) and 0.0 <= x <= 1.0,
),
_ParamInfo(
name="max_results",
default_value=100_000,
description="The maximum number of results to save after all filtering.",
section="filtering",
validate_func=lambda x: isinstance(x, int),
),
_ParamInfo(
name="near_dup_thresh",
default_value=10,
description="The threshold for considering two observations as near duplicates (in pixels).",
section="filtering",
validate_func=lambda x: isinstance(x, int),
),
_ParamInfo(
name="nightly_coadds",
default_value=False,
description="If True, generate an additional coadd for each calendar date.",
section="stamps",
validate_func=lambda x: isinstance(x, bool),
),
_ParamInfo(
name="num_obs",
default_value=10,
description="The minimum number of valid observations for the trajectory to be accepted.",
section="filtering",
validate_func=lambda x: isinstance(x, int),
),
_ParamInfo(
name="peak_offset_max",
default_value=None,
description="Maximum allowed offset (in pixels) between predicted and detected peak positions.",
section="filtering",
validate_func=lambda x: isinstance(x, (int, float)) or x is None,
),
_ParamInfo(
name="pred_line_cluster",
default_value=False,
description="If True, applies line clustering to the predicted lines.",
section="filtering",
validate_func=lambda x: isinstance(x, bool),
),
_ParamInfo(
name="pred_line_params",
default_value=[4.0, 2, 60],
description="Parameters for the line prediction model.",
section="filtering",
validate_func=lambda x: isinstance(x, list) and len(x) == 3,
),
_ParamInfo(
name="psf_val",
default_value=1.4,
description="The default standard deviation of the Gaussian PSF in pixels (if not provided in the data).",
section="core",
validate_func=lambda x: isinstance(x, (int, float)) and x > 0.0,
),
_ParamInfo(
name="result_filename",
default_value=None,
description="The filename to which results will be saved.",
section="core",
validate_func=lambda x: isinstance(x, str) or x is None,
),
_ParamInfo(
name="results_per_pixel",
default_value=8,
description="The maximum number of results to return from the GPU per pixel.",
section="filtering",
validate_func=lambda x: isinstance(x, int) and x > 0,
),
_ParamInfo(
name="save_all_stamps",
default_value=False,
description="If True, save all stamps to the results.",
section="output",
validate_func=lambda x: isinstance(x, bool),
),
_ParamInfo(
name="save_config",
default_value=True,
description="If True, save the configuration used for processing.",
section="output",
validate_func=lambda x: isinstance(x, bool),
),
_ParamInfo(
name="separate_col_files",
default_value=["all_stamps"],
description="List of columns to save in separate files.",
section="output",
validate_func=lambda x: isinstance(x, list) and all(isinstance(i, str) for i in x),
),
_ParamInfo(
name="sigmaG_filter",
default_value=True,
description="If True, apply sigmaG filtering.",
section="filtering",
validate_func=lambda x: isinstance(x, bool),
),
_ParamInfo(
name="sigmaG_lims",
default_value=[25, 75],
description="The lower and upper limits for sigmaG filtering.",
section="filtering",
validate_func=lambda x: len(x) == 2 and x[0] < x[1] and all(isinstance(i, (int, float)) for i in x),
),
_ParamInfo(
name="stamp_radius",
default_value=10,
description="The radius (in pixels) of the stamp to extract.",
section="stamps",
validate_func=lambda x: isinstance(x, int) and x > 0,
),
_ParamInfo(
name="stamp_type",
default_value="sum",
description="The type of stamp to extract.",
section="stamps",
validate_func=lambda x: x in ["sum", "mean", "median", "weighted"],
),
_ParamInfo(
name="track_filtered",
default_value=False,
description="If True, track the filtered objects in the results table.",
section="filtering",
validate_func=lambda x: isinstance(x, bool),
),
_ParamInfo(
name="x_pixel_bounds",
default_value=None,
description="The x pixel bounds for the search starting location (None = use every pixel).",
section="core",
validate_func=lambda x: x is None or (len(x) == 2 and x[0] < x[1]),
),
_ParamInfo(
name="x_pixel_buffer",
default_value=None,
description="If not None, the number of x pixels beyond the image bounds to use for starting coordinates.",
section="core",
validate_func=lambda x: x is None or (isinstance(x, int) and x >= 0),
),
_ParamInfo(
name="y_pixel_bounds",
default_value=None,
description="The y pixel bounds for the search starting location (None = use every pixel).",
section="core",
validate_func=lambda x: x is None or (len(x) == 2 and x[0] < x[1]),
),
_ParamInfo(
name="y_pixel_buffer",
default_value=None,
description="If not None, the number of y pixels beyond the image bounds to use for starting coordinates.",
section="core",
validate_func=lambda x: x is None or (isinstance(x, int) and x >= 0),
),
]
[docs]class SearchConfiguration:
"""This class stores a collection of configuration parameter settings.
Parameters
----------
data : `dict`
A dictionary of initial values.
"""
def __init__(self, data=None):
# Reprocess the list of supported parameters into dictionaries for easy access.
self._param_info = {p.name: p for p in _SUPPORTED_PARAMS}
self._params = {p.name: p.default_value for p in _SUPPORTED_PARAMS}
if data is not None:
self.set_multiple(data)
def __contains__(self, key):
return key in self._params
def __getitem__(self, key):
"""Gets the value of a specific parameter.
Parameters
----------
key : `str`
The parameter name.
Raises
------
Raises a KeyError if the parameter is not included.
"""
return self._params[key]
def __str__(self):
result = "Configuration:\n"
for key, value in self._params.items():
result += f"{key}: {value}\n"
return result
[docs] def help(self, param=None):
"""Print help information for a specific parameter or all parameters.
Parameters
----------
param : `str`, optional
The parameter name. If not provided, help for all parameters is printed.
"""
if param is not None:
if param in self._param_info:
print(self._param_info[param])
else:
print(f"Parameter {param} is not recognized.")
else:
print("KBMOD Supported Parameters:")
section_to_params = self._cluster_parameters()
for section, section_keys in section_to_params.items():
print(f"\n--- {section.capitalize()} Parameters ---")
for key in section_keys:
if key in self._param_info:
print(self._param_info[key])
else:
print(f"{key}: No description available.")
[docs] def copy(self):
"""Create a new deep copy of the configuration."""
return copy.deepcopy(self)
[docs] def set(self, param, value, warn_on_unknown=False):
"""Sets the value of a specific parameter.
Parameters
----------
param : `str`
The parameter name.
value : any
The parameter's value.
warn_on_unknown : `bool`
Generate a warning if the parameter is not known.
"""
if warn_on_unknown and param not in self._params:
logger.warning(f"Setting unknown parameter: {param}")
self._params[param] = value
[docs] def set_multiple(self, overrides):
"""Sets multiple parameters from a dictionary.
Parameters
----------
overrides : `dict`
A dictionary of parameter->value to overwrite.
"""
for key, value in overrides.items():
self.set(key, value)
[docs] def validate(self):
"""Check that the configuration has the necessary parameters.
Returns
-------
`bool`
Returns True if the configuration is valid and False (logging the reason)
if the configuration is invalid.
"""
# Check parameters that have known constraints.
for key, value in self._params.items():
param_info = self._param_info.get(key, None)
if param_info is not None and not param_info.validate(value):
logger.warning(f"Invalid value for parameter {key}: {value}")
return False
return True
[docs] @classmethod
def from_dict(cls, d):
"""Sets multiple values from a dictionary.
Parameters
----------
d : `dict`
A dictionary mapping parameter name to valie.
"""
config = SearchConfiguration()
for key, value in d.items():
config.set(key, value)
return config
[docs] @classmethod
def from_table(cls, t):
"""Sets multiple values from an astropy Table with a single row and
one column for each parameter.
Parameters
----------
t : `~astropy.table.Table`
Astropy Table containing the required configuration parameters.
strict : `bool`
Raise an exception on unknown parameters.
Raises
------
Raises a ``KeyError`` if the parameter is not part on the list of known parameters
and ``strict`` is False.
Raises a ``ValueError`` if the table is the wrong shape.
"""
if len(t) > 1:
raise ValueError(f"More than one row in the configuration table ({len(t)}).")
# guaranteed to only have 1 element due to check above
params = {col.name: safe_load(col.value[0]) for col in t.values()}
return SearchConfiguration.from_dict(params)
[docs] @classmethod
def from_yaml(cls, config):
"""Load a configuration from a YAML file.
Parameters
----------
config : `str` or `_io.TextIOWrapper`
The serialized YAML data.
"""
yaml_params = safe_load(config)
return SearchConfiguration.from_dict(yaml_params)
[docs] @classmethod
def from_hdu(cls, hdu):
"""Load a configuration from a FITS extension file.
Parameters
----------
hdu : `astropy.io.fits.BinTableHDU`
The HDU from which to parse the configuration information.
"""
t = Table(hdu.data)
return SearchConfiguration.from_table(t)
@classmethod
def from_file(cls, filename):
with open(filename) as ff:
return SearchConfiguration.from_yaml(ff.read())
[docs] def to_hdu(self):
"""Create a fits HDU with all the configuration parameters.
Returns
-------
hdu : `astropy.io.fits.BinTableHDU`
The HDU with the configuration information.
"""
serialized_dict = {key: dump(val, default_flow_style=True) for key, val in self._params.items()}
t = Table(
rows=[
serialized_dict,
]
)
return fits.table_to_hdu(t)
[docs] def to_yaml(self):
"""Save a configuration file with the parameters.
Returns
-------
result : `str`
The serialized YAML string.
"""
return dump(self._params)
def _cluster_parameters(self, sections=None):
"""Get the clustering parameters as a dictionary.
Parameters
----------
sections : `list` of `str`, optional
The sections to include. If not provided, all sections are included.
Returns
-------
result : `dict`
A dictionary with the parameters clustered into sections.
"""
if sections is None:
sections = ["core", "filtering", "stamps", "clustering", "output", "other"]
# Create the empty dictionary.
section_to_params = {}
for sec in sections:
section_to_params[sec] = []
if "other" not in sections:
section_to_params["other"] = []
# Fill the dictionary from the parameter info.
for key, value in self._param_info.items():
section = value.section if value.section in section_to_params else "other"
section_to_params[section].append(key)
return section_to_params
[docs] def to_file(self, filename, overwrite=False):
"""Save a configuration file with the parameters.
Parameters
----------
filename : str
The filename, including path, of the configuration file.
overwrite : bool
Indicates whether to overwrite an existing file.
"""
if Path(filename).is_file() and not overwrite:
logger.warning(f"Configuration file {filename} already exists.")
return
logger.info(f"Saving configuration to {filename}")
# Output the configuration file in sections for easier reading. We add the sections
# in the order we want them to appear.
section_to_params = self._cluster_parameters()
with open(filename, "w") as file:
for section, section_keys in section_to_params.items():
file.write("# ======================================================================\n")
file.write(f"# {section.capitalize()} Configuration\n")
file.write("# ======================================================================\n")
for key in section_keys:
if key in self._param_info and self._param_info[key].description:
file.write(f"\n# {self._param_info[key].description}\n")
file.write(dump({key: self._params[key]}))
file.write("\n")