import psycopg2 # driver 임포트
import json
import bcrypt
from cryptography.hazmat.primitives.ciphers import Cipher, algorithms, modes
from cryptography.hazmat.primitives import padding
from cryptography.hazmat.backends import default_backend
import re
import os
from io import StringIO
from datetime import datetime, timedelta

config_file_path = "database/db_config.json"

class DB():
    def __init__(self):
        # Load the database configuration from the JSON file
        self.db_config = self.load_db_config(config_file_path)

        # Initialize database connection
        self.conn = psycopg2.connect(
            host=self.db_config['host'],
            dbname=self.db_config['dbname'],
            user=self.db_config['user'],
            password=self.db_config['password'],
            port=self.db_config['port'],
            # options=self.db_config['options']
        )
        self.schema = self.db_config["schema"]
        self.conn.autocommit=True
        self.cur = self.conn.cursor()
        # yeah, that double quotation is absolutely needed (to distinguish capital letters)
        self.cur.execute("SET search_path TO " + f'"{self.schema}"')
        with open("database/keys/encryption_key2024-09-05_14:27:02", "rb") as f:
            self.encryption_key = f.read()

    def load_db_config(self, config_file_path):
        """
        Loads database configuration from a JSON file.
        """
        with open(config_file_path, 'r') as config_file:
            return json.load(config_file)

    def encrypt_aes(self, plain_text):
        iv = os.urandom(16)  # AES block size is 16 bytes
        cipher = Cipher(algorithms.AES(self.encryption_key), modes.CBC(iv), backend=default_backend())
        encryptor = cipher.encryptor()

        # Pad the plaintext to be a multiple of 16 bytes
        padder = padding.PKCS7(algorithms.AES.block_size).padder()
        padded_data = padder.update(plain_text.encode('utf-8')) + padder.finalize()

        encrypted_data = encryptor.update(padded_data) + encryptor.finalize()
        return encrypted_data, iv

    def decrypt_aes(self, encrypted_data, iv):
        cipher = Cipher(algorithms.AES(self.encryption_key), modes.CBC(iv), backend=default_backend())
        decryptor = cipher.decryptor()

        decrypted_data = decryptor.update(encrypted_data) + decryptor.finalize()

        # Remove padding after decryption
        unpadder = padding.PKCS7(algorithms.AES.block_size).unpadder()
        unpadded_data = unpadder.update(decrypted_data) + unpadder.finalize()

        return unpadded_data.decode('utf-8')

    def cleanse_and_validate_input(self, data):
        """
        Cleanses input by removing leading/trailing spaces and validates the data.
        Returns cleansed data and an error message if validation fails.
        """
        username = data.get('username', '').strip()
        password = data.get('password', '').strip()
        email = data.get('email', '').strip()
        phone = data.get('phone', '').strip()
        sex = data.get('sex', '').strip()

        # Validate username
        if not username:
            return None, "Username is required."
        if len(username) > 255:
            return None, "Username must not exceed 255 characters."

        # Validate password
        if not password:
            return None, "Password is required."
        if len(password) < 8:
            return None, "Password must be at least 8 characters long."

        # Validate email format
        if not email or not re.fullmatch(r"[^@]+@[^@]+\.[^@]+", email):
            return None, "Invalid email address."

        # Validate phone number format
        if not re.fullmatch(r'010\d{8}', phone):
            return None, "Phone number must be in the format 010XXXXXXXX where X are digits."

        # Validate sex input
        if not sex:
            return None, "Sex is required."
        if sex not in ['Male', 'Female', 'Non-binary', 'Other']:
            return None, "Invalid value for sex."

        return {
            'username': username,
            'password': password,
            'email': email,
            'phone': phone,
            'sex': sex
        }, None

    def register_user(self, data):
        data, error = self.cleanse_and_validate_input(data)
        if error:
            return {'status': 'error', 'message': error}, 400

        username = data['username']
        password = data['password']
        email = data['email']
        phone = data['phone']
        sex = data['sex']

        # Hash the password with bcrypt, which automatically handles the salt
        hashed_pw = bcrypt.hashpw(password.encode('utf-8'), bcrypt.gensalt())

        # Encrypt email, phone, and sex with AES
        encrypted_email, email_iv = self.encrypt_aes(email)
        encrypted_phone, phone_iv = self.encrypt_aes(phone)
        encrypted_sex, sex_iv = self.encrypt_aes(sex)

        # Insert the user into the database
        try:
            self.cur.execute(f"""
                INSERT INTO users (username, user_pw, user_email, email_iv, user_phone, phone_iv, user_sex, user_time_stamp)
                VALUES (%s, %s, %s, %s, %s, %s, %s, %s)
            """, (
                username,
                psycopg2.Binary(hashed_pw),
                psycopg2.Binary(encrypted_email),
                psycopg2.Binary(email_iv),
                psycopg2.Binary(encrypted_phone),
                psycopg2.Binary(phone_iv),
                psycopg2.Binary(encrypted_sex),
                datetime.now()  # Correct way to insert current timestamp with timezone
            )
                             )
            self.conn.commit()
            return {'status': 'success', 'message': f'user {username} registered successfully'}, 200
        except psycopg2.Error as e:
            self.conn.rollback()
            return {'status': 'error', 'message': str(e)}, 400

    def login_user(self, data):
        username = data.get('username', '').strip()
        password = data.get('password', '').strip()

        # Validate input
        if not username or not password:
            return {'status': 'error', 'message': 'Username and password are required.'}, 400

        # Retrieve the user's hashed password from the database
        self.cur.execute("SELECT user_pw FROM users WHERE username = %s", (username,))
        user = self.cur.fetchone()

        if user is None:
            return {'status': 'error', 'message': 'Invalid username or password'}, 401

        hashed_pw = bytes(user[0])  # Convert the retrieved hashed password to bytes

        # Check if the provided password matches the stored hashed password
        if bcrypt.checkpw(password.encode('utf-8'), hashed_pw):
            return {'status': 'success', 'message': 'Logged in successfully'}, 200
        else:
            return {'status': 'error', 'message': 'Invalid username or password'}, 401

    def get_phone_number(self, data):
        username = data.get('username', '').strip()

        if not username:
            return {'status': 'error', 'message': 'Username is required.'}, 400

        # Retrieve the encrypted phone number and IV from the database
        self.cur.execute("SELECT user_phone, phone_iv FROM users WHERE username = %s", (username,))
        user = self.cur.fetchone()

        if user is None:
            return {'status': 'error', 'message': 'User not found'}, 404

        encrypted_phone, phone_iv = user

        # Decrypt the phone number
        decrypted_phone = self.decrypt_aes(encrypted_phone, phone_iv)

        return {'status': 'success', 'phone_number': decrypted_phone}, 200


    def get_email(self, data):
        username = data.get('username', '').strip()

        if not username:
            return {'status': 'error', 'message': 'Username is required.'}, 400

        # Retrieve the encrypted phone number and IV from the database
        self.cur.execute("SELECT user_email, email_iv FROM users WHERE username = %s", (username,))
        user = self.cur.fetchone()

        if user is None:
            return {'status': 'error', 'message': 'User not found'}, 404

        encrypted_phone, phone_iv = user

        # Decrypt the phone number
        decrypted_phone = self.decrypt_aes(encrypted_phone, phone_iv)

        return {'status': 'success', 'phone_number': decrypted_phone}, 200

    def insert_gps_data(self, csv_block, columns):
        cur = self.conn.cursor()
        data = StringIO(csv_block)

        # using COPY instead of INSERT to do even less operation per data.
        cur.copy_from(data, 'gps_data', sep=',', columns=columns)
        self.conn.commit()
        cur.close()
        return True

    def insert_trip_data(
            self,
            username,
            trip_id,
            total_distance_m,
            total_time_s,
            abrupt_start_count,
            abrupt_stop_count,
            abrupt_acceleration_count,
            abrupt_deceleration_count,
            helmet_on,
            final_score
    ):

        self.cur.execute(f"""
            INSERT INTO trip_log (username, trip_id, timestamp, total_distance_m, total_time_s, abrupt_start_count, abrupt_stop_count,
             abrupt_acceleration_count, abrupt_deceleration_count, helmet_on, final_score)
            VALUES (%s, %s, %s, %s, %s, %s, %s, %s, %s, %s, %s)
        """, (
            username,
            trip_id,
            datetime.now(),
            total_distance_m,
            total_time_s,
            abrupt_start_count,
            abrupt_stop_count,
            abrupt_acceleration_count,
            abrupt_deceleration_count,
            helmet_on,
            final_score
            )
        )

    def get_history(self, user_name):
        """
        Retrieves all trip logs for the specified user within the last month and returns them in JSON format.
            [
              {
                "trip_id": "trip_001",
                "timestamp": "2024-09-01 12:45:00",
                "total_distance_m": 1000.5,
                "total_time_s": 600,
                "abrupt_start_count": 3,
                "abrupt_stop_count": 2,
                "abrupt_acceleration_count": 1,
                "abrupt_deceleration_count": 1,
                "helmet_on": true,
                "final_score": 85.5
              },
              {
                "trip_id": "trip_002",
                "timestamp": "2024-09-02 14:30:00",
                "total_distance_m": 1500.0,
                "total_time_s": 720,
                "abrupt_start_count": 2,
                "abrupt_stop_count": 3,
                "abrupt_acceleration_count": 1,
                "abrupt_deceleration_count": 2,
                "helmet_on": false,
                "final_score": 90.0
              }
            ]
        """
        try:
            # Execute the query to retrieve logs within the last month
            self.cur.execute("""
                SELECT trip_id, timestamp, total_distance_m, total_time_s, abrupt_start_count, 
                       abrupt_stop_count, abrupt_acceleration_count, abrupt_deceleration_count, 
                       helmet_on, final_score
                FROM trip_log
                WHERE username = %s
                AND timestamp >= NOW() - INTERVAL '1 month'
            """, (user_name,))

            # Fetch all results
            rows = self.cur.fetchall()

            # Format the results into a list of dictionaries
            result = []
            for row in rows:
                trip_log = {
                    "trip_id": row[0],
                    "timestamp": row[1].strftime("%Y-%m-%d %H:%M:%S"),  # Format timestamp
                    "total_distance_m": row[2],
                    "total_time_s": row[3],
                    "abrupt_start_count": row[4],
                    "abrupt_stop_count": row[5],
                    "abrupt_acceleration_count": row[6],
                    "abrupt_deceleration_count": row[7],
                    "helmet_on": bool(row[8]),  # Convert INT to Boolean
                    "final_score": row[9]
                }
                result.append(trip_log)

            # Convert the result list to JSON format
            return json.dumps(result), 200

        except psycopg2.Error as e:
            self.conn.rollback()
            return {'status': 'error', 'message': str(e)}, 500




    def close_connection(self):
        cur = self.cur
        cur.close()
        return True


    
