#import gym
#from gym import error, spaces, utils
#from gym.utils import seeding
import numpy as np
from gym import Env, spaces
from gym.utils import seeding
import wget
import os
import matplotlib.pyplot as plt
import matplotlib.animation
from IPython.display import HTML
from IPython.display import clear_output

#from gym.envs.toy_text import discrete

def categorical_sample(prob_n, np_random):
    """
    Sample from categorical distribution
    Each row specifies class probabilities
    """
    prob_n = np.asarray(prob_n)
    csprob_n = np.cumsum(prob_n)
    return (csprob_n > np_random.rand()).argmax()



def load_files(i_map) :

  #if i_map == 0 :
    #snr_url = "https://www.dropbox.com/s/e438p1pgl7s1p3n/all_snr_reshaped.npy?dl=1"
    #source_url = "https://www.dropbox.com/s/7jkkv3wwz1v7e0z/source_pos.npy?dl=1"
    #obs_url = "https://www.dropbox.com/s/4ckgpvrr7vwykf3/obs_pos.npy?dl=1"

  ## source slowed down a bit
  #elif i_map == 1 :
    #snr_url = "https://www.dropbox.com/s/e438p1pgl7s1p3n/all_snr_reshaped.npy?dl=1"
    #source_url = "https://www.dropbox.com/s/7jkkv3wwz1v7e0z/source_pos.npy?dl=1"
    #obs_url = "https://www.dropbox.com/s/4ckgpvrr7vwykf3/obs_pos.npy?dl=1"

  ## source turning angles reduced
  #elif i_map == 2 :
    #snr_url = "https://www.dropbox.com/s/e438p1pgl7s1p3n/all_snr_reshaped.npy?dl=1"
    #source_url = "https://www.dropbox.com/s/7jkkv3wwz1v7e0z/source_pos.npy?dl=1"
    #obs_url = "https://www.dropbox.com/s/4ckgpvrr7vwykf3/obs_pos.npy?dl=1"


  snr_url = "http://sabiod.lis-lab.fr/pub/nthellier/RL/maps_data/map"+str(i_map)+"_all_snr_reshaped.npy"
  source_url = "http://sabiod.lis-lab.fr/pub/nthellier/RL/maps_data/map"+str(i_map)+"_source_pos.npy"
  obs_url = "http://sabiod.lis-lab.fr/pub/nthellier/RL/maps_data/map"+str(i_map)+"_obs_pos.npy"
  P0_url = "http://sabiod.lis-lab.fr/pub/nthellier/RL/maps_data/map"+str(i_map)+"_amp_p0.npy"
  P1_url = "http://sabiod.lis-lab.fr/pub/nthellier/RL/maps_data/map"+str(i_map)+"_amp_p1.npy"
  P2_url = "http://sabiod.lis-lab.fr/pub/nthellier/RL/maps_data/map"+str(i_map)+"_amp_p2.npy"

  #if os.path.isfile('all_snr_reshaped.npy'):
    #snr = np.load('all_snr_reshaped.npy')
  #else:
    #filename = wget.download(snr_url)
    #snr = np.load(filename)
    ### nsteps x Xbins x Ybins

  filename = wget.download(snr_url)
  snr = np.load(filename)

  #if os.path.isfile('source_pos.npy'):
    #source_pos = np.load('source_pos.npy')
  #else:
    #filename = wget.download(source_url)
    #source_pos = np.load(filename)
    ###nsteps x 3(xyz)

  filename = wget.download(source_url)
  source_pos = np.load(filename)

  #if os.path.isfile('obs_pos.npy'):
    #obs_pos = np.load('obs_pos.npy')
  #else:
    #filename = wget.download(obs_url)
    #obs_pos = np.load(filename)

  filename = wget.download(obs_url)
  obs_pos = np.load(filename)

  obs_pos = obs_pos.reshape(-1,3)
  ### gridbins x 3(xyz)

  filename = wget.download(P0_url)
  amp_p0 = np.load(filename)
  filename = wget.download(P1_url)
  amp_p1 = np.load(filename)
  filename = wget.download(P2_url)
  amp_p2 = np.load(filename)
  ### shape (60, 21, 21)
  return snr, source_pos, obs_pos, amp_p0, amp_p1, amp_p2

# possible actions

WEST = 0
SW = 1
SOUTH = 2
SE = 3
EAST = 4
NE = 5
NORTH = 6
NW = 7
WAIT = 8

#class adsilEnv(gym.Env):
#class adsilEnv(discrete.DiscreteEnv):
#class adsilEnv(discrete.DiscreteEnv):
class adsilEnv(Env):
  """
  Has the following members
  - nS: number of states
  - nA: number of actions
  - P: transitions (*)
  - isd: initial state distribution (**)
  (*) dictionary of lists, where
    P[s][a] == [(probability, nextstate, reward, done), ...]
  (**) list or array of length nS
  """

  #metadata = {'render.modes': ['human']}
  def __init__(self, i_map = 0, steering_constraint = False):
  #def __init__(self, agents, steering_constraint = False):
    # gridparams = [xbins, xlims, ylims]
    #agents : list(agentID, initPos)
    # initPos = list[x,y]
    self.actions = {0:"WEST", 1:"SW", 2:"SOUTH", 3:"SE", 4:"EAST", 5:"NE", 6:"NORTH", 7:"NW", 8:"WAIT"}
    self.steering_constraint = steering_constraint
    self.snr, self.source_pos, self.obs_pos, self.amp_p0, self.amp_p1, self.amp_p2 = load_files(i_map)
    self.obs_coord = self.obs_pos.reshape(-1,3)
    self.stepCounter = 0
    self.timestep = 0
    self.nA = 9 # Nb of possible actions
    self.done = False
    self.lastaction = None  # for rendering

    # discrete duration of the dynamic environement
    self.maxsteps = self.source_pos.shape[0]
    print(self.maxsteps)
    # nb of points in the 2d grid
    self.gridbins = self.obs_pos.shape[0]
    print(self.gridbins)
    self.grid_side_bins = int(np.sqrt(self.gridbins))
    print(self.grid_side_bins)
    self.nrow = self.grid_side_bins
    self.ncol = self.grid_side_bins
    # total nb of States
    self.nS = self.maxsteps * self.gridbins
    print(self.nS)
    self.states = np.arange(0, self.nS).reshape(-1,1)
    print(self.states.shape)
    # discrete duration of the dynamic environement
    self.states_map = self.states.reshape(self.maxsteps,self.grid_side_bins,self.grid_side_bins)
    print(self.states_map.shape)
    # if dynamic : self.nsteps would be equal to number of env grids
    # dynamic grid would be of shape nsteps x Xbins x Ybins

    self.obs_pos_state = np.tile(self.obs_pos,(self.maxsteps,1))
    print(self.obs_pos_state.shape)
    self.snr_state = self.snr.reshape(-1,1)

    self.initialize_pos(20)
    self.s = self.init_state
    #self.initialize_pos("random")

    self.isd = np.zeros(self.nS).reshape(-1,1)
    self.isd[self.init_state]=1

    self.action_space = spaces.Discrete(self.nA)
    self.observation_space = spaces.Discrete(self.nS)

    self.seed()
    self.s = categorical_sample(self.isd, self.np_random)
  
    self.fill_transition_matrix()
    self.recap = {t:[] for t in range(self.maxsteps)}

  def update_probability_matrix(self, state, action):
    row, col, timestep = self.s_to_rowcol(state)
    newrow, newcol = self.inc(row, col, action)
    newstate = self.to_s(newrow, newcol, timestep+1)
    reward = self.snr_state[newstate,0]
    done = True if timestep == (self.maxsteps-2) else False
    return newstate, reward, done


  # Transition Matrix filling
  def fill_transition_matrix(self) :
    P = {s: {a: [] for a in range(self.nA)} for s in range(self.nS)}
    for s in self.states[:,0] :
      if s < (self.nS - self.gridbins) :
        for a in range(self.nA):
          li = P[s][a]
          if self.steering_constraint:
            for b in [(a - 2) % (self.nA-1), (a - 1) % (self.nA-1), a, (a + 1) % (self.nA-1),(a + 2) % (self.nA-1), self.nA-1]:
              li.append((1.0, *self.update_probability_matrix(s, b)))
          else:
            li.append((1.0, *self.update_probability_matrix(s, a)))
          #li.append((1.0, *update_probability_matrix(s, a)))
      else :
        for a in range(self.nA):
          li = P[s][a]
          li.append((1.0, s, 0, True))

      self.P = P





  def inc(self, row, col, a):
    if a == WEST:
      newcol = max(col - 1, 0)
      newrow = row
    elif a == SOUTH:
      newrow = min(row + 1, self.nrow - 1)
      newcol = col
    elif a == EAST:
      newcol = min(col + 1, self.ncol - 1)
      newrow = row
    elif a == NORTH:
      newrow = max(row - 1, 0)
      newcol = col
    elif a == SW:
      newcol = max(col - 1, 0)
      newrow = min(row + 1, self.nrow - 1)
    elif a == SE:
      newrow = min(row + 1, self.nrow - 1)
      newcol = min(col + 1, self.ncol - 1)
    elif a == NE:
      newrow = max(row - 1, 0)
      newcol = min(col + 1, self.ncol - 1)
    elif a == NW:
      newrow = max(row - 1, 0)
      newcol = max(col - 1, 0)
    elif a == WAIT :
      newrow = row
      newcol = col
    
    return (newrow, newcol)

  def to_s(self, row, col, timestep):
      return timestep * self.gridbins + (row * self.ncol + col)

  def s_to_rowcol(self, state) :
    
    timestep = state // self.gridbins
    steps = timestep+1
    state = state % self.gridbins
    
    col = state % self.grid_side_bins
    row = state // self.grid_side_bins
    return(row, col, timestep)

  def seed(self, seed=None):
      self.np_random, seed = seeding.np_random(seed)
      return [seed]

  def reset(self):
      self.timestep = 0
      self.recap = {t:[] for t in range(self.maxsteps)}
      self.s = categorical_sample(self.isd, self.np_random)
      self.lastaction = None
      self.recap[self.timestep] = [int(self.s), self.lastaction, 0, 0] # state, action, reward, cum_reward
      return int(self.s)

  def step(self, a):
      self.timestep += 1
      transitions = self.P[self.s][a]
      i = categorical_sample([t[0] for t in transitions], self.np_random)
      p, s, r, d = transitions[i]
      self.s = s
      self.lastaction = a
      self.recap[self.timestep] = [int(self.s), self.actions[self.lastaction], r, self.recap[self.timestep-1][3] + r ] # state, action, reward, cum_reward

      return (int(s), r, d, {"prob": p})
  
  
  def initialize_pos(self, init_pos) :
    if type(init_pos) is int :
      assert (init_pos >=0 & init_pos < self.gridbins), "problem with initialisation of agent"
      self.init_state = init_pos
    elif init_pos == "random" :
      a = np.arange(0,21,3)
      b = np.arange(441-21, 441, 3)
      c = np.arange(0,441,3*21)
      d = np.arange(20,441,3*21)
      init_choices = np.concatenate((a,b,c,d))
      init_choices = np.unique(init_choices)
      self.init_state = np.random.choice(init_choices)
      #print("please keep this init state")





  def render(self, mode='human') :
    t = self.timestep
    xlims = [np.min(self.obs_coord[:,0]),np.max(self.obs_coord[:,0])]
    ylims = [np.min(self.obs_coord[:,1]),np.max(self.obs_coord[:,1])]

    extent = [xlims[0] , xlims[1], ylims[0] , ylims[1]]

    if mode == 'human' :
      plt.figure(figsize=(7,7))
      plt.imshow(self.snr[t,:,:], origin='lower',interpolation = 'bicubic', extent=extent)
      plt.plot(self.source_pos[:,0],self.source_pos[:,1],'black')
      plt.scatter(self.source_pos[t,0],self.source_pos[t,1],c='blue',s=30)
      plt.scatter(self.obs_pos_state[self.s,0],self.obs_pos_state[self.s,1],c='red',s=30)
      plt.ylim(ylims)
      plt.xlim(xlims)
    elif mode == "hidden" :
      ep = self.recap
      #xlims = [1000, 5000]
      #ylims = [-2500, 1500]
      
      fig, ax = plt.subplots(figsize=(7,7))
      plot_source = ax.plot(self.source_pos[:,0],self.source_pos[:,1],'black')
      scat_source = ax.scatter(self.source_pos[0,0],self.source_pos[0,1],c='blue',s=50)
      scat_obs = ax.scatter(self.obs_pos_state[ep[0][0],0],self.obs_pos_state[ep[0][0],1],c='red',s=50)
      ax.set_xlim(xlims)
      ax.set_ylim(ylims)
      

      t = self.maxsteps
      def animate(t) :
        #l.set_data(gridcoord[i,0], gridcoord[i,1])
        #ax.imshow(snr_t[t,:].reshape((gridbiny,gridbinx)),extent=extent, alpha = 0.5)
        scat_source.set_offsets(self.source_pos[t,0:2])
        scat_obs.set_offsets(self.obs_pos_state[ep[t][0],0:2])
        ax.imshow(self.snr[t,:,:],origin='lower',interpolation = 'bicubic', extent=extent, alpha = 0.4)
        #ax.imshow(snr.reshape((gridbiny,gridbinx)),extent=extent)
      

      ani = matplotlib.animation.FuncAnimation(fig, animate, frames=t, interval=500)
      #clear_output(wait=True)
      #plt.close()
      HTML(ani.to_jshtml())

      f = r"animation.mp4" 
      writermp4 = matplotlib.animation.FFMpegWriter(fps=2) 
      ani.save(f, writer=writermp4)


      #plt.imshow(self.snr[t,:,:], origin='lower',interpolation = 'bicubic', extent=(1000, 5000, -2500, 1500))
      
      #plt.scatter(self.source_pos[t,0],self.source_pos[t,1],c='blue',s=30)
      #plt.scatter(self.obs_pos_state[self.s,0],self.obs_pos_state[self.s,1],c='red',s=30)
      #plt.ylim([-2500,1500])
      #plt.xlim([1000,5000])



    #plt.scatter(gridcoord[:,0],gridcoord[:,1])
    #f2 = plt.figure()
    #ax = f2.gca()
    #ax.scatter(gridcoord[6,0],gridcoord[6,1])
    #ax.set_xlim(1.1*xlims)
    #ax.set_ylim(1.1*ylims)

    #statemap = np.arange(0,nstate)
    #statemapgrid = statemap.reshape(gridbiny,gridbinx)



    # ATTENTION : statemap renommé en statemapgrid
    #print(statemap)
    #print(statemapgrid)
    # ATTENTION : statemap renommé en statemapgrid

    #s = np.array([[0,0,0]])
    #snr = - np.linalg.norm(s - gridcoord, axis = 1)
    ##print(snr.reshape((gridbiny,gridbinx)))
    ##print(snr)
    
    #self.nrow, self.ncol = statemapgrid.shape
    #self.nS = self.nrow * self.ncol
    #self.gridcoord = gridcoord
    #self.statemap = statemap
    #self.statemapgrid = statemapgrid
    #self.grid = grid
    #self.agents = agents
  """
    #self.reward_range = (0, 1) ## si on veut restreindre l'espace des rewards

    #self.nS = nS = self.nrow * self.ncol
    
    cond1 = self.gridcoord[:,0] == self.agents[1][0]
    cond2 = self.gridcoord[:,1] == self.agents[1][1]
    cond3 = self.gridcoord[:,2] == 0
    #gridcoordd[1,:]
    mask = cond1 & cond2 & cond3
    #print(mask)
    idxstart = np.where(mask)[0][0]
    print(idxstart)
    #self.gridcoord[idxstart,:]
    self.idxstart=idxstart
    isd = np.zeros(self.nS)
    isd[idxstart] = 1
    #isd /= isd.sum()

    


  
  def step(self, action):
    if self.stepCounter < self.nsteps :
      self.stepCounter +=1
      if self.stepCounter == self.nsteps :
        self.done = True

      super().step(action)

    else :
      pass



    #b=3
  def reset(self):
    self.stepCounter = 0
    self.done = False
    super().reset()

  #ef render(self, mode='human'):
    #d=5
  #def close(self):
    #e=6
  def compute_reward() :
    return 1

  """

  '''
  class MultiAgentEnv(gym.Env):

    def step(self, action_n):
      obs_n    = list()
      reward_n = list()
      done_n   = list()
      info_n   = {'n': []}
      # ...
      return obs_n, reward_n, done_n, info_n

  '''


  def foo1(gridparams) :
    if gridparams is None :
      gridbinx = 6
      xsidekm = 6000
      ysidekm = 4000
      xlims = np.array([-xsidekm/2,xsidekm/2])
      ylims = np.array([-ysidekm/2,ysidekm/2])
      gridbiny = int(gridbinx * ysidekm/xsidekm)
      nstate = int(gridbinx*gridbiny)
      print("nstate",nstate)

      x = np.linspace(xlims[0],xlims[1], gridbinx)
      y = np.linspace(ylims[0],ylims[1], gridbiny)

      print(x.shape)
      print(y.shape)

      xx,yy = np.meshgrid(x,y)

      gridcoordx = np.reshape(xx,(-1,1))
      gridcoordy = np.reshape(yy,(-1,1))
      gridcoordz = np.zeros((nstate,1))
      gridcoord = np.concatenate((gridcoordx,gridcoordy,gridcoordz),axis=1)

      print(gridcoord[0:8,:])


