# Source code for showapi.compliancetests.compliancetests

import numpy as np
import scipy.optimize
import math
import matplotlib.pyplot as plt
from typing import Pattern, List, Dict, Tuple, Any, Union
from show_config.show_configuration import Parameters
from .. level0 import SHOWLevel0
from .. level0.level0imagearray import ImageStats, Level0_Image,Level0_ImageCollection
from .. level0.l0algorithms import Level1_Spectra, L0Algorithms

#------------------------------------------------------------------------------
#           plot_show_image
#------------------------------------------------------------------------------

def plot_show_image( image: np.ndarray, title :str ='Title Not Set', subplot=None, figurenum=None ) -> None :

if subplot is not None: plt.subplot(subplot)
if figurenum is not None: plt.figure(figurenum)
p2 = np.percentile( image,   2)
p98 = np.percentile( image, 98)

plt.imshow( image, clim=[p2,p98], origin='lower', aspect='auto')
plt.colorbar()
plt.title(title)
plt.xlabel('Spectral Dimension')
plt.ylabel('Altitude Dimension')

#------------------------------------------------------------------------------
#           oldschool_straightline_fit
#------------------------------------------------------------------------------

def oldschool_straightline_fit( x: np.ndarray, y:np.ndarray, weight:np.ndarray )->Tuple[ np.ndarray, np.ndarray, np.ndarray, np.ndarray]:
"""
Fit a straight line using the standard simple formula. Note this formula does not check for many of the ill-conditions that may occur
but apart from that is a nice simple algorithm.

:param x: An 1-d array of N independent variables, X must be of shape (N,). N usually corresponds to integration time
:param y: A 2-d array of shape (N,M). Each element M corresponds to a pixel on the detector and N dimension corresponds to each integration time
:param weight: A 2-d array of shape (N,M). This is 1/stddev for each measurement in Y. (Note it is not 1/stddev**2 ).
:return: Tuple of straight line parameters [ c(M,), m(M,), dc(M,), dm(M,)] where c is the y intercept, m is the gradient, dc is the error in c and dm is the error in m

"""

s    = y.shape                                      #
npix = int(y.size/s[0])
n    = int(x.size)
x = np.tile(x, (npix, 1)).transpose()               # Replicate x so it is the same shape as y
w = weight*weight
sxy = np.sum( w*x*y, 0 )
sx  = np.sum( w*x,   0 )
sy  = np.sum( w*y,   0 )
sw  = np.sum( w,     0 )
sx2 = np.sum( w*x*x, 0 )
denom = sx2*sw - sx*sx
a     = ( (sxy*sw) - sx*sy)/denom
b     = ( sxy - a*sx2)/sx
aerr  = sw/denom
berr  = sx2/denom
return b,a, np.sqrt(berr), np.sqrt(aerr)

#------------------------------------------------------------------------------
#           fit_straight_line
#------------------------------------------------------------------------------

def fit_straight_line( x: np.ndarray, y:np.ndarray, weight:np.ndarray )->Tuple[ np.ndarray, np.ndarray, np.ndarray, np.ndarray]:

# s    = y.shape
# npix = int(y.size/s[0])
# npts = int(x.size)
#
# c = np.zeros((npix,))
# m = np.zeros((npix,))
# dc = np.zeros((npix,))
# dm = np.zeros((npix,))
# for i in range(npix):
#     p, cov = np.polyfit(x, y[:, i], 1, w=weight[:, i], full=False, cov=True)
#     c[i] = p[1]
#     m[i] = p[0]
#     dc[i] = cov[1, 1]
#     dm[i] = cov[0, 0]

c,m,dc,dm = oldschool_straightline_fit(x, y, weight )
return c, m, dc, dm

#------------------------------------------------------------------------------
#           class ComplianceTests

#------------------------------------------------------------------------------

[docs]class ComplianceTests:

def __init__(self, instrumentname : str = "er2_2017") -> None :
self.params = Parameters( instrumentname )                     # fetch the instrument design parameters

#------------------------------------------------------------------------------
#           ComplianceTests::design_params
#------------------------------------------------------------------------------

def design_params(self):
return self.params

#------------------------------------------------------------------------------
#               ComplianceTests::instrument_name
#------------------------------------------------------------------------------

[docs]    def instrument_name(self):
"""
Returns the instrument name, usually er2_2017
:return:
"""
return self.params.instrumentname

# ------------------------------------------------------------------------------
#           ComplianceTests::analyze_contrast
# ------------------------------------------------------------------------------

[docs]    def analyze_contrast(self, algo: L0Algorithms, krdata: np.ndarray):
"""
Evaluates the difference between a cosine function modulated by gaussian envelope and measured spectrum.

:param algo:
:param krdata:
:return:
"""

image, error, bounds = algo.interferogram_subwindow(krdata)
fitresults,fitimage = algo.fit_cosine_with_gaussian_envelope(image)                  # Fit a cosine with gaussian + constant window to the fringes at each height in the sub-window
# fits y = A( cos(wx +phi) exp( -(x-x0)^2/2*sigma^2 ) + K
H,M   = image.shape
A     = fitresults[0,:]                                                     # extract the amplitude A from the fit for each height
w     = fitresults[1,:]*M/(2*math.pi)                                       # extract the spatial frequency from the fit and map it to the FFT space
phi   = fitresults[2,:]                                                     # extract the phase from the fit
K     = fitresults[3,:]                                                     # extract the constant from the fit
x0    = fitresults[4,:]                                                     # extract the guassian peak location
sigma = np.absolute( fitresults[5,:])                                       # extract the gaussian standard deviation

lastp  = phi[0]                                                             # For good krypton lines the cosine phase should slope fairly uniformly to the left or right (ie increase or decreas)
offset = 0                                                                  # unfortunately it wraps around  by 2.pi and we need to take this out
for ih in range(H):                                                         # so
ds    = phi[ih] - lastp                                                 # look for places
delta = abs(ds)                                                         # where the phase chanegs by
if  (delta > 1.6*math.pi) and (delta < 2.6*math.pi):                    # a number close to 2 pi
sign    = -1.0 if (ds > 0) else +1.0                                # and update the offset
offset += 2*math.pi*sign
lastp   = phi[ih]
phi[ih] += offset

h = np.arange(0,H)                                                          # now fit  a straight line
res = np.polyfit( h, phi, 1)                                                # to the adjusted phase data
phase_slope = res[0]                                                        # and get the slope. This will tells us which side of Littrow we are on. It only works if phase data are clean and properly corrected

plt.figure(11)
plt.clf()
rowindex = [2, H-2, int(H/3), int(2*H/3)]
for ig in range(4):
i = rowindex[ig]
plt.subplot(2,2,ig+1)
plt.plot( image[i,:],'k-' )
plt.plot( fitimage[i,:], 'r-')
plt.title('Row %d'%(i,))

plt.figure(10)
plt.clf()

plt.subplot(3, 2, 1)
plt.plot( A)
plt.title("Amplitude")

plt.subplot(3, 2, 2)
plt.plot( w)
plt.title("Spatial Frequency")

plt.subplot(3, 2, 3)
plt.plot(phi)
plt.title("Phase")

plt.subplot(3, 2, 4)
plt.plot(K)
plt.title("Constant K")

plt.subplot(3, 2, 5)
plt.plot(x0)
plt.title("Gaussian Max")
plt.xlabel('Height Row')

plt.subplot(3, 2, 6)
plt.plot(sigma)
plt.title("Gaussian Sigma")
plt.xlabel('Height Row')

return fitresults, math.nan,  (phase_slope > 0)#contrastimage, np.percentile(maxcontrast, 95)

#------------------------------------------------------------------------------
#           ComplianceTests::analyze_contrast
#------------------------------------------------------------------------------

[docs]    def analyze_contrast_zero_crossing(self, algo : L0Algorithms,  krdata : Level0_Image ):
"""
zero crossing docs

:param algo:
:param krdata:
:return:
"""

image,error,bounds = algo.interferogram_subwindow( krdata )
H, M = image.shape
contrastimage = np.zeros( (H,M) )
x = np.arange(0, M, 1)
maxcontrast=[]
for ih in range(H):
s  = image[ih, :]                                               # get the Krypton data for this height
s2 = np.roll(s, 1)                                              # get the next pixel along
dy = (s2 - s)                                                   # get the nominal gradient, peak minim and maxim are where gradient is zero
dy[0] = dy[1]                                                   # clean up the end points
dy[-1] = dy[-2]                                                 # do both ends
dy2 = np.roll(dy, 1)                                            # Get the gradient stepped one element along
lowzero = np.where((dy2 >= 0) & (dy < 0))[0]                    # Find downward zero crossings (ie low peaks)
hihzero = np.where((dy2 < 0) & (dy >= 0))[0]                    # Find upward zero crossinds ( ie high peaks
dz = hihzero - lowzero                                          # get the difference between the maxima and minima
zm = np.median(dz)                                              # get the median difference
good = np.where((dz > (zm - 2)) & (dz < (zm + 2)))              # and Krypton lines whould fall within +/- 2 pixels of median
hihpeak = hihzero[good]                                         # select the good maxima
lowpeak = lowzero[good]                                         # and the good minima
highs = s[hihpeak]                                              # get the value of high
lows = s[lowpeak]                                               # and the value of the low
good = np.where( highs > 0)                                     # we need to make sure no zeros
highs = highs[good]                                             # If we do then
lows  = lows[good]                                              # eliminate them.
hihpeak = hihpeak[good]
if len(highs) > 0:
contrast = (highs - lows) / highs                           # estimate contrast
contrastimage[ih, :] = np.interp(x, hihpeak, contrast, left=0, right=0) # interpolate contrast to each pixekl
maxcontrast.append( np.max(contrast) )
return contrastimage, np.percentile(maxcontrast, 95)

#------------------------------------------------------------------------------
#           ComplianceTests::analyze_configuration_A
#------------------------------------------------------------------------------

[docs]    def analyze_configuration_A( self, netcdf_filename    : str,
krypton_group      : str,
dark_group         : str,
arma_group         : str,
armb_group         : str,
plot_height_pixels : Tuple[int, int, int] = (10,75,140)):
"""
Analze the configuration A.

:param krypton_directory:
:param dark_directory:
:param arma_directory:
:param armb_directory:
:param plot_height_pixels:
:return:
"""

krdata   = SHOWLevel0(self.instrument_name(), netcdf_filename, krypton_group)                       # Load in the Krypton image headers
dark = krdata.make_dark_current( netcdf_filename, dark_group)                                       # Load and create dark current averages for the krypton images
ff   = krdata.make_flat_field  ( netcdf_filename, arma_group, netcdf_filename, armb_group, dark)    # load and create flat field corrections for the krypton images
krstats   = krdata.average_images  (darkcurrent=dark, flatfield=ff)                                 # make average, std-dev and error signals form the krypton dta
algorithm = L0Algorithms( krdata.params )                                                           # create a level 0 algorithms object

contrast,maxcontrast,kr_above_littrow = self.analyze_contrast( algorithm,krstats.average)                            # Analyze the contrast of the Krypton lines

level1, level1error, ifgramwh, ifgramw   = algorithm.interferogram_to_spectrum(krstats.average)                  # and call the code to convert the interferogram to spectrum
spectrum = np.absolute(level1)                                                             # get the magnitude of each Fourier components

H,F      = spectrum.shape                                                                           # Get the shape of the spectraum, H= num height, F = num frequencies
smax     = np.tile( np.max( spectrum, 1), (F,1)).transpose()                                        # Get the maximum signal at each height, propagate the max to all M elements along the spectri,
spectrum /= smax                                                                                    # and divide the spectrum by the maximum signal at each height

# -- fit a 5 point parabola to the Krypton peak region. The spectral position of the Krypton lamp essentially fixes the Littrow wavenumber

maxindex = np.argmax(spectrum, 1)                                   # find the index of the maximum value of each spectrum at each height, array(H)
minindex = maxindex  - 2                                            # Get the start of the 5 point parabola fit at each height, array(H)
bad      = minindex < 0                                             # watch out for any indices below 0
minindex[bad] = 0                                                   # and set them to zero
y = np.zeros( (5,H) )                                               # Create an array to hold data for the 5 points surrounding the peak, (we need this dimension order for polyfit to work efficiently)
for ih in range(H):                                                 # for each height
y[:, ih] = spectrum[ ih, minindex[ih]:(minindex[ih]+5)]         # get the 5 points surrounding the peak from the data
x = np.linspace(0, 4, 5)                                            # get the x coordinates relative to minindex
p = np.polyfit(x,y,2)                                               # and fit a quadratic/parabola to the 5 points
a = p[0,:]                                                          # get the x**2 coefficient from the parabola  fit for each height
b = p[1,:]                                                          # get the x coeffiecient from the parabola fit for each height
c = p[2,:]
pk = -b/(2.0*a) + minindex                                          # get the peak position for each height using standard parabola formula and add on the minimum index
ypk = c - b*b/(4*a)                                                 # get the 'height' of the peak
avgpk      = np.average(pk)                                         # get the average of the krypton peak position
peak_error = np.std(pk)                                             # and the standard deviation of the peak position.

# ---- find the FWHM of the krypton line at each height

fwhm = np.zeros( (H,) )
for ih in range(H):
ixmax = maxindex[ih]                                            # get the FFT bin of the max value
y     = spectrum[ih,:]                                          # extract a slice from the Krypton spectrum
yhalfmax = 0.5 #ypk[ih]/2.0                                     # getthe half max value, 0.5 seems better than max of parabola
ixmin = ixmax                                                   # step to left of peak valu
while y[ixmin] > yhalfmax:                                      # and step down  until we are before the half max value
ixmin -= 1                                                  # just step down
if ixmin < 0:                                               # watch out for dysfunction
ixmin = 0
break
x = np.arange( ixmin, ixmax+1, 1 )                              # get the x coordinates from min to ma
xleft = np.interp( [yhalfmax], y[x], x)                         # and get the x coordinate at y half max from linear interpolation

ixmin = ixmax                                                   # now start from the peak
while y[ixmin] > yhalfmax:                                      # and step to the right
ixmin += 1                                                  # until we are below the halfmax
if ixmin >= H:                                              # watch out for dysfunction
ixmin = H-1
break
x      = np.arange( ixmax, ixmin+1,  1)                         # get the x coordinates of stepped interval
xright = np.interp([yhalfmax], np.flipud( y[x]) , np.flipud(x)) # get the x coordinate of half max. Note we have to reverse arrays to get y in ascending order
fwhm[ih] = (xright[0]-xleft[0])                                 # and finally get FWHM value at this height

avgfwhm    = np.average(fwhm)                                       # get the average of the krypton fwhm
fwhm_error = np.std(fwhm)                                           # and the standard deviation of the fwhm.

# ---- print the results out to the screen

dsigma = self.params.fft_step_size_wavenumber()                     # get the wavenumber change per FFT spectral bin in cm-1
dlamda = dsigma*1364.0*1364.0*1.0E-07
sign   =  -1.0 if (kr_above_littrow) else 1.0
krypt_wavenum=1.0E7/1363.42206                                      # Get the wavenumber of the krypton line in cm-1 in air from NIST Atomic Spectra Database Lines Data
#        krypt_wavenum=1.0E7/1363.79                                         # Krypton wavelength on a vacuum scale
littrow_wavenum = krypt_wavenum - sign*avgpk*dsigma                 # get the Littrow wavenumber (at zero frequency in FTT space)
littrow_wavenum_below = krypt_wavenum + avgpk * dsigma
littrow_wavenum_above = krypt_wavenum - avgpk * dsigma

error_wavenum      = peak_error*dsigma                              # Get the error in the littrow wavenumber
littrow_wavelength = 1.0E7/littrow_wavenum                          # and convert that to wavelength in nm.
littrow_wavelength_below = 1.0E7 / littrow_wavenum_below            # and convert that to wavelength in nm.
littrow_wavelength_above = 1.0E7 / littrow_wavenum_above
error_wavelen      = 1.0E7*error_wavenum/(littrow_wavenum*littrow_wavenum )
fwhm_wavenum       = avgfwhm*dsigma                                         # get the FWHM in wavenumbers
fwhm_err_wavenum   = fwhm_error*dsigma                                      # get the FWHM error in wavenumber
fwhm_wavelen       = fwhm_wavenum*1.0E7/(krypt_wavenum *krypt_wavenum)      # get the FWHM in wavelength nm
fwhm_err_wavelen   = fwhm_err_wavenum*1.0E7/(krypt_wavenum *krypt_wavenum)  # get the FWHM error in wavenumber
minwavenumber      = littrow_wavenum
maxwavenumber      = littrow_wavenum + F*dsigma
maxwavelength      = 1.0E7/minwavenumber
minwavelength      = 1.0E7/maxwavenumber

abovestr = "BELOW" if kr_above_littrow else "ABOVE"
print("Configuration A results from instrument [%s]:"%( self.instrument_name(),))
print()
print("Average krypton peak position = (%7.4f +/- %6.4f) fft bins" % (avgpk, peak_error))
print('Krypton wavenumber            = %10.5f cm-1'%(krypt_wavenum,))
print('Krypton wavelength            = %10.5f nm'%(1.0E7/krypt_wavenum,))
print('Wavenumbers per fft bin       = %10.6f cm-1'%(dsigma,))
print('Wavelength per fft bin        = %10.6f nm' % (dlamda,))
print()
print("Estimated Littrow wavelength  = (%11.6f +/- %8.6f) nm  if Littrow is below or (%11.6f +/- %8.6f) nm if Littrow is above"%(littrow_wavelength_below,error_wavelen,  littrow_wavelength_above,error_wavelen))
print("Tilt of Krypton lines places Littrow wavelength %s the Krypton wavelength"%( abovestr, ))
print("Spectral Resolution FWHM      = (%11.6f +/- %8.6f) nm"%(fwhm_wavelen,fwhm_err_wavelen))
print("Nominal spectral range        = (   %8.3f to  %8.3f) nm"%(minwavelength,maxwavelength))
print()
print("Estimated Littrow wavelength  = (%11.6f +/- %8.6f) nm  if below Littrow or (%11.6f +/- %8.6f) nm if above Littrow"%(littrow_wavenum_below,error_wavenum,  littrow_wavenum_above,error_wavenum))
print("Spectral Resolution FWHM      = (%11.6f +/- %8.6f) cm-1"%(fwhm_wavenum,fwhm_err_wavenum))
print("Nominal spectral range        = (   %8.3f to  %8.3f) cm-1"%(minwavenumber,maxwavenumber))
print()
print('Maximum contrast in Kr fringes= %7.5f'%(maxcontrast,))

# ---- plot the results to several graphs on the scree

plt.figure(15)
plt.clf()
plot_show_image(ifgramw, 'Krypton Interferogram, %s'%(self.instrument_name(),), subplot=211)
plt.subplot(212)

lines_if = []
for ipix in plot_height_pixels:
l, =plt.plot(ifgramwh[ipix, :], '-', label='Height pixel %d' % (ipix,) )
lines_if.append(l)
plt.legend(handles=lines_if)
plt.title('Krypton Lamp Apodized Interferogram, %s'%( self.instrument_name(),))
plt.xlabel('Windowed pixel')
plt.ylabel('DC Biased Signal')

# ---- plot the FFT spectra
plt.figure(16)
plt.clf()
plot_show_image(spectrum, 'Normalized Krypton Spectra, %s'%(self.instrument_name(),), subplot=211)
plt.subplot(212)

lines_sh = []
HH,MM = spectrum.shape
wav = np.arange(0,MM)*dlamda + littrow_wavelength
for ipix in plot_height_pixels:
l, = plt.plot( wav, spectrum[ipix, : ], '-', label='Height pixel %d'%(ipix,)  )
lines_sh.append(l)
l, = plt.plot(wav, np.average(spectrum, 0), '-', label='Average')
lines_sh.append(l)
plt.legend( handles=lines_sh)
plt.title('Krypton Lamp Spectra, %s'%( self.instrument_name(),))
plt.xlabel('Wavelength')
plt.ylabel('Height Normalized Signal')
ax = plt.gca()
ax.ticklabel_format(useOffset=False)

# ---- Plot the peak fit results
plt.figure(17)
plt.clf()
plt.subplot(211)

line_pk1= plt.plot( pk, 'ks')
line_pk2= plt.errorbar( [H/2],   [avgpk], yerr= [peak_error], ecolor='r')
line_pk3 = plt.plot( [0,H-1], [avgpk, avgpk], 'r-')
plt.text(0.95, 0.95, 'Peak = (%7.4f +/- %6.4f) fft bins' % (avgpk, peak_error), horizontalalignment='right', verticalalignment='top', transform=plt.gca().transAxes)
plt.title('Krypton Peak Location vs. Height')
plt.ylabel('Peak location, FFT bins')
plt.xlabel('Height bin')

# ---- Plot the FWHM results

plt.subplot(212)
line_pk1= plt.plot( fwhm, 'ks')
line_pk2= plt.errorbar( [H/2],   [avgfwhm], yerr= [fwhm_error], ecolor='r')
line_pk3 = plt.plot( [0,H-1], [avgfwhm, avgfwhm], 'r-')
plt.text(0.95, 0.95, 'FWHM = (%7.4f +/- %6.4f) fft bins' % (avgfwhm, fwhm_error), horizontalalignment='right', verticalalignment='top', transform=plt.gca().transAxes)
plt.title('Krypton FWHM vs. Height')
plt.ylabel('FWHM, FFT bins')
plt.xlabel('Height bin')

# ---- Plot the contrast results
#        plt.figure(4)
#        plot_show_image(contrast, 'Krypton Contrast, %5.1f ms, %s' % (exposuretime_usecs / 1000.0, self.instrument_name(),), subplot=221)

plt.subplot(223)
lines_sh = []
image = ifgramw
for ipix in plot_height_pixels:
l, = plt.plot(image[ipix, :], '-', label='Height pixel %d' % (ipix,))
lines_sh.append(l)
plt.legend(handles=lines_sh)
plt.title('Select Krypton Fringes, %s' % (self.instrument_name(),))
plt.xlabel('Spectral dimension')
plt.ylabel('Signal DN')

#      plt.subplot(224)
#      plt.clf()
#      lines_sh = []
#      for ipix in plot_height_pixels:
#          l, = plt.plot(contrast[ipix, :], '-', label='Height pixel %d' % (ipix,))
#          lines_sh.append(l)
#      plt.legend(handles=lines_sh, loc=4)
#      plt.title('Krypton Contrast, %5.1f ms, %s' % (exposuretime_usecs / 1000.0, self.instrument_name(),))
#      plt.xlabel('Spectral dimension')
#      plt.ylabel('(Peak-Trough)/Peak')

krdata.close()
print()
print("End of SHOW Littrow wavelength and spectral resolution compliance test.")
print(" **** Delete the Figure windows to return to command prompt ****** ")
#       plt.show()

#------------------------------------------------------------------------------
#           make_spectra_fromcollection
#------------------------------------------------------------------------------

[docs]    def make_spectralaverage_from_L0collection(self, L0:SHOWLevel0, algorithm :L0Algorithms, darkcurrent=None, ff=None, imageindex=None ):
"""
Make the FFT from each element of an array of interferograms . Get the average and the standard deviation

:param L0:
:param algorithm:
:return:
"""

N = L0.numrecords()
if (imageindex is None): imageindex = int(N/2)
avg = None
sd  = None
oneimage = None
for i in range(N):
l0 = L0.correct_image_ff_and_dc( i, darkcurrent=darkcurrent, ff=ff)        # Get the level 0 image with DC and flat field correction
complexspectrum, error, ifgramwh, ifgramw = algorithm.interferogram_to_spectrum( l0 )   # convert the level 0 to spectrum
imgc = complexspectrum                                                                  # get the complex spectrum
imge = np.absolute(complexspectrum)                                                     # get the  magnitude of the spectrum
if (i == imageindex): oneimage = imge
if avg is None:                                                                     # do the stats
avg = imgc                                                                      # keep the average in the complex notation
sd  = imge * imge                                                               # use magnitude for the error analysis
else:
avg += imgc
sd  += imge * imge
avg = np.absolute(avg) / N          # get the average sigma
q = (sd - N * (avg * avg)) / N      # Get the mean square deviation, (I avoid the n-1 form so it still works if we have only 1 image)
q[np.where(q < 0.0)] = 0.0          # Check for a few bad points that might stray ever so slightly less than 0.0, it throws the sqrt code
sd = np.sqrt(q)                     # Get the root mean square deviation
return avg,sd, oneimage

#------------------------------------------------------------------------------
#           ComplianceTests::analyze_configuration_B
#------------------------------------------------------------------------------

[docs]    def analyze_configuration_B( self, netcdf_filename    : str,
white_groupname    : str,
dark_groupname     : str,
arma_groupname     : str,
armb_groupname     : str,
littrow_wavenum    : float = 7336.016707,
fwhm_wavenum       : float = 0.268485,
plot_height_pixels : Tuple[int, int, int] = (12,75,135),
plot_h2o_xsection : bool = False) -> bool :
"""
Executes compliance tests for instrument configuration B. This is a white light configuration and is used to
perform the spectral range validation. The spectral range is limited by an interference filter from 1363 nm to 1366 nm.
This test is really a check on the filter to be sure that it has not delaminated and that the filter bandpass still
covers a spectral range that includes water lines that we expect to see.

This test provides  a quick method to observe white light spectra. These spectra if taken under normal laboratory conditions
should exhibit water absorption features. The test provides the option for the use to plot water cross-section alongside the
spectra.

We also apply a signal-to-noise check on the system.

:param exposuretime_usecs:
:param white_directory:
:param dark_directory:
:param arma_directory:
:param armb_directory:
:param littrow_wavenum:
:param fwhm_wavenum:
:param plot_height_pixels:
:return:
"""

wtdata   = SHOWLevel0(self.instrument_name(), netcdf_filename, white_groupname)                                         # Load in the white light image headers
dark     = wtdata.make_dark_current( netcdf_filename, dark_groupname)                                                                       # Load and create dark current averages for the white light images
ff       = wtdata.make_flat_field  ( netcdf_filename, arma_groupname, netcdf_filename, armb_groupname, dark )                                                       # load and create flat field corrections for the white light images
wtstats  = wtdata.average_images  (darkcurrent=dark, flatfield=ff)               # make average, std-dev and error signals form the white light data
algorithm = L0Algorithms( wtdata.params )                                                                       # create a level 0 algorithms object
complexspectra, error, ifgramwh, ifgramw   = algorithm.interferogram_to_spectrum(wtstats.average, userdefinederror=wtstats.stddev )          # and call the code to convert the interferogram to spectrum
spectrum = np.absolute(complexspectra)                                                                         # get the magnitude of each Fourier components

# get the collection of images at this exposure time
imageindex  = int(wtdata.numrecords()/2)                                                                       # get the "one image" in the collection that we should plot
heightindex = 110
sdx   = algorithm.subwindow_array(wtstats.stddev)                                           # Get the standard deviation of the interferogram clipped to the  sub-window
Ef,Ex = algorithm.rms_frequency_error_from_rms_spatial_error( sdx )                                             # Get the error on the standard deviation (should be same as error on one profile)
avgspectrum, sd_spectrum, one_spectrum = self.make_spectralaverage_from_L0collection( wtdata,                # average all the spectra from individual FFT's, get one sample spectrum
algorithm,
darkcurrent = dark,
ff          = ff,
imageindex  = imageindex)
H,N = avgspectrum.shape
Eff = np.tile( Ef, (N,1)).transpose()
spectrum_error_ratio = sd_spectrum/Eff
avg_spectrum_err_ratio = np.median(spectrum_error_ratio)

avginterferogram = algorithm.subwindow_array( wtstats.average )                             # Get the average interferogram, reduce it to the sub-window
userdefinederror = algorithm.subwindow_array( wtstats.stddev  )                             # Get the standard deviation p of the interferograms
record           = wtdata.correct_image_ff_and_dc( imageindex,                                                # Get one image
darkcurrent = dark,                         # and correct for dark current
ff          = ff)                           # and flat-field
oneinterferogram = algorithm.subwindow_array( record )                                                    # Get the one image into a sub-window
oneifgramerror   = algorithm.rms_signal_error_from_detector_specs( oneinterferogram )                           # get the theoretical signal error from detector stats
ifgram_error_ratio = userdefinederror/oneifgramerror

Edn = self.params.config['level0']['electrons_per_DN']
avg_ifgram_err_ratio = np.median( ifgram_error_ratio )
Edn_new = Edn/ (avg_ifgram_err_ratio*avg_ifgram_err_ratio)
print("Configuration B results  from instrument [%s]:"%(self.instrument_name(),))
print()
print('Analysis of Interferogram Statistics')
print("Median ratio of Std-dev to theoretical noise      = %6.2f"%(avg_ifgram_err_ratio,))
print("Current configured Electrons per DN               = %6.2f"%(Edn,))
print("New estimate of Electrons per DN to match std-dev = %6.2f"%(Edn_new,))

print()
print('Analysis of Fourier transform Statistics')
print("Median ratio of transform Std-dev to noise propagated from interferogram std dev = %6.2f"%(avg_spectrum_err_ratio,))

plt.figure(4)
H,N = avgspectrum.shape
x = np.arange(0,N,1)
de = np.zeros( [N,])+Ef[heightindex]
plt.plot( x, avgspectrum[heightindex,:], 'k-', label='Average spectrum')
plt.plot( x, one_spectrum[heightindex,:], 'r-', label='One spectrum')
plt.errorbar( x, one_spectrum[heightindex, :], yerr=de,                         ecolor='b', fmt='none')
plt.errorbar( x, one_spectrum[heightindex,:],  yerr=sd_spectrum[heightindex,:], ecolor='g', fmt='none')
plt.legend( ('Average spectrum','One spectrum','Theoretical RMS Error','FFT Standard Dev.'), loc=1  )
plt.title('SHOW Spectral Noise Analysis, %s'%(self.instrument_name(),))

plt.figure(5)
H,N = avginterferogram.shape
x = np.arange(0,N,1)
plt.plot    ( x, avginterferogram[heightindex,:], 'k-', label='Average Ifgram, Row %d'%(heightindex,))
plt.plot    ( x, oneinterferogram[heightindex,:], 'r-', label='One Ifgram Row %d'%(heightindex,))
plt.errorbar( x, oneinterferogram[heightindex,:], yerr=oneifgramerror  [heightindex, :], ecolor='r', fmt='none')
plt.errorbar( x, oneinterferogram[heightindex,:], yerr=userdefinederror[heightindex, :], ecolor='g', fmt='none')
plt.legend( ('Average Ifgram, Row %d'%(heightindex,),'One Ifgram Row %d'%(heightindex,),'Detector Stats Error','Ifgram Standard Dev.'), loc=1 )
plt.title('SHOW Interferogram Noise Analysis, %s'%(self.instrument_name(),))

plot_show_image(spectrum_error_ratio, 'Spectral Transform Noise Ratio, %s'%(self.instrument_name(),), figurenum=6)
plot_show_image(ifgram_error_ratio,   'Interferogram Noise Ratio, %s'%(self.instrument_name(),), figurenum=7)

H,F      = spectrum.shape                                                                                       # Get the shape of the spectrum, H= num height, F = num frequencies
smax     = np.tile( np.max( spectrum, 1), (F,1)).transpose()                                                    # Get the maximum signal at each height, propagate the max to all F elements along the spectri,
spectrum /= smax                                                                                                # and divide the spectrum by the maximum signal at each height

nair      = 1.0002734                                                                                           # refractive index of air at 1.363 microns
h2owavenum, h2oabsxs = show_configuration.read_pickle_h20xsection()                                                      # get the hitran H2O cross-sections from pre-calculated python pickle file
#h2owavenum *= nair                                                                                              # Convert wavenumber from vacuum to air
h2oxs     = h2oabsxs/np.max(h2oabsxs)

plt.figure(1)
plt.clf()
plot_show_image(ifgramw, 'White Light Interferogram, %s'%(self.instrument_name(),), subplot=211)
plt.subplot(212)
lines_if = []
for ipix in plot_height_pixels:
l, =plt.plot(ifgramwh[ipix, :], '-', label='Height pixel %d' % (ipix,) )
lines_if.append(l)
plt.legend(handles=lines_if)
plt.title('White Light Apodized Interferogram, %s'%(self.instrument_name(),))
plt.xlabel('Windowed pixel')
plt.ylabel('DC Biased Signal')

# ---- plot the white light spectra

plot_show_image(spectrum, 'White Light Normalized Spectra, %s'%(self.instrument_name(),), figurenum=2)           # plot the spectrum of white light as a 2-D image

plt.figure(3)
plt.clf()
lines_sh = []
if plot_h2o_xsection:
l, = plt.plot( 1.0E07/h2owavenum, h2oxs, 'k-', label='H2O cross-section')
lines_sh.append(l)
dsigma = self.params.fft_step_size_wavenumber()                                                                 # Get the wavenumbers per FFT spectral bin
pix_wavenum = littrow_wavenum - np.arange(0,F,1)*dsigma                                                         # Get the nominal wavenumber of each FFT spectral bin
pix_wavelen = 1.0E7/pix_wavenum
for ipix in plot_height_pixels:
l, = plt.plot( pix_wavelen, spectrum[ipix, : ], '-', label='Height pixel %d'%(ipix,)  )
lines_sh.append(l)
l, = plt.plot(pix_wavelen, np.average(spectrum, 0), '-', label='Average')
lines_sh.append(l)
plt.legend( handles=lines_sh)
plt.title('White Light Spectra, %s'%(self.instrument_name(),))
plt.xlabel('Wavelength nm')
plt.ylabel('Height Normalized Signal')
ax = plt.gca()
ax.ticklabel_format(useOffset=False)

wtdata.close()
print()
print("End of SHOW spectral range compliance test.")
print(" **** Delete the Figure windows to return to command prompt ****** ")
plt.show()

return True

#------------------------------------------------------------------------------
#               ComplianceTests::analyze_configuration_C
#------------------------------------------------------------------------------

[docs]    def analyze_configuration_C( self, dark_directories : Union[str,List[str]] ):
"""
This function implements the Benchmark: Dark Current and Bias described in SHOW
compliance Test and Performance Benchmarking document. We expect to perform this
test with 5 unique exposure times ( ~0, 100, 200, 300 and 400 mili-seconds). However the fit algorithm
will work properly with 2 or more unique exposure times.  The plotting code only works properly with
6 or less unique exposure times.

:param dark_directories: A list of directories containing (only)dark current images. The code will read in all images from
the list of directories and sort them by exposure time. It is the callers responsibility to ensure that all the images are only
dark current.

:return: True
"""

print("SHOW Benchmark:Dark Current and Bias.")
l0dark      = SHOWLevel0( self.instrument_name(), dirnames = dark_directories )                     # Load the dark current image headers from the given list of directories and sort according to exposure time
print("... generating average and standard deviations. This may take a while.")
l0darkstats = l0dark.average_images()                                                               # Generate a dictioonary of {integration time and image statistics}
k = list(l0darkstats.keys())                                                                        # Get the exposure time (usecs) from the dictionary keys

plt.figure(1)                                                                                       # Plot the average images for each expsoure time
for i in range( len(k)):                                                                            # the plotitng only works properly if there are less than 6 exposure time (2x3)
plot_show_image(  l0darkstats[k[i]].average.image, 'Dark Average, %5.1f ms'%(k[i]/1000.0,), subplot=231+i)

plt.figure(2)                                                                                       # plot the standard deviation
for i in range(len(k)):                                                                             # of the images statistics for each exposure time
plot_show_image(  l0darkstats[k[i]].stddev, 'Dark Std Dev, %5.1f ms' % (k[i]/1000.0,), subplot=231+i)

print("... performing straight line fit to ", len(k), " points/integration times" )                 # now perform the weighted straight line fit

dims    = l0darkstats[k[0]].average.image.shape                                                           # we must reshape the arrays into 2d arrays of ( exposuretimes, pixels)
ndark   = len(k)                                                                                    # Get the number of dark expsoure times
npix    = l0darkstats[k[0]].average.image.size                                                            # get the number of pixels in each image (l0darkstats[k[0]].average is just a convenient image)
x       = np.array( k )/1000.0                                                                      # Get the exposure times in milli-seconds, its more intuitive for the fit results
y_avg   = np.zeros( (ndark, dims[0], dims[1]) )                                                     # generate the array of average images so the first dimension is integration time
y_std   = np.zeros( (ndark, dims[0], dims[1]) )                                                     # generate the array of std-dev images so the first dimension is integration time
for ik in range(len(k)):                                                                            # load the individual images into the array of images, for each integration time
y_avg[ik,:,:] = l0darkstats[k[ik]].average.image                                                      # insert the average image for this integration time
y_std[ik,:,:] = l0darkstats[k[ik]].stddev                                                       # insert the std-dev image for this integration time

y         = y_avg.reshape( (ndark, npix) )                                                          # reshape the average so its not 3-D but 2-D, eg change (4, 640,512) to (4, 327680)
weight    = 1.0/y_std.reshape( (ndark, npix) )                                                      # do the same with the st-dev but also convert to the weight for the straight line fit
c,m,dc,dm = fit_straight_line( x, y, weight )                                                       # do the straight line fit, all 4 params come back as 1-D arrays eg c(327680)
c      = c.reshape( dims )                                                                          # reshape the resturn fit values to the original image eg. (327680) -> (640, 512)
m      = m.reshape( dims )
dc     = dc.reshape( dims )
dm     = dm.reshape( dims )

plt.figure(3)                                                                                       # plot the straighht line fit results
plot_show_image( c,  'DC Bias, DN',                      subplot=221)                               # plot the DC bias from the y intercept at zero integration time
plot_show_image( m,  'Dark Current, DN/milli-sec',       subplot=222)                               # plot the dark current from the gradient
plot_show_image( dc, 'DC Bias Error, DN',                subplot=223)                               # plot the error in the DC bias
plot_show_image( dm, 'Dark Current Error, DN/milli-sec', subplot=224)
# and make sure the we can see the image
print('Configuration C: Dark Current and Bias Test Results')
print()
print('Median DC bias      = %10.2f DN'%(np.median(c),))
print('Median Dark Current = %10.5f DN/millisecond'%(np.median(m),))
print()
print('Completed SHOW Dark Current and Bias Compliance Test')
print(" **** Delete the Figure windows to return to command prompt ****** ")
plt.show()                                       # print the "Test Finished" statement.
return True

# ------------------------------------------------------------------------------
#               ComplianceTests::test_detector_noise
# ------------------------------------------------------------------------------

[docs]    def test_detector_noise(self, white_light_directories: Union[str, List[str]], dark_directories: Union[str, List[str]],):
"""
This function implements the Benchmark: Detector Noise described in the SHOW
Compliance Test and Performance Benchmarking document.

:param dark_directories: A list of directories containing (only)dark current images. The code will read in all images from
the list of directories and sort them by exposure time. It is the callers responsibility to ensure that all the images are only
dark current.

:return: True
"""

print("SHOW Benchmark: Detector Noise.")
l0dark  = SHOWLevel0(self.instrument_name(), dirnames=dark_directories)                         # Load the dark current image header from the given list of directories and sort according to exposure time
print("... generating dark current average and standard deviations. This may take a while.")
l0darkstats  = l0dark.average_images()                                                          # Generate a dictionary of {integration time and image statistics}
print("... generating white screen average and standard deviations. This may take a while.")
l0whitestats = l0white.average_images( darkcurrent=l0darkstats)                                 # Generate a dictionary of white light images with dark current subtracted
k = list(l0whitestats.keys())                                                                    # Get the exposure time (usecs) from the dictionary keys

plt.figure(1)  # Plot the average images for each expsoure time
for i in range(len(k)):  # the plotting only works properly if there are less than 6 exposure time (2x3)
plot_show_image(l0whitestats[k[i]].average.image, 'White Screen Average, %6.1f ms' % (k[i] / 1000.0,), subplot=111+i)

plt.figure(2)  # plot the standard deviation
for i in range(len(k)):  # of the images statistics for each exposure time
plot_show_image(l0whitestats[k[i]].stddev, 'White Screen Std Dev, %6.1f ms' % (k[i] / 1000.0,), subplot=111+i)

print('Completed SHOW BenchmarkDark Current and Bias Compliance Test')  # print the "Test Finished" statement.
print('Close the figures to continue')

return True