Pantograph/experiments/dsp/plot.py

144 lines
3.9 KiB
Python
Raw Normal View History

2024-10-08 19:20:57 -07:00
import json
from pathlib import Path
from matplotlib import pyplot
from typing import Callable, Optional
from tqdm import tqdm
import seaborn
import pandas
from termcolor import colored
from solve.data import (
Datum,
SamplingParams,
SketchParseFailure,
SearchFailure,
DatumResult,
)
experiment_dir = Path(__file__).resolve().parent
def read_data(name: str, result_path: Path):
assert result_path.is_dir(), f"Result path is not a directory: {result_path}"
objs = [
DatumResult.parse(json.load(open(file_name, "r")))
for file_name in tqdm(result_path.glob("*"), desc="Reading results")
]
print(colored(f"{name}", attrs=["underline"]))
# Calculate the metrics
successes = sum(obj.success for obj in objs)
success_rate = successes / len(objs)
print(f"Success rate: {successes} / {len(objs)} = {success_rate:.3f}")
# Calculate hammer rates
hammer_invocations = [
obj.hammer_invocations
for obj in objs
if obj.hammer_invocations
]
if hammer_invocations:
avg_hammer_invocation = sum(hammer_invocations) / len(hammer_invocations)
print(f"Hammer invocations: {avg_hammer_invocation:.3f}")
else:
print("Hammer invocations cannot be calculated")
durations = [obj.duration for obj in objs if obj.duration and obj.duration > 0]
if durations:
avg_duration = sum(durations) / len(durations)
print(f"Durations: {avg_duration:.3f}")
else:
print("Durations cannot be calculated")
return objs
def differentiated_histogram(
path: Path,
data: dict[str, list[dict]],
key: str,
key_func: Callable[[DatumResult], Optional[float]],
xticks: list[float] = None,
**kwargs,
):
"""
Map objects using `key_func`, filtering o
"""
dfs = [
pandas.DataFrame({ 'name': name, key: [ key_func(obj) for obj in objs ]})
for name, objs in data.items()
]
df = pandas.concat(dfs).reset_index()
fig, ax = pyplot.subplots(figsize=(6,4))
seaborn.histplot(
df,
ax=ax,
x=key, hue="name",
multiple="stack",
**kwargs,
)
if xticks:
ax.set_xticks(xticks, labels=[str(t) for t in xticks])
fig.savefig(path, dpi=300)
def plot(args):
assert len(args.result) == len(args.names), "Names must have a 1-1 correspondence with results"
print(colored("Reading data ...", color="blue"))
data = {
name: read_data(name, Path(result_path))
for name, result_path in zip(args.names, args.result)
}
path_plot_output = Path(args.plot_output)
path_plot_output.mkdir(exist_ok=True, parents=True)
# Generate plots
differentiated_histogram(
path_plot_output / "hammer.jpg",
data,
key="Hammered Goals",
key_func=lambda obj: obj.hammer_invocations)
differentiated_histogram(
path_plot_output / "runtime.jpg",
data,
key="Runtime (s)",
key_func=lambda obj: obj.duration if obj.duration > 0.0 else None,
log_scale=True,
xticks=[5, 10, 20, 40, 80, 160, 320],
)
if __name__ == "__main__":
import argparse
parser = argparse.ArgumentParser(
prog='DSP Plot',
description="Generates plots for DSP",
formatter_class=argparse.ArgumentDefaultsHelpFormatter,
)
parser.add_argument(
"--result",
help="Result directory",
nargs="+",
default=experiment_dir / 'result',
)
parser.add_argument(
"--names",
help="Experiment names",
nargs="+",
default=["unnamed"],
)
parser.add_argument(
"--plot-output",
help="Plot generation directory",
default=experiment_dir / 'result-plot',
)
parser.add_argument(
"--palette",
help="Colour palette",
default="Paired",
)
args = parser.parse_args()
seaborn.set_palette(args.palette)
plot(args)