import psycopg2
from flask_restx import Resource, Api, Namespace, fields
from flask import request
from flask import Flask, render_template, request, jsonify, Response, marshal_with
from flask_restful import reqparse
from statsmodels.tsa.statespace.sarimax import SARIMAX
from datetime import datetime, timedelta
from scipy.stats import stats
from scipy.stats import pointbiserialr
import pandas as pd
import numpy as np
import pickle
import time

Action = Namespace(
    name="Action",
    description="다양한 분석과 DB 조회 기능",
)

db_config = {
    'dbname': 'welding',
    'user': 'postgres',
    'password': 'ts4430!@',
    'host': 'localhost',  # e.g., 'localhost'
    'port': '5432',  # e.g., '5432'
}

key_columns = ["temperature", "relative_humidity", "absolute_humidity"]

def buck_equation(temperature): # temp in Celsius
    saturation_vapor_pressure = 0.61121 * np.exp((18.678 - temperature / 234.5) * (temperature / (257.14 + temperature)))
    return saturation_vapor_pressure * 1000 # KPa -> Pa

def absolute_humidity(relative_humidity, temperature):
    relative_humidity = np.array(relative_humidity)
    temperature = np.array(temperature)
    saturation_vapor_pressure = buck_equation(temperature)
    # 461.5/Kg Kelvin is specific gas constant
    return saturation_vapor_pressure * relative_humidity * 0.01 /(461.5 * (temperature + 273.15)) # g/m^3

# @sched.scheduler
# def weather_update

@Action.route('/forecast')
class forecast(Resource):
    @Action.doc(responses={200: 'Success'})
    @Action.doc(responses={500: 'Register Failed'})
    def post(self):
        if request.method == 'GET':
            df = pd.read_csv("data/weather/weather_data_forecast.csv")
            humidity = df['forecast'].value[6:] / 6
            if humidity > 90:
                return {
                    'report': "warn"
                }, 200
            elif humidity <= 90 and humidity > 80:
                return {
                    'report': "caution"
                }, 200
            else:
                return {
                    'report': "safe"
                }, 200

@Action.route('/train_sarima', methods=['POST'])
class TrainSARIMA(Resource):
    @Action.doc(responses={200: 'Success'})
    @Action.doc(responses={500: 'Register Failed'})
    def post(self):
        query = "SELECT * FROM weather_data ORDER BY time DESC LIMIT 600"

        try:
            view_past = int(request.form.get('past_data_for_prediction',72))
            future_hours = int(request.form.get('future_hours', 24))
            save_name = request.form.get('save_name', 'prediction')

            with psycopg2.connect(**db_config) as conn:
                df = pd.read_sql_query(query, conn)
            df_sarima = df.iloc[:view_past]
            df_sarima = df_sarima.iloc[::-1].reset_index(drop=True)
            df = df.iloc[::-1].reset_index(drop=True)

            return_index = [df_sarima['time'].iloc[-1] + timedelta(hours=i) for i in range(1, future_hours + 1)]
            forecast_return = pd.DataFrame(None, columns=["time", "temperature", "relative_humidity", "absolute_humidity"])
            forecast_return['time'] = return_index
            seasonal_order = {
                f"{key_columns[0]}": (0, 1, 1, 24),
                f"{key_columns[1]}": (0, 1, 1, 24),
                f"{key_columns[2]}": (0, 1, 1, 24),
            }
            t1 = time.time()
            for col_key in key_columns:

                model = SARIMAX(df_sarima[col_key], order=(1, 0, 2), seasonal_order=seasonal_order[col_key])
                model_fit = model.fit(disp=False)


                forecast_values = model_fit.forecast(steps=future_hours)
                forecast_return[col_key] = forecast_values.values
                # with open(f'predictions/sarima_model_{save_name}_{col_key}.pkl', 'wb') as pkl_file:
                #     pickle.dump(model_fit, pkl_file)
            t2 = -(t1 - time.time())
            print(f"{t2} seconds per {future_hours*3}\n"
                  f"that is {future_hours*3 / t2} per seconds")
            return Response((pd.concat((df, forecast_return)).reset_index(drop=True)).to_json(orient='columns'), mimetype='application/json')

        except Exception as e:
            return jsonify({"error": str(e)}), 500


def forecast_from_saved_model(df, trained_weight="predictions" , future_hours=24):
    # Load the saved model
    forecast_df = None

    for key_col in key_columns:
        with open(trained_weight, 'rb') as pkl_file:
            loaded_model = pickle.load(pkl_file)

        print("files loaded")
        t1 = time.time()
        # Forecast the next 'future_hours' using the loaded model
        forecast_values = loaded_model.forecast(steps=future_hours)
        forecast_index = [df['time'].iloc[-1] + timedelta(hours=i) for i in range(1, future_hours + 1)]
        forecast_df = pd.DataFrame({
            'time': forecast_index,
            'forecast': forecast_values
        })

    # forecast_df.to_csv(f"{file.split('.')[0]}_forecast.csv", index=False)
    t2 = -(t1 - time.time())
    # print(forecast_df)
    print(f"{t2} seconds per {future_hours}\n"
          f"that is {future_hours/t2} per seconds")
    return forecast_df

@Action.route('/fetch_sensor')
class FetchSensorData(Resource):
    @Action.doc(responses={200: 'Success', 500: 'Failed'})
    def post(self):
        conn_params = db_config  # Define or fetch your connection parameters here
        query = "SELECT * FROM weather_data ORDER BY time DESC LIMIT 600"
        try:
            with psycopg2.connect(**conn_params) as conn:
                df = pd.read_sql_query(query, conn)
                # predictions
                # Convert Timestamp columns to string
                for column in df.columns:
                    if df[column].dtype == "datetime64[ns]":
                        df[column] = df[column].astype(str)

                return df.to_dict(orient='list'), 200
        except Exception as e:
            return {"message": str(e)}, 500

def get_manufacturing_data():
    # Connect to the database
    connection = psycopg2.connect(**db_config)

    # Query the relevant data
    df = pd.read_sql_query("SELECT * FROM Welding_Jobs ORDER BY welding_job_number ASC;", connection)
    connection.close()

    return df


resource_fields = {
    'welding_job_number': fields.Integer,
    'mold_name': fields.String,
    'work_start_time': fields.DateTime,
    'defect_status': fields.String,
    'temperature': fields.Float,
    'relative_humidity': fields.Float,
    'absolute_humidity': fields.Float
}

@Action.route('/Request_Manufacturing_Data')
class ManufacturingData(Resource):
    @marshal_with(resource_fields)
    def get(self):
        return get_manufacturing_data().to_dict(orient="records")


@Action.route('/correlation')
class Correlation(Resource):
    @Action.doc(responses={200: 'Success'})
    @Action.doc(responses={500: 'Register Failed'})
    def post(self):
        try:
            df_failure = get_manufacturing_data()

            correlation_manufacturing_abhumidity = pointbiserialr(df_failure["absolute_humidity"],
                                                                  df_failure['defect_status'])
            correlation_manufacturing_rehumidity = pointbiserialr(df_failure["relative_humidity"],
                                                                  df_failure['defect_status'])
            correlation_manufacturing_temp = pointbiserialr(df_failure["temperature"], df_failure['defect_status'])

            correlations = {
                'Absolute Humidity': correlation_manufacturing_abhumidity,
                'Relative Humidity': correlation_manufacturing_rehumidity,
                'Temperature': correlation_manufacturing_temp
            }

            return {"status": "success", "correlations": correlations}, 200

        except Exception as e:
            return {"status": "failure", "message": str(e)}, 500


@Action.route('/anova')
class AnovaAnalysis(Resource):
    @Action.doc(responses={200: 'Success'})
    @Action.doc(responses={500: 'Analysis Failed'})
    def post(self):
        try:
            df_failure = get_manufacturing_data()

            F_statistic, pVal = stats.f_oneway(df_failure[df_failure['defect_status'] == 0].loc[:,
                                               ['relative_humidity', 'temperature', 'absolute_humidity']],
                                               df_failure[df_failure['defect_status'] == 1].loc[:,
                                               ['relative_humidity', 'temperature', 'absolute_humidity']])

            results = {
                'F_statistic': F_statistic.tolist(),
                'pVal': pVal.tolist()
            }

            return {"status": "success", "results": results}, 200

        except Exception as e:
            return {"status": "failure", "message": str(e)}, 500


parser = Action.model('공정정보 업로드', {
    'mold_name': fields.String(required=True, description='Mold name'),
    'work_start_time': fields.DateTime(required=True, description='Start time of work'),
    'defect_status': fields.String(required=True, description='Defect status')
})

@Action.route('/upload_manufacturing_data')
class UploadData(Resource):

    @Action.doc(responses={200: 'Success', 500: 'Analysis Failed'})
    @Action.expect(parser)
    def post(self):
        try:
            # Extract data from POST request
            data = request.json

            # Connect to the database
            connection = psycopg2.connect(**db_config)
            cursor = connection.cursor()

            # Query the latest weather data
            weather_query = """
                SELECT temperature, relative_humidity, absolute_humidity 
                FROM weather_data 
                ORDER BY time DESC 
                LIMIT 1;
            """
            cursor.execute(weather_query)
            weather_data = cursor.fetchone()

            # If no weather data is found, return an error message
            if not weather_data:
                return {"status": "failure", "message": "No weather data found"}, 500

            # Extract the latest welding job number
            job_number_query = """
                SELECT welding_job_number 
                FROM Welding_Jobs 
                ORDER BY welding_job_number DESC 
                LIMIT 1;
            """
            cursor.execute(job_number_query)
            latest_job_number = cursor.fetchone()[0] + 1

            # Construct the SQL query
            query = """
                INSERT INTO Welding_Jobs (welding_job_number, mold_name, work_start_time, defect_status, temperature, relative_humidity, absolute_humidity)
                VALUES (%s, %s, %s, %s, %s, %s, %s);
            """

            # Execute the insert query
            cursor.execute(query, (latest_job_number, data['mold_name'], data['work_start_time'], data['defect_status'], weather_data[0], weather_data[1], weather_data[2]))
            connection.commit()

            cursor.close()
            connection.close()

            return {"status": "success", "message": "Data uploaded successfully"}, 200

        except Exception as e:
            return {"status": "failure", "message": str(e)}, 500


if __name__ == "__main__":
    get_manufacturing_data()