#include "USTV_SystemTrace.h"
#include "stdio.h"
#include "stdlib.h"
#include "cnn.h"
#include "math.h"
#include <string.h>
#include "peripheral/gpio/plib_gpio.h"
#include "app_sd_tasks.h"

extern short SPECIE;
#define LINE_MAX_LENGTH 1000
char line[LINE_MAX_LENGTH];

int run_CNN(APP_CNN_DETECTION_DATA* app_cnn_detectionData, short lenspec, short melfeat, float* preds){
    Conv_t* model = (Conv_t*) app_cnn_detectionData->startaddr;
    FrontBN_t* frontbn = (FrontBN_t*) (model + NCONV);
    void* ddr = (void*) (frontbn + 1);
    if(load_CNN(model, frontbn, ddr) == 0){
        #ifdef USE_SYSTEM_TRACE
            USTV_SystemTraceWriteLine("APP_CNN: failed to load CNN");
        #endif
        return 0;
    }
    forward(app_cnn_detectionData->meldspectro, model, frontbn, (float*) ddr, lenspec, melfeat, preds);
    return 1;
}
    

void read_line(char* line, SYS_FS_HANDLE file, float* table, int size){
	int i;
	SYS_FS_FileStringGet(file, line, LINE_MAX_LENGTH);
	*(table) = atof(strtok(line, ","));
	for(i=1; i<size; i++){
		*(table + i) = atof(strtok(NULL, ","));
	}
}

void jumpLines(SYS_FS_HANDLE file, int n){
    int i;
    char buff[50];
    for(i=0; i<n; i++){
        SYS_FS_FileStringGet(file, buff, 50);
    }
    return;
}

int load_CNN(Conv_t* model, FrontBN_t* frontbn, void* ddr){
	int i, j, infeat, outfeat, kernel, stride;
    SYS_FS_HANDLE ftpr=SYS_FS_HANDLE_INVALID;
    char* path;
    if(SPECIE==CACHA){
        path=USTV_GetFilePath("cacha_archi.txt");
    }else{
        path=USTV_GetFilePath("rorqual_archi.txt");
    }
    ftpr = SYS_FS_FileOpen(path, (SYS_FS_FILE_OPEN_READ));
    if(ftpr == SYS_FS_HANDLE_INVALID){
        #ifdef USE_SYSTEM_TRACE
            USTV_SystemTraceWriteLine("APP_CNN: run_cnn fail : failed to open archi file ");
        #endif
        return 0;
    }
	for(i=0; i<NCONV; i++){ //load architecture
        SYS_FS_FileStringGet(ftpr, line, 15); 
		infeat = atoi(strtok(line, ","));
		outfeat = atoi(strtok(NULL, ","));
		kernel = atoi(strtok(NULL, ","));
		stride = atoi(strtok(NULL, ","));
		*(model + i) = (Conv_t){infeat, outfeat, kernel, stride};
	}
    SYS_FS_FileClose(ftpr);
    #ifdef USE_SYSTEM_TRACE
        USTV_SystemTraceWriteLine("APP_CNN: done loading archi file");
    #endif
    SYS_FS_HANDLE fp_stdc=SYS_FS_HANDLE_INVALID;
    if(SPECIE==CACHA){
        path=USTV_GetFilePath("cacha_stdc.txt");
    }else{
        path=USTV_GetFilePath("rorqual_stdc.txt");
    }
    fp_stdc = SYS_FS_FileOpen(path, (SYS_FS_FILE_OPEN_READ));
    if(fp_stdc == SYS_FS_HANDLE_INVALID){
        #ifdef USE_SYSTEM_TRACE
            USTV_SystemTraceWriteLine("APP_CNN: run_cnn fail : failed to open stdc file ");
        #endif
        return 0;
    }
	jumpLines(fp_stdc, 4);
    //load front batch norm
    frontbn->nfeat = model[0].in_feat;
    read_line(line, fp_stdc, frontbn->bn_weight, frontbn->nfeat);

    jumpLines(fp_stdc, 2);
    read_line(line, fp_stdc, frontbn->bn_bias, frontbn->nfeat);
    jumpLines(fp_stdc, 2);
    read_line(line, fp_stdc, frontbn->bn_runningmean, frontbn->nfeat);
    jumpLines(fp_stdc, 2);
    read_line(line, fp_stdc, frontbn->bn_runningvar, frontbn->nfeat);
    jumpLines(fp_stdc, 5);

	for(i=0; i<NCONV; i++){
		for(j=0; j<model[i].in_feat; j++){
			read_line(line, fp_stdc, model[i].depthwise_weights[j], model[i].kernel);
		}
		jumpLines(fp_stdc, 2);
		read_line(line, fp_stdc, model[i].depthwise_bias, model[i].in_feat);
		jumpLines(fp_stdc, 2);
		for(j=0; j<model[i].out_feat; j++){
			read_line(line, fp_stdc, model[i].pointwise_weight[j], model[i].in_feat);
		}
		jumpLines(fp_stdc, 2);
		read_line(line, fp_stdc, model[i].pointwise_bias, model[i].out_feat);
        if(i==NCONV -1){
            break; // no batchnorm after last layer
        }
		jumpLines(fp_stdc, 2);
		read_line(line, fp_stdc, model[i].bn_weight, model[i].out_feat);
		jumpLines(fp_stdc, 2);
		read_line(line, fp_stdc, model[i].bn_bias, model[i].out_feat);
		jumpLines(fp_stdc, 2);
		read_line(line, fp_stdc, model[i].bn_runningmean, model[i].out_feat);
		jumpLines(fp_stdc, 2);
		read_line(line, fp_stdc, model[i].bn_runningvar, model[i].out_feat);
		jumpLines(fp_stdc, 5);
	}
	SYS_FS_FileClose(fp_stdc);
	return 1;
}

void forward(float* spectro, Conv_t* model, FrontBN_t* frontbn, float* ddr, int lenspec, short melfeat, float* preds){
	float* x = ddr;  //[LENSPEC][MAX_FEAT]
    ddr += lenspec * MAX_FEAT;
	float* temp = ddr;
	float sum, nextlen, bn_coeff, bn_bias;
	int feat, t, i, lay;
	nextlen = lenspec;
    // front BN
    for(feat=0; feat<frontbn->nfeat; feat++){
        bn_coeff = 1.0 / sqrtf(frontbn->bn_runningvar[feat] + 0.00001);
        bn_bias = frontbn->bn_bias[feat] - frontbn->bn_weight[feat] * frontbn->bn_runningmean[feat] * bn_coeff;
        bn_coeff = bn_coeff * frontbn->bn_weight[feat];
        for(t=0; t<lenspec; t++){
            *(spectro + t *melfeat + feat) = bn_coeff * (*(spectro + t *melfeat + feat)) + bn_bias;
        }
    }
    for(lay=0; lay<NCONV; lay++){
		nextlen = floorf((nextlen - model[lay].kernel)/model[lay].stride + 1);
		// depthwise Conv
		for(feat=0; feat<model[lay].in_feat; feat++){ // iterate over features
			for(t=0; t<nextlen; t++){ // slide the kernel
				sum = 0.0; // compute the kernel activation
                for(i=0; i<model[lay].kernel; i++){
					if(lay==0){ // for the first layer we use spectro instead of x as input
                        sum += model[lay].depthwise_weights[feat][i] * (*(spectro + (t * model[lay].stride + i)*melfeat + feat));
					}else{
						sum += model[lay].depthwise_weights[feat][i] * (*(x + (t * model[lay].stride + i)*MAX_FEAT + feat));
					}
				}
				*(temp + t) = sum + model[lay].depthwise_bias[feat];
			}
			for(t=0; t<nextlen; t++){ // store the activations in x
				*(x + t * MAX_FEAT + feat) = *(temp + t);
			}
		}
        if(lay < NCONV - 1){
            //precompute BN weights
            for(feat=0; feat<model[lay].out_feat; feat++){
                bn_coeff = 1.0 / sqrtf(model[lay].bn_runningvar[feat] + 0.00001);
                model[lay].bn_bias[feat] = model[lay].bn_bias[feat] - model[lay].bn_weight[feat] * model[lay].bn_runningmean[feat] * bn_coeff;
                model[lay].bn_weight[feat] = bn_coeff * model[lay].bn_weight[feat];
            }
        }else{
            for(feat=0; feat<model[lay].out_feat; feat++){
                model[lay].bn_bias[feat] = 0;
                model[lay].bn_weight[feat] = 1;
            }
        }
        // pointwise Conv
		for(t=0; t<nextlen; t++){
			for(feat=0; feat<model[lay].out_feat; feat++){ // iterate over features
				sum = 0.0;
				for(i=0; i<model[lay].in_feat; i++){
					sum += model[lay].pointwise_weight[feat][i] * (*(x + t * MAX_FEAT + i));
				}
				*(temp + feat) = (sum + model[lay].pointwise_bias[feat]) * model[lay].bn_weight[feat] + model[lay].bn_bias[feat];
			}
            for(feat=0; feat<model[lay].out_feat; feat++){
                if((*(temp + feat) > 0 ) || (lay == NCONV - 1)){ //RELU
                    *(x + t * MAX_FEAT + feat) = *(temp + feat);
                }else{
                    *(x + t * MAX_FEAT + feat) = 0;
                }
            }
        }
	}

	for(t=0; t<nextlen; t++){
        preds[t] = *(x + t * MAX_FEAT);
	}
	return;
}

