Skip to content

Commit 65433e9

Browse files
Plots results (#37)
1 parent 2d48629 commit 65433e9

12 files changed

+253
-20
lines changed

environment.yml

+1
Original file line numberDiff line numberDiff line change
@@ -11,6 +11,7 @@ dependencies:
1111
- cudatoolkit =10.1
1212
- plotly >=4.8
1313
- plotly-orca >=1.3
14+
- matplotlib
1415
- numpy
1516
- pandas
1617
- pyyaml

mcp/argparser.py

+4
Original file line numberDiff line numberDiff line change
@@ -32,6 +32,10 @@ def parse_eval_arguments():
3232
return parser.parse_args()
3333

3434

35+
# For now both are the same
36+
parse_viz_arguments = parse_eval_arguments
37+
38+
3539
def _default_parser():
3640
parser = argparse.ArgumentParser()
3741
parser.add_argument(

mcp/context/base.py

+9-2
Original file line numberDiff line numberDiff line change
@@ -40,6 +40,7 @@
4040
from mcp.task.supervised import SupervisedTask
4141
from mcp.training.loop import TrainingLoop
4242
from mcp.training.trainer import Trainer, TrainerLoggers
43+
from mcp.viz.base import Vizualization
4344

4445
ValidFewShotDataLoaderFactory = NewType(
4546
"ValidFewShotDataLoaderFactory", FewShotDataLoaderFactory
@@ -96,10 +97,10 @@ def provide_evaluation_loggers(self) -> EvaluationLoggers:
9697

9798
return EvaluationLoggers(
9899
support=ResultLogger(
99-
"Evaluation Support", os.path.join(output_dir, "support")
100+
"Evaluation Support", os.path.join(output_dir, "test-support")
100101
),
101102
evaluation=ResultLogger(
102-
"Evaluation Support", os.path.join(output_dir, "eval")
103+
"Evaluation Support", os.path.join(output_dir, "test-eval")
103104
),
104105
)
105106

@@ -136,6 +137,12 @@ def provide_evaluation(
136137
def provide_experiment_result(self) -> ExperimentResult:
137138
return ExperimentResult(self.config, self.output_dir)
138139

140+
@provider
141+
@inject
142+
@singleton
143+
def provide_vizualization(self, result: ExperimentResult) -> Vizualization:
144+
return Vizualization(result)
145+
139146

140147
class TrainerModule(Module):
141148
def __init__(self, config: ExperimentConfig, output_dir: str, device: torch.device):

mcp/main.py

+9
Original file line numberDiff line numberDiff line change
@@ -7,6 +7,7 @@
77
from mcp.evaluation import Evaluation
88
from mcp.result.experiment import ExperimentResult
99
from mcp.training.trainer import Trainer
10+
from mcp.viz.base import Vizualization
1011

1112

1213
def run_train(
@@ -33,3 +34,11 @@ def run_eval(config: ExperimentConfig, result_dir: str, device_str: str):
3334
evaluation = injector.get(Evaluation)
3435
result = injector.get(ExperimentResult)
3536
evaluation.eval(result.best_epoch())
37+
38+
39+
def run_viz(config: ExperimentConfig, result_dir: str, device_str: str):
40+
device = torch.device(device_str)
41+
injector = create_injector(config, result_dir, device)
42+
43+
viz = injector.get(Vizualization)
44+
viz.plot()

mcp/result/experiment.py

+81-18
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,6 @@
11
import os
2-
from typing import List
2+
import sys
3+
from typing import Callable, List, Optional
34

45
import numpy as np
56

@@ -10,27 +11,56 @@
1011
logger = create_logger(__name__)
1112

1213

14+
class EpochResult(object):
15+
def __init__(self, file_name: str):
16+
self.file_name = file_name
17+
18+
def load(self) -> List[List[ResultRecord]]:
19+
return load_records_from_file(self.file_name)
20+
21+
@staticmethod
22+
def losses(records: List[List[ResultRecord]]) -> List[List[float]]:
23+
return [[r.loss for r in rec] for rec in records]
24+
25+
@staticmethod
26+
def metric(records: List[List[ResultRecord]]) -> List[List[float]]:
27+
return [[r.metric for r in rec] for rec in records]
28+
29+
@staticmethod
30+
def task_name(records: List[List[ResultRecord]]) -> List[str]:
31+
return [r.name for r in records[0]]
32+
33+
@staticmethod
34+
def metric_name(records: List[List[ResultRecord]]) -> List[str]:
35+
return [r.metric_name for r in records[0]]
36+
37+
@staticmethod
38+
def reduce(
39+
values: List[List[float]],
40+
reduce_task: Optional[Callable] = np.mean,
41+
reduce_iter: Optional[Callable] = np.mean,
42+
) -> np.ndarray:
43+
if reduce_task is None and reduce_iter is None:
44+
raise ValueError("Must reduce on something")
45+
46+
if reduce_task is not None:
47+
values = [reduce_task(np.asarray(vv), axis=-1) for vv in values]
48+
49+
if reduce_iter is not None:
50+
values = reduce_iter(np.asarray(values), axis=0)
51+
52+
return values
53+
54+
1355
class ExperimentResult(object):
1456
def __init__(self, config: ExperimentConfig, output_dir: str):
1557
self.config = config
1658
self.output_dir = output_dir
17-
self._records_dir = os.path.join(self.output_dir, "train")
59+
self._records_dir_train = os.path.join(self.output_dir, "train")
60+
self._records_dir_eval = os.path.join(self.output_dir, "evaluation")
1861

1962
def best_epoch(self) -> int:
20-
losses = []
21-
for epoch in range(1, self.config.trainer.epochs + 1):
22-
try:
23-
file_name = os.path.join(self._records_dir, f"eval-{epoch}")
24-
records_valid = load_records_from_file(file_name)
25-
loss = np.asarray(
26-
[self._records_loss(rs) for rs in records_valid]
27-
).mean()
28-
losses.append(loss)
29-
except FileNotFoundError:
30-
logger.warning(
31-
f"Training did not complete {epoch-1}/{self.config.trainer.epochs}"
32-
)
33-
break
63+
losses = self.metric("train", EpochResult.losses)
3464

3565
indexes = np.argsort(np.asarray(losses))
3666
index = indexes[0]
@@ -40,5 +70,38 @@ def best_epoch(self) -> int:
4070
logger.info(f"Found the best epoch to be {epoch} with valid loss {valid_loss}")
4171
return epoch
4272

43-
def _records_loss(self, records: List[ResultRecord]) -> float:
44-
return np.asarray([r.loss for r in records]).mean()
73+
def records(self, tag: str, train: bool = True) -> List[EpochResult]:
74+
records_dir = self._records_dir_train if train else self._records_dir_eval
75+
76+
results: List[EpochResult] = []
77+
for epoch in range(1, sys.maxsize):
78+
file_name = os.path.join(records_dir, f"{tag}-{epoch}")
79+
if not os.path.exists(file_name):
80+
break
81+
82+
results.append(EpochResult(file_name))
83+
84+
return results
85+
86+
def task_names(self, tag: str, train: bool = True) -> List[str]:
87+
e_records = self.records(tag, train=train)[0]
88+
return EpochResult.task_name(e_records.load())
89+
90+
def metric_names(self, tag: str, train: bool = True) -> List[str]:
91+
e_records = self.records(tag, train)[0]
92+
return EpochResult.metric_name(e_records.load())
93+
94+
def metric(
95+
self, tag: str, metric, reduce_task=np.mean, reduce_iter=np.mean, train=True
96+
) -> np.ndarray:
97+
e_records = self.records(tag, train=train)
98+
return np.asarray(
99+
[
100+
EpochResult.reduce(
101+
metric(records.load()),
102+
reduce_task=reduce_task,
103+
reduce_iter=reduce_iter,
104+
)
105+
for records in e_records
106+
]
107+
)

mcp/result/logger.py

+4
Original file line numberDiff line numberDiff line change
@@ -33,6 +33,10 @@ def load_records_from_file(file_path: str) -> List[List[ResultRecord]]:
3333

3434

3535
def load_records(line: str) -> List[ResultRecord]:
36+
"""Load all records for an iteration.
37+
38+
The number of records is determined by the number of tasks.
39+
"""
3640
objs = json.loads(line)
3741
return [
3842
ResultRecord(

mcp/viz/__init__.py

Whitespace-only changes.

mcp/viz/base.py

+17
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,17 @@
1+
import os
2+
3+
from mcp.result.experiment import ExperimentResult
4+
from mcp.viz.loss import plot_loss
5+
from mcp.viz.metric import plot_metric
6+
7+
8+
class Vizualization(object):
9+
def __init__(self, results: ExperimentResult):
10+
self.results = results
11+
12+
def plot(self):
13+
output_dir = os.path.join(self.results.output_dir, "viz")
14+
os.makedirs(output_dir, exist_ok=True)
15+
16+
plot_loss(output_dir, self.results)
17+
plot_metric(output_dir, self.results)

mcp/viz/line_plot.py

+43
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,43 @@
1+
from typing import List
2+
3+
import numpy as np
4+
from matplotlib import pyplot as plt
5+
from matplotlib.ticker import MaxNLocator
6+
7+
8+
def line_plot(
9+
task_names_train: List[str],
10+
task_names_eval: List[str],
11+
values_train: np.ndarray,
12+
values_eval: np.ndarray,
13+
y_label: str,
14+
x_label: str = "Epoch",
15+
bbox_to_anchor=(0.90, 0.88),
16+
y_int: bool = False,
17+
x_int: bool = True,
18+
) -> plt.Figure:
19+
20+
fig = plt.figure()
21+
ax = fig.subplots()
22+
for i, name in enumerate(task_names_train):
23+
x = list(range(len(values_train)))
24+
ax.plot(
25+
x, values_train[:, i], label=f"Train - {name}",
26+
)
27+
28+
for i, name in enumerate(task_names_eval):
29+
x = list(range(len(values_eval)))
30+
ax.plot(
31+
x, values_eval[:, i], label=f"Valid - {name}", linestyle=":",
32+
)
33+
34+
ax.set_xlabel(x_label)
35+
ax.set_ylabel(y_label)
36+
37+
if x_int:
38+
ax.xaxis.set_major_locator(MaxNLocator(integer=True))
39+
if y_int:
40+
ax.yaxis.set_major_locator(MaxNLocator(integer=True))
41+
42+
fig.legend(bbox_to_anchor=bbox_to_anchor)
43+
return fig

mcp/viz/loss.py

+19
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,19 @@
1+
import os
2+
3+
from mcp.result.experiment import EpochResult, ExperimentResult
4+
from mcp.viz.line_plot import line_plot
5+
6+
7+
def plot_loss(output_dir: str, results: ExperimentResult):
8+
losses_train = results.metric("train", EpochResult.losses, reduce_task=None)
9+
task_names_train = results.task_names("train")
10+
11+
losses_eval = results.metric("eval", EpochResult.losses, reduce_task=None)
12+
task_names_eval = results.task_names("eval")
13+
14+
fig = line_plot(
15+
task_names_train, task_names_eval, losses_train, losses_eval, "Loss"
16+
)
17+
18+
file_name = os.path.join(output_dir, "losses.png")
19+
fig.savefig(file_name)

mcp/viz/metric.py

+44
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,44 @@
1+
import os
2+
3+
import numpy as np
4+
5+
from mcp.result.experiment import EpochResult, ExperimentResult
6+
from mcp.utils import logging
7+
from mcp.viz.line_plot import line_plot
8+
9+
logger = logging.create_logger(__name__)
10+
11+
12+
def plot_metric(output_dir: str, results: ExperimentResult):
13+
metric_train = results.metric("train", EpochResult.metric, reduce_task=None)
14+
task_names_train = results.task_names("train")
15+
16+
metric_eval = results.metric("eval", EpochResult.metric, reduce_task=None)
17+
task_names_eval = results.task_names("eval")
18+
19+
metric_names_train = results.metric_names("train")
20+
metric_names_eval = results.metric_names("eval")
21+
22+
for i, (task, metric) in enumerate(zip(task_names_train, metric_names_train)):
23+
fig = line_plot([task], [], metric_train[:, i : i + 1], np.array([]), metric)
24+
file_name = os.path.join(output_dir, f"metric-{task}-{metric}-train.png")
25+
fig.savefig(file_name)
26+
27+
for i, (task, metric) in enumerate(zip(task_names_eval, metric_names_eval)):
28+
fig = line_plot([], [task], np.array([]), metric_eval[:, i : i + 1], metric)
29+
file_name = os.path.join(output_dir, f"metric-{task}-{metric}-eval.png")
30+
fig.savefig(file_name)
31+
32+
metric_test_eval = results.metric("eval", EpochResult.metric, train=False)
33+
34+
if len(metric_test_eval) > 0:
35+
metric_name_test_eval = results.metric_names("eval", train=False)[0]
36+
task_name_test_eval = results.task_names("eval", train=False)[0]
37+
38+
values = np.asarray(metric_test_eval)
39+
mean = np.mean(values)
40+
std = np.std(values, ddof=1)
41+
42+
logger.info(
43+
f"Test {task_name_test_eval}: {mean} +- {std} {metric_name_test_eval}"
44+
)

scripts/viz.py

+22
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,22 @@
1+
#!/usr/bin/env python
2+
import os
3+
4+
from mcp.argparser import initialize_logging, parse_viz_arguments
5+
6+
7+
def run(args):
8+
from mcp import main
9+
from mcp.config.loader import load
10+
from mcp.config.parser import parse
11+
12+
config_path = os.path.join(args.result, "config_full.yml")
13+
configs = load(config_path)
14+
config_experiment = parse([configs])
15+
16+
main.run_viz(config_experiment, args.result, args.device) # type: ignore
17+
18+
19+
if __name__ == "__main__":
20+
args = parse_viz_arguments()
21+
initialize_logging(args.logging, args.result, args.debug)
22+
run(args)

0 commit comments

Comments
 (0)