import numpy as np
from scipy.interpolate import CubicSpline
from datetime import datetime, timedelta
import pandas as pd
import matplotlib.pyplot as plt


def interpolate_value(df, timestamp, col_key, k=12, show_graph=False):
    """
    :param df: DataFrame with 'timestamp' and a specified column key.
    :param timestamp: String of timestamp in format '%Y%m%d%H%M' to interpolate value for.
    :param col_key: The key/column name in df for which value needs to be interpolated.
    :param k: Number of hours before and after the target timestamp to be considered for interpolation.
    :return: Interpolated value for the given timestamp.
    """
    # Convert timestamp to datetime object
    if type(timestamp) is str:
        target_time = datetime.strptime(timestamp, '%Y%m%d%H%M')
    else:
        target_time = timestamp

    # Get nearest left 'o clock' time
    left_o_clock = target_time.replace(minute=0)

    # Prepare a list of relevant times for interpolation: k before and k after
    relevant_times = [left_o_clock + timedelta(hours=i) for i in range(-k, k + 1)]

    # Convert them to string format for matching with the DataFrame
    relevant_times_str = [dt.strftime('%Y%m%d%H%M') for dt in relevant_times]
    relevant_times_str = np.array(relevant_times_str, dtype=int)
    # Extract relevant rows from DataFrame
    relevant_df = df[df['관측시각'].isin(relevant_times_str)].sort_values(by='관측시각')

    # Convert datetime to numerical format: -k to k
    x = [i for i in range(-k, k+1)]
    y = relevant_df[col_key]

    # Find the x value for the target timestamp
    x_target = (target_time - left_o_clock).total_seconds() / 3600

    # Create a cubic spline interpolation function
    cs = CubicSpline(x, y)

    if show_graph:
        # For visualization
        x_dense = np.linspace(min(x), max(x), 400)  # Densely sampled x values
        y_dense = cs(x_dense)  # Interpolated values

        fig, ax = plt.subplots(figsize=(6.5, 4))
        ax.plot(x, y, 'o', label='Data')
        ax.plot(x_dense, y_dense, label='Cubic Spline Interpolation')
        ax.axvline(x_target, color='red', linestyle='--', label='Target Timestamp')  # Marking the target timestamp
        ax.legend()
        plt.show()

    # Return the interpolated value at x_target
    return cs(x_target)


if __name__ == '__main__':
    df = pd.read_csv("/home/juni/PycharmProjects/failure_analysis/data/weather/202007010000_202308310000_f.csv",
                     encoding='utf-8')
    timestamp = '202307271015'
    print(interpolate_value(df, timestamp, '기온', show_graph=True))