from scipy.spatial.transform import Rotation
import numpy as np
import pandas as pd
from os import listdir
from os.path import isfile, join
from ahrs.filters.aqua import AQUA
from ahrs.filters.madgwick import Madgwick
from plotly import express as px
import time

#I hate to concat like this
def reshape_columnvec(vec):
    length = len(vec)
    return np.reshape(vec, (length, 1))

# [w, x, y, z] or [w, i, j, k]
def quantarion_to_pitch_roll_yaw(quantarion):
    # y-axis rotation
    pitch = np.arcsin(2*(quantarion[:,0]*quantarion[:,2]-quantarion[:,1]*quantarion[:,3]))
    # x-axis rotation
    roll = np.arctan2(
        2*(quantarion[:,0]*quantarion[:,1]+quantarion[:,2]*quantarion[:,3]),
        1-2*(np.power(quantarion[:,1],2)+np.power(quantarion[:,2],2))
    )
    # z-axis rotation
    yaw = np.arctan2(
        2*(quantarion[:,0]*quantarion[:,3]+quantarion[:,1]*quantarion[:,2]),
        1-2*(np.power(quantarion[:,2],2)+np.power(quantarion[:,3],2))
    )

    return np.concatenate([reshape_columnvec(roll), reshape_columnvec(pitch), reshape_columnvec(yaw)],axis=1)

def quantarion_multiplication(q1, q2):
    scalar = reshape_columnvec(q1[:, 0] * q2[:, 0] - np.einsum('ij,ij->i', q1[:, 1:3], q2[:, 1:3]))
    vector = np.multiply(q1[:,0], q2[:,1:].T).T + np.multiply(q2[:,0], q1[:,1:].T).T + np.cross(q1[:, 1:], q2[:, 1:])
    return np.concatenate([scalar, vector], axis=1)

def quantarion_rotation(vec, q):
    q_conj = np.concatenate([reshape_columnvec(q[:,0]), -q[:,1:]], axis=1)
    return quantarion_multiplication(quantarion_multiplication(q, vec), q_conj)

def read_axis_sensor(file_path, rolling_value=1):
    df = pd.read_csv(file_path)
    df_x = df.loc[:, 'x'].rolling(rolling_value).mean().values[rolling_value:]
    df_y = df.loc[:, 'y'].rolling(rolling_value).mean().values[rolling_value:]
    df_z = df.loc[:, 'z'].rolling(rolling_value).mean().values[rolling_value:]
    df = np.concatenate((reshape_columnvec(df_x), reshape_columnvec(df_y), reshape_columnvec(df_z)), axis=1)

    return df

dirs = [f for f in listdir('data') if not isfile(join('data', f))]
rolling_value = 1
for path in dirs:
    now = time.time()
    # Gyro and Acc must be synced
    # time,seconds_elapsed,z,y,x
    df_gy = pd.read_csv(f"data/{path}/Gyroscope.csv")
    gy_x = df_gy.loc[:, 'x'].rolling(rolling_value).mean().values[rolling_value:]
    gy_y = df_gy.loc[:, 'y'].rolling(rolling_value).mean().values[rolling_value:]
    gy_z = df_gy.loc[:, 'z'].rolling(rolling_value).mean().values[rolling_value:]
    gy = np.concatenate((reshape_columnvec(gy_x),reshape_columnvec(gy_y),reshape_columnvec(gy_z)),axis=1)
    df_ac = pd.read_csv(f"data/{path}/Accelerometer.csv")
    ac_x = df_ac.loc[:, 'x'].rolling(rolling_value).mean().values[rolling_value:]
    ac_y = df_ac.loc[:, 'y'].rolling(rolling_value).mean().values[rolling_value:]
    ac_z = df_ac.loc[:, 'z'].rolling(rolling_value).mean().values[rolling_value:]
    ac = np.concatenate((reshape_columnvec(ac_x), reshape_columnvec(ac_y), reshape_columnvec(ac_z)),axis=1)
    df_mg = pd.read_csv(f"data/{path}/Magnetometer.csv")
    mg_x = df_mg.loc[:, 'x'].rolling(rolling_value).mean().values[rolling_value:]
    mg_y = df_mg.loc[:, 'y'].rolling(rolling_value).mean().values[rolling_value:]
    mg_z = df_mg.loc[:, 'z'].rolling(rolling_value).mean().values[rolling_value:]
    mg = np.concatenate((reshape_columnvec(mg_x), reshape_columnvec(mg_y), reshape_columnvec(mg_z)),axis=1)

    df_grv = pd.read_csv(f"data/{path}/Gravity.csv")
    empty_column = np.zeros(len(df_grv)-rolling_value)
    grv_x = df_grv.loc[:, 'x'].rolling(rolling_value).mean().values[rolling_value:]
    grv_y = df_grv.loc[:, 'y'].rolling(rolling_value).mean().values[rolling_value:]
    grv_z = df_grv.loc[:, 'z'].rolling(rolling_value).mean().values[rolling_value:]
    # we get unit quantarion
    grv = np.concatenate(
        (reshape_columnvec(empty_column), reshape_columnvec(grv_x), reshape_columnvec(grv_y), reshape_columnvec(grv_z)),
        axis=1)
    # Earth Gravitation
    down_direction = grv / 9.80665




    amplitude = []
    aqua = AQUA(gy,ac,mg, q0=down_direction[0])
    quarterion = aqua.Q

    # this is the part where we use gravitayional info to find where the phone is looking at
    quarterion = quantarion_rotation(quarterion, down_direction)
    rotation = quantarion_to_pitch_roll_yaw(quarterion)
    rotation = pd.DataFrame(rotation, columns=['x','y','z'])
    fig = px.bar(rotation)
    fig.show()

    fqa = Madgwick(gy, ac, mg, q0=down_direction[0])
    quarterion = quantarion_rotation(quarterion, down_direction)
    quarterion = quantarion_multiplication(quarterion, -down_direction)
    rotation = quantarion_to_pitch_roll_yaw(quarterion)
    rotation = pd.DataFrame(rotation, columns=['x','y','z'])
    fig = px.bar(rotation)
    fig.show()

    df_ori = pd.read_csv(f"data/{path}/Orientation.csv")
    ori_x = df_ori.loc[:, 'roll'].rolling(rolling_value).mean().values[rolling_value:]
    ori_y = df_ori.loc[:, 'pitch'].rolling(rolling_value).mean().values[rolling_value:]
    ori_z = df_ori.loc[:, 'yaw'].rolling(rolling_value).mean().values[rolling_value:]
    ori = np.concatenate((reshape_columnvec(ori_x),reshape_columnvec(ori_y),reshape_columnvec(ori_z)),axis=1)
    fig = px.bar(pd.DataFrame(ori))
    fig.show()
    elpt = time.time() - now
    print(elpt)
    print(len(df_grv)/elpt)