import argparse import re import os import matplotlib.pyplot as plt import numpy as np from datetime import datetime def extract_learning_curves(args): paths = args.log_path.split(',') if len(paths) == 1 and os.path.isdir(paths[0]): paths = [os.path.join(paths[0], f) for f in os.listdir(paths[0]) if os.path.isfile(os.path.join(paths[0], f))] learning_curves = {} print(paths) for path in paths: print(path) learning_curve = [] lines = open(path, 'r').readlines() last_epoch = -1 stacked_epoch = -1 max_epoch = -1 for line in lines: matched = re.match(r'[0-9\- :,]*\[INFO\] - \[Epoch ([0-9]+)\].*Valid MRR: ([0-9\.]+).*', line) # matched = re.match(r'\tMRR: Tail : [0-9\.]+, Head : [0-9\.]+, Avg : ([0-9\.]+)', line) if matched: this_epoch = int(matched.group(1)) if (this_epoch > max_epoch): learning_curve.append(float(matched.group(2))) max_epoch = this_epoch stacked_epoch = this_epoch elif (this_epoch < max_epoch and this_epoch > last_epoch): last_epoch = this_epoch max_epoch = stacked_epoch + 1 + this_epoch learning_curve.append(float(matched.group(2))) if max_epoch >= args.num_epochs: break # if matched: # max_epoch += 1 # learning_curve.append(float(matched.group(1))) # if max_epoch >= args.num_epochs: # break while len(learning_curve) < args.num_epochs: learning_curve.append(learning_curve[-1]) learning_curves[os.path.basename(path)] = learning_curve return learning_curves def draw_learning_curves(args, learning_curves): for name in learning_curves.keys(): epochs = np.arange(len(learning_curves[name])) matched = re.match(r'(.*)\..*', name) if matched: label = matched.group(1) else: label = name plt.plot(epochs, learning_curves[name], label = label) plt.xlabel("Epochs") plt.ylabel("Best Valid MRR") plt.legend(title=args.legend_title) plt.savefig(os.path.join(args.out_path, str(round(datetime.utcnow().timestamp() * 1000)) + '.' + args.fig_filetype)) if __name__ == '__main__': parser = argparse.ArgumentParser( description="Parser For Arguments", formatter_class=argparse.ArgumentDefaultsHelpFormatter) parser.add_argument('--task', default = None, type=str) parser.add_argument('--log_path', type=str, default=None) parser.add_argument('--out_path', type=str, default=None) parser.add_argument('--num_epochs', type=int, default=200) parser.add_argument('--legend_title', type=str, default="Learning rate") parser.add_argument('--fig_filetype', type=str, default="svg") args = parser.parse_args() if (args.task == 'learning_curve'): draw_learning_curves(args, extract_learning_curves(args))