import pandas as pd
import os
import pyproj
import numpy as np

transformer = pyproj.Transformer.from_crs("EPSG:2154", "EPSG:4326", always_xy=True)

def lambert93_to_wgs84_batch(x, y):
    # Conversion des coordonnées de Lambert 93 à WGS84
    mask = (x.notna()) & (y.notna())  # Masque pour éviter les NaN
    x_valid, y_valid = x[mask].astype(float), y[mask].astype(float)
    lon, lat = transformer.transform(x_valid.to_numpy(), y_valid.to_numpy())

    x.loc[mask], y.loc[mask] = lat.astype(str), lon.astype(str)
    return x, y

def charger_donnees(date):
    file_path_res = f"data/raw/UDI_RES_{date}.txt"
    file_path_plv = f"data/raw/UDI_PLV_{date}.txt"
    
    if not os.path.exists(file_path_res) or not os.path.exists(file_path_plv):
        print(f"Attention fichiers manquants pour {date}.")
        return False
    
    # Choix des colonnes intéressantes de UDI_RES (données des résultats des prélèvements)
    columns_res = [
        "cddept", "referenceprel", "cdparametre", "rsana", "cdunitereferencesiseeaux", "cdunitereference", "rqana",
        "rssigne"
    ]
    dtype_res = {col: "string" for col in columns_res}
    data_res = pd.read_csv(file_path_res, sep=",", dtype=dtype_res, usecols=columns_res)
    data_res['rqana'] = pd.to_numeric(data_res['rqana'].str.extract(r'(\d+)', expand=False), errors='coerce')
    
    output_res = f"data/processed/res{date}.csv"
    data_res.to_csv(output_res, index=False)
    
    # Choix des colonnes intéressantes de UDI_PLV (métadonnées des prélèvements)
    columns_plv = [
        "cddept", "inseecommune", "nomcommune", "cdreseau", "cdpointsurv", "nompointsurv",
        "referenceprel", "dateprel", "coord_x", "coord_y"
    ]
    dtype_plv = {col: "string" for col in columns_plv}
    data_plv = pd.read_csv(file_path_plv, sep=",", dtype=dtype_plv, usecols=columns_plv)


    # Conversion des coordonnées
    data_plv["coord_x"], data_plv["coord_y"] = lambert93_to_wgs84_batch(data_plv["coord_x"], data_plv["coord_y"])

    output_plv = f"data/processed/plv{date}.csv"
    data_plv.to_csv(output_plv, index=False)
    
    return True

def jointure(date):
    # Jointure des fichiers RES et PLV.
    res_path = f"data/processed/res{date}.csv"
    plv_path = f"data/processed/plv{date}.csv"
    
    data_res = pd.read_csv(res_path, dtype=str)
    data_plv = pd.read_csv(plv_path, dtype=str)
    
    # Formatting des code de reference des prélèvements (clé de la jointure)
    data_res['referenceprel'] = data_res['referenceprel'].astype(str).str.strip()
    data_plv['referenceprel'] = data_plv['referenceprel'].astype(str).str.strip()
    
    data = pd.merge(data_res, data_plv, on="referenceprel", how="outer", indicator=True)
    
    print(data["_merge"].value_counts())
    
    output_path = f"data/processed/Table{date}.csv"
    data.drop(columns=['_merge'], inplace=True)
    data.to_csv(output_path, index=False)
    
    return True

def supprimer_donnees(date):
    # Supprime les fichiers temporaires uniquement si la jointure a réussi
    table_path = f"data/processed/Table{date}.csv"
    if os.path.exists(table_path):
        os.remove(f"data/processed/res{date}.csv")
        os.remove(f"data/processed/plv{date}.csv")

def normalize_cdparametre(date):
    table_path = f"data/processed/Table{date}.csv"
    output_file = f"data/processed/Table{date}_normalized.csv"
    df = pd.read_csv(table_path, dtype={'cdparametre': str})
    # Convertir toutes les codes des paramètres en string et ajouter ".0" si nécessaire (format des codes très disparatre)
    df['cdparametre'] = df['cdparametre'].astype(str).apply(lambda x: x if x.endswith('.0') else x + '.0')

    df.to_csv(output_file, index=False)

    return df

dates = ["20" + str(i).zfill(2) for i in range(18, 25)]
processed_dates = []

for date in dates:
    print(f"Processing de la data de {date}")
    if charger_donnees(date):
        if jointure(date):
            supprimer_donnees(date)
            normalize_cdparametre(date)
            processed_dates.append(date)
            print(f"Data de {date} traitée")
        else:
            print(f"Jointure échouée pour {date}")
    else:
        print(f"Chargement échoué pour {date}")