#
# Takes index csv with frequency bins and creates a False Color LTS image
#

import os
from pathlib import Path
import soundfile as sf
import maad
import numpy as np 
import pandas as pd
import matplotlib as mpl
import matplotlib.pyplot as plt 
from matplotlib.backends.backend_agg import FigureCanvasAgg as FigureCanvas
from scipy.io import wavfile
from scipy import signal
import json
from soundscape_IR.soundscape_viewer import lts_viewer, lts_maker, interactive_matrix, audio_visualization
from soundscape_IR.soundscape_viewer.utility import matrix_operation
from plotnine import ggplot, aes, geom_boxplot, theme_classic, scale_x_discrete, theme, element_text, scale_fill_manual, facet_wrap, ggsave
import argparse

# Defaults
LOCAL_ROOT = "/Users/patrickmclean/GPDev/bioacoustics/mks/plots/"
SERVER_ROOT = "D:/Bioacoustics/Projects/mks/plots/"
sub_dir = "plot1/20220927/"             # Set as blank to process everything in the root_dir
ch = 2                                  # Channel to be processed, using numbering [1,2,3,4]
columns = ['Date','frequencies','LTS']
indices = ['ACI_per_bin', 'Ht_per_bin', 'EVNspCount_per_bin']   # Indices to be processed
columns.extend(indices)
fig_size = (20,20)
ACI_MIN = 0.4                           # ACI normalization
ACI_MAX = 0.7
ACI_THRESHOLD = 0.7                     # Values under this threshold will be set to 0
EVN_MULT = 2                            # EVNspCount amplification

# Create json of file times & get start and stop time
def get_start_and_stop_time(dir, cur_plot, cur_date):
    print(f"Processing start and stop on {cur_plot}:{cur_date} ")
    parent = Path(dir).parent.absolute()   
    os.makedirs(os.path.join(parent,"data"), exist_ok=True)
    files = sorted(os.listdir(dir) )
    files = [file for file in files if file[0] != '.']
    if len(files) > 0:
        file_times = list(map(lambda x: x[9:15], files))
        start_time = int(files[0][9:11]) + int(files[0][11:13])/60
        stop_time = int(files[-1][9:11]) + int(files[-1][11:13])/60
        print(f"Day range (decimal hours): {start_time:.2f} to {stop_time:.2f}")
        with open(f"../data/{cur_plot}_{cur_date}_files.json", 'w') as outfile:
            json.dump(files, outfile)
        with open(f"../data/{cur_plot}_{cur_date}_file_times.json", 'w') as outfile:
            json.dump(file_times, outfile)
        return (start_time, stop_time)
    else:
        print("No files found")
        return (-1, -1)

# Display a 2d image of the index
def ai_2ddisplay(root, target_dir, bin_csv_file,date,plot, ch, start_time, stop_time):

    try:
        bin_df = pd.read_csv(f"{root}/{bin_csv_file}", sep=";")
        bin_df = bin_df[columns]

        # Turn each index into a greyscale image
        values = {}
        for index in indices:
            column = bin_df[index].tolist()
            column = [x for x in column if type(x) == str] # Some columns have null values
            values[index] = np.array([json.loads(x) for x in column])
            values[index] = np.fliplr(values[index][:,1:])
            values[index] = np.swapaxes(values[index],0,1)
            image_width = values[index].shape[1]
            image_height = values[index].shape[0]
            image_aspect_ratio = image_height / image_width

            # Clean it
            #values[index],_ = maad.sound.remove_background_along_axis(values[index], mode='ale') 
    
            # Pad the image if partial day
            segment_duration = image_width / (stop_time - start_time)
            left_pad_len = int(start_time * segment_duration)
            right_pad_len = int((24 - stop_time) * segment_duration)
            if (index == 'Ht_per_bin'):
                left_pad = np.ones(shape=(image_height,left_pad_len))
                right_pad = np.ones(shape=(image_height,right_pad_len))
            else:
                left_pad = np.zeros(shape=(image_height,left_pad_len))
                right_pad = np.zeros(shape=(image_height,right_pad_len))
            values[index] = np.hstack((left_pad,values[index],right_pad))
            image_width = values[index].shape[1]
            image_aspect_ratio = image_height / image_width
            print(f"Partial day. Padding left {left_pad_len}, right {right_pad_len}")

            # Plot the image
            # fig, ax = plt.subplots(1,1, figsize=fig_size) # 
            # im = ax.imshow(X=values[index],cmap='gray',extent=[0,image_width,0,image_height], aspect=1/image_aspect_ratio)  
            # plt.colorbar(im, ax=ax)
            # plt.savefig(f"{root}/{plot}_{date}_{index}_i2d_fig.png")
            # #plt.show()
            # #plt.imsave(f"{root}/{plot}_{date}_{index}_i2d.png",values[index])

        # Normalization steps
        # ACI is normalized and thresholded
        # Ht is inverted, otherwise unchanged
        # EVNspCount is normalized then amplified
        print(f"ACI Max before normalization: {values['ACI_per_bin'].max()}")
        values['ACI_per_bin'] = (values['ACI_per_bin'] - ACI_MIN) / (ACI_MAX - ACI_MIN)    
        values['Ht_per_bin'] = (1 - values['Ht_per_bin'])   
        values['EVNspCount_per_bin'] = EVN_MULT * values['EVNspCount_per_bin'] / values['EVNspCount_per_bin'].max()

        # Clip values
        values['ACI_per_bin'][values['ACI_per_bin'] < ACI_THRESHOLD] = 0
        values['ACI_per_bin'] = np.clip(values['ACI_per_bin'], 0, 1)
        values['Ht_per_bin'] = np.clip(values['Ht_per_bin'], 0, 1)
        values['EVNspCount_per_bin'] = np.clip(values['EVNspCount_per_bin'], 0, 1)
     
        # Cut the height of the image to 1/2 - 68kHz cutoff
        image_height = int(image_height/2)
        values['Ht_per_bin'] = values['Ht_per_bin'][image_height+1:,:]
        values['EVNspCount_per_bin'] = values['EVNspCount_per_bin'][image_height+1:,:]
        values['ACI_per_bin'] = values['ACI_per_bin'][image_height+1:,:]

        # Create false color image
        fig, ax = plt.subplots(1,1, figsize=fig_size) # 
        false_color = np.zeros((image_height, image_width, 3), dtype=np.float32)
        false_color[:,:,0] = values['ACI_per_bin']          # Red
        false_color[:,:,1] = values['Ht_per_bin']           # Green
        false_color[:,:,2] = values['EVNspCount_per_bin']   # Blue
        #im = ax.imshow(X=false_color,extent=[0,image_width,0,image_height], aspect=0.2) #aspect=1/image_aspect_ratio
        #print(f"Saving to {target_dir}/{plot}_{date}_fclts_ch{ch}.png")
        plt.imsave(f"{target_dir}/{plot}_{date}_fclts_68k_ch{ch}.png",false_color)

        # Cut the height again by a third - 22kHz cutoff
        image_height = int(image_height/3)
        values['Ht_per_bin'] = values['Ht_per_bin'][2*image_height+1:,:]
        values['EVNspCount_per_bin'] = values['EVNspCount_per_bin'][2*image_height+1:,:]
        values['ACI_per_bin'] = values['ACI_per_bin'][2*image_height+1:,:]

        # Create false color image
        fig, ax = plt.subplots(1,1, figsize=fig_size) # 
        false_color = np.zeros((image_height, image_width, 3), dtype=np.float32)
        false_color[:,:,0] = values['ACI_per_bin']          # Red
        false_color[:,:,1] = values['Ht_per_bin']           # Green
        false_color[:,:,2] = values['EVNspCount_per_bin']   # Blue
        #im = ax.imshow(X=false_color,extent=[0,image_width,0,image_height], aspect=0.2) #aspect=1/image_aspect_ratio
        #print(f"Saving to {target_dir}/{plot}_{date}_fclts_ch{ch}.png")
        plt.imsave(f"{target_dir}/{plot}_{date}_fclts_22k_ch{ch}.png",false_color)

        fig.clear()
        plt.close('all')

    except Exception as e:
        print(f"Error in ai_2ddisplay {plot} {date}: {e}")

####################
# Main #
####################

# Read command line arguments, otherwise we use the default values
env = os.getenv("GPSVR")
argParser = argparse.ArgumentParser()
argParser.add_argument("-d", "--dir", required=False, help="directory to process")
argParser.add_argument("-c", "--channel", required=False, type=int, help="channel to process")
args = vars(argParser.parse_args())
if args['dir'] != None:
    sub_dir = args['dir']
if args['channel'] != None:
    ch = args['channel']
root_dir = (SERVER_ROOT if env == "SERVER" else LOCAL_ROOT) + sub_dir

# Iterate over all index files in the directory
for root, dirs, data_files in os.walk(root_dir):
    path = root.split(os.sep)
    dir = os.path.basename(root)
    if dir == "indices":
        os.chdir(root)
        parent = Path(root).parent.absolute()
        target_dir = str(parent) + '/lts'
        cur_date = os.path.basename(parent)
        cur_plot = os.path.basename(parent.parent.absolute())
        print(f"Adding 2d indices for {cur_plot}:{cur_date} {str(parent)}")
        (start_time, stop_time) = get_start_and_stop_time(f"{parent}/wav", cur_plot, cur_date)
        files = os.listdir()
        try:
            index_bin_csv_file = [f for f in files if (f.endswith(f"ch{ch}.csv") and "bin" in f)][0]
            ai_2ddisplay(root,target_dir, index_bin_csv_file,cur_date,cur_plot, ch, start_time, stop_time)
        except:
            print("No index file found")