2023-06-08 06:40:16 +00:00
|
|
|
import argparse
|
|
|
|
import re
|
|
|
|
import os
|
|
|
|
import matplotlib.pyplot as plt
|
|
|
|
import numpy as np
|
|
|
|
from datetime import datetime
|
|
|
|
|
|
|
|
def extract_learning_curves(args):
|
|
|
|
paths = args.log_path.split(',')
|
|
|
|
if len(paths) == 1 and os.path.isdir(paths[0]):
|
|
|
|
paths = [os.path.join(paths[0], f) for f in os.listdir(paths[0]) if os.path.isfile(os.path.join(paths[0], f))]
|
|
|
|
learning_curves = {}
|
|
|
|
print(paths)
|
|
|
|
for path in paths:
|
|
|
|
print(path)
|
|
|
|
learning_curve = []
|
|
|
|
lines = open(path, 'r').readlines()
|
2023-06-24 04:11:17 +00:00
|
|
|
last_epoch = -1
|
|
|
|
stacked_epoch = -1
|
2023-06-08 06:40:16 +00:00
|
|
|
max_epoch = -1
|
|
|
|
for line in lines:
|
|
|
|
matched = re.match(r'[0-9\- :,]*\[INFO\] - \[Epoch ([0-9]+)\].*Valid MRR: ([0-9\.]+).*', line)
|
2023-06-24 04:11:17 +00:00
|
|
|
# matched = re.match(r'\tMRR: Tail : [0-9\.]+, Head : [0-9\.]+, Avg : ([0-9\.]+)', line)
|
2023-06-08 06:40:16 +00:00
|
|
|
if matched:
|
|
|
|
this_epoch = int(matched.group(1))
|
|
|
|
if (this_epoch > max_epoch):
|
|
|
|
learning_curve.append(float(matched.group(2)))
|
|
|
|
max_epoch = this_epoch
|
2023-06-24 04:11:17 +00:00
|
|
|
stacked_epoch = this_epoch
|
|
|
|
elif (this_epoch < max_epoch and this_epoch > last_epoch):
|
|
|
|
last_epoch = this_epoch
|
|
|
|
max_epoch = stacked_epoch + 1 + this_epoch
|
|
|
|
learning_curve.append(float(matched.group(2)))
|
2023-06-08 06:40:16 +00:00
|
|
|
if max_epoch >= args.num_epochs:
|
|
|
|
break
|
2023-06-24 04:11:17 +00:00
|
|
|
# if matched:
|
|
|
|
# max_epoch += 1
|
|
|
|
# learning_curve.append(float(matched.group(1)))
|
|
|
|
# if max_epoch >= args.num_epochs:
|
|
|
|
# break
|
2023-06-08 06:40:16 +00:00
|
|
|
while len(learning_curve) < args.num_epochs:
|
|
|
|
learning_curve.append(learning_curve[-1])
|
|
|
|
learning_curves[os.path.basename(path)] = learning_curve
|
|
|
|
return learning_curves
|
|
|
|
|
|
|
|
def draw_learning_curves(args, learning_curves):
|
|
|
|
for name in learning_curves.keys():
|
|
|
|
epochs = np.arange(len(learning_curves[name]))
|
|
|
|
matched = re.match(r'(.*)\..*', name)
|
|
|
|
if matched:
|
|
|
|
label = matched.group(1)
|
|
|
|
else:
|
|
|
|
label = name
|
|
|
|
plt.plot(epochs, learning_curves[name], label = label)
|
|
|
|
plt.xlabel("Epochs")
|
2023-06-24 04:11:17 +00:00
|
|
|
plt.ylabel("Best Valid MRR")
|
2023-06-08 06:40:16 +00:00
|
|
|
plt.legend(title=args.legend_title)
|
|
|
|
plt.savefig(os.path.join(args.out_path, str(round(datetime.utcnow().timestamp() * 1000)) + '.' + args.fig_filetype))
|
|
|
|
|
|
|
|
|
|
|
|
if __name__ == '__main__':
|
|
|
|
parser = argparse.ArgumentParser(
|
|
|
|
description="Parser For Arguments", formatter_class=argparse.ArgumentDefaultsHelpFormatter)
|
|
|
|
parser.add_argument('--task', default = None, type=str)
|
|
|
|
parser.add_argument('--log_path', type=str, default=None)
|
|
|
|
parser.add_argument('--out_path', type=str, default=None)
|
|
|
|
parser.add_argument('--num_epochs', type=int, default=200)
|
|
|
|
parser.add_argument('--legend_title', type=str, default="Learning rate")
|
|
|
|
parser.add_argument('--fig_filetype', type=str, default="svg")
|
|
|
|
|
|
|
|
args = parser.parse_args()
|
|
|
|
|
|
|
|
if (args.task == 'learning_curve'):
|
|
|
|
draw_learning_curves(args, extract_learning_curves(args))
|