import tkinter as tk
import random
import math
from pybirdsreynolds.args import compute_args, display_range
import signal
import sys
import copy
from importlib.metadata import version
import time
from tkinter import font
from pybirdsreynolds.const import *
import types

# variables
refresh_ms = REFRESH_MS_DEFAULT
color= COLOR_DEFAULT
triangles= TRIANGLES_DEFAULT
size = SIZE_DEFAULT
font_size = FONT_SIZE_DEFAULT
font_type = FONT_TYPE_DEFAULT
version_prog = version("pybirdsreynolds")
options = compute_args()
max_speed = options.max_speed
neighbor_radius = options.neighbor_radius
num_birds = options.num_birds
width, height = options.width, options.height
random_speed = options.random_speed
random_angle = options.random_angle
sep_weight = options.sep_weight
align_weight = options.align_weight
coh_weight = options.coh_weight
paused = True
blink_state = True
frame_count = 0
last_time = time.time()
fps_value = 0
fonts=[]
fps= False
free=options.free
count= not paused
resizing = False
if not color:
    canvas_bg = "black"
    fill_color = "white"
    outline_color = "black"
else:
    canvas_bg = "blue"
    fill_color = "white"
    outline_color = "black"
margin=1
selected_index=0
shift_pressed = False


# deep copy
max_speed_init = copy.deepcopy(max_speed)
neighbor_radius_init = copy.deepcopy(neighbor_radius)
num_birds_init = copy.deepcopy(num_birds)
width_init = copy.deepcopy(width)
height_init = copy.deepcopy(height)
refresh_ms_init = copy.deepcopy(refresh_ms)
random_speed_init = copy.deepcopy(random_speed)
random_angle_init = copy.deepcopy(random_angle)
sep_weight_init = copy.deepcopy(sep_weight)
align_weight_init = copy.deepcopy(align_weight)
coh_weight_init = copy.deepcopy(coh_weight)
size_init = copy.deepcopy(size)
font_size_init = copy.deepcopy(font_size)
font_type_init = copy.deepcopy(font_type)
triangles_init = copy.deepcopy(triangles)
free_init = copy.deepcopy(free)
color_init= copy.deepcopy(color)
repeating = {"active": False, "job": None}


# doc dict
param_docs = {
    name.lower().removesuffix("_doc"): value
    for name, value in globals().items()
    if name.endswith("_DOC")
}
param_order = list(param_docs.keys())

def app():

    def restore_options():

        global max_speed, neighbor_radius, num_birds, width, height
        global refresh_ms, random_speed, random_angle
        global sep_weight, align_weight, coh_weight
        global paused, size, triangles, color, canvas_bg, font_size, font_type, fonts
        global fill_color, outline_color, fps, free
        
        max_speed = copy.deepcopy(max_speed_init)
        neighbor_radius = copy.deepcopy(neighbor_radius_init)
        num_birds = copy.deepcopy(num_birds_init)
        width = copy.deepcopy(width_init)
        height = copy.deepcopy(height_init)
        refresh_ms = copy.deepcopy(refresh_ms_init)
        random_speed = copy.deepcopy(random_speed_init)
        random_angle = copy.deepcopy(random_angle_init)
        sep_weight = copy.deepcopy(sep_weight_init)
        align_weight = copy.deepcopy(align_weight_init)
        coh_weight = copy.deepcopy(coh_weight_init)
        size = copy.deepcopy(size_init)
        color = copy.deepcopy(color_init)
        free = copy.deepcopy(free_init)
        triangles = copy.deepcopy(triangles_init)
        font_type = copy.deepcopy(font_type)   

        if not color:
            canvas_bg = "black"
            fill_color = "white"
            outline_color = "black"
        else:
            canvas_bg = "blue"
            fill_color = "white"
            outline_color = "black"  

    def start_repeat(ligne, direction):
        def repeat():
            on_click(ligne, direction)
            if repeating["active"]:
                repeating["job"] = canvas.after(100, repeat)  # répète toutes les 100ms
        repeating["active"] = True
        repeat()

    def stop_repeat():
        repeating["active"] = False
        if repeating["job"]:
            canvas.after_cancel(repeating["job"])
            repeating["job"] = None

    def on_shift_press(event):
        global shift_pressed
        shift_pressed = True

    def on_shift_release(event):
        global shift_pressed
        shift_pressed = False

    def draw():
        draw_canvas()
        draw_status(False)
        draw_points()
        draw_rectangle()
        draw_fps()

    def draw_fps():
        global font_type
        canvas.delete("fps")
        if fps:
            if not paused:
                if fps_value == 0:
                    value="..."
                else:    
                    value = f"{fps_value:.1f}"
            else:
                value = "NA"
            canvas.create_text(
                WIDTH_PARAMS_DEFAULT,
                0,            
                anchor="nw",  
                fill="yellow",
                font=(font_type, font_size, "bold"),
                tags="fps",
                text=f" FPS : {value}"
        )

    def draw_paused():
        global font_type
        global blink_state
        canvas.delete("paused")
        if paused:
            if blink_state:
                canvas.create_text(
                    WIDTH_PARAMS_DEFAULT,
                    max(height, HEIGHT_PARAMS_DEFAULT),
                    anchor="sw",
                    fill="red",
                    font=(font_type, font_size, "bold"),
                    tags="paused",
                    text=" PAUSED - press Space - "
                )
            blink_state = not blink_state
            canvas.after(500, draw_paused)

    def toggle_pause(event=None):
        global paused
        global blink_state
        blink_state = True
        paused = not paused
        draw_status(False)
        draw_paused()

    def change_value(type, val, free):
        value = globals().get(type)
        prefix = type.upper()
        default = globals().get(f"{prefix}_DEFAULT")
        min_value = globals().get(f"{prefix}_MIN")
        max_value = globals().get(f"{prefix}_MAX")
        min_free_value = globals().get(f"{prefix}_FREE_MIN")
        max_free_value = globals().get(f"{prefix}_FREE_MAX")                    
        value += val
        if not free:
            if max_value is not None:
                value = min(value, max_value)
            if min_value is not None:
                value = max(value, min_value)
        else:
            if max_free_value is not None:
                value = min(value, max_free_value)
            if min_free_value is not None:
                value = max(value, min_free_value)
        return value

    def on_other_key(event):
        global selected_index, num_birds, max_speed
        global neighbor_radius, sep_weight, align_weight
        global coh_weight, size, random_speed, random_angle
        global triangles, free , refresh_ms, width, height, fonts
        global color, canvas_bg, fill_color, outline_color, fps, font_type
        global shift_pressed
        shift = getattr(event, "state", None)        
        if shift is not None:
            shift = (shift & 0x1) != 0
        else:
            shift = shift_pressed
        mult = 10 if shift else 1
        val = mult if event.keysym == "Right" else 1*-mult
        param = param_order[selected_index]
        if event.keysym == "Up":
            selected_index = (selected_index - 1) % len(param_order)
        elif event.keysym == "Down":
            selected_index = (selected_index + 1) % len(param_order)
        elif event.keysym == "Right"  or event.keysym == "Left":
            if param == "triangles":
                triangles = not triangles
                draw()
            elif param == "font_type":
                current_index = fonts.index(font_type)
                font_type = fonts[(current_index + val) % len(fonts)]
                draw()
                draw_status(True)               
            elif param == "color":
                color = not color 
                if not color:
                    canvas_bg = "black"
                    fill_color = "white"
                    outline_color = "black"
                else:
                    canvas_bg = "blue"
                    fill_color = "white"
                    outline_color = "black" 
                draw()
            elif param == "free":
                free = not free
                for paramm in param_order:
                    if paramm not in ["free", "color" , "triangles", "font_type"]:
                        globals()[paramm] = change_value(paramm, 0, free)                                                      
            else:  
                globals()[param] = change_value(param, val, free)

            if param == "num_birds":
                generate_points_and_facultative_move(False)
                draw_points()                 
            elif param == "width":  
                generate_points_and_facultative_move(False)
                draw()  
            elif param == "height":  
                generate_points_and_facultative_move(False)
                draw()
            elif param == "free":
                generate_points_and_facultative_move(False)
                draw()                       
            elif param == "size":
                generate_points_and_facultative_move(False)
                draw()  
        elif getattr(event, "keysym", "").lower() == "r":
            restore_options()
            generate_points_and_facultative_move(False)
            root.geometry(f"{WIDTH_PARAMS_DEFAULT+width}x{max(height,HEIGHT_PARAMS_DEFAULT)}+0+0")
            draw()
            draw_canvas()
            root.state('withdrawn')
            root.state('normal')
            root.geometry(f"{WIDTH_PARAMS_DEFAULT+width}x{max(height,HEIGHT_PARAMS_DEFAULT)}+0+0")            
        elif getattr(event, "keysym", "").lower() == "n":
            global velocities
            global birds
            global paused
            pause= True
            velocities = []
            birds= [] 
            generate_points_and_facultative_move(False)
            draw_points()
        elif getattr(event, "keysym", "").lower() == "f":
            fps = not fps
            draw_fps()            
        draw_status(False)


    def on_resize(event):
        global width, height

        width = event.width - WIDTH_PARAMS_DEFAULT
        height = event.height

        generate_points_and_facultative_move(False)
        draw_status(False)
        draw_points()
        draw_rectangle()
        draw_fps()


    def draw_canvas():
        global canvas_bg, height, width
        root.geometry(f"{WIDTH_PARAMS_DEFAULT+width}x{max(height,HEIGHT_PARAMS_DEFAULT)}+0+0")
        canvas.config(width=width + WIDTH_PARAMS_DEFAULT, height=max(height,HEIGHT_PARAMS_DEFAULT), bg=canvas_bg)

    def on_click(l, sens):
        global selected_index
        first_word = l.split()[0] if l.split() else None
        lines = [
            f"{name.lower().removesuffix('_doc'):15} :     {str(globals()[name.lower().removesuffix('_doc')]).split(maxsplit=1)[0]}"
            for name in globals()
            if name.endswith("_DOC")
        ] + COMMON_CONTROLS
        selected_index = next(
            (i for i, line in enumerate(lines) if line.split(":")[0].strip() == first_word),
            0
        )
        on_other_key(types.SimpleNamespace(keysym=sens))

    def draw_status(fullRefresh):
        global font_type
        normal_font = font.Font(family=font_type, size=font_size, weight="normal")
        bold_font   = font.Font(family=font_type, size=font_size, weight="bold")
        italic_font = font.Font(family=font_type, size=font_size, slant="italic", weight="bold")

        lines = [
            f"{name.lower().removesuffix('_doc'):15} :    {str(globals()[name.lower().removesuffix('_doc')]).split(maxsplit=1)[0]}"
            for name in globals()
            if name.endswith("_DOC")
        ] + COMMON_CONTROLS
        x_text = 10
        y_text = 10
        canvas.delete("status")
        if fullRefresh:
            for item in canvas.find_all():
                if canvas.type(item) == "window":
                    canvas.delete(item)
        for i, line in enumerate(lines):
            font_to_use = normal_font
            fill = fill_color

            if i == selected_index:
                fill = "red"

            if "[" in line:
                fill = "yellow"

            canvas.create_text(
                x_text,
                y_text + i * 2.1 * font_size,
                anchor="nw",
                fill=fill,
                font=font_to_use,
                tags="status",
                text=line,
            )
            if fullRefresh:
                y_pos = y_text + i * 2.1 * font_size
                first_colon_index = line.find(":") + 1 
                f = font.Font(font=font_to_use)
                x_offset = f.measure(line[:first_colon_index])
                if "[" not in line:
                    lbl_left = tk.Label(canvas, text="<", fg="blue", bg="white", font=font_to_use)
                    lbl_left.bind("<ButtonPress-1>", lambda e, l=line: start_repeat(l, "Left"))
                    lbl_left.bind("<ButtonRelease-1>", lambda e: stop_repeat())                     
                    #lbl_left.bind("<Button-1>", lambda e, l=line: on_click(l, "Left"))
                    canvas.create_window(x_text + x_offset + 1, y_pos, anchor="nw", window=lbl_left, tags=("line_left",))
                    lbl_right = tk.Label(canvas, text=">", fg="blue", bg="white", font=font_to_use)
                    lbl_right.bind("<ButtonPress-1>", lambda e, l=line: start_repeat(l, "Right"))
                    lbl_right.bind("<ButtonRelease-1>", lambda e: stop_repeat()) 
                    canvas.create_window(x_text + x_offset + 18, y_pos, anchor="nw", window=lbl_right, tags=("line_right",))
                else:

                    btn_font = ("Courier", 9)
                    btn_width = 2
                    btn_height = 1
                    highlight_color = "yellow"
                    highlight_thickness = 2

                    if "[Space]" in line:
                        lbl_btn = tk.Label(
                            canvas, text="⏯", fg="green", bg="white",
                            font=btn_font, width=btn_width, height=btn_height, anchor="center",
                            highlightbackground=highlight_color, highlightthickness=highlight_thickness
                        )
                        lbl_btn.bind("<Button-1>", lambda e: toggle_pause())
                    elif "[r]" in line:
                        lbl_btn = tk.Label(
                            canvas, text="🔄", fg="orange", bg="white",
                            font=btn_font, width=btn_width, height=btn_height, anchor="center",
                            highlightbackground=highlight_color, highlightthickness=highlight_thickness
                        )
                        lbl_btn.bind("<Button-1>", lambda e: on_other_key(types.SimpleNamespace(keysym='r')))
                    elif "[n]" in line:
                        lbl_btn = tk.Label(
                            canvas, text="🪶", fg="purple", bg="white",
                            font=btn_font, width=btn_width, height=btn_height, anchor="center",
                            highlightbackground=highlight_color, highlightthickness=highlight_thickness
                        )
                        lbl_btn.bind("<Button-1>", lambda e: on_other_key(types.SimpleNamespace(keysym='n')))
                    elif "[f]" in line:
                        lbl_btn = tk.Label(
                            canvas, text="⏱", fg="brown", bg="white",
                            font=btn_font, width=btn_width, height=btn_height, anchor="center",
                            highlightbackground=highlight_color, highlightthickness=highlight_thickness
                        )
                        lbl_btn.bind("<Button-1>", lambda e: on_other_key(types.SimpleNamespace(keysym='f')))
                    else:
                        lbl_btn = None

                    if lbl_btn:
                        canvas.create_window(x_text + x_offset + 2, y_pos, anchor="nw",
                            window=lbl_btn, tags=("line_btn",))


        param_name = param_order[selected_index]   
        doc_text = param_docs.get(param_name, "") + " ("+display_range(param_name.upper())+")"
        if doc_text:
            canvas.create_text(
                x_text,
                y_text + (len(lines) * 2.1 * font_size),
                anchor="nw",
                fill="green",
                font=italic_font,
                tags="status",
                text=param_name + " : " + doc_text,
                width=WIDTH_PARAMS_DEFAULT - 2 * x_text
            )

    def draw_rectangle():
        canvas.delete("boundary")
        canvas.create_rectangle(
            WIDTH_PARAMS_DEFAULT, 0, WIDTH_PARAMS_DEFAULT + width, height,
            outline=fill_color, width=margin,
            tags="boundary"
        )

    def generate_points_and_facultative_move(with_move):
        global velocities
        global new_velocities
        if not birds: 
            velocities = []
            for _ in range(num_birds):
                px = random.randint(margin + WIDTH_PARAMS_DEFAULT, width - margin + WIDTH_PARAMS_DEFAULT)
                py = random.randint(margin, height - margin)
                birds.append((px, py))
                angle = random.uniform(0, 2 * math.pi)
                speed = random.uniform(0, max_speed)
                vx = speed * math.cos(angle)
                vy = speed * math.sin(angle)
                velocities.append((vx, vy))

        else:
            # Keep birds only if inside
            inside_points = []
            inside_velocities = []
            for (x, y), (vx, vy) in zip(birds, velocities):
                if WIDTH_PARAMS_DEFAULT + margin <= x <= WIDTH_PARAMS_DEFAULT + width - margin and 0 + margin <= y <= height - margin:
                    inside_points.append((x, y))
                    inside_velocities.append((vx, vy))
            birds[:] = inside_points
            velocities[:] = inside_velocities
            new_velocities = []
            current_count = len(birds)
            
            # Add birds if not enough
            if num_birds > current_count:
                for _ in range(num_birds - current_count):
                    px = random.randint(margin + WIDTH_PARAMS_DEFAULT, width - margin + WIDTH_PARAMS_DEFAULT)
                    py = random.randint(margin, height - margin)
                    birds.append((px, py))

                    angle = random.uniform(0, 2 * math.pi)
                    speed = random.uniform(0, max_speed)
                    vx = speed * math.cos(angle)
                    vy = speed * math.sin(angle)
                    velocities.append((vx, vy))

            # Delete birds if not enough
            elif num_birds < current_count:
                for _ in range(current_count - num_birds):
                    idx = random.randint(0, len(birds) - 1)
                    birds.pop(idx)
                    velocities.pop(idx)

            if with_move:                    
                move()
    def move():
        global velocities
        global new_velocities
        #TODO n2 use Grid / Uniform Cell List 
        for i, (x, y) in enumerate(birds):
            move_sep_x, move_sep_y = 0, 0
            move_align_x, move_align_y, move_align_x_tmp, move_align_y_tmp = 0, 0, 0, 0
            move_coh_x, move_coh_y, move_coh_x_tmp, move_coh_y_tmp = 0, 0, 0, 0
            neighbors = 0
            vx, vy = velocities[i]
            if neighbor_radius > 0 and not (sep_weight == 0 and align_weight == 0 and coh_weight == 0):
                for j, (x2, y2) in enumerate(birds):
                    if i == j:
                        continue
                    dist = math.sqrt((x2 - x)**2 + (y2 - y)**2)
                    if dist < neighbor_radius and dist > 0:
                        # SEPARATION
                        # If a neighbor is too close, add a vector to move away from it (opposite direction of the neighbor).
                        move_sep_x += (x - x2) / dist
                        move_sep_y += (y - y2) / dist
                        # ALIGNMENT
                        # Add the neighbor's velocity so the bird tends to align with it.
                        # Division is done later
                        vx2, vy2 = velocities[j]
                        move_align_x_tmp += vx2
                        move_align_y_tmp += vy2
                        # COHESION
                        # Add the neighbor's position to later calculate an average point, 
                        # so the bird moves toward the group's center.
                        # Division is done later
                        move_coh_x_tmp += x2
                        move_coh_y_tmp += y2
                        neighbors += 1
                
                if neighbors > 0:
                    move_align_x = move_align_x_tmp/neighbors
                    move_align_y = move_align_y_tmp/neighbors
                    move_coh_x = move_coh_x_tmp/neighbors
                    move_coh_y = move_coh_y_tmp/neighbors
                    move_coh_x = move_coh_x - x
                    move_coh_y = move_coh_y - y

                vx += sep_weight * move_sep_x + align_weight * move_align_x + coh_weight * move_coh_x
                vy += sep_weight * move_sep_y + align_weight * move_align_y + coh_weight * move_coh_y
      
            #RANDOM
            speed = math.sqrt(vx**2 + vy**2)
            if random_speed!=0:
                target_speed = max_speed / 2
                sigma_percent = random_speed       # écart-type maximal en % de vmax
                adjust_strength = 0.05    # rappel vers target_speed
                sigma_base = (sigma_percent / 100) * max_speed
                weight = 4 * speed * (max_speed - speed) / (max_speed ** 2)
                sigma = sigma_base * weight
                delta_speed = random.gauss(0, sigma)
                new_speed = speed + delta_speed
                new_speed += (target_speed - new_speed) * adjust_strength
                new_speed = max(0.1, min(max_speed, new_speed))
                factor = new_speed / speed
                vx *= factor
                vy *= factor
            if random_angle!=0:
                angle = math.atan2(vy, vx)
                angle += math.radians(random.uniform(-1 * random_angle, random_angle))
                speed = math.sqrt(vx**2 + vy**2)
                vx = speed * math.cos(angle)
                vy = speed * math.sin(angle)
            vx, vy = limit_speed(vx, vy)
                
            new_velocities.append((vx, vy))

        velocities = new_velocities

        # Update positions
        new_points = []
        for (x, y), (vx, vy) in zip(birds, velocities):
            nx = x + vx
            ny = y + vy
            # Bounces
            while nx < margin + WIDTH_PARAMS_DEFAULT or nx > width - margin + WIDTH_PARAMS_DEFAULT:
                if nx < margin + WIDTH_PARAMS_DEFAULT:
                    overshoot = (margin + WIDTH_PARAMS_DEFAULT) - nx
                    nx = (margin + WIDTH_PARAMS_DEFAULT) + overshoot
                    vx = abs(vx)
                elif nx > width - margin + WIDTH_PARAMS_DEFAULT:
                    overshoot = nx - (width - margin + WIDTH_PARAMS_DEFAULT)
                    nx = (width - margin + WIDTH_PARAMS_DEFAULT) - overshoot
                    vx = -abs(vx)
            while ny < margin or ny > height - margin:
                if ny < margin:
                    overshoot = margin - ny
                    ny = margin + overshoot
                    vy = abs(vy)
                elif ny > height - margin:
                    overshoot = ny - (height - margin)
                    ny = (height - margin) - overshoot
                    vy = -abs(vy)
            idx = birds.index((x, y))
            velocities[idx] = (vx, vy)
            new_points.append((nx, ny))
        birds[:] = new_points


    def draw_points():
        for pid in point_ids:
            canvas.delete(pid)
        point_ids.clear()

        triangle_size = 6*size
        triangle_width = 4*size

        for (x, y), (vx, vy) in zip(birds, velocities):
            if not triangles: 
                pid = canvas.create_oval(
                    x - size, y - size,
                    x + size, y + size,
                    fill=fill_color, outline=outline_color)
            else:
                angle = math.atan2(vy, vx)
                tip_x = x + math.cos(angle) * triangle_size
                tip_y = y + math.sin(angle) * triangle_size
                left_angle = angle + math.radians(150)
                right_angle = angle - math.radians(150)

                left_x = x + math.cos(left_angle) * triangle_width
                left_y = y + math.sin(left_angle) * triangle_width

                right_x = x + math.cos(right_angle) * triangle_width
                right_y = y + math.sin(right_angle) * triangle_width

                pid = canvas.create_polygon(
                    tip_x, tip_y,
                    left_x, left_y,
                    right_x, right_y,
                    fill=fill_color, outline=outline_color
                )
            point_ids.append(pid)

    def limit_speed(vx, vy):
        speed = math.sqrt(vx*vx + vy*vy)
        if speed > max_speed:
            vx = (vx / speed) * max_speed
            vy = (vy / speed) * max_speed
        return vx, vy

    def update():
        global frame_count, last_time, fps_value, count
        if not paused:
            generate_points_and_facultative_move(True)
            draw_points()
            draw_fps()
            frame_count += 1
            now = time.time()
            if not count:
                last_time = now
                count = True
            #add demay to stabilize fps    
            if now - last_time >= 1.0: 
                fps_value = frame_count / (now - last_time)
                frame_count = 0
                last_time = now 
        #reset fps        
        else:
            frame_count = 0
            count = False
            fps_value = 0          

        root.after(refresh_ms, update)

    def signal_handler(sig, frame):
        print("Interrupted! Closing application...")
        root.destroy() 
        sys.exit(0)
    root = tk.Tk()
    root.title(f"pybirdsreynolds - {version_prog}")
    root.minsize(WIDTH_PARAMS_DEFAULT + WIDTH_MIN, max(height,HEIGHT_PARAMS_DEFAULT))

    global font_type, font_type, fonts
    default_fonts = [f for f in FONT_TYPE_LIST if f in font.families()]  # ne garder que les polices disponibles
    available_fonts = font.families()
    mono_fonts = [f for f in available_fonts if "mono" in f.lower()]

    fonts = []
    for f in default_fonts + mono_fonts:
        if f not in fonts:
            fonts.append(f)

    if font_type not in fonts:
        font_type = fonts[0]  

    canvas = tk.Canvas(root, width=width+WIDTH_PARAMS_DEFAULT, height=height, bg=canvas_bg)
    canvas.pack(fill="both", expand=True)

    birds = [] 
    point_ids = []

    generate_points_and_facultative_move(True)
    draw()
    draw_status(True)
    draw_paused()
    root.bind("<space>", toggle_pause)
    root.bind("<Key>", on_other_key)
    root.bind_all("<Shift_L>", on_shift_press)
    root.bind_all("<Shift_R>", on_shift_press)
    root.bind_all("<KeyRelease-Shift_L>", on_shift_release)
    root.bind_all("<KeyRelease-Shift_R>", on_shift_release)

    canvas.bind("<Configure>", on_resize)
    
    signal.signal(signal.SIGINT, signal_handler)
    update()
    root.mainloop()             


