"""
Module Data Science 01 - Qu'est-ce que l'apprentissage automatique ?

Reproduit l'analyse du module de bout en bout :
  1. charge player_season.csv (un joueur-saison par ligne) ;
  2. oppose une regle ecrite a la main a un modele appris (erreur moyenne) ;
  3. predit les points pour des joueurs a 5, 10, 15, 20, 25 tirs/match ;
  4. trace le nuage de points + la relation apprise f(X) + une prediction ;
  5. mesure la generalisation via une separation entrainement / test.

Usage :
    python script.py [chemin_vers_player_season.csv]

Sans argument, lit /data/player_season.csv (chemin Pyodide dans la webapp).
random_state fixe partout pour des resultats reproductibles.
"""

import sys

import numpy as np
import pandas as pd
import matplotlib

matplotlib.use("Agg")  # backend non interactif (sauvegarde fichier)
import matplotlib.pyplot as plt

from sklearn.linear_model import LinearRegression
from sklearn.model_selection import train_test_split
from sklearn.metrics import mean_absolute_error

RANDOM_STATE = 42


def charger_donnees(chemin: str) -> pd.DataFrame:
    df = pd.read_csv(chemin)
    print(f"observations (joueurs) : {len(df)}")
    print(f"colonnes disponibles : {df.columns.tolist()}")
    return df


def regle_vs_modele(df: pd.DataFrame) -> LinearRegression:
    """Compare une regle de bon sens et un modele appris."""
    X = df[["fga_pg"]]  # variable explicative (entree)
    y = df["ppg"]       # cible (sortie)

    # Regle ecrite a la main : 1 tir = 1 point
    erreur_regle = mean_absolute_error(y, df["fga_pg"])

    # Modele APPRIS : il trouve lui-meme la relation
    modele = LinearRegression().fit(X, y)
    erreur_modele = mean_absolute_error(y, modele.predict(X))

    print("\n--- Regle a la main vs modele appris ---")
    print(f"Regle a la main (pts = tirs) : erreur moyenne {erreur_regle:.2f} points")
    print(f"Modele appris                : erreur moyenne {erreur_modele:.2f} points")
    print(f"Relation apprise : 1 tir = {modele.coef_[0]:.2f} point")
    return modele


def predire_quelques_cas(modele: LinearRegression) -> None:
    print("\n--- Predictions pour des cas nouveaux ---")
    for tirs in [5, 10, 15, 20, 25]:
        pred = modele.predict(pd.DataFrame({"fga_pg": [tirs]}))[0]
        print(f"{tirs} tirs/match  ->  {pred:.1f} points predits")


def tracer_relation(df: pd.DataFrame, sortie: str = "module01_relation.png") -> None:
    X = df[["fga_pg"]]
    y = df["ppg"]
    modele = LinearRegression().fit(X, y)

    grille = np.linspace(df["fga_pg"].min(), df["fga_pg"].max(), 100)
    ligne = modele.predict(pd.DataFrame({"fga_pg": grille}))

    fig, ax = plt.subplots(figsize=(8, 5))
    ax.scatter(df["fga_pg"], df["ppg"], s=18, alpha=0.4, label="Joueurs reels")
    ax.plot(grille, ligne, color="crimson", lw=2.5, label="Relation apprise f(X)")

    x0 = 15
    y0 = modele.predict(pd.DataFrame({"fga_pg": [x0]}))[0]
    ax.scatter([x0], [y0], color="black", s=90, zorder=5)
    ax.annotate(
        f"{x0} tirs -> {y0:.1f} pts",
        (x0, y0),
        xytext=(x0 + 1.5, y0 - 5),
        arrowprops=dict(arrowstyle="->"),
    )

    ax.set_xlabel("Tirs tentes par match (fga_pg)")
    ax.set_ylabel("Points par match (ppg)")
    ax.set_title("Le modele apprend la tendance, puis predit")
    ax.legend()
    plt.tight_layout()
    plt.savefig(sortie, dpi=110)
    print(f"\nFigure enregistree : {sortie}")


def evaluer_generalisation(df: pd.DataFrame) -> None:
    """Erreur d'entrainement vs erreur de test (joueurs jamais vus)."""
    X = df[["fga_pg"]]
    y = df["ppg"]

    X_tr, X_te, y_tr, y_te = train_test_split(
        X, y, test_size=0.25, random_state=RANDOM_STATE
    )

    modele = LinearRegression().fit(X_tr, y_tr)
    err_tr = mean_absolute_error(y_tr, modele.predict(X_tr))
    err_te = mean_absolute_error(y_te, modele.predict(X_te))

    print("\n--- Generalisation (separation entrainement / test) ---")
    print(f"Erreur sur l'entrainement (joueurs vus)  : {err_tr:.2f} points")
    print(f"Erreur sur le test (joueurs jamais vus)  : {err_te:.2f} points")


def main() -> None:
    chemin = sys.argv[1] if len(sys.argv) > 1 else "/data/player_season.csv"
    df = charger_donnees(chemin)
    modele = regle_vs_modele(df)
    predire_quelques_cas(modele)
    tracer_relation(df)
    evaluer_generalisation(df)


if __name__ == "__main__":
    main()
