
File name
Commit message
Commit date
File name
Commit message
Commit date
import sys
import time
import plotly.express as px
import dash
import dash_core_components as dcc
import dash_html_components as html
from dash.dependencies import Input, Output
class Logger():
def __init__(self):
self.start_time = time.time()
self.epoch_start_time = self.start_time
self.losses = []
self.app = dash.Dash(__name__)
self.app.layout = html.Div([
dcc.Graph(id='live-update-graph'),
dcc.Interval(
id='interval-component',
interval=1 * 1000, # in milliseconds
n_intervals=0
)
])
@self.app.callback(Output('live-update-graph', 'figure'),
[Input('interval-component', 'n_intervals')])
def update_graph_live(n):
# Create the graph with subplots
fig = px.line(x=list(range(len(self.losses))), y=self.losses, labels={'x': 'Epoch', 'y': 'Loss'})
return fig
def print_training_log(self, current_epoch, total_epoch, losses):
assert type(losses) == dict
current_time = time.time()
epoch_time = current_time - self.epoch_start_time
total_time = current_time - self.start_time
estimated_total_time = total_time * total_epoch / (current_epoch + 1)
remaining_time = estimated_total_time - total_time
self.epoch_start_time = current_time
terminal_logging_string = ""
for loss_name, loss_value in losses.items():
terminal_logging_string += f"{loss_name} : {loss_value}\n"
if loss_name == 'loss':
self.losses.append(loss_value)
sys.stdout.write(
f"epoch : {current_epoch}/{total_epoch}\n"
f"estimated time remaining : {remaining_time}"
f"{terminal_logging_string}"
)
def print_training_history(self):
self.app.run_server(debug=True)