Source code for kbmod.configuration

import copy
import math

from astropy.io import fits
from astropy.table import Table
from pathlib import Path
from yaml import dump, safe_load
from kbmod.search import Logging


logger = Logging.getLogger(__name__)


[docs]class SearchConfiguration: """This class stores a collection of configuration parameter settings. Parameters ---------- data : `dict` A dictionary of initial values. """ def __init__(self, data=None): self._required_params = set() self._params = { "chunk_size": 500000, "clip_negative": False, "cluster_eps": 20.0, "cluster_type": "all", "cluster_v_scale": 1.0, "coadds": [], "debug": False, "do_clustering": True, "do_mask": True, "encode_num_bytes": -1, "generator_config": { "name": "EclipticCenteredSearch", "velocities": [92.0, 526.0, 257], "angles": [-math.pi / 15, math.pi / 15, 129], "angle_units": "radian", "velocity_units": "pix / d", "given_ecliptic": None, }, "gpu_filter": False, "lh_level": 10.0, "max_lh": 1000.0, "num_obs": 10, "psf_val": 1.4, "result_filename": None, "results_per_pixel": 8, "save_all_stamps": False, "sigmaG_filter": True, "sigmaG_lims": [25, 75], "stamp_radius": 10, "stamp_type": "sum", "track_filtered": False, "x_pixel_bounds": None, "x_pixel_buffer": None, "y_pixel_bounds": None, "y_pixel_buffer": None, } if data is not None: self.set_multiple(data) def __contains__(self, key): return key in self._params def __getitem__(self, key): """Gets the value of a specific parameter. Parameters ---------- key : `str` The parameter name. Raises ------ Raises a KeyError if the parameter is not included. """ return self._params[key] def __str__(self): result = "Configuration:\n" for key, value in self._params.items(): result += f"{key}: {value}\n" return result
[docs] def copy(self): """Create a new deep copy of the configuration.""" return copy.deepcopy(self)
[docs] def set(self, param, value, warn_on_unknown=False): """Sets the value of a specific parameter. Parameters ---------- param : `str` The parameter name. value : any The parameter's value. warn_on_unknown : `bool` Generate a warning if the parameter is not known. """ if warn_on_unknown and param not in self._params: logger.warning(f"Setting unknown parameter: {param}") self._params[param] = value
[docs] def set_multiple(self, overrides): """Sets multiple parameters from a dictionary. Parameters ---------- overrides : `dict` A dictionary of parameter->value to overwrite. """ for key, value in overrides.items(): self.set(key, value)
[docs] def validate(self): """Check that the configuration has the necessary parameters. Raises ------ Raises a ``ValueError`` if a parameter is missing. """ for p in self._required_params: if self._params.get(p, None) is None: raise ValueError(f"Required configuration parameter {p} missing.")
[docs] @classmethod def from_dict(cls, d): """Sets multiple values from a dictionary. Parameters ---------- d : `dict` A dictionary mapping parameter name to valie. """ config = SearchConfiguration() for key, value in d.items(): config.set(key, value) return config
[docs] @classmethod def from_table(cls, t): """Sets multiple values from an astropy Table with a single row and one column for each parameter. Parameters ---------- t : `~astropy.table.Table` Astropy Table containing the required configuration parameters. strict : `bool` Raise an exception on unknown parameters. Raises ------ Raises a ``KeyError`` if the parameter is not part on the list of known parameters and ``strict`` is False. Raises a ``ValueError`` if the table is the wrong shape. """ if len(t) > 1: raise ValueError(f"More than one row in the configuration table ({len(t)}).") # guaranteed to only have 1 element due to check above params = {col.name: safe_load(col.value[0]) for col in t.values()} return SearchConfiguration.from_dict(params)
[docs] @classmethod def from_yaml(cls, config): """Load a configuration from a YAML file. Parameters ---------- config : `str` or `_io.TextIOWrapper` The serialized YAML data. """ yaml_params = safe_load(config) return SearchConfiguration.from_dict(yaml_params)
[docs] @classmethod def from_hdu(cls, hdu): """Load a configuration from a FITS extension file. Parameters ---------- hdu : `astropy.io.fits.BinTableHDU` The HDU from which to parse the configuration information. """ t = Table(hdu.data) return SearchConfiguration.from_table(t)
@classmethod def from_file(cls, filename): with open(filename) as ff: return SearchConfiguration.from_yaml(ff.read())
[docs] def to_hdu(self): """Create a fits HDU with all the configuration parameters. Returns ------- hdu : `astropy.io.fits.BinTableHDU` The HDU with the configuration information. """ serialized_dict = {key: dump(val, default_flow_style=True) for key, val in self._params.items()} t = Table( rows=[ serialized_dict, ] ) return fits.table_to_hdu(t)
[docs] def to_yaml(self): """Save a configuration file with the parameters. Returns ------- result : `str` The serialized YAML string. """ return dump(self._params)
[docs] def to_file(self, filename, overwrite=False): """Save a configuration file with the parameters. Parameters ---------- filename : str The filename, including path, of the configuration file. overwrite : bool Indicates whether to overwrite an existing file. """ if Path(filename).is_file() and not overwrite: logger.warning(f"Configuration file {filename} already exists.") return with open(filename, "w") as file: file.write(self.to_yaml())