import sys
import time
import plotly.express as px
import dash
from dash import dcc
from dash import html
from dash.dependencies import Input, Output


class Logger():
    def __init__(self):
        self.start_time = time.time()
        self.epoch_start_time = self.start_time
        self.losses = []

        self.app = dash.Dash(__name__)

        self.app.layout = html.Div([
            dcc.Graph(id='live-update-graph'),
            dcc.Interval(
                id='interval-component',
                interval=1 * 1000,  # in milliseconds
                n_intervals=0
            )
        ])

        @self.app.callback(Output('live-update-graph', 'figure'),
                           [Input('interval-component', 'n_intervals')])
        def update_graph_live(n):
            # Create the graph with subplots
            fig = px.line(x=list(range(len(self.losses))), y=self.losses, labels={'x': 'Epoch', 'y': 'Loss'})
            return fig

    def print_training_log(self, current_epoch, total_epoch, current_batch, total_batch, losses):
        assert type(losses) == dict
        current_time = time.time()
        epoch_time = current_time - self.epoch_start_time
        total_time = current_time - self.start_time

        estimated_total_time = total_time * total_epoch / (current_epoch + 1)
        remaining_time = estimated_total_time - total_time

        self.epoch_start_time = current_time

        terminal_logging_string = ""
        for loss_name, loss_value in losses.items():
            terminal_logging_string += f"{loss_name} : {loss_value}\n"
            if loss_name == 'loss':
                self.losses.append(loss_value)

        sys.stdout.write(
            f"epoch : {current_epoch}/{total_epoch}\n"
            f"batch : {current_batch}/{total_batch}\n"
            f"estimated time remaining : {remaining_time}\n"
            f"{terminal_logging_string}\n"
        )

    def print_training_history(self):
        self.app.run_server(debug=True)
