#
# Calculate acoustic indices
# Input is all the wav files in the /wav directory
# Output is 2 csv files in the /indices directory
# One for indices by time and one for indices by time and frequency bin
# Files are segmented into 60s chunks for consistency
#

import os
import sys
from pathlib import Path
import soundfile as sf
import maad
import numpy as np 
import pandas as pd
import matplotlib as mpl
import matplotlib.pyplot as plt 
from matplotlib.backends.backend_agg import FigureCanvasAgg as FigureCanvas
from scipy.io import wavfile
from scipy import signal
import json
from soundscape_IR.soundscape_viewer import lts_viewer, lts_maker, interactive_matrix, audio_visualization
from soundscape_IR.soundscape_viewer.utility import matrix_operation
import argparse

# Defaults
LOCAL_ROOT = "/Users/patrickmclean/GPDev/bioacoustics/mks/plots/"
SERVER_ROOT = "D:/Bioacoustics/Projects/mks/plots/"
sub_dir = "plot1/20220926/"             # Set as blank to process everything in the root_dir
ch = 2                                  # Channel to be processed, using numbering [1,2,3,4]

# Calculate acoustic indices for all wav files in a directory
def acoustic_indices(path, df, cur_plot, cur_date, channel):
    ## list of all acoustic indices that can be computed
    SPECTRAL_FEATURES=['MEANf','VARf','SKEWf','KURTf','NBPEAKS','LEQf',
    'ENRf','BGNf','SNRf','Hf', 'EAS','ECU','ECV','EPS','EPS_KURT','EPS_SKEW','ACI',
    'NDSI','rBA','AnthroEnergy','BioEnergy','BI','ROU','ADI','AEI','LFC','MFC','HFC',
    'ACTspFract','ACTspCount','ACTspMean', 'EVNspFract','EVNspMean','EVNspCount',
    'TFSD','H_Havrda','H_Renyi','H_pairedShannon', 'H_gamma', 'H_GiniSimpson','RAOQ',
    'AGI','ROItotal','ROIcover']

    TEMPORAL_FEATURES=['ZCR','MEANt', 'VARt', 'SKEWt', 'KURTt',
    'LEQt','BGNt', 'SNRt','MED', 'Ht','ACTtFraction', 'ACTtCount',
    'ACTtMean','EVNtFraction', 'EVNtMean', 'EVNtCount']

    #print(f"Computing acoustic indices for {path}")
    
    df_indices = pd.DataFrame()
    df_indices_per_bin = pd.DataFrame()

    # Loop over all the files
    for index, row in df.iterrows() :
        filename = row['filename']
        print(f"Processing file {index} of {len(df.index)}: {filename}")
        try :
            wave,fs = sf.read(path+'/wav/'+filename)
        except:
            # Delete the row if the file does not exist or raise a value error (i.e. no EOF)
            df.drop(index, inplace=True)
            continue
        # Split audio file into 60 second segments
        file_duration = wave.shape[0]/fs
        for offset in range(0, int(file_duration/60)*60, 60):
            wav_segment = wave[offset*fs:(offset+60)*fs]                           
            if len(wav_segment.shape) != 1: # Subset to the channel of interest
                wav_segment = wav_segment[:,(channel-1)]
            try:
                # compute all the audio indices and store them into a DataFrame
                df_audio_ind = maad.features.all_temporal_alpha_indices(wav_segment, fs,
                                                verbose = False, display = False)
                Sxx_power,tn,fn,ext = maad.sound.spectrogram(wav_segment, fs, window='hanning', nperseg = 2048, noverlap=2048//2, verbose = False, display = False, savefig = None)
                df_spec_ind, df_spec_ind_per_bin = maad.features.all_spectral_alpha_indices(Sxx_power,
                                                                    tn,fn,
                                                                    flim_low = [0,1500],
                                                                    flim_mid = [1500,10000],
                                                                    flim_hi  = [10000,128000],
                                                                    verbose = False,
                                                                    R_compatible = 'soundecology',
                                                                    display = False)
            except Exception as e:
                print(f"Audio processing error {e}")
                continue
            data = pd.DataFrame(row.append(pd.Series([offset], index=['offset']))).T
            df_indices = df_indices.append(pd.concat([df_audio_ind, df_spec_ind, data], axis=1))
            df_indices_per_bin = df_indices_per_bin.append(pd.concat([df_spec_ind_per_bin,data],axis=1))

    df_indices.reset_index(inplace=True, drop=True)
    df_indices_per_bin.reset_index(inplace=True, drop=True)

    # Save files
    df_indices.to_csv(f"{path}/indices/indices_{cur_plot}_{cur_date}_ch{channel}.csv", sep=";",date_format='%Y-%m-%d %H:%M:%S')
    df_indices_per_bin.to_csv(f"{path}/indices/bin_indices_{cur_plot}_{cur_date}_ch{channel}.csv", sep=";",date_format='%Y-%m-%d %H:%M:%S')


####################
# Main #
####################

# Read command line arguments, otherwise we use the default values
env = os.getenv("GPSVR")
argParser = argparse.ArgumentParser()
argParser.add_argument("-d", "--dir", required=False, help="sub directory to process")
argParser.add_argument("-c", "--channel", required=False, type=int, help="channel to process")
args = vars(argParser.parse_args())
if args['dir'] != None:
    sub_dir = args['dir']
if args['channel'] != None:
    ch = args['channel']
root_dir = (SERVER_ROOT if env == "SERVER" else LOCAL_ROOT) + sub_dir

# Iterate over all the plots and dates under a root directory and create index files
for root, dirs, data_files in os.walk(root_dir):
    path = root.split(os.sep)
    dir = os.path.basename(root)
    if dir == "wav":
        os.chdir(root)
        parent = Path(root).parent.absolute()
        cur_date = os.path.basename(parent)
        cur_plot = os.path.basename(parent.parent.absolute())
        print(f"Processing indices for {cur_plot}:{cur_date} ch {ch}")
        if not os.path.exists(f"{parent}/indices"):
            os.mkdir(f"{parent}/indices")
        files = os.listdir()
        wav_files = [f for f in files if f.endswith(".wav")]
        if (len(wav_files) != 0):
            df = pd.DataFrame(wav_files, columns=['filename'])
            df['Date'] = df['filename'].str[0:15]
            acoustic_indices(f"{parent}",df,cur_plot,cur_date,channel=ch)