import pandas as pd
import numpy as np
import time
from statsmodels.tsa.statespace.sarimax import SARIMAX
from statsmodels.graphics.tsaplots import plot_acf, plot_pacf
import matplotlib.pyplot as plt
from datetime import datetime, timedelta
import plotly.express as px
from tools.algo.humidity import absolute_humidity
import pickle


def sarima(file, save_name, col_key='상대습도', future_hours=24):
    df = pd.read_csv(file)

    df = df.iloc[-600:]

    ah = absolute_humidity(df[col_key], df["기온"])

    df['관측시각'] = df['관측시각'].apply(lambda x: datetime.strptime(f"{x}", '%Y%m%d%H%M'))
    df["절대습도"] = ah

    # Use the data up to the last 'future_hours' for fitting
    model = SARIMAX(df[col_key].iloc[:-future_hours], order=(1, 0, 2), seasonal_order=(0, 1, 2, 24))
    model_fit = model.fit()
    print(model_fit.summary())

    # Forecast the next 'future_hours' and append to df
    forecast_values = model_fit.forecast(steps=future_hours)
    forecast_index = [df['관측시각'].iloc[-1] + timedelta(hours=i) for i in range(1, future_hours + 1)]
    forecast_series = pd.Series(data=forecast_values, index=forecast_index, name='forecast')

    # Forecast the next 'future_hours'
    forecast_values = model_fit.forecast(steps=future_hours)
    forecast_index = [df['관측시각'].iloc[-1] + timedelta(hours=i) for i in range(1, future_hours + 1)]
    forecast_df = pd.DataFrame({
        '관측시각': forecast_index,
        'forecast': forecast_values
    })

    forecast_df.to_csv(f"{file.split('.')[0]}_forecast.csv", index=False)

    with open(f'sarima_model_{col_key}.pkl', 'wb') as pkl_file:
        pickle.dump(model_fit, pkl_file)

def forecast_from_saved_model(file, model_file, future_hours=24):
    # Load the saved model
    with open(model_file, 'rb') as pkl_file:
        loaded_model = pickle.load(pkl_file)

    df = pd.read_csv(file)
    print("files loaded")
    t1 = time.time()

    df['관측시각'] = df['관측시각'].apply(lambda x: datetime.strptime(f"{x}", '%Y%m%d%H%M'))

    # Forecast the next 'future_hours' using the loaded model
    forecast_values = loaded_model.forecast(steps=future_hours)
    forecast_index = [df['관측시각'].iloc[-1] + timedelta(hours=i) for i in range(1, future_hours + 1)]
    forecast_df = pd.DataFrame({
        '관측시각': forecast_index,
        'forecast': forecast_values
    })

    forecast_df.to_csv(f"{file.split('.')[0]}_forecast.csv", index=False)
    t2 = -(t1 - time.time())
    # print(forecast_df)
    print(f"{t2} seconds per {future_hours}\n"
          f"that is {future_hours/t2} per seconds")
    return forecast_df

if __name__ == "__main__":
    sarima("/home/juni/PycharmProjects/failure_analysis/data/weather/weather_data.csv", "test1", "상대습도")
    # forecast = forecast_from_saved_model("/home/juni/PycharmProjects/failure_analysis/data/weather/weather_data.csv",
    #                           "/home/juni/PycharmProjects/failure_analysis/tools/algo/sarima_model_test1.pkl",
    #                           24)
    px.bar(forecast)
    # df = pd.read_csv("/home/juni/PycharmProjects/failure_analysis/data/weather/202007010000_202308310000_f.csv")
    # ah = absolute_humidity(df["상대습도"], df["기온"])
    # df['관측시각'] = df['관측시각'].apply(lambda x: datetime.strptime(f"{x}", '%Y%m%d%H%M'))
    # df["절대습도"] = ah
    # # fig = go.Figure()
    # #
    # # fig.add_trace(
    # #     go.Scatter(x=df["관측시각"], y=df["절대습도"])
    # # )
    # # fig.add_trace(
    # #     go.Scatter(x=df["관측시각"], y=signal.savgol_filter(
    # #         df["절대습도"],72,3)
    # #     ))
    # # fig.show()
    # log_df = np.log(df["상대습도"])
    # diff_1 = (log_df.diff(periods=1).iloc[1:])
    # diff_2 = diff_1.diff(periods=1).iloc[1:]
    # plot_acf(diff_2)
    # plot_pacf(diff_2)
    # plt.show()
    # model = SARIMAX(df["상대습도"], order=(2,0,2), seasonal_order=(1,1,2,24))
    # model_fit = model.fit()
    # # ARIMA_model = pm.auto_arima(df['절대습도'],
    # #                             start_p=1,
    # #                             start_q=1,
    # #                             test='adf',  # use adftest to find optimal 'd'
    # #                             max_p=3, max_q=3,  # maximum p and q
    # #                             m=24,  # frequency of series (if m==1, seasonal is set to FALSE automatically)
    # #                             d=None,  # let model determine 'd'
    # #                             D=2, #order of the seasonal differencing
    # #                             seasonal=True,  # No Seasonality for standard ARIMA
    # #                             trace=False,  # logs
    # #                             error_action='warn',  # shows errors ('ignore' silences these)
    # #                             suppress_warnings=False,
    # #                             stepwise=True)
    # print(model_fit.summary())
    # df['forecast'] = model_fit.predict(start=-100, end=-1, dynamic=True)
    # # df[['절대습도', 'forecast']].plot(figsize=(12, 8))
    # fig = px.line(df[['상대습도', 'forecast']])
    # fig.show()