from typing import List,Dict,Union,Tuple
import math
import numpy as np
import scipy.optimize
from show_config.show_configuration import Parameters
from .level0imagearray import Level0_ImageCollection, Level0_Image,  ImageStats
#------------------------------------------------------------------------------
#           Level1_Spectra
#------------------------------------------------------------------------------
[docs]class Level1_Spectra:
    """
    Instance Attributes:
    .. py:attribute:: mjd
        The UTC time of the image represented as a modified julian date. Float.
    .. py:attribute:: exposure_time
        The exposure time in micro-seconds seconds.  Float.
    .. py:attribute:: sensor_temp
        The detector temperature in Celsius. Float.
    .. py:attribute:: sensor_pcb_temp
        The detector printed circuit board temperature in Celsius. Float.
    .. py:attribute:: top_intferom_temp
        The temperature in Celsius of the top of the interferometer. Float.
    .. py:attribute:: bottom_intferom_temp
        The temperature in Celsius of the bottom of the interferometer.  Float
    .. py:attribute:: optobox_temp
        The temperature in Celsius of the optics box. Float
    .. py:attribute:: Q7_temp
        The temperature in Celsius of Q7. Float
    .. py:attribute:: high_gain
        The detector gain setting. The value is `True` for high gain and `False` for low gain.
    .. py:attribute:: comment
        An optional comment string for each image. The comment is typically supplied by the operator during data acquisition.
    .. py:attribute:: spectrum
        The Level 1 spectral image expressed as a numpy 2-D float array of dimension (H,M) where H is the number of height rows and M is the number
        of transform frequencies, typically half the number of interfrogram pixels.
    .. py:attribute:: error
        The error on the Level 1 spectrum . The error field may be None meaning no error value is available. If it is not None then it will be
        a numpy 2-D float array of dimension (H,M) where H is the number of height rows and M is the number of transform fequencies.
        It will be the the same size as the spectrum attribute.
    .. py:attribute:: ifgram_bounds
        The sub-window bounds used to select the useful detector area from the original interferogram. This is a four elelemnt tuple (x0,x1,y0,y1)
    """
    def __init__(self):
        self.mjd                  = 0.0
        self.exposure_time        = 0.0
        self.sensor_temp          = 0.0
        self.setpoint_temp        = 0.0
        self.sensor_pcb_temp      = 0.0
        self.top_intferom_temp    = 0.0
        self.bottom_intferom_temp = 0.0
        self.optobox_temp         = 0.0
        self.Q7_temp              = 0.0
  #      self.tec_on               = 0.0
        self.high_gain            = 0.0
        self.comment              = None
        self.spectrum             = None
        self.error                = None
        self.ifgram_bounds        = [] 
#------------------------------------------------------------------------------
#           class L0Algorithms:
#------------------------------------------------------------------------------
[docs]class L0Algorithms:
    def __init__(self, parameters: Parameters ) -> None:
        self.params  = parameters                                        # type: Parameters
    #------------------------------------------------------------------------------
    #           L0Algorithms::rms_signal_error_from_detector_specs
    #------------------------------------------------------------------------------
[docs]    def rms_signal_error_from_detector_specs(self, dnsignal: np.ndarray ) ->np.ndarray:
        """
        Fetch the theoretical error on the signal read from the detector in DN.  Considers Poisson counting statistics
        on total number of electrons an dthe detector readout noise.
        :param dnsignal: The signal readout from the detector. Note that you may have to subtract any DC bias
         ont the detector before calling this routine
        :return: The calculated error in DN.
        """
        Edn = self.params.config['level0']['electrons_per_DN']              # Get electrons per DN
        Rde = self.params.config['level0']['detector_readout_noise']        # get readout noise in electrons
        Ne = dnsignal*Edn                                                   # get the number of electrons in signal, (this is the square of the error!)
        Netot = np.sqrt( Ne + Rde*Rde )                                     # get the total electron error, Ne poisson error^2 + readout error^2
        dnerror = Netot/Edn                                                 # convert total electron error to DN
        return dnerror                                                      # return the dnerror 
    #------------------------------------------------------------------------------
    #               L0Algorithms::removedc_bias_from_apodized_interferogram
    #------------------------------------------------------------------------------
[docs]    def apodized_interferogram_with_zero_bias(self, x: np.ndarray, y : np.ndarray ) -> np.ndarray:
        """
        Apply an apodization function, *y*,  to the each height row of the interferogram, *x*. Apply
        a dc offset to the interferogram, *x*, such that the average of
        the product of the interferogram, *x*, and the apodization function, *y*, is zero. This eliminates the large
        zero order component bleeding into nearby frequencies after we perform the FFT.
        Note that removing DC bias from the interferogram before apodization is applied results in a slight non-zero
        dc component that makes a quite large spike in the zero harmonic which  bleeds into neighbouring frequencies. This technique
        removes that bias and helps avoid bleed through of the zero harmonic in the fft spectrum.
        :param x:  Original interferogram, a 2d-array of size (H,M). We need to average over M
        :param y: a 2d array of size (H,M). We need to average over M
        :return: the apodized integrforogram with zero bias for each height row, 2-d array (H,M)
        """
        H,M = x.shape
        xy      = x*y
        xyav    = np.average( xy, 1)                    # Get the average signal of the apodized signal at each height,     array (H)
        yav     = np.average( y,  1)                    # Get the average value of the apodization function at each height, array (H)
        b       = xyav/yav                              # Get the DC correction to subtrcat from original signal to zero bias apodized signal, array (H)
        bav     = np.tile(b, (M, 1)).transpose()        # make the DC correction into 2-D array
        xy      = (x-bav)*y
        return xy 
    #------------------------------------------------------------------------------
    #           L0Algorithms::rms_frequency_error_from_rms_spatial_error
    #------------------------------------------------------------------------------
[docs]    def rms_frequency_error_from_rms_spatial_error(self, spatial_error_array: np.ndarray, include_apodization=True, real_component_only= False ) -> Tuple[np.ndarray, np.ndarray]:
        """
        Calculates the theoretical root mean square error in Fourier transform space given the root mean square error in interferogram space. The algorithm is applied to
        each height row independently. For each height row only one value, the RMS error is returned.
        :param spatialerrorarray: 2D array (H,M) of errors for M points  at each height (H) row.
        :param include_apodization: Default True. If True then divide theoretrical RMS error by a factor of 2
        :param real_component_only: Default False. If True then divide theoretical RMS errror by sqrt(2)
        :return: the theoretical root mean square error at each height row. An array (H)
        """
        H,N = spatial_error_array.shape
        Ex = np.sqrt( np.sum( spatial_error_array*spatial_error_array , 1)/N )      # Get the rms error in spatial space, Ex(H)
        factor = 0.5*math.sqrt(N) if include_apodization else math.sqrt(N)          # get the factor for the fourier transform RMS error
        if real_component_only: factor = factor/math.sqrt(2)                        # See if we are getting real+imag or only real component
        Ef     = factor*Ex                                                          # and get the RMS error in transform space.
        return Ef, Ex 
    #------------------------------------------------------------------------------
    #           L0Algorithms::interferogram_subwindow
    #------------------------------------------------------------------------------
[docs]    def interferogram_subwindow(self, level0: np.ndarray, userdefinederror : np.ndarray=None) -> Tuple[np.ndarray, np.ndarray, Tuple[int, int, int, int]]:
        """
        Fetch the interferogram image and error arrays sub-windowed into the useful detector area.
        :param level0: The incoming :class:`~showapi.level0.level0imagearray.Level0_Image` record.
        :param userdefinederror: Default None. If not None then derive the error on the record's image from this array rathe rthan the *error* field stored in the record
        :return:
            A three element tuple containing, (i) the sub-windows image, (ii) the sub-windowe error and (iii) the detector bounding region as (x0,x1,y0,y1).
        """
        x0  = self.params.fft_window_start
        x1  = self.params.fft_window_end
        y0  = self.params.height_window_start
        y1  = self.params.height_window_end
        ifgram  = np.copy( level0[ y0:(y1+1), x0:(x1+1)])                               # Select the subset of pixels used for the interferogram
        e       = userdefinederror                                                      # select the inteferogram error source
        error   = np.copy( e[y0:(y1+1), x0:(x1+1)]) if e is not None else None          # sub-window the inteferogram error if its not none
        return ifgram, error, (x0,x1,y0,y1) 
    #------------------------------------------------------------------------------
    #           L0Algorithms::subwindow_array
    #------------------------------------------------------------------------------
[docs]    def subwindow_array(self, image: np.ndarray) -> np.ndarray:
        """
        Fetch this instruments sub-window from a given image.
        :param image: An incoming 2-D array which will be sub-windowed. It is assumed it is in the same shape as the detector
        :return: The sub-windowed image as a 2-D array
        """
        x0  = self.params.fft_window_start
        x1  = self.params.fft_window_end
        y0  = self.params.height_window_start
        y1  = self.params.height_window_end
        ifgram  = np.copy( image[ y0:(y1+1), x0:(x1+1)])                         # Select the subset of pixels used for the interferogram
        return ifgram 
    #------------------------------------------------------------------------------
    #               L0Algorithms::interferogram_to_spectrum
    #------------------------------------------------------------------------------
[docs]    def interferogram_to_spectrum( self, level0 : np.ndarray, userdefinederror : np.ndarray = None )-> Tuple[ np.ndarray, np.ndarray, np.ndarray, np.ndarray] :
        """
        Converts the interferogram to spectra by applying an apodization (Hanning) window and taking the FFT at each
        selected height level. A sub-window (see below) is selected to create clean interferograms of dimension (H,M).
        The clean interferogram is apodized and fourier transformed to make a spectral 2-D image of size (H, M/2+1), ie we remove
        frequencies above the nyquist limit. The new spectral image is returned as a :class:`~showapi.level0.l0algorithms.Level1_Spectra` object.
        The code selects a sub-window as defined by the instrument parameters fft_window_start,
        fft_window_end, height_window_start and height_window_end. This sub-window is meant to eliminate all the useless
        edge areas of the detector. There are no requirements on the size of the sub-window, for example it does not have to
        be a power of 2 etc. Users should eliminate bad and marginal regions near the edges of the detector with the sub-window
        as this is much better leaving them in where they ultimately corrupt the entire signal.
        The code removes a DC bias from the original intefreogram such that the zero order harmonic of the FFT is zero. There
        is a subtlety in this process as we remove a DCbias from the originl signal so the average of the product of the
        orginal signal times the apodization function is zero. This is not quite the same as ensuring the average of the
        original signal is zero. If this correction is not made then we see the zero harmonic get quite large and it bleeds out
        into neighbouring spectral pixels during the FFT.
        :param level0:
            A level 0 record, :class:`~showapi.level0.level0imagearray.Level0_Image`. This contains both the image which is transformed and a header. Both header and transormed image are copied  to the
            level 1 spectral object, :class:`~showapi.level0.l0algorithms.Level1_Spectra`. By default an error estimate of the transform is made from the error estimate in the Level 0 record. The error estimate can
            be overridden using a user defined error estimate
        :param userdefinederror:
            Default = None. A user defined 2-d array specifying the error on the Level 0 interferogram image. If provided this array will be used in the error analysis and propagation instead of
            the Level 0 records error field. 2-d image.
        :return:
            A three element tuple containing (i) the desired :class:`~showapi.level0.l0algorithms.Level1_Spectra` object, (ii)  the windowed and apodized interferogram, array(H,M) and (iii) the windowed but not apodized interferogram, array (H,M)
        """
        ifgram, iferror, bounds  =  self.interferogram_subwindow(level0, userdefinederror=userdefinederror)                        # Select the subset of pixels used for the interferogram, and get the bounding region indices
        H,M       = ifgram.shape                                                      # get the shape of the sub-window
        han       = np.tile(np.hanning(M), (H, 1))                                    # Select the apodization function for the FFT
        ifgramwh  = self.apodized_interferogram_with_zero_bias( ifgram, han )         # apply apodization and zero bias the average of the product
        spectra   = np.fft.rfft( ifgramwh )                                           # Compute the real FFT. The real fft does not compute the negative frequencies
        if iferror is not None:
            EFrms     = np.sqrt( np.sum( iferror*iferror, 1 )/2)                          # Get the root mean square error in fourier transform space for each height, array (H)
            HH,FF = spectra.shape                                                         # get the shape of the spectra, we want F as its about half of M
            error = np.tile( EFrms, (FF,1)).transpose()                                   # Now tile the root mean square error across the F spectral elements at each height
        else:
            error = None
        return spectra, error, ifgramwh, ifgram 
    # ------------------------------------------------------------------------------
    #               _cosinefunc( param, M, s):
    # ------------------------------------------------------------------------------
    @staticmethod
    def _cosinefunc(param, M, s):
        """
        Evaluates the difference between a cosine function modulated by gaussian envelope and measured spectrum ``s``.
        :param M:
        :param s:
        :return:
        """
        A     = param[0]
        w     = param[1]
        phi   = param[2]
        K     = param[3]
        x0    = param[4]
        sigma = param[5]
        x = np.arange(0, M)
        sigma2 = sigma * sigma
        x2 = np.square( x-x0)
        g = np.exp(x2/(-2 * sigma2))
        y = (A * np.cos(w * x + phi)*g + K) - s
        return y
    # ------------------------------------------------------------------------------
    #               _cosinefunc_jacobian(param, M, s):
    # ------------------------------------------------------------------------------
    @staticmethod
    def _cosinefunc_jacobian(param, M, s):
        """
        Generates a cosine function with Gaussian envelope used to fit to the Krypton
        fringe data
        :param M: A tuple of parameters
        :param s: The measured spectrum
        :return: The diffreence between modelled function and measured spectrum
        """
        A     = param[0]
        w     = param[1]
        phi   = param[2]
        K     = param[3]
        x0    = param[4]
        sigma = param[5]
        x     = np.arange(0, M)
        J     = np.zeros([M, 6])
        sigma2 = sigma*sigma
        x1     = (x-x0)
        x2     = np.square(x1)
        g      = A*np.exp( x2/(-2*sigma2) )
        y =  np.cos(w * x + phi)*g
        z = -np.sin(w * x + phi)*g
        J[:, 0] = y/A
        J[:, 1] = z*x
        J[:, 2] = z
        J[:, 3] = np.zeros([M]) + 1
        J[:, 4] = y*x1/(sigma2)
        J[:, 5] = y*x2/(sigma2*sigma)
        return J
    # ------------------------------------------------------------------------------
    #               def fit_cosine_with_exponential(self, algo: L0Algorithms, krdata: Level0_Image):
    # ------------------------------------------------------------------------------
[docs]    def fit_cosine_with_gaussian_envelope(self, krdata_subwindow: np.ndarray ):
        """
        Fits a cosine with a gaussian enevelope to each row of a sub-windowed interferogram using a least squares approach
        :param krdata_subwindow:
            A 2-D interferogram sub-window of size (H,M) where H is the number of height rows and M is the number of interferogram bins.
            The cosine and gaussian enevelope is fitted to each row of the sub-window.
        :return:
            a 2-D array of size (6,H) where H is the number of heightrows and 6 is the number of fitted parameters.
        Theory: The code fits the following function to the measured interferogram,
        .. math::
           y = A\\cos (\\omega x + \\phi)e^{-\\frac{(x-x_0)^2}{2\\sigma^2} } + K
        where :math:`x` is the inteferogram bin number and we fit for the following 6 variables,
        1. :math:`A`
        2. :math:`\\omega`
        3. :math:`\\phi`
        4. :math:`K`
        5. :math:`x_0`
        6. :math:`\\sigma`
        The least squares algorithm uses the analytic differentials to find the best fit parameters. Let
        .. math::
            z = A\\sin (\\omega x + \\phi)e^{-\\frac{(x-x_0)^2}{2\\sigma^2}}
        then the differentials are given by
        .. math::
           \\frac{\\partial y}{\\partial A} &= \\frac{y-K}{A} \\\\
           \\frac{\\partial y}{\\partial \\omega} &= -zx \\\\
           \\frac{\\partial y}{\\partial \\phi} &= -z \\\\
           \\frac{\\partial y}{\\partial K} &= 1 \\\\
           \\frac{\\partial y}{\\partial x_0} &= (y-K)\\frac{(x-x_0)}{\\sigma^2} \\\\
           \\frac{\\partial y}{\\partial \\sigma} &= (y-K)\\frac{(x-x_0)^2}{\\sigma^3} \\\\
        """
        image    = krdata_subwindow
        H, M     = image.shape                                                          # Get the Number of heights and number of interferogram pixel
        han      = np.tile(np.hanning(M), (H, 1))                                       # Select the apodization function for the FFT
        ifgramwh = self.apodized_interferogram_with_zero_bias(image, han)               # apply apodization and zero bias the average of the product
        spectra  = np.fft.rfft(ifgramwh)                                                 # and get the real Fourier transform
        x        = np.arange(0, M, 1)                                                    # get the values of x for each interferogram bin
        fitimage   = np.zeros( image.shape )
        fitresults = np.zeros( [6,H] )
        for ih in range(H):                                                             # For each valid height
            s = image[ih, :]                                                            # get the Krypton data for this height
            ffts = spectra[ih, :]                                                       # Get the spectrum for this height
            imax = np.argmax(np.absolute(ffts))                                         # Find the maximum frequency in the spectrum. This is probably Krypton
            w = 2 * math.pi * imax / M                                                  # Find the nominal frequency of this component.
            phi = np.angle(ffts[imax])                                                  # Get the phase of this compoennts
            A = 4 * np.absolute(ffts[imax]) / M                                         # and the magnirude of the cosine
            K = np.average(s)                                                           # Guess the nominal background signal
            xx0 = np.argmax(s)                                                          # Centre of gaussian envelope is on the maxim interferogram signal
            sigma = M/2                                                                 # Guess the standard deviation is half the width
            x0 = [A, w, phi, K, xx0, sigma ]                                            # Set up the least squares vector
            results = scipy.optimize.least_squares(L0Algorithms._cosinefunc, x0, jac=L0Algorithms._cosinefunc_jacobian, args=[M, s]) # and so the least squares fit.
            x = results.x                                                               # get the results from the fit.
            fitimage[ih,:] = results.fun
            #print(results.message)                                                      # NEED TO CHECK THAT TEH FIT WORKED.
            #print('A     initial %15.7f ---> %15.7f' % (x0[0], x[0]))
            #print('w     initial %15.7f ---> %15.7f' % (x0[1], x[1]))
            #print('Phi   initial %15.7f ---> %15.7f' % (x0[2], x[2]))
            #print('K     initial %15.7f ---> %15.7f' % (x0[3], x[3]))
            #print('x0    initial %15.7f ---> %15.7f' % (x0[4], x[4]))
            #print('sigma initial %15.7f ---> %15.7f' % (x0[5], x[5]))
            #res = { 'A': x[0], 'w' : x[1], 'phi':x[2], 'K' : x[3], 'x0':x[4], 'sigma':x[5] }
            fitresults[:,ih] = x
        fitimage += image
        return fitresults, fitimage