import librosa
import librosa.display
import matplotlib.pyplot as plt
import numpy as np
import math
import umap
from sklearn.decomposition import PCA
import soundfile as sf  
from scipy import signal
from scipy.signal import convolve2d, correlate
import scipy


#Pour les baleines a bosse d'après les spectros 
freq_min = 250  
freq_max = 650

#decouper en rectangles de 1s de long pour les bosses
duree_espece = 1

data, sample_rate = sf.read("audio/data2.wav")

# mieux en normalisant, apparemment c'est ce que librosa fait
# naturellement mais scipy non
data = data / np.max(np.abs(data), axis=0)

decim = 16

#Pour pouvoir utiliser les fcts de librosa
data = [data[::decim, i].astype(np.float64) for i in range(4)]
sample_rate //= decim 


#pour avoir de beaux spectrogrammes
n_fft = 18000 // decim
hop_length = n_fft // 8

S_db_filtered = []
S_db = []
for i, datai in enumerate(data):
    S = librosa.stft(datai, n_fft=n_fft, hop_length=hop_length)
    freqs = librosa.fft_frequencies(sr=sample_rate, n_fft=n_fft)

    min_idx = np.argmax(freqs >= freq_min)
    max_idx = np.argmax(freqs > freq_max)

    S = librosa.amplitude_to_db(np.abs(S))
    S_db.append(S)
    S_filtered = librosa.amplitude_to_db(np.abs(S[min_idx:max_idx, :]))
    S_db_filtered.append(S_filtered)

S_db_filtered = np.array(S_db_filtered)

Nf, Nt = S_db_filtered[0].shape
dt = (len(data[0]) / sample_rate) / Nt
Lt = math.floor(duree_espece / dt) + 1

def max_cross_correlation(sig1, sig2):
    corr = correlate(sig1, sig2, mode='valid')
    return np.max(np.abs(corr))

E1 = []


for deb in range(0, Nt - Lt, Lt):

    #Ici pour l'energie  plusieurs possibilites :

    S_max = [np.sum((np.abs(S_db_filtered[i][:, deb:deb + Lt]))**2) for i in range(4)] #energie L^2
    #S_max = [np.max((np.abs(S_db_filtered[i][:, deb:deb + Lt]))) for i in range(4)] #energie max
    #S_max = [np.max((np.abs(S_db_filtered[i][:, deb:deb + Lt]))**2) for i in range(4)] #energie max avec carre pour accentuer les ecarts ?
    #S_max = [np.sum((np.abs(S_db_filtered[i][:, deb:deb + Lt])))**2 for i in range(4)] #energie L^1 avec carre pour accentuer les ecarts ? 
    #S_max = [np.sum((10**(S_db_filtered[i][:, deb:deb + Lt] / 10))) for i in range(4)] #energie en repassant en lineaire, il y a quand 
                                                                                        # meme de bons resultats 
    cross_corr_12 = max_cross_correlation(S_db_filtered[0][:, deb:deb + Lt].flatten(),
                                          S_db_filtered[1][:, deb:deb + Lt].flatten())
    cross_corr_13 = max_cross_correlation(S_db_filtered[0][:, deb:deb + Lt].flatten(),
                                          S_db_filtered[2][:, deb:deb + Lt].flatten())
    cross_corr_14 = max_cross_correlation(S_db_filtered[0][:, deb:deb + Lt].flatten(),
                                          S_db_filtered[3][:, deb:deb + Lt].flatten())


    # pour prendre en compte le retard, à commenter  si on veut juste faire avec l'intensite
    S_max.extend([cross_corr_12, cross_corr_13, cross_corr_14])
    E1.append(S_max)


E1 = np.array(E1)


#test d une idee
emoy = np.mean(E1, axis=0)
std_dev = np.std(E1, axis=0)
mult = 0.0
threshold = emoy + mult * std_dev
mask = np.any(E1 > 0, axis=1)





umap_model = umap.UMAP(n_components=2, random_state=42)
E1_2D_umap = umap_model.fit_transform(E1)

plt.figure(figsize=(8, 6))
plt.scatter(E1_2D_umap[:, 0], E1_2D_umap[:, 1], color='green', label='E1 projections (UMAP)')
plt.xlabel("Composante 1")
plt.ylabel("Composante 2")
plt.title("Projection de E1 en dimension 2 (UMAP)")
plt.legend()
plt.grid(True)
plt.show()
