"""
Decision trees — scouter le poste d'un joueur NBA (PG/SG/SF/PF/C).

Module machine-learning / 01-arbres-de-decision.
Reproduit l'analyse complete de la lecon :
  - arbre lisible (profondeur 2) + lecture de ses regles
  - calcul de l'impurete de Gini a la main pour la premiere coupure
  - diagnostic du sur-apprentissage (profondeur + validation croisee)
  - elagage par cost-complexity pruning (ccp_alpha)
  - regions de decision (taille x poids)
  - matrice de confusion + importance des variables
  - comparaison avec une regression logistique (modele lineaire)

Cible (y)    : position
Variables (X): physique (taille, poids) + stats par match
random_state : 42 partout, pour des resultats reproductibles.

Usage :
    python script.py [chemin/vers/player_season.csv]
Sans argument, lit /data/player_season.csv (chemin Pyodide du webapp).
"""

import sys

import numpy as np
import pandas as pd
import matplotlib

matplotlib.use("Agg")
import matplotlib.pyplot as plt
from sklearn.linear_model import LogisticRegression
from sklearn.metrics import confusion_matrix
from sklearn.model_selection import cross_val_score, train_test_split
from sklearn.pipeline import make_pipeline
from sklearn.preprocessing import StandardScaler
from sklearn.tree import DecisionTreeClassifier, export_text

RANDOM_STATE = 42
FEATS = [
    "height_in", "weight", "trb_pg", "orb_pg", "drb_pg", "ast_pg",
    "blk_pg", "stl_pg", "fg3_pg", "ppg", "fga_pg", "ft_pg",
]


def gini(labels: pd.Series) -> float:
    """Indice de Gini = 1 - somme(p_k^2) = taux d'erreur d'un devineur naif."""
    p = labels.value_counts(normalize=True).values
    return 1.0 - np.sum(p ** 2)


def main(path: str) -> None:
    df = pd.read_csv(path)
    X, y = df[FEATS], df["position"]
    Xtr, Xte, ytr, yte = train_test_split(
        X, y, test_size=0.3, random_state=RANDOM_STATE, stratify=y
    )
    postes = sorted(y.unique())

    print("=== 1. Arbre lisible (profondeur 2) ===")
    arbre2 = DecisionTreeClassifier(max_depth=2, random_state=RANDOM_STATE).fit(Xtr, ytr)
    print("exactitude test :", round(arbre2.score(Xte, yte), 3))
    print(export_text(arbre2, feature_names=FEATS))

    print("=== 2. Impurete de Gini de la premiere coupure (taille <= 78,5) ===")
    parent = gini(ytr)
    gauche = ytr[Xtr["height_in"] <= 78.5]
    droite = ytr[Xtr["height_in"] > 78.5]
    wg, wd = len(gauche) / len(ytr), len(droite) / len(ytr)
    enfants = wg * gini(gauche) + wd * gini(droite)
    print("Gini parent          :", round(parent, 3))
    print("Gini gauche (petits) :", round(gini(gauche), 3), " n =", len(gauche))
    print("Gini droite (grands) :", round(gini(droite), 3), " n =", len(droite))
    print("Gini enfants (moyen) :", round(enfants, 3))
    print("Gain de la coupure   :", round(parent - enfants, 3))

    print("=== 3. Profondeur et sur-apprentissage ===")
    profondeurs = [1, 2, 3, 4, 5, 6, 8, 12]
    tr, te, cv = [], [], []
    for d in profondeurs:
        a = DecisionTreeClassifier(max_depth=d, random_state=RANDOM_STATE).fit(Xtr, ytr)
        tr.append(a.score(Xtr, ytr))
        te.append(a.score(Xte, yte))
        cv.append(
            cross_val_score(
                DecisionTreeClassifier(max_depth=d, random_state=RANDOM_STATE),
                X, y, cv=5,
            ).mean()
        )
        print(f"profondeur {d:2} : train={tr[-1]:.3f} test={te[-1]:.3f} cv={cv[-1]:.3f}")

    plt.figure(figsize=(8, 5))
    plt.plot(profondeurs, tr, "o-", label="entrainement")
    plt.plot(profondeurs, te, "s-", label="test")
    plt.plot(profondeurs, cv, "^--", label="validation croisee (5 plis)")
    plt.xlabel("profondeur maximale")
    plt.ylabel("exactitude")
    plt.title("Profondeur et sur-apprentissage")
    plt.legend()
    plt.grid(alpha=0.3)
    plt.tight_layout()
    plt.savefig("arbre_profondeur.png", dpi=100)
    plt.close()

    print("=== 4. Elagage par cost-complexity pruning ===")
    brut = DecisionTreeClassifier(random_state=RANDOM_STATE).fit(Xtr, ytr)
    print("brut   : feuilles =", brut.get_n_leaves(), "| test =", round(brut.score(Xte, yte), 3))
    alphas = brut.cost_complexity_pruning_path(Xtr, ytr).ccp_alphas
    meilleur_a, meilleur_cv = 0.0, -1.0
    for a in alphas:
        s = cross_val_score(
            DecisionTreeClassifier(random_state=RANDOM_STATE, ccp_alpha=a), Xtr, ytr, cv=5
        ).mean()
        if s > meilleur_cv:
            meilleur_cv, meilleur_a = s, a
    elague = DecisionTreeClassifier(random_state=RANDOM_STATE, ccp_alpha=meilleur_a).fit(Xtr, ytr)
    print("ccp_alpha retenu     :", round(meilleur_a, 5))
    print(
        "elague : feuilles =", elague.get_n_leaves(),
        "| profondeur =", elague.get_depth(),
        "| test =", round(elague.score(Xte, yte), 3),
    )

    print("=== 5. Regions de decision (taille x poids) ===")
    X2 = df[["height_in", "weight"]].values
    codes, labels = pd.factorize(df["position"], sort=True)
    clf2 = DecisionTreeClassifier(max_depth=3, random_state=RANDOM_STATE).fit(X2, codes)
    x0, x1 = X2[:, 0].min() - 1, X2[:, 0].max() + 1
    y0, y1 = X2[:, 1].min() - 10, X2[:, 1].max() + 10
    xx, yy = np.meshgrid(np.linspace(x0, x1, 300), np.linspace(y0, y1, 300))
    Z = clf2.predict(np.c_[xx.ravel(), yy.ravel()]).reshape(xx.shape)
    fig, ax = plt.subplots(figsize=(8, 6))
    ax.contourf(xx, yy, Z, alpha=0.25, cmap="tab10", levels=np.arange(-0.5, 5, 1))
    ax.scatter(X2[:, 0], X2[:, 1], c=codes, cmap="tab10", edgecolor="k", s=25)
    ax.set_xlabel("taille (pouces)")
    ax.set_ylabel("poids (livres)")
    ax.set_title("Les boites de decision d'un arbre (taille x poids)")
    poignees = [
        plt.Line2D([0], [0], marker="o", color="w",
                   markerfacecolor=plt.cm.tab10(i / 9), markeredgecolor="k",
                   markersize=8, label=lab)
        for i, lab in enumerate(labels)
    ]
    ax.legend(handles=poignees, title="poste")
    plt.tight_layout()
    plt.savefig("arbre_regions.png", dpi=100)
    plt.close()

    print("=== 6. Matrice de confusion + comparaison lineaire ===")
    arbre4 = DecisionTreeClassifier(max_depth=4, random_state=RANDOM_STATE).fit(Xtr, ytr)
    print("Arbre (profondeur 4)  | test =", round(arbre4.score(Xte, yte), 3))
    print("Matrice de confusion (lignes = vrai, colonnes = predit) :", postes)
    print(confusion_matrix(yte, arbre4.predict(Xte), labels=postes))
    print("Variables les plus utiles :")
    print(
        pd.Series(arbre4.feature_importances_, index=FEATS)
        .sort_values(ascending=False)
        .head(4)
        .round(3)
    )
    logit = make_pipeline(
        StandardScaler(), LogisticRegression(max_iter=2000, random_state=RANDOM_STATE)
    ).fit(Xtr, ytr)
    print("Regression logistique | test =", round(logit.score(Xte, yte), 3))

    print("\nFigures ecrites : arbre_profondeur.png, arbre_regions.png")


if __name__ == "__main__":
    csv_path = sys.argv[1] if len(sys.argv) > 1 else "/data/player_season.csv"
    main(csv_path)
