from typing import List, Tuple
import numpy as np
import pandas as pd
import scipy.interpolate as interp
from astropy.table import Table
from scipy.ndimage.filters import gaussian_filter
from casper.interface.gisic import norm_functions
from casper.interface.gisic.segment import Segment
from casper.utils.logger_config import setup_logger
logger = setup_logger(__name__)
[docs]
class Spectrum:
def __init__(self, wavelength: np.ndarray, flux: np.ndarray) -> None:
"""
Initialize a Spectrum object with wavelength and flux data.
Parameters
----------
wavelength : np.ndarray
The wavelength values of the spectrum.
flux : np.ndarray
The flux values corresponding to each wavelength.
Sets
----
self.wavelength : np.ndarray
Input wavelength array.
self.flux : np.ndarray
Input flux array.
self.segments : Optional[List]
Will store segmentation data for continuum fitting (initialized to None).
self.mad_global : Optional[float]
Global MAD (Median Absolute Deviation) for the entire spectrum (initialized to None).
"""
self.wavelength = wavelength
self.flux = flux
self.segments = None
self.mad_global = None
return
[docs]
def generate_segments(self, bins: int = 25, lower: int = 70) -> None:
"""
Divide the spectrum into segments and compute statistics for each.
Parameters
----------
bins : int, optional
Number of equal-sized segments to divide the spectrum into. Default is 25.
lower : int, optional
Lower percentile threshold for flux clipping during statistics calculation. Default is 70.
Sets
----
self.segments : List[Segment]
A list of Segment objects representing subdivided portions of the spectrum.
The first and last segments are marked as edges.
Each segment will also compute its own local statistics (e.g., MAD, median flux).
"""
self.segments = [
Segment(wl, flux)
for wl, flux in zip(np.array_split(self.wavelength, bins), np.array_split(self.flux, bins))
]
self.segments[0].is_edge("left")
self.segments[-1].is_edge("right")
[segment.get_statistics(lower) for segment in self.segments]
return
[docs]
def generate_inflection_segments(
self,
sigma: int = 30,
width: int = 30,
cahk: bool = False,
cahkwidth: int = 2,
band_check: bool = False,
flux_min: float = 70,
) -> None:
"""
Identify inflection segments in the spectrum based on second derivative zero crossings.
Parameters
----------
sigma : int, optional
Gaussian smoothing kernel width for the flux array. Default is 30.
width : int, optional
Width of each segment around identified inflection points. Default is 30.
cahk : bool, optional
If True, forces inflection segments at Ca H (3916Å) and Ca K (3991Å) lines. Default is False.
cahkwidth : int, optional
Width around the Ca H&K lines if `cahk` is True. Default is 2.
band_check : bool, optional
Whether to skip segments inside known molecular bands. Default is False.
flux_min : float, optional
Lower percentile threshold used during segment statistics estimation. Default is 70.
Sets
----
self.segments : List[Segment]
List of spectral segments created around inflection points.
self.frame : pd.DataFrame
Intermediate DataFrame containing smoothed flux and its derivatives.
self.ZEROS : List[float]
List of wavelength values where the second derivative crosses zero.
"""
self.smooth = gaussian_filter(self.flux, sigma=sigma)
self.d1 = np.gradient(self.smooth)
self.d2 = np.gradient(self.d1)
self.d1 = self.d1 / max(self.d1)
self.d2 = self.d2 / max(self.d2)
hack = Table(
[self.wavelength, self.flux, self.d1, self.d2],
names=("wave", "flux", "d1", "d2"),
dtype=("f8", "f8", "f8", "f8"),
)
self.frame = pd.DataFrame({"wave": hack["wave"], "flux": hack["flux"], "d1": hack["d1"], "d2": hack["d2"]})
self.ZEROS = []
for i in range(len(self.d2) - 1):
if self.d2[i] * self.d2[i + 1] < 0.0:
self.ZEROS.append((self.frame.wave[i] + self.frame.wave[i + 1]) / 2.0)
elif self.d2[i] == 0.0:
self.ZEROS.append(self.frame.wave[i])
MINIMUMS = []
for i in range(len(self.ZEROS) - 1):
SEGMENT = self.frame[self.frame["wave"].between(self.ZEROS[i], self.ZEROS[i + 1], inclusive="both")].copy()
if len(SEGMENT[SEGMENT["d2"] < 0.0]) > 0.8 * len(SEGMENT["d2"]):
MIN = SEGMENT[SEGMENT["d2"] == min(SEGMENT["d2"])].copy()
MIN.loc[:, "size"] = len(SEGMENT)
MINIMUMS.append(MIN)
if cahk:
SEG1 = self.frame[self.frame["wave"].between(3916 - cahkwidth, 3916 + cahkwidth, inclusive="both")]
SEG2 = self.frame[self.frame["wave"].between(3991 - cahkwidth, 3991 + cahkwidth, inclusive="both")]
SEG1 = SEG1[SEG1["flux"] == max(SEG1["flux"])].copy()
SEG2 = SEG2[SEG2["flux"] == max(SEG2["flux"])].copy()
SEG1.loc[:, "size"] = len(SEG1)
SEG2.loc[:, "size"] = len(SEG2)
MIN_CAT = pd.concat(MINIMUMS)
EXTREMA = pd.concat([MIN_CAT, SEG1, SEG2])
EXTREMA = EXTREMA.sort_values(by="wave")
EXTREMA = EXTREMA.iloc[np.unique(EXTREMA["wave"], return_index=True)[1]]
else:
EXTREMA = pd.concat(MINIMUMS)
self.segments = []
if band_check:
for i, row in EXTREMA.iterrows():
if not norm_functions.in_molecular_band(row["wave"], tol=10):
SEGMENT = self.frame[
self.frame["wave"].between(
row["wave"] - int(width / 2), row["wave"] + int(width / 2), inclusive="both"
)
].copy()
self.segments.append(Segment(np.array(SEGMENT["wave"]), np.array(SEGMENT["flux"])))
else:
pass
else:
for i, row in EXTREMA.iterrows():
SEGMENT = self.frame[
self.frame["wave"].between(
row["wave"] - int(width / 2), row["wave"] + int(width / 2), inclusive="both"
)
].copy()
self.segments.append(Segment(np.array(SEGMENT["wave"]), np.array(SEGMENT["flux"])))
self.segments.insert(
0, Segment(np.array(self.frame["wave"].iloc[0:width]), np.array(self.frame["flux"].iloc[0:width]))
)
self.segments.append(
Segment(np.array(self.frame["wave"].iloc[-width:]), np.array(self.frame["flux"].iloc[-width:]))
)
self.segments[0].is_edge("left")
self.segments[-1].is_edge("right")
[segment.get_statistics(flux_min=flux_min) for segment in self.segments]
return
[docs]
def assess_segment_variation(self) -> None:
"""
Compute variation metrics for each segment based on normalized MAD (Median Absolute Deviation).
This method evaluates the variability of flux within each segment by collecting normalized
MAD values. It then calculates global statistics across all segments and derives a
scaled relative MAD array used for weighting or further processing.
Sets
----
self.mad_array : np.ndarray
Normalized MAD values for all segments.
self.mad_global : float
Global median of the segment MAD values.
self.mad_min : float
Minimum normalized MAD value across segments.
self.mad_max : float
Maximum normalized MAD value across segments.
self.mad_range : float
Difference between mad_max and mad_min.
self.mad_relative_array : np.ndarray
Normalized MAD values scaled between 0 and 1.
"""
self.mad_array = np.array([segment.mad_normal for segment in self.segments], dtype=float)
self.mad_global = np.median(self.mad_array)
self.mad_min, self.mad_max = min(self.mad_array), max(self.mad_array)
self.mad_range = self.mad_max - self.mad_min
self.mad_relative_array = np.divide(self.mad_array - self.mad_min, self.mad_range)
return
[docs]
def define_cont_points(self, boost: float) -> None:
"""
Define continuum points for all segments using a boosted MAD-normalized strategy.
This method loops over all Segment objects in `self.segments` and calls their
`define_cont_point()` method. The boost value scales the clipped median flux
based on the MAD distribution across segments.
Parameters
----------
boost : float
A scaling factor applied to adjust the continuum point above the median
flux, based on signal variation.
Returns
-------
None
Notes
-----
`self.assess_segment_variation()` must be called beforehand to ensure
`self.mad_min` and `self.mad_range` are initialized.
"""
[segment.define_cont_point(self.mad_min, self.mad_range, boost=boost) for segment in self.segments]
[docs]
def set_segment_midpoints(self) -> np.ndarray:
"""
Collect and store the midpoint wavelength of each segment.
This method iterates over all spectral segments and extracts their midpoint values,
storing them in `self.midpoints`.
Returns
-------
np.ndarray
Array of midpoint wavelength values for all segments.
"""
self.midpoints = [segment.midpoint for segment in self.segments]
return np.array(self.midpoints, dtype=float)
[docs]
def set_segment_continuum(self) -> np.ndarray:
"""
Collect and store the continuum flux point of each segment.
This method gathers the estimated continuum flux points from each spectral segment
and stores them in `self.fluxpoints`.
Returns
-------
np.ndarray
Array of continuum flux values for all segments.
"""
self.fluxpoints = [segment.continuum_point for segment in self.segments]
return np.array(self.fluxpoints, dtype=float)
[docs]
def add_continuum_point(self, point: Tuple[float, float]) -> None:
"""
Add a new continuum anchor point and sort existing points by wavelength.
Parameters
----------
point : tuple of float
A tuple containing (wavelength, flux) to be added to the continuum points.
Notes
-----
After adding the new point, both `self.midpoints` and `self.fluxpoints` are
sorted based on the wavelength values in `self.midpoints`.
"""
self.midpoints.append(point[0])
self.fluxpoints.append(point[1])
self.fluxpoints = list(np.array(self.fluxpoints)[np.argsort(self.midpoints)])
self.midpoints = list(np.array(self.midpoints)[np.argsort(self.midpoints)])
return
[docs]
def remove_point(self, index: List[int]) -> None:
"""
Remove continuum anchor points at the specified indices.
Parameters
----------
index : list of int
A list of indices indicating which points to remove from `midpoints` and `fluxpoints`.
Notes
-----
Indices are assumed to be sorted in increasing order. To avoid indexing errors
due to shifting during deletion, removal is done in reverse order.
"""
for i, value in enumerate(index):
del self.midpoints[value - i]
del self.fluxpoints[value - i]
return
[docs]
def get_continuum_points(self) -> None:
"""
Print all continuum anchor points.
Prints the index, midpoint wavelength, and corresponding flux value
for each continuum anchor point stored in `self.midpoints` and `self.fluxpoints`.
Returns
-------
None
"""
for i in range(len(self.midpoints)):
logger.info(f"{i}: {self.midpoints[i]} {self.fluxpoints[i]}")
[docs]
def set_wavelength(self, wavelength: list[float]) -> None:
"""
Set the list of continuum midpoint wavelengths.
Parameters
----------
wavelength : list of float
List of wavelength values representing the continuum midpoints.
Returns
-------
None
"""
self.midpoints = wavelength
return
[docs]
def set_fluxpoints(self, flux_values: list[float]) -> None:
"""
Set the list of flux values at continuum midpoints.
Parameters
----------
flux_values : list of float
List of flux values corresponding to each continuum midpoint.
Returns
-------
None
"""
self.fluxpoints = flux_values
return
[docs]
def spline_continuum(self, k: int = 3, s: float = 5.0) -> None:
"""
Fit a spline to the continuum points and evaluate it across the full wavelength range.
Parameters
----------
k : int, optional
Degree of the smoothing spline. Default is 3 (cubic spline).
s : float, optional
Smoothing factor. Larger values result in smoother fits. Default is 5.0.
Returns
-------
None
"""
tck = interp.splrep(self.midpoints, self.fluxpoints, k=k, s=s, quiet=True)
self.continuum = interp.splev(self.wavelength, tck)
[docs]
def normalize(self) -> None:
"""
Normalize the flux by dividing it by the fitted continuum.
This method computes the normalized flux (`flux_norm`) by dividing
the raw flux values by the continuum. Extreme values below 0 or
above 2 are clipped to 1.0 for stability.
Returns
-------
None
"""
self.flux_norm = np.divide(self.flux, self.continuum)
if len(self.flux_norm[self.flux_norm < 0.0]) > 1:
self.flux_norm[self.flux_norm < 0.0] = 0.0
if len(self.flux_norm[self.flux_norm > 2.0]) > 1:
self.flux_norm[self.flux_norm > 2.0] = 1.0