# import os
# os.environ['CUDA_VISIBLE_DEVICES'] = '0'
# import cv2
# import tensorflow as tf
# from yolov3.utils import detect_image, detect_realtime, detect_video, Load_Yolo_model, detect_video_realtime_mp
# from yolov3.configs import *
import matplotlib.pyplot as plt
from mpl_toolkits.mplot3d import axes3d
from matplotlib import cm

import numpy as np
import random as rd
from math import cos, sin, tan, atan


#yolo = Load_Yolo_model()

### Carte ###

# pixel_x_ext, pixel_y_ext = [242, 220, 165, 110, 63, 33, 22, 34, 63, 110, 165, 220, 243, 310, 334, 388, 443, 490, 521, 531, 520, 489, 443, 388, 333, 310], [76, 64, 52, 64, 95, 141, 196, 252, 298, 330, 340, 328, 318, 316, 328, 339, 329, 298, 251, 196, 142, 95, 64, 53, 64, 77]
# pixel_x_int, pixel_y_int = [245, 238, 222, 196, 166, 134, 108, 91, 85, 90, 109, 134, 165, 196, 222, 239, 308, 314, 332, 358, 388, 419, 445, 462, 468, 462, 445, 419, 388, 359, 332, 314], [201, 167, 140, 123, 116, 123, 140, 165, 195, 228, 253, 270, 277, 270, 253, 227, 200, 226, 253, 270, 277, 270, 253, 228, 197, 166, 140, 122, 117, 123, 140, 166]
# diametre = 225
# centre_x, centre_y = 278, 200
# coord_x_ext, coord_y_ext = [i/diametre for i in pixel_x_ext], [i/diametre for i in pixel_y_ext]
# coord_x_int, coord_y_int = [i/diametre for i in pixel_x_int], [i/diametre for i in pixel_y_int]

#coord_ext = [(i/diametre , j/diametre) for i,j in zip(pixel_x_ext, pixel_y_ext)]
#coord_int = [(i/diametre , j/diametre) for i,j in zip(pixel_x_int, pixel_y_int)]


coord_x_int = []
coord_y_int = []

coord_x_ext = []
coord_y_ext = []

r_in = 1-0.14/2
r_ext = r_in + 0.394 + 0.14

for i in range(16):
    theta = 2*3.1415*i/16
    coord_x_int.append(-1.2+r_in*cos(theta))
    coord_y_int.append(r_in*sin(theta))
    
    coord_x_int.append(1.2+r_in*cos(theta))
    coord_y_int.append(r_in*sin(theta))
    
    if 1<i<15:
        coord_x_ext.append(-1.2+r_ext*cos(theta))
        coord_y_ext.append(r_ext*sin(theta))
    
    if not 7<=i<=9:
        coord_x_ext.append(1.2+r_ext*cos(theta))
        coord_y_ext.append(r_ext*sin(theta))

coord_ext = [(i,j) for i,j in zip(coord_x_ext, coord_y_ext)]
coord_int = [(i,j) for i,j in zip(coord_x_int, coord_y_int)]

### Paramètres ###
sigma_position = 0.05
sigma_direction = 8*3.1415/180
seuil_cone = 0.6


dep_x, dep_y, dep_theta = 1.314, 1.162, 3.14/3.6
nb_particule = 30 
pos = [[dep_x, dep_y, dep_theta] for i in range(nb_particule)]

F =  145.31 #194.97 #1/0.0051/0.8 #3957/3 #156.25        #Focale camera
h_reel = 1  #Hauteur d'un plot
y_0 = 3264/2      #Milieu de l'image

dist_roue = 0.25    #Distance entre les roues avant et les roues arrieres


# class position:
    
#     def __init__(self, pos_x, pos_y, nb_particule):
#         self.__pos = [pos_x, pos_y]*nb_particule


def normalisation(W):
    a = sum(W)
    return [w/a for w in W]

def distance(x_1, y_1, x_2, y_2):
    return np.sqrt((x_1-x_2)**2 + (y_1-y_2)**2)

def boite2coord(x1,y1,x2,y2):
    
    d = F*h_reel/abs(y1-y2) - 0.3982
    
    ypmax=960/2
    ymax=0.6
    #y_r = -((x1+x2)/2 - ypmax)*d/F
    tan_theta=((x1+x2)/2-ypmax)/ypmax*ymax
    theta = -atan(tan_theta)
    x_r = d*cos(theta)
    y_r = d*sin(theta)
    
    #print(sin_theta)
    #x_r = d*(1-sin_theta**2)**0.5
    
    return x_r,y_r

def rotation(point, centre, angle):
    
    x1,y1 = point
    x2,y2 = centre
    
    new_x = (x1-x2)*cos(angle) - (y1-y2)*sin(angle) + x2
    new_y = (x1-x2)*sin(angle) + (y1-y2)*cos(angle) + y2
    
    return new_x, new_y

def distance_Chamfer(A_x, A_y, B_x, B_y):
    m = len(A_x)
    n = len(B_x)
    
    if m==0 or n==0:
        return 0.0001
    
    res = 0
    
    tab = [[distance(A_x[i], A_y[i], B_x[j], B_y[j]) for i in range(m)] for j in range(n)]
    tab = np.array(tab)
    for i in range(m):
        res += np.min(tab[:,i])
    for j in range(n):
        res += np.min(tab[j,:])
    
    return res


# def orientation_voiture(x,y):
#     x1,y1 = x-centre_x, y-centre_y
#     if x<0:
#         x2,y2 = x1+diametre/2, y1
#         theta = atan(y2/x2)
#         return theta-3.1415/2
#     else:
#         x2,y2 = x1-diametre/2, y1
#         theta = atan(y2/x2)
#         return theta-3.1415/2
    
    
def motion_update(commande, position):
    vitesse, direction, FPS = commande
    x,y,theta = position
    dt = 1/FPS
    
    
    if direction ==0:
        new_x = x + dt*vitesse*cos(theta)
        new_y = y + dt*vitesse*sin(theta)
        new_theta = theta
        
    else:
        R = dist_roue/tan(direction)
        angle_rotation = dt*vitesse/R
        
        if direction>0:
            centre_rotation_x = x + R*sin(theta)
            centre_rotation_y = y - R*cos(theta)
        else:
            centre_rotation_x = x - R*sin(theta)
            centre_rotation_y = y + R*cos(theta)
        
        new_x, new_y = rotation((x,y),(centre_rotation_x, centre_rotation_y), angle_rotation)
        new_theta = theta + angle_rotation #orientation_voiture(new_x, new_y)
    
    new_x += rd.gauss(0, sigma_position)
    new_y += rd.gauss(0, sigma_position)
    new_theta += rd.gauss(0, sigma_direction)
        
    return (new_x, new_y, new_theta)



def sensor_update(observation, position):
    x,y,theta = position
    vision_x = []
    vision_y = []
    vision_x_ext = []
    vision_y_ext = []
    for pt in coord_int:
        vision_x.append((pt[0]-x)*cos(theta) + (pt[1]-y)*sin(theta))
        vision_y.append(-(pt[0]-x)*sin(theta) + (pt[1]-y)*cos(theta))
    
    
    for pt in coord_ext:
        vision_x_ext.append((pt[0]-x)*cos(theta) + (pt[1]-y)*sin(theta))
        vision_y_ext.append(-(pt[0]-x)*sin(theta) + (pt[1]-y)*cos(theta))
        
    cones_vu_x = []
    cones_vu_y = []
    for i in range(len(vision_x)):
        if vision_x[i]>0 and abs(vision_y[i])<1.15*vision_x[i]:
            cones_vu_x.append(vision_x[i])
            cones_vu_y.append(vision_y[i])

    for i in range(len(vision_x_ext)):
        if vision_x_ext[i]>0 and abs(vision_y_ext[i])<0.6*vision_x_ext[i]:
            cones_vu_x.append(vision_x_ext[i])
            cones_vu_y.append(vision_y_ext[i])
    
    if len(cones_vu_x) == 0:
        return 0
    else:
        obs_x = []
        obs_y = []
        for i in observation:
            obs_x.append(i[0])
            obs_y.append(i[1])
                
        return 1/distance_Chamfer(cones_vu_x, cones_vu_y, obs_x, obs_y)**4






def particle_filter(pos,u_t,z_t): #Position, commande, observation
    X_t_barre, X_t = [], []
    M = len(pos)
    for m in range(M):
        x = motion_update(u_t, pos[m])
        w = sensor_update(z_t, x)
        X_t_barre.append((x,w))
    
    X = [X_t_barre[i][0] for i in range(M)]
    W = [X_t_barre[i][1] for i in range(M)]
    W = normalisation(W)
    X_t = low_variance_resampling(X, W)
    
    return X_t,W


def low_variance_resampling(X,W):
    X_t = []
    J = len(X)
    r = rd.random()/J
    c = W[0]
    i=0
    for j in range(J):
        U = r + (j-1)/J
        while U > c:
            i += 1
            c += W[i]
        X_t.append(X[i])
    return X_t


def get_position(boxes,commande,pos):
    
    
    # Positionnement des plots à partir des boites
    
    liste_x =[]
    liste_y =[]
    for i in boxes:
        if i[4] >= seuil_cone:
            x,y = boite2coord(i[0],i[1],i[2],i[3])
            liste_x.append(x)
            liste_y.append(y)
        
    z_t = []
    cone_x = []
    cone_y = []
    
    
    for i in range(len(liste_x)):
        z_t.append((liste_x[i], liste_y[i]))
    
    
    #Positionnement
    pos, W = particle_filter(pos, commande, z_t)
    
    pos_calc = [0,0,0]
    for i in range(len(pos)):
        pos_calc[0] += pos[i][0]*W[i]
        pos_calc[1] += pos[i][1]*W[i]
        pos_calc[2] += pos[i][2]*W[i]
    
    for i in range(len(liste_x)):
        x,y = rotation((liste_x[i] ,liste_y[i]), (0,0), pos_calc[2])
        cone_x.append(x+pos_calc[0])
        cone_y.append(y+pos_calc[1])
    
    return pos_calc, pos, cone_x, cone_y

        

if __name__ == "__main__":
    
    detection = [np.array([115.43045807, 210.114151  , 394.80041504, 559.36151123,
         0.95864254,   0.        ]), np.array([515.29907227, 304.63769531, 671.36590576, 497.72848511,
         0.93361914,   0.        ]), np.array([7.85924133e+02, 3.29671387e+02, 9.00264526e+02, 4.77042267e+02,
       7.98994482e-01, 0.00000000e+00])]
    
    
    z_t = []
    liste_x = []
    liste_y = []
    nb_particle = 100
    #pos = [(2.5*rd.random(), 1.6*rd.random(), 2*3.14*rd.random()) for i in range(nb_particle)]
    pos = [(1.2,1.19,-0.3) for i in range(nb_particle)]
    
    for i in detection:
        if i[4] >= seuil_cone:
            x,y = boite2coord(i[0],i[1],i[2],i[3])
            liste_x.append(x)
            liste_y.append(y)
            
    for i in range(len(liste_x)):
        z_t.append((liste_x[i], liste_y[i]))
    
    
    # x = np.linspace(0,2.5,100)
    # y = np.linspace(0,1.6,100)
    # X, Y = np.meshgrid(x,y)
    # Z = np.zeros((100,100))
    # for i in range(100):
    #     for j in range(100):
    #         #Z[i,j] = sensor_update(z_t, (X[i,j],Y[i,j], 3.14/3))
    #         Z[i,j] = orientation_voiture(X[i,j], Y[i,j])
    
    # fig, ax = plt.subplots(subplot_kw={"projection": "3d"})
    # surf = ax.plot_surface(X, Y, Z, cmap=cm.coolwarm,
    #                     linewidth=0, antialiased=False)
    

    plt.ion()
    fig = plt.figure()
    ax = fig.add_subplot(111)
    exterieur = ax.plot(coord_x_ext, coord_y_ext,'+')
    interieur = ax.plot(coord_x_int, coord_y_int,'+')
    
    
    pos_x = [elt[0] for elt in pos]
    pos_y = [elt[1] for elt in pos]
    line1, = ax.plot(pos_x, pos_y, '.')
    line2, = ax.plot(0,0,'o')
    line3, = ax.plot(0,0,'o')
    
    for i in range(100):
        
        
        pos, W = particle_filter(pos, (0,0,1), z_t)
        #W = [0.01 for i in range(100)]
        
        pos_x = [0.0 for i in range(nb_particle)]
        pos_y = [0.0 for i in range(nb_particle)]
        pos_moy_x, pos_moy_y, pos_moy_theta = 0,0,0
        cone_x, cone_y = [], []
        
        for i in range(nb_particle):
            pos_x[i] = pos[i][0]
            pos_y[i] = pos[i][1]
            pos_moy_x += W[i]*pos[i][0]
            pos_moy_y += W[i]*pos[i][1]
            pos_moy_theta += W[i]*pos[i][2]
        
        for i in range(len(liste_x)):
            x,y = rotation((liste_x[i] ,liste_y[i]), (0,0), pos_moy_theta)
            cone_x.append(x+pos_moy_x)
            cone_y.append(y+pos_moy_y)
        
        line1.set_xdata(pos_x)
        line1.set_ydata(pos_y)
        line2.set_xdata(pos_moy_x)
        line2.set_ydata(pos_moy_y)
        line3.set_xdata(cone_x)
        line3.set_ydata(cone_y)
        
        fig.canvas.draw()
        
        plt.pause(0.1)
''' 
if __name__ == "__main__":
    # pos = [(1.314, 1.162, 3.14/3.6) for i in range(10)]
    # #pos = [(2.5*rd.random(), 1.5*rd.random(), 6.28*rd.random()) for i in range(50)]
    
    # liste_x = [0.37409758380473157,
    #              0.6517064494114153,
    #              0.23060853761333963,
    #              0.5278583908503303,
    #              0.14161368355256793,
    #              0.5134652832573952]
    
    # liste_y = [0.14021924676581576,
    #              -0.3119493901540909,
    #              -0.3464004029844368,
    #              0.01390277627039628,
    #              -0.2754514724880131,
    #              -0.5902545559074325]
    
    # observation = []
    # for i in range(len(liste_x)):
    #     observation.append((liste_x[i], liste_y[i]))
    
    commande = (0.138841, -pi/6)
    
    # a,W = particle_filter(pos, commande, observation)
    #print(a,W)
    
    
    
    
    
    
    pos_initiale = (1.314, 1.162, 3.14/3.6)
    pos_finale = (1.4, 1.271, 3.14/5.5)
    
    pos = [pos_initiale for i in range(30)]
    
    pos_calc,pos = get_position(commande,pos)
    
    
    plt.figure(1)
    plt.plot(pos_initiale[0], pos_initiale[1], 'o', label='Position initale')
    plt.arrow(pos_initiale[0], pos_initiale[1], 0.07*cos(pos_initiale[2]), 0.07*sin(pos_initiale[2]))
    plt.plot(pos_finale[0], pos_finale[1], 'o', label='Position finale')
    plt.arrow(pos_finale[0], pos_finale[1], 0.07*cos(pos_finale[2]), 0.07*sin(pos_finale[2]))
    
    plt.plot(coord_x_ext, coord_y_ext,'+')
    plt.plot(coord_x_int, coord_y_int,'+')
    
    pos_x = []
    pos_y = []
    for k in range(len(pos)):
        i = pos[k]
        pos_x.append(i[0])
        pos_y.append(i[1])
    
    plt.plot(pos_x,pos_y,'x', label='Positions possibles')
    
    # pos_calc = [0,0,0]
    # for i in range(len(a)):
    #     pos_calc[0] += a[i][0]*W[i]
    #     pos_calc[1] += a[i][1]*W[i]
    #     pos_calc[2] += a[i][2]*W[i]
    
    plt.plot(pos_calc[0], pos_calc[1], 'o', label='Moyenne pondérée des positions calculées')
    plt.arrow(pos_calc[0], pos_calc[1], 0.07*cos(pos_calc[2]), 0.07*sin(pos_calc[2]))
    
    
    plt.legend()
    plt.show()
    
    '''