from matplotlib.widgets import Button
import matplotlib.pyplot as plt
import numpy as np
import soundfile as sf
from scipy import signal
import pandas as pd
import os
import sounddevice as sd

df = pd.read_csv('df0.csv')

fs = 22050

out = {}

rowtostr = lambda row : row.fn + '_'  + str(row.time) +'_'+str(row.number)

class ppp(object):
    def __init__(self):
        self.iter = df.iterrows()
        self.get()
        self.prev = ''

    def onclick(self, event):
        if event.button is not None and event.inaxes:
            x = np.argmin(abs(self.t-event.xdata))
            y = np.argmin(abs(self.f-event.ydata))
            self.mat[y, x] = True
            self.plot()

    def down(self, a) :
        y, x = np.where(self.mat)
        self.mat[y, x] = False
        self.mat[y//2, x] = True
        self.plot()

    def get(self, a=None):
        print(len(out.keys()), ' annotations so far')
        if hasattr(self, 'row'):
            out[rowtostr(self.row)] = self.mat if self.mat.any() else False
            self.prev = rowtostr(self.row)
        id, self.row = self.iter.__next__()
        while rowtostr(self.row) in out.keys():
            try:
                id, self.row = self.iter.__next__()
            except:
                np.save('annot_pitch',out)
                print('finnished !')
                exit()
        info = sf.info(df.iloc[0].fn)
        self.fs, self.dur = info.samplerate, info.duration
        self.sig, self.fs = sf.read(self.row.fn, start=max(0,int(self.row.time-2)*self.fs), stop=min(int((self.row.time+2)*self.fs), int(self.dur*self.fs)), always_2d=True)
        self.sig = self.sig[:,0]
        print(fs)
        if self.fs != fs:
            self.sig = signal.resample(self.sig, int(len(self.sig)/self.fs*fs))
        self.f, self.t, self.Sxx = signal.spectrogram(self.sig, fs=fs, noverlap=1000//2, nfft=1024//2, nperseg=1024//2)
        self.Sxx= 20*np.log(self.Sxx)
        self.mat = np.zeros(self.Sxx.shape).astype(bool)
        self.plot()

    def plot(self):
        ax.clear()
        ax.imshow(self.Sxx ,origin='lower', aspect='auto', extent=[0,self.t[-1], 0, self.f[-1]])
        ax.set_title(self.row.fn  + ' ' + str(self.row.time))
        ax.set_ylim(200, 12000)
        y, x = np.where(self.mat)
        ax.scatter(self.t[x], self.f[y],10)
        #ax.set_ylim(200, 11000)
        fig.canvas.draw()
    def clear(self, a) :
        self.mat = np.zeros(self.Sxx.shape).astype(bool)
        self.plot()
    def dropprev(self, a):
        del out[self.prev]
    def play(self, a):
        sd.play(self.sig, 44100)

    # def next(self, event):

fig, ax = plt.subplots()

pp = ppp()


bndraw = Button(plt.axes([0.7, 0.05, 0.1, 0.075]), 'down')
bndraw.on_clicked(pp.down)
bnclear = Button(plt.axes([0.8, 0.05, 0.1, 0.075]), 'clear')
bnclear.on_clicked(pp.clear)
bnnext = Button(plt.axes([0.6, 0.05, 0.1, 0.075]), 'next')
bnnext.on_clicked(pp.get)
bnsave = Button(plt.axes([0.5, 0.05, 0.1, 0.075]), 'save')
bnsave.on_clicked(lambda e : np.save('annot_pitch',out))
bndropprev = Button(plt.axes([0.4, 0.05, 0.1, 0.075]), 'drop prev')
bndropprev.on_clicked(pp.dropprev)
bnplay = Button(plt.axes([0.3, 0.05, 0.1, 0.075]), 'play')
bnplay.on_clicked(pp.play)


cid = fig.canvas.mpl_connect('button_press_event', pp.onclick)
cid = fig.canvas.mpl_connect('motion_notify_event', pp.onclick)
plt.show()
