import sys import time import plotly.express as px import dash from dash import dcc from dash import html from dash.dependencies import Input, Output class Logger(): def __init__(self): self.start_time = time.time() self.epoch_start_time = self.start_time self.losses = [] self.app = dash.Dash(__name__) self.app.layout = html.Div([ dcc.Graph(id='live-update-graph'), dcc.Interval( id='interval-component', interval=1 * 1000, # in milliseconds n_intervals=0 ) ]) @self.app.callback(Output('live-update-graph', 'figure'), [Input('interval-component', 'n_intervals')]) def update_graph_live(n): # Create the graph with subplots fig = px.line(x=list(range(len(self.losses))), y=self.losses, labels={'x': 'Epoch', 'y': 'Loss'}) return fig def print_training_log(self, current_epoch, total_epoch, current_batch, total_batch, losses): assert type(losses) == dict current_time = time.time() epoch_time = current_time - self.epoch_start_time total_time = current_time - self.start_time estimated_total_time = total_time * total_epoch / (current_epoch + 1) remaining_time = estimated_total_time - total_time self.epoch_start_time = current_time terminal_logging_string = "" for loss_name, loss_value in losses.items(): terminal_logging_string += f"{loss_name} : {loss_value}\n" if loss_name == 'loss': self.losses.append(loss_value) sys.stdout.write( f"epoch : {current_epoch}/{total_epoch}\n" f"batch : {current_batch}/{total_batch}\n" f"estimated time remaining : {remaining_time}" f"{terminal_logging_string}" ) def print_training_history(self): self.app.run_server(debug=True)