import pandas as pd
from statsmodels.tsa.statespace.sarimax import SARIMAX
import matplotlib.pyplot as plt
from datetime import datetime, timedelta
import plotly.express as px
from tools.algo.humidity import absolute_humidity


def sarima(file, col_key='상대습도', future_hours=24):
    df = pd.read_csv(file)

    ah = absolute_humidity(df[col_key], df["기온"])

    df['관측시각'] = df['관측시각'].apply(lambda x: datetime.strptime(f"{x}", '%Y%m%d%H%M'))
    df["절대습도"] = ah

    # Use the data up to the last 'future_hours' for fitting
    model = SARIMAX(df[col_key].iloc[:-future_hours], order=(1, 0, 2), seasonal_order=(0, 1, 2, 24))
    model_fit = model.fit()
    print(model_fit.summary())

    # Forecast the next 'future_hours' and append to df
    forecast_values = model_fit.forecast(steps=future_hours)
    forecast_index = [df['관측시각'].iloc[-1] + timedelta(hours=i) for i in range(1, future_hours + 1)]
    forecast_series = pd.Series(data=forecast_values, index=forecast_index, name='forecast')

    # Forecast the next 'future_hours'
    forecast_values = model_fit.forecast(steps=future_hours)
    forecast_index = [df['관측시각'].iloc[-1] + timedelta(hours=i) for i in range(1, future_hours + 1)]
    forecast_df = pd.DataFrame({
        '관측시각': forecast_index,
        'forecast': forecast_values
    })

    forecast_df.to_csv(f"{file.split('.')[0]}_forecast.csv", index=False)

if __name__ == "__main__":
    df = pd.read_csv("/home/juni/PycharmProjects/failure_analysis/data/weather/202007010000_202308310000_f.csv")
    ah = absolute_humidity(df["상대습도"], df["기온"])
    df['관측시각'] = df['관측시각'].apply(lambda x: datetime.strptime(f"{x}", '%Y%m%d%H%M'))
    df["절대습도"] = ah
    # fig = go.Figure()
    #
    # fig.add_trace(
    #     go.Scatter(x=df["관측시각"], y=df["절대습도"])
    # )
    # fig.add_trace(
    #     go.Scatter(x=df["관측시각"], y=signal.savgol_filter(
    #         df["절대습도"],72,3)
    #     ))
    # fig.show()
    # log_df = np.log(df["절대습도"])
    diff_1 = (df["절대습도"].diff(periods=1).iloc[1:])
    diff_2 = diff_1.diff(periods=1).iloc[1:]
    # plot_acf(diff_2)
    # plot_pacf(diff_2)
    plt.show()
    model = SARIMAX(df["절대습도"], order=(1,0,2), seasonal_order=(0,1,2,24))
    model_fit = model.fit()
    # ARIMA_model = pm.auto_arima(df['절대습도'],
    #                             start_p=1,
    #                             start_q=1,
    #                             test='adf',  # use adftest to find optimal 'd'
    #                             max_p=3, max_q=3,  # maximum p and q
    #                             m=24,  # frequency of series (if m==1, seasonal is set to FALSE automatically)
    #                             d=None,  # let model determine 'd'
    #                             D=2, #order of the seasonal differencing
    #                             seasonal=True,  # No Seasonality for standard ARIMA
    #                             trace=False,  # logs
    #                             error_action='warn',  # shows errors ('ignore' silences these)
    #                             suppress_warnings=False,
    #                             stepwise=True)
    print(model_fit.summary())
    df['forecast'] = model_fit.predict(start=-100, end=-1, dynamic=True)
    # df[['절대습도', 'forecast']].plot(figsize=(12, 8))
    fig = px.line(df[['절대습도', 'forecast']])
    fig.show()