Velvet Star Monitor

Standout celebrity highlights with iconic style.

news

How to extract loss and accuracy from logger by each epoch in pytorch lightning?

Writer Matthew Barrera

I want to extract all data to make the plot, not with tensorboard. My understanding is all log with loss and accuracy is stored in a defined directory since tensorboard draw the line graph.

%reload_ext tensorboard
%tensorboard --logdir lightning_logs/

enter image description here

However, I wonder how all log can be extracted from the logger in pytorch lightning. The next is the code example in training part.

#model
ssl_classifier = SSLImageClassifier(lr=lr)
#train
logger = pl.loggers.TensorBoardLogger(name=f'ssl-{lr}-{num_epoch}', save_dir='lightning_logs')
trainer = pl.Trainer(progress_bar_refresh_rate=20, gpus=1, max_epochs = max_epoch, logger = logger, )
trainer.fit(ssl_classifier, train_loader, val_loader)

I had confirmed that trainer.logger.log_dir returned directory which seems to save logs and trainer.logger.log_metrics returned <bound method TensorBoardLogger.log_metrics of <pytorch_lightning.loggers.tensorboard.TensorBoardLogger object at 0x7efcb89a3e50>>.

trainer.logged_metrics returned only the log in the final epoch, like

{'epoch': 19, 'train_acc': tensor(1.), 'train_loss': tensor(0.1038), 'val_acc': 0.6499999761581421, 'val_loss': 1.2171183824539185}

Do you know how to solve the situation?

2 Answers

The accepted answer is not fundamentally wrong but does not follow the official (current) guidelines by Pytorch-Lightning.

As suggested here:

It is suggested to write a class like:

from pytorch_lightning.utilities import rank_zero_only
from pytorch_lightning.loggers import LightningLoggerBase
from pytorch_lightning.loggers.base import rank_zero_experiment
class MyLogger(LightningLoggerBase): @property def name(self): return "MyLogger" @property @rank_zero_experiment def experiment(self): # Return the experiment object associated with this logger. pass @property def version(self): # Return the experiment version, int or str. return "0.1" @rank_zero_only def log_hyperparams(self, params): # params is an argparse.Namespace # your code to record hyperparameters goes here pass @rank_zero_only def log_metrics(self, metrics, step): # metrics is a dictionary of metric names and values # your code to record metrics goes here pass @rank_zero_only def save(self): # Optional. Any code necessary to save logger data goes here # If you implement this, remember to call `super().save()` # at the start of the method (important for aggregation of metrics) super().save() @rank_zero_only def finalize(self, status): # Optional. Any code that needs to be run after training # finishes goes here pass

By looking inside the class LightningLoggerBase, one can see some suggestions of function that could be overriden.

Here is a minimalistic loggers of mine. It is highly not optimised, but would be a good first shot. I will edit if I improved it.

import collections
from pytorch_lightning.loggers import LightningLoggerBase
from pytorch_lightning.loggers.base import rank_zero_experiment
from pytorch_lightning.utilities import rank_zero_only
class History_dict(LightningLoggerBase): def __init__(self): super().__init__() self.history = collections.defaultdict(list) # copy not necessary here # The defaultdict in contrast will simply create any items that you try to access @property def name(self): return "Logger_custom_plot" @property def version(self): return "1.0" @property @rank_zero_experiment def experiment(self): # Return the experiment object associated with this logger. pass
@rank_zero_only
def log_metrics(self, metrics, step): # metrics is a dictionary of metric names and values # your code to record metrics goes here for metric_name, metric_value in metrics.items(): if metric_name != 'epoch': self.history[metric_name].append(metric_value) else: # case epoch. We want to avoid adding multiple times the same. It happens for multiple losses. if (not len(self.history['epoch']) or # len == 0: not self.history['epoch'][-1] == metric_value) : # the last values of epochs is not the one we are currently trying to add. self.history['epoch'].append(metric_value) else: pass return def log_hyperparams(self, params): pass

Lightning do not store all logs by itself. All it does is streams them into the logger instance and the logger decides what to do.

The best way to retrieve all logged metrics is by having a custom callback:

class MetricTracker(Callback): def __init__(self): self.collection = [] def on_validation_batch_end(trainer, module, outputs, ...): vacc = outputs['val_acc'] # you can access them here self.collection.append(vacc) # track them def on_validation_epoch_end(trainer, module): elogs = trainer.logged_metrics # access it here self.collection.append(elogs) # do whatever is needed

You can then access all logged stuff from the callback instance

cb = MatricTracker()
Trainer(callbacks=[cb])
cb.collection # do you plotting and stuff

Your Answer

Sign up or log in

Sign up using Google Sign up using Facebook Sign up using Email and Password

Post as a guest

By clicking “Post Your Answer”, you agree to our terms of service and acknowledge that you have read and understand our privacy policy and code of conduct.