import copy
import math
from astropy.io import fits
from astropy.table import Table
from pathlib import Path
from yaml import dump, safe_load
from kbmod.search import Logging
logger = Logging.getLogger(__name__)
class _ParamInfo:
    """Class to store information about a configuration parameter.
    Parameters
    ----------
    name : `str`
        The parameter name.
    default_value : any, optional
        The default value for the parameter. If not provided, defaults to None.
    description : `str`, optional
        A description of the parameter. If not provided, defaults to an empty string.
    section : `str`, optional
        The section the parameter belongs to. If not provided, defaults to "other".
    validate_func : `callable`, optional
        A function to validate the parameter's value. If not provided, defaults to None.
    required : `bool`, optional
        Whether the parameter is required. If not provided, defaults to False.
    """
    def __init__(
        self,
        name,
        default_value,
        description="",
        section="other",
        validate_func=None,
        required=False,
    ):
        self.name = name
        self.default_value = default_value
        self.description = description
        self.section = section
        self.validate_func = validate_func
        self.required = required
    def __str__(self):
        return f"{self.name}: {self.description} (Default: {self.default_value})"
    def validate(self, value):
        """Validate the parameter's value using the provided validation function.
        Parameters
        ----------
        value : any
            The value to validate.
        Returns
        -------
        `bool`
            True if the value is valid, False otherwise.
        """
        if self.required and value is None:
            return False
        if self.validate_func is not None:
            return self.validate_func(value)
        return True
# List of all the supported configuration parameters (in alphabetical order).
_SUPPORTED_PARAMS = [
    _ParamInfo(
        name="clip_negative",
        default_value=False,
        description="If True remove all negative values prior to sigmaG computing the percentiles.",
        section="filtering",
        validate_func=lambda x: isinstance(x, bool),
    ),
    _ParamInfo(
        name="cluster_eps",
        default_value=20.0,
        description="The epsilon parameter for clustering (in pixels).",
        section="clustering",
        validate_func=lambda x: isinstance(x, (int, float)) and x >= 0,
    ),
    _ParamInfo(
        name="cluster_type",
        default_value="all",
        description="The type of clustering algorithm to use (if do_clustering = True).",
        section="clustering",
        validate_func=lambda x: isinstance(x, str),
    ),
    _ParamInfo(
        name="cluster_v_scale",
        default_value=1.0,
        description="The weight of differences in velocity relative to differences in distances during clustering.",
        section="clustering",
        validate_func=lambda x: isinstance(x, (int, float)) and x >= 0,
    ),
    _ParamInfo(
        name="color_scale",
        default_value=None,
        description="A dictionary mapping filter names to a color scale factor to use for those images.",
        section="core",
        validate_func=lambda x: x is None or isinstance(x, dict | int | float),
    ),
    _ParamInfo(
        name="cnn_filter",
        default_value=False,
        description="If True, applies a CNN filter to the stamps.",
        section="filtering",
        validate_func=lambda x: isinstance(x, bool),
    ),
    _ParamInfo(
        name="cnn_model",
        default_value=None,
        description="The path to the CNN model file to use for filtering.",
        section="filtering",
        validate_func=lambda x: isinstance(x, str) or x is None,
    ),
    _ParamInfo(
        name="cnn_coadd_type",
        default_value="mean",
        description="The type of coadd to use for CNN filtering ('mean', 'median', or 'sum').",
        section="filtering",
        validate_func=lambda x: x in ["mean", "median", "sum"],
    ),
    _ParamInfo(
        name="cnn_stamp_radius",
        default_value=49,
        description="The radius (in pixels) of the stamp to use for CNN filtering if cnn_filter is True.",
        section="filtering",
        validate_func=lambda x: isinstance(x, int) and x > 0,
    ),
    _ParamInfo(
        name="cnn_model_type",
        default_value="resnet18",
        description="The type of CNN model to use ('resnet18', 'resnet34', etc.) if cnn_filter is True.",
        section="filtering",
        validate_func=lambda x: isinstance(x, str),
    ),
    _ParamInfo(
        name="coadds",
        default_value=[],
        description="The list of coadd images to compute ('mean', 'median', 'sum', 'weighted').",
        section="stamps",
        validate_func=lambda x: isinstance(x, list)
        and all(i in ["mean", "median", "sum", "weighted"] for i in x),
    ),
    _ParamInfo(
        name="compute_ra_dec",
        default_value=True,
        description="If True, compute RA and Dec for each result.",
        section="output",
        validate_func=lambda x: isinstance(x, bool),
    ),
    _ParamInfo(
        name="cpu_only",
        default_value=False,
        description="If True, only use the CPU for processing, even if a GPU is available.",
        section="other",
        validate_func=lambda x: isinstance(x, bool),
    ),
    _ParamInfo(
        name="debug",
        default_value=False,
        description="Run with debug logging enabled.",
        section="other",
        validate_func=lambda x: isinstance(x, bool),
    ),
    _ParamInfo(
        name="do_clustering",
        default_value=True,
        description="If true, perform clustering on the results.",
        section="clustering",
        validate_func=lambda x: isinstance(x, bool),
    ),
    _ParamInfo(
        name="drop_columns",
        default_value=[],
        description="List of result table columns to drop.",
        section="output",
        validate_func=lambda x: isinstance(x, list) and all(isinstance(i, str) for i in x),
    ),
    _ParamInfo(
        name="encode_num_bytes",
        default_value=-1,
        description="Number of bytes to use for encoding pixel values on GPU. -1 means no encoding.",
        section="core",
        validate_func=lambda x: x in set([-1, 1, 2, 4]),
    ),
    _ParamInfo(
        name="generator_config",
        default_value={
            "name": "EclipticCenteredSearch",
            "velocities": [92.0, 526.0, 257],
            "angles": [-math.pi / 15, math.pi / 15, 129],
            "angle_units": "radian",
            "velocity_units": "pix / d",
            "given_ecliptic": None,
        },
        description="Configuration dictionary for the trajectory generator.",
        section="core",
        validate_func=lambda x: isinstance(x, dict) and "name" in x,
    ),
    _ParamInfo(
        name="generate_psi_phi",
        default_value=True,
        description="If True, computes the psi and phi curves and saves them with the results.",
        section="filtering",
        validate_func=lambda x: isinstance(x, bool),
    ),
    _ParamInfo(
        name="gpu_filter",
        default_value=False,
        description="If True, performs initial sigmaG filtering on GPU.",
        section="filtering",
        validate_func=lambda x: isinstance(x, bool),
    ),
    _ParamInfo(
        name="lh_level",
        default_value=10.0,
        description="The log-likelihood level above which results are kept.",
        section="filtering",
        validate_func=lambda x: isinstance(x, (int, float)),
    ),
    _ParamInfo(
        name="max_masked_pixels",
        default_value=0.5,
        description="The maximum fraction of masked pixels allowed before an input image is dropped.",
        section="core",
        validate_func=lambda x: isinstance(x, (int, float)) and 0.0 <= x <= 1.0,
    ),
    _ParamInfo(
        name="max_results",
        default_value=100_000,
        description="The maximum number of results to save after all filtering.",
        section="filtering",
        validate_func=lambda x: isinstance(x, int),
    ),
    _ParamInfo(
        name="near_dup_thresh",
        default_value=10,
        description="The threshold for considering two observations as near duplicates (in pixels).",
        section="filtering",
        validate_func=lambda x: isinstance(x, int),
    ),
    _ParamInfo(
        name="nightly_coadds",
        default_value=False,
        description="If True, generate an additional coadd for each calendar date.",
        section="stamps",
        validate_func=lambda x: isinstance(x, bool),
    ),
    _ParamInfo(
        name="num_obs",
        default_value=10,
        description="The minimum number of valid observations for the trajectory to be accepted.",
        section="filtering",
        validate_func=lambda x: isinstance(x, int),
    ),
    _ParamInfo(
        name="peak_offset_max",
        default_value=None,
        description="Maximum allowed offset (in pixels) between predicted and detected peak positions.",
        section="filtering",
        validate_func=lambda x: isinstance(x, (int, float)) or x is None,
    ),
    _ParamInfo(
        name="pred_line_cluster",
        default_value=False,
        description="If True, applies line clustering to the predicted lines.",
        section="filtering",
        validate_func=lambda x: isinstance(x, bool),
    ),
    _ParamInfo(
        name="pred_line_params",
        default_value=[4.0, 2, 60],
        description="Parameters for the line prediction model.",
        section="filtering",
        validate_func=lambda x: isinstance(x, list) and len(x) == 3,
    ),
    _ParamInfo(
        name="psf_val",
        default_value=1.4,
        description="The default standard deviation of the Gaussian PSF in pixels (if not provided in the data).",
        section="core",
        validate_func=lambda x: isinstance(x, (int, float)) and x > 0.0,
    ),
    _ParamInfo(
        name="result_filename",
        default_value=None,
        description="The filename to which results will be saved.",
        section="core",
        validate_func=lambda x: isinstance(x, str) or x is None,
    ),
    _ParamInfo(
        name="results_per_pixel",
        default_value=8,
        description="The maximum number of results to return from the GPU per pixel.",
        section="filtering",
        validate_func=lambda x: isinstance(x, int) and x > 0,
    ),
    _ParamInfo(
        name="save_all_stamps",
        default_value=False,
        description="If True, save all stamps to the results.",
        section="output",
        validate_func=lambda x: isinstance(x, bool),
    ),
    _ParamInfo(
        name="save_config",
        default_value=True,
        description="If True, save the configuration used for processing.",
        section="output",
        validate_func=lambda x: isinstance(x, bool),
    ),
    _ParamInfo(
        name="separate_col_files",
        default_value=["all_stamps"],
        description="List of columns to save in separate files.",
        section="output",
        validate_func=lambda x: isinstance(x, list) and all(isinstance(i, str) for i in x),
    ),
    _ParamInfo(
        name="sigmaG_filter",
        default_value=True,
        description="If True, apply sigmaG filtering.",
        section="filtering",
        validate_func=lambda x: isinstance(x, bool),
    ),
    _ParamInfo(
        name="sigmaG_lims",
        default_value=[25, 75],
        description="The lower and upper limits for sigmaG filtering.",
        section="filtering",
        validate_func=lambda x: len(x) == 2 and x[0] < x[1] and all(isinstance(i, (int, float)) for i in x),
    ),
    _ParamInfo(
        name="stamp_radius",
        default_value=10,
        description="The radius (in pixels) of the stamp to extract.",
        section="stamps",
        validate_func=lambda x: isinstance(x, int) and x > 0,
    ),
    _ParamInfo(
        name="stamp_type",
        default_value="sum",
        description="The type of stamp to extract.",
        section="stamps",
        validate_func=lambda x: x in ["sum", "mean", "median", "weighted"],
    ),
    _ParamInfo(
        name="track_filtered",
        default_value=False,
        description="If True, track the filtered objects in the results table.",
        section="filtering",
        validate_func=lambda x: isinstance(x, bool),
    ),
    _ParamInfo(
        name="x_pixel_bounds",
        default_value=None,
        description="The x pixel bounds for the search starting location (None = use every pixel).",
        section="core",
        validate_func=lambda x: x is None or (len(x) == 2 and x[0] < x[1]),
    ),
    _ParamInfo(
        name="x_pixel_buffer",
        default_value=None,
        description="If not None, the number of x pixels beyond the image bounds to use for starting coordinates.",
        section="core",
        validate_func=lambda x: x is None or (isinstance(x, int) and x >= 0),
    ),
    _ParamInfo(
        name="y_pixel_bounds",
        default_value=None,
        description="The y pixel bounds for the search starting location (None = use every pixel).",
        section="core",
        validate_func=lambda x: x is None or (len(x) == 2 and x[0] < x[1]),
    ),
    _ParamInfo(
        name="y_pixel_buffer",
        default_value=None,
        description="If not None, the number of y pixels beyond the image bounds to use for starting coordinates.",
        section="core",
        validate_func=lambda x: x is None or (isinstance(x, int) and x >= 0),
    ),
]
[docs]class SearchConfiguration:
    """This class stores a collection of configuration parameter settings.
    Parameters
    ----------
    data : `dict`
        A dictionary of initial values.
    """
    def __init__(self, data=None):
        # Reprocess the list of supported parameters into dictionaries for easy access.
        self._param_info = {p.name: p for p in _SUPPORTED_PARAMS}
        self._params = {p.name: p.default_value for p in _SUPPORTED_PARAMS}
        if data is not None:
            self.set_multiple(data)
    def __contains__(self, key):
        return key in self._params
    def __getitem__(self, key):
        """Gets the value of a specific parameter.
        Parameters
        ----------
        key : `str`
            The parameter name.
        Raises
        ------
        Raises a KeyError if the parameter is not included.
        """
        return self._params[key]
    def __str__(self):
        result = "Configuration:\n"
        for key, value in self._params.items():
            result += f"{key}: {value}\n"
        return result
[docs]    def help(self, param=None):
        """Print help information for a specific parameter or all parameters.
        Parameters
        ----------
        param : `str`, optional
            The parameter name. If not provided, help for all parameters is printed.
        """
        if param is not None:
            if param in self._param_info:
                print(self._param_info[param])
            else:
                print(f"Parameter {param} is not recognized.")
        else:
            print("KBMOD Supported Parameters:")
            section_to_params = self._cluster_parameters()
            for section, section_keys in section_to_params.items():
                print(f"\n--- {section.capitalize()} Parameters ---")
                for key in section_keys:
                    if key in self._param_info:
                        print(self._param_info[key])
                    else:
                        print(f"{key}: No description available.") 
[docs]    def copy(self):
        """Create a new deep copy of the configuration."""
        return copy.deepcopy(self) 
[docs]    def set(self, param, value, warn_on_unknown=False):
        """Sets the value of a specific parameter.
        Parameters
        ----------
        param : `str`
            The parameter name.
        value : any
            The parameter's value.
        warn_on_unknown : `bool`
            Generate a warning if the parameter is not known.
        """
        if warn_on_unknown and param not in self._params:
            logger.warning(f"Setting unknown parameter: {param}")
        self._params[param] = value 
[docs]    def set_multiple(self, overrides):
        """Sets multiple parameters from a dictionary.
        Parameters
        ----------
        overrides : `dict`
            A dictionary of parameter->value to overwrite.
        """
        for key, value in overrides.items():
            self.set(key, value) 
[docs]    def validate(self):
        """Check that the configuration has the necessary parameters.
        Returns
        -------
        `bool`
            Returns True if the configuration is valid and False (logging the reason)
            if the configuration is invalid.
        """
        # Check parameters that have known constraints.
        for key, value in self._params.items():
            param_info = self._param_info.get(key, None)
            if param_info is not None and not param_info.validate(value):
                logger.warning(f"Invalid value for parameter {key}: {value}")
                return False
        return True 
[docs]    @classmethod
    def from_dict(cls, d):
        """Sets multiple values from a dictionary.
        Parameters
        ----------
        d : `dict`
            A dictionary mapping parameter name to valie.
        """
        config = SearchConfiguration()
        for key, value in d.items():
            config.set(key, value)
        return config 
[docs]    @classmethod
    def from_table(cls, t):
        """Sets multiple values from an astropy Table with a single row and
        one column for each parameter.
        Parameters
        ----------
        t : `~astropy.table.Table`
            Astropy Table containing the required configuration parameters.
        strict : `bool`
            Raise an exception on unknown parameters.
        Raises
        ------
        Raises a ``KeyError`` if the parameter is not part on the list of known parameters
        and ``strict`` is False.
        Raises a ``ValueError`` if the table is the wrong shape.
        """
        if len(t) > 1:
            raise ValueError(f"More than one row in the configuration table ({len(t)}).")
        # guaranteed to only have 1 element due to check above
        params = {col.name: safe_load(col.value[0]) for col in t.values()}
        return SearchConfiguration.from_dict(params) 
[docs]    @classmethod
    def from_yaml(cls, config):
        """Load a configuration from a YAML file.
        Parameters
        ----------
        config : `str` or `_io.TextIOWrapper`
            The serialized YAML data.
        """
        yaml_params = safe_load(config)
        return SearchConfiguration.from_dict(yaml_params) 
[docs]    @classmethod
    def from_hdu(cls, hdu):
        """Load a configuration from a FITS extension file.
        Parameters
        ----------
        hdu : `astropy.io.fits.BinTableHDU`
            The HDU from which to parse the configuration information.
        """
        t = Table(hdu.data)
        return SearchConfiguration.from_table(t) 
    @classmethod
    def from_file(cls, filename):
        with open(filename) as ff:
            return SearchConfiguration.from_yaml(ff.read())
[docs]    def to_hdu(self):
        """Create a fits HDU with all the configuration parameters.
        Returns
        -------
        hdu : `astropy.io.fits.BinTableHDU`
            The HDU with the configuration information.
        """
        serialized_dict = {key: dump(val, default_flow_style=True) for key, val in self._params.items()}
        t = Table(
            rows=[
                serialized_dict,
            ]
        )
        return fits.table_to_hdu(t) 
[docs]    def to_yaml(self):
        """Save a configuration file with the parameters.
        Returns
        -------
        result : `str`
            The serialized YAML string.
        """
        return dump(self._params) 
    def _cluster_parameters(self, sections=None):
        """Get the clustering parameters as a dictionary.
        Parameters
        ----------
        sections : `list` of `str`, optional
            The sections to include. If not provided, all sections are included.
        Returns
        -------
        result : `dict`
            A dictionary with the parameters clustered into sections.
        """
        if sections is None:
            sections = ["core", "filtering", "stamps", "clustering", "output", "other"]
        # Create the empty dictionary.
        section_to_params = {}
        for sec in sections:
            section_to_params[sec] = []
        if "other" not in sections:
            section_to_params["other"] = []
        # Fill the dictionary from the parameter info.
        for key, value in self._param_info.items():
            section = value.section if value.section in section_to_params else "other"
            section_to_params[section].append(key)
        return section_to_params
[docs]    def to_file(self, filename, overwrite=False):
        """Save a configuration file with the parameters.
        Parameters
        ----------
        filename : str
            The filename, including path, of the configuration file.
        overwrite : bool
            Indicates whether to overwrite an existing file.
        """
        if Path(filename).is_file() and not overwrite:
            logger.warning(f"Configuration file {filename} already exists.")
            return
        logger.info(f"Saving configuration to {filename}")
        # Output the configuration file in sections for easier reading. We add the sections
        # in the order we want them to appear.
        section_to_params = self._cluster_parameters()
        with open(filename, "w") as file:
            for section, section_keys in section_to_params.items():
                file.write("# ======================================================================\n")
                file.write(f"# {section.capitalize()} Configuration\n")
                file.write("# ======================================================================\n")
                for key in section_keys:
                    if key in self._param_info and self._param_info[key].description:
                        file.write(f"\n# {self._param_info[key].description}\n")
                        file.write(dump({key: self._params[key]}))
                file.write("\n")