Source code for lsst.sims.maf.metrics.transientMetrics

from builtins import zip
import numpy as np
from .baseMetric import BaseMetric

__all__ = ['TransientMetric']

[docs]class TransientMetric(BaseMetric): """ Calculate what fraction of the transients would be detected. Best paired with a spatial slicer. We are assuming simple light curves with no color evolution. Parameters ---------- transDuration : float, optional How long the transient lasts (days). Default 10. peakTime : float, optional How long it takes to reach the peak magnitude (days). Default 5. riseSlope : float, optional Slope of the light curve before peak time (mags/day). This should be negative since mags are backwards (magnitudes decrease towards brighter fluxes). Default 0. declineSlope : float, optional Slope of the light curve after peak time (mags/day). This should be positive since mags are backwards. Default 0. uPeak : float, optional Peak magnitude in u band. Default 20. gPeak : float, optional Peak magnitude in g band. Default 20. rPeak : float, optional Peak magnitude in r band. Default 20. iPeak : float, optional Peak magnitude in i band. Default 20. zPeak : float, optional Peak magnitude in z band. Default 20. yPeak : float, optional Peak magnitude in y band. Default 20. surveyDuration : float, optional Length of survey (years). Default 10. surveyStart : float, optional MJD for the survey start date. Default None (uses the time of the first observation). detectM5Plus : float, optional An observation will be used if the light curve magnitude is brighter than m5+detectM5Plus. Default 0. nPrePeak : int, optional Number of observations (in any filter(s)) to demand before peakTime, before saying a transient has been detected. Default 0. nPerLC : int, optional Number of sections of the light curve that must be sampled above the detectM5Plus theshold (in a single filter) for the light curve to be counted. For example, setting nPerLC = 2 means a light curve is only considered detected if there is at least 1 observation in the first half of the LC, and at least one in the second half of the LC. nPerLC = 4 means each quarter of the light curve must be detected to count. Default 1. nFilters : int, optional Number of filters that need to be observed for an object to be counted as detected. Default 1. nPhaseCheck : int, optional Sets the number of phases that should be checked. One can imagine pathological cadences where many objects pass the detection criteria, but would not if the observations were offset by a phase-shift. Default 1. countMethod : {'full' 'partialLC'}, defaults to 'full' Sets the method of counting max number of transients. if 'full', the only full light curves that fit the survey duration are counted. If 'partialLC', then the max number of possible transients is taken to be the integer floor """ def __init__(self, metricName='TransientDetectMetric', mjdCol='observationStartMJD', m5Col='fiveSigmaDepth', filterCol='filter', transDuration=10., peakTime=5., riseSlope=0., declineSlope=0., surveyDuration=10., surveyStart=None, detectM5Plus=0., uPeak=20, gPeak=20, rPeak=20, iPeak=20, zPeak=20, yPeak=20, nPrePeak=0, nPerLC=1, nFilters=1, nPhaseCheck=1, countMethod='full', **kwargs): self.mjdCol = mjdCol self.m5Col = m5Col self.filterCol = filterCol super(TransientMetric, self).__init__(col=[self.mjdCol, self.m5Col, self.filterCol], units='Fraction Detected', metricName=metricName, **kwargs) self.peaks = {'u': uPeak, 'g': gPeak, 'r': rPeak, 'i': iPeak, 'z': zPeak, 'y': yPeak} self.transDuration = transDuration self.peakTime = peakTime self.riseSlope = riseSlope self.declineSlope = declineSlope self.surveyDuration = surveyDuration self.surveyStart = surveyStart self.detectM5Plus = detectM5Plus self.nPrePeak = nPrePeak self.nPerLC = nPerLC self.nFilters = nFilters self.nPhaseCheck = nPhaseCheck self.countMethod = countMethod
[docs] def lightCurve(self, time, filters): """ Calculate the magnitude of the object at each time, in each filter. Parameters ---------- time : numpy.ndarray The times of the observations. filters : numpy.ndarray The filters of the observations. Returns ------- numpy.ndarray The magnitudes of the object at each time, in each filter. """ lcMags = np.zeros(time.size, dtype=float) rise = np.where(time <= self.peakTime) lcMags[rise] += self.riseSlope * time[rise] - self.riseSlope * self.peakTime decline = np.where(time > self.peakTime) lcMags[decline] += self.declineSlope * (time[decline] - self.peakTime) for key in self.peaks: fMatch = np.where(filters == key) lcMags[fMatch] += self.peaks[key] return lcMags
[docs] def run(self, dataSlice, slicePoint=None): """" Calculate the detectability of a transient with the specified lightcurve. Parameters ---------- dataSlice : numpy.array Numpy structured array containing the data related to the visits provided by the slicer. slicePoint : dict, optional Dictionary containing information about the slicepoint currently active in the slicer. Returns ------- float The total number of transients that could be detected. """ # Total number of transients that could go off back-to-back if self.countMethod == 'partialLC': _nTransMax = np.ceil(self.surveyDuration / (self.transDuration / 365.25)) else: _nTransMax = np.floor(self.surveyDuration / (self.transDuration / 365.25)) tshifts = np.arange(self.nPhaseCheck) * self.transDuration / float(self.nPhaseCheck) nDetected = 0 nTransMax = 0 for tshift in tshifts: # Compute the total number of back-to-back transients are possible to detect # given the survey duration and the transient duration. nTransMax += _nTransMax if tshift != 0: nTransMax -= 1 if self.surveyStart is None: surveyStart = dataSlice[self.mjdCol].min() time = (dataSlice[self.mjdCol] - surveyStart + tshift) % self.transDuration # Which lightcurve does each point belong to lcNumber = np.floor((dataSlice[self.mjdCol] - surveyStart) / self.transDuration) lcMags = self.lightCurve(time, dataSlice[self.filterCol]) # How many criteria needs to be passed detectThresh = 0 # Flag points that are above the SNR limit detected = np.zeros(dataSlice.size, dtype=int) detected[np.where(lcMags < dataSlice[self.m5Col] + self.detectM5Plus)] += 1 detectThresh += 1 # If we demand points on the rise if self.nPrePeak > 0: detectThresh += 1 ord = np.argsort(dataSlice[self.mjdCol]) dataSlice = dataSlice[ord] detected = detected[ord] lcNumber = lcNumber[ord] time = time[ord] ulcNumber = np.unique(lcNumber) left = np.searchsorted(lcNumber, ulcNumber) right = np.searchsorted(lcNumber, ulcNumber, side='right') # Note here I'm using np.searchsorted to basically do a 'group by' # might be clearer to use scipy.ndimage.measurements.find_objects or pandas, but # this numpy function is known for being efficient. for le, ri in zip(left, right): # Number of points where there are a detection good = np.where(time[le:ri] < self.peakTime) nd = np.sum(detected[le:ri][good]) if nd >= self.nPrePeak: detected[le:ri] += 1 # Check if we need multiple points per light curve or multiple filters if (self.nPerLC > 1) | (self.nFilters > 1): # make sure things are sorted by time ord = np.argsort(dataSlice[self.mjdCol]) dataSlice = dataSlice[ord] detected = detected[ord] lcNumber = lcNumber[ord] time = time[ord] ulcNumber = np.unique(lcNumber) left = np.searchsorted(lcNumber, ulcNumber) right = np.searchsorted(lcNumber, ulcNumber, side='right') detectThresh += self.nFilters for le, ri in zip(left, right): points = np.where(detected[le:ri] > 0) ufilters = np.unique(dataSlice[self.filterCol][le:ri][points]) phaseSections = np.floor(time[le:ri][points] / self.transDuration * self.nPerLC) for filtName in ufilters: good = np.where(dataSlice[self.filterCol][le:ri][points] == filtName) if np.size(np.unique(phaseSections[good])) >= self.nPerLC: detected[le:ri] += 1 # Find the unique number of light curves that passed the required number of conditions nDetected += np.size(np.unique(lcNumber[np.where(detected >= detectThresh)])) # Rather than keeping a single "detected" variable, maybe make a mask for each criteria, then # reduce functions like: reduce_singleDetect, reduce_NDetect, reduce_PerLC, reduce_perFilter. # The way I'm running now it would speed things up. return float(nDetected) / nTransMax