import sched
import psycopg2
from flask_restx import Resource, Api, Namespace, fields
from flask import request
from flask import Flask, render_template, request, jsonify, Response
from statsmodels.tsa.statespace.sarimax import SARIMAX
from datetime import datetime, timedelta
import pandas as pd
import numpy as np
import pickle
import time

Action = Namespace(
    name="Action",
    description="노드 분석을 위해 사용하는 api.",
)

db_config = {
    'dbname': 'welding',
    'user': 'postgres',
    'password': 'ts4430!@',
    'host': 'localhost',  # e.g., 'localhost'
    'port': '5432',  # e.g., '5432'
}

key_columns = ["temperature", "relative_humidity", "absolute_humidity"]

def buck_equation(temperature): # temp in Celsius
    saturation_vapor_pressure = 0.61121 * np.exp((18.678 - temperature / 234.5) * (temperature / (257.14 + temperature)))
    return saturation_vapor_pressure * 1000 # KPa -> Pa

def absolute_humidity(relative_humidity, temperature):
    relative_humidity = np.array(relative_humidity)
    temperature = np.array(temperature)
    saturation_vapor_pressure = buck_equation(temperature)
    # 461.5/Kg Kelvin is specific gas constant
    return saturation_vapor_pressure * relative_humidity * 0.01 /(461.5 * (temperature + 273.15)) # g/m^3

# @sched.scheduler
# def weather_update

@Action.route('/forecast')
class forecast(Resource):
    @Action.doc(responses={200: 'Success'})
    @Action.doc(responses={500: 'Register Failed'})
    def post(self):
        if request.method == 'GET':
            df = pd.read_csv("data/weather/weather_data_forecast.csv")
            humidity = df['forecast'].value[6:] / 6
            if humidity > 90:
                return {
                    'report': "warn"
                }, 200
            elif humidity <= 90 and humidity > 80:
                return {
                    'report': "caution"
                }, 200
            else:
                return {
                    'report': "safe"
                }, 200

@Action.route('/train_sarima', methods=['POST'])
class TrainSARIMA(Resource):
    @Action.doc(responses={200: 'Success'})
    @Action.doc(responses={500: 'Register Failed'})
    def post(self):
        query = "SELECT * FROM weather_data ORDER BY time DESC LIMIT 600"

        try:
            view_past = int(request.form.get('past_data_for_prediction',72))
            future_hours = int(request.form.get('future_hours', 24))
            save_name = request.form.get('save_name', 'prediction')

            with psycopg2.connect(**db_config) as conn:
                df = pd.read_sql_query(query, conn)
            df_sarima = df.iloc[:view_past]
            df_sarima = df_sarima.iloc[::-1].reset_index(drop=True)
            df = df.iloc[::-1].reset_index(drop=True)

            return_index = [df_sarima['time'].iloc[-1] + timedelta(hours=i) for i in range(1, future_hours + 1)]
            forecast_return = pd.DataFrame(None, columns=["time", "temperature", "relative_humidity", "absolute_humidity"])
            forecast_return['time'] = return_index
            seasonal_order = {
                f"{key_columns[0]}": (0, 1, 1, 24),
                f"{key_columns[1]}": (0, 1, 1, 24),
                f"{key_columns[2]}": (0, 1, 1, 24),
            }
            t1 = time.time()
            for col_key in key_columns:

                model = SARIMAX(df_sarima[col_key], order=(1, 0, 2), seasonal_order=seasonal_order[col_key])
                model_fit = model.fit(disp=False)


                forecast_values = model_fit.forecast(steps=future_hours)
                forecast_return[col_key] = forecast_values.values
                # with open(f'predictions/sarima_model_{save_name}_{col_key}.pkl', 'wb') as pkl_file:
                #     pickle.dump(model_fit, pkl_file)
            t2 = -(t1 - time.time())
            print(f"{t2} seconds per {future_hours*3}\n"
                  f"that is {future_hours*3 / t2} per seconds")
            return Response((pd.concat((df, forecast_return)).reset_index(drop=True)).to_json(orient='columns'), mimetype='application/json')

        except Exception as e:
            return jsonify({"error": str(e)}), 500


def forecast_from_saved_model(df, trained_weight="predictions" , future_hours=24):
    # Load the saved model
    forecast_df = None

    for key_col in key_columns:
        with open(trained_weight, 'rb') as pkl_file:
            loaded_model = pickle.load(pkl_file)

        print("files loaded")
        t1 = time.time()
        # Forecast the next 'future_hours' using the loaded model
        forecast_values = loaded_model.forecast(steps=future_hours)
        forecast_index = [df['time'].iloc[-1] + timedelta(hours=i) for i in range(1, future_hours + 1)]
        forecast_df = pd.DataFrame({
            'time': forecast_index,
            'forecast': forecast_values
        })

    # forecast_df.to_csv(f"{file.split('.')[0]}_forecast.csv", index=False)
    t2 = -(t1 - time.time())
    # print(forecast_df)
    print(f"{t2} seconds per {future_hours}\n"
          f"that is {future_hours/t2} per seconds")
    return forecast_df

@Action.route('/fetch_sensor')
class FetchSensorData(Resource):
    @Action.doc(responses={200: 'Success', 500: 'Failed'})
    def get(self):
        conn_params = db_config  # Define or fetch your connection parameters here
        query = "SELECT * FROM weather_data ORDER BY time DESC LIMIT 600"
        try:
            with psycopg2.connect(**conn_params) as conn:
                df = pd.read_sql_query(query, conn)
                # predictions
                # Convert Timestamp columns to string
                for column in df.columns:
                    if df[column].dtype == "datetime64[ns]":
                        df[column] = df[column].astype(str)

                return df.to_dict(orient='list'), 200
        except Exception as e:
            return {"message": str(e)}, 500