import math
import time
import random
from pyglet.window import Window
from pyglet.window import mouse
from pyglet.app import run
from pyglet.shapes import Circle
from pyglet.graphics import Batch
from pyglet import clock

# Full disclosure
# I did use ai to help write some of this code. 
# - Mainly github copilot inline suggestions while writing.
# - google search ai mode to help reason about the logic, physics and implementation.

# Features added :
# - planet-planet collision and gravitational interaction
# - dragging planets and sun with mouse, adding fling velocity on release
# - bounce off edges of the window
# - added alot of small planets

def hex_to_rgb(hex_color):
    return (
        int(hex_color[1:3], 16),
        int(hex_color[3:5], 16),
        int(hex_color[5:7], 16),
    )


class SolarSystem(Window):
    def __init__(self):
        super().__init__(640, 640, "Mini Solar System", resizable=True)

        self.draw_batch = Batch()
        self.center_x, self.center_y = 320, 320
        self.sun_circle = Circle(self.center_x, self.center_y, 25,
                          color=hex_to_rgb("#F9F871"),
                          batch=self.draw_batch)

        self.planet_specs = [
            (80, 1.0, 8, "#00D2FC"),   
            (80, 1.0, 8, "#00D2FC"),  
            (140, 0.5, 10, "#FFC75F"), 
            (200, 0.2, 12, "#C34A36"), 
            (100, 10.0, 5, "#FFFFFF"),  
            (250, 0.1, 15, "#FACCFF"), 
            (300, 0.05, 20, "#FF6F91"), 
        ]
        for i in range(len(self.planet_specs), 100):
            self.planet_specs.append((80 + i + 10 * random.random(), 0.5 * random.random(), 3.0 * random.random(), "#FFFFFF"))

        self.planet_bodies = []
        for orbit_distance, orbit_speed, planet_radius, color_hex in self.planet_specs:
            planet_circle = Circle(self.center_x + orbit_distance, self.center_y,
                            planet_radius,
                            color=hex_to_rgb(color_hex),
                            batch=self.draw_batch)

            orbital_speed = math.sqrt(12000.0 * 5000.0 / max(orbit_distance, 1.0))
            self.planet_bodies.append({
                "mass": max(1.0, planet_radius * planet_radius * 0.8),
                "vx": 0.0,
                "vy": orbital_speed * orbit_speed,
                "circle": planet_circle,
            })

        self.gravitational_constant = 6000.0
        self.sun_gravity_mass = 2500.0
        self.sun_gravity_softening = 10.0
        self.planet_gravity_softening = 10.0
        self.edge_bounce_factor = 0.95
        self.planet_collision_restitution = 0.9
        self.is_dragging_sun = False
        self.sun_drag_offset_x = 0.0
        self.sun_drag_offset_y = 0.0
        self.selected_planet = None
        self.last_mouse_drag_time = None
        self.last_mouse_drag_x = 0.0
        self.last_mouse_drag_y = 0.0
        self.fling_velocity_x = 0.0
        self.fling_velocity_y = 0.0

    def update(self, dt):
        dt = min(dt, 1 / 30)

        for planet_body in self.planet_bodies:
            if planet_body is self.selected_planet:
                continue

            planet_circle = planet_body["circle"]
            planet_x = planet_circle.x
            planet_y = planet_circle.y

            sun_offset_x = self.center_x - planet_x
            sun_offset_y = self.center_y - planet_y
            radius_squared = sun_offset_x * sun_offset_x + sun_offset_y * sun_offset_y + self.sun_gravity_softening * self.sun_gravity_softening
            radius = math.sqrt(radius_squared)
            acceleration_x = self.gravitational_constant * self.sun_gravity_mass * sun_offset_x / (radius_squared * radius)
            acceleration_y = self.gravitational_constant * self.sun_gravity_mass * sun_offset_y / (radius_squared * radius)

            for other_body in self.planet_bodies:
                if other_body is planet_body:
                    continue
                other_circle = other_body["circle"]
                planet_offset_x = other_circle.x - planet_x
                planet_offset_y = other_circle.y - planet_y
                other_radius_squared = planet_offset_x * planet_offset_x + planet_offset_y * planet_offset_y + self.planet_gravity_softening * self.planet_gravity_softening
                other_radius = math.sqrt(other_radius_squared)
                factor = self.gravitational_constant * other_body["mass"] / (other_radius_squared * other_radius)
                acceleration_x += factor * planet_offset_x
                acceleration_y += factor * planet_offset_y

            planet_body["vx"] += acceleration_x * dt
            planet_body["vy"] += acceleration_y * dt
            planet_circle.x += planet_body["vx"] * dt
            planet_circle.y += planet_body["vy"] * dt

            sun_offset_x = planet_circle.x - self.center_x
            sun_offset_y = planet_circle.y - self.center_y
            sun_distance = math.hypot(sun_offset_x, sun_offset_y)
            minimum_sun_distance = self.sun_circle.radius + planet_circle.radius + 2
            if sun_distance < minimum_sun_distance:
                normal_x = sun_offset_x / max(sun_distance, 1e-6)
                normal_y = sun_offset_y / max(sun_distance, 1e-6)
                planet_circle.x = self.center_x + normal_x * minimum_sun_distance
                planet_circle.y = self.center_y + normal_y * minimum_sun_distance

                radial_velocity = planet_body["vx"] * normal_x + planet_body["vy"] * normal_y
                if radial_velocity < 0:
                    planet_body["vx"] -= 1.8 * radial_velocity * normal_x
                    planet_body["vy"] -= 1.8 * radial_velocity * normal_y

            if planet_circle.x - planet_circle.radius < 0:
                planet_circle.x = planet_circle.radius
                if planet_body["vx"] < 0:
                    planet_body["vx"] = -planet_body["vx"] * self.edge_bounce_factor
            elif planet_circle.x + planet_circle.radius > self.width:
                planet_circle.x = self.width - planet_circle.radius
                if planet_body["vx"] > 0:
                    planet_body["vx"] = -planet_body["vx"] * self.edge_bounce_factor

            if planet_circle.y - planet_circle.radius < 0:
                planet_circle.y = planet_circle.radius
                if planet_body["vy"] < 0:
                    planet_body["vy"] = -planet_body["vy"] * self.edge_bounce_factor
            elif planet_circle.y + planet_circle.radius > self.height:
                planet_circle.y = self.height - planet_circle.radius
                if planet_body["vy"] > 0:
                    planet_body["vy"] = -planet_body["vy"] * self.edge_bounce_factor

        for index, planet_body_a in enumerate(self.planet_bodies):
            for planet_body_b in self.planet_bodies[index + 1:]:
                self._resolve_planet_collision(planet_body_a, planet_body_b)

    def _planet_under_cursor(self, x, y):
        for planet_body in reversed(self.planet_bodies):
            planet_circle = planet_body["circle"]
            cursor_offset_x = x - planet_circle.x
            cursor_offset_y = y - planet_circle.y
            if cursor_offset_x * cursor_offset_x + cursor_offset_y * cursor_offset_y <= planet_circle.radius * planet_circle.radius:
                return planet_body
        return None

    def _closest_planet_to_cursor(self, x, y):
        closest_planet = None
        closest_distance_squared = None

        for planet_body in self.planet_bodies:
            planet_circle = planet_body["circle"]
            cursor_offset_x = x - planet_circle.x
            cursor_offset_y = y - planet_circle.y
            distance_squared = cursor_offset_x * cursor_offset_x + cursor_offset_y * cursor_offset_y

            if closest_distance_squared is None or distance_squared < closest_distance_squared:
                closest_distance_squared = distance_squared
                closest_planet = planet_body

        return closest_planet

    def _sun_under_cursor(self, x, y):
        cursor_offset_x = x - self.sun_circle.x
        cursor_offset_y = y - self.sun_circle.y
        return cursor_offset_x * cursor_offset_x + cursor_offset_y * cursor_offset_y <= self.sun_circle.radius * self.sun_circle.radius

    def _resolve_planet_collision(self, planet_body_a, planet_body_b):
        circle_a = planet_body_a["circle"]
        circle_b = planet_body_b["circle"]

        separation_x = circle_b.x - circle_a.x
        separation_y = circle_b.y - circle_a.y
        distance = math.hypot(separation_x, separation_y)
        min_distance = circle_a.radius + circle_b.radius

        if distance <= 1e-6:
            distance = 1e-6
            separation_x = 1e-6
            separation_y = 0.0

        if distance >= min_distance:
            return

        normal_x = separation_x / distance
        normal_y = separation_y / distance
        overlap = min_distance - distance

        mass_a = planet_body_a["mass"]
        mass_b = planet_body_b["mass"]
        inverse_mass_a = 0.0 if planet_body_a is self.selected_planet else 1.0 / mass_a
        inverse_mass_b = 0.0 if planet_body_b is self.selected_planet else 1.0 / mass_b
        inverse_mass_sum = inverse_mass_a + inverse_mass_b

        if inverse_mass_sum > 0.0:
            correction_x = normal_x * overlap
            correction_y = normal_y * overlap
            circle_a.x -= correction_x * (inverse_mass_a / inverse_mass_sum)
            circle_a.y -= correction_y * (inverse_mass_a / inverse_mass_sum)
            circle_b.x += correction_x * (inverse_mass_b / inverse_mass_sum)
            circle_b.y += correction_y * (inverse_mass_b / inverse_mass_sum)

        relative_vx = planet_body_b["vx"] - planet_body_a["vx"]
        relative_vy = planet_body_b["vy"] - planet_body_a["vy"]
        velocity_along_normal = relative_vx * normal_x + relative_vy * normal_y

        if velocity_along_normal > 0.0:
            return

        impulse = -(1.0 + self.planet_collision_restitution) * velocity_along_normal
        impulse /= inverse_mass_sum if inverse_mass_sum > 0.0 else 1.0

        if planet_body_a is not self.selected_planet:
            planet_body_a["vx"] -= impulse * inverse_mass_a * normal_x
            planet_body_a["vy"] -= impulse * inverse_mass_a * normal_y

        if planet_body_b is not self.selected_planet:
            planet_body_b["vx"] += impulse * inverse_mass_b * normal_x
            planet_body_b["vy"] += impulse * inverse_mass_b * normal_y

    def on_mouse_press(self, x, y, button, modifiers):
        if button == mouse.LEFT:
            if self._sun_under_cursor(x, y):
                self.is_dragging_sun = True
                self.sun_drag_offset_x = self.center_x - x
                self.sun_drag_offset_y = self.center_y - y
                self.selected_planet = None
            else:
                self.is_dragging_sun = False
                self.selected_planet = self._planet_under_cursor(x, y)
                if self.selected_planet is None:
                    self.selected_planet = self._closest_planet_to_cursor(x, y)

            if self.selected_planet:
                self.last_mouse_drag_time = time.perf_counter()
                self.last_mouse_drag_x = float(x)
                self.last_mouse_drag_y = float(y)
                self.fling_velocity_x = 0.0
                self.fling_velocity_y = 0.0

    def on_mouse_drag(self, x, y, drag_delta_x, drag_delta_y, buttons, modifiers):
        if self.is_dragging_sun and (buttons & mouse.LEFT):
            self.center_x = x + self.sun_drag_offset_x
            self.center_y = y + self.sun_drag_offset_y
            self.sun_circle.x = self.center_x
            self.sun_circle.y = self.center_y
            return

        if self.selected_planet and (buttons & mouse.LEFT):
            planet_circle = self.selected_planet["circle"]
            planet_circle.x = x
            planet_circle.y = y

            now = time.perf_counter()
            elapsed = now - self.last_mouse_drag_time if self.last_mouse_drag_time is not None else 0.0
            if elapsed > 1e-4:
                instant_velocity_x = (x - self.last_mouse_drag_x) / elapsed
                instant_velocity_y = (y - self.last_mouse_drag_y) / elapsed
                self.fling_velocity_x = 0.65 * self.fling_velocity_x + 0.35 * instant_velocity_x
                self.fling_velocity_y = 0.65 * self.fling_velocity_y + 0.35 * instant_velocity_y

            self.last_mouse_drag_time = now
            self.last_mouse_drag_x = float(x)
            self.last_mouse_drag_y = float(y)

    def on_mouse_release(self, x, y, button, modifiers):
        if button == mouse.LEFT and self.is_dragging_sun:
            self.is_dragging_sun = False
            return

        if button == mouse.LEFT and self.selected_planet:
            planet_circle = self.selected_planet["circle"]
            release_offset_x = planet_circle.x - self.center_x
            release_offset_y = planet_circle.y - self.center_y
            dist = math.hypot(release_offset_x, release_offset_y)
            min_dist = self.sun_circle.radius + planet_circle.radius + 2
            if dist < min_dist:
                normal_x = release_offset_x / max(dist, 1e-6)
                normal_y = release_offset_y / max(dist, 1e-6)
                planet_circle.x = self.center_x + normal_x * min_dist
                planet_circle.y = self.center_y + normal_y * min_dist

            self.selected_planet["vx"] = self.fling_velocity_x
            self.selected_planet["vy"] = self.fling_velocity_y

            self.selected_planet = None
            self.last_mouse_drag_time = None

    def on_draw(self):
        self.clear()
        self.draw_batch.draw()

    def on_resize(self, width, height):
        old_center_x, old_center_y = self.center_x, self.center_y
        self.center_x, self.center_y = width // 2, height // 2
        shift_x = self.center_x - old_center_x
        shift_y = self.center_y - old_center_y

        self.sun_circle.x = self.center_x
        self.sun_circle.y = self.center_y

        for planet_body in self.planet_bodies:
            planet_circle = planet_body["circle"]
            planet_circle.x += shift_x
            planet_circle.y += shift_y

        return super().on_resize(width, height)


game = SolarSystem()
clock.schedule_interval(game.update, 1 / 60)

run()
 