import sys
import time
import plotly.express as px
import dash
from dash import dcc
from dash import html
from dash.dependencies import Input, Output


class Logger():
    def __init__(self):
        self.start_time = time.time()
        self.epoch_start_time = self.start_time
        self.losses = []

    def print_training_log(self, current_epoch, total_epoch, current_batch, total_batch, losses):
        assert type(losses) == dict
        current_time = time.time()
        epoch_time = current_time - self.epoch_start_time
        total_time = current_time - self.start_time

        estimated_total_time = total_time * total_epoch / (current_epoch + 1)
        remaining_time = estimated_total_time - total_time

        self.epoch_start_time = current_time

        terminal_logging_string = ""
        for loss_name, loss_value in losses.items():
            terminal_logging_string += f"{loss_name} : {loss_value}\n"
            if loss_name == 'loss':
                self.losses.append(loss_value)

        sys.stdout.write(
            f"epoch : {current_epoch}/{total_epoch}\n"
            f"batch : {current_batch}/{total_batch}\n"
            f"estimated time remaining : {remaining_time}\n"
            f"{terminal_logging_string}\n"
        )
