Files

98 lines
3.7 KiB
Python
Raw Permalink Blame History

This file contains ambiguous Unicode characters
This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.
"""
Equity curve and trade distribution charts.
Saved to reports/ as PNG files.
"""
from __future__ import annotations
from pathlib import Path
import numpy as np
import pandas as pd
import matplotlib
matplotlib.use("Agg") # non-interactive backend
import matplotlib.pyplot as plt
import matplotlib.gridspec as gridspec
from engine.backtest import Trade
from engine.metrics import Metrics
REPORTS_DIR = Path(__file__).parent
def save_equity_chart(
equity: np.ndarray,
trades: list[Trade],
metrics: Metrics,
label: str,
filename: str = "equity.png",
) -> Path:
closed = [t for t in trades if t.closed]
# drawdown
peak = np.maximum.accumulate(equity)
dd = (equity - peak) / peak * 100
# daily pnl histogram
by_day: dict[str, float] = {}
for t in closed:
day = t.exit_time.date().isoformat() # type: ignore[union-attr]
by_day[day] = by_day.get(day, 0) + t.pnl_r
daily_r = np.array(list(by_day.values()))
fig = plt.figure(figsize=(14, 10))
fig.suptitle(label, fontsize=13, fontweight="bold")
gs = gridspec.GridSpec(3, 2, figure=fig, hspace=0.45, wspace=0.3)
# ── equity curve ──────────────────────────────────────────────────────────
ax1 = fig.add_subplot(gs[0, :])
ax1.plot(equity, color="#2196F3", linewidth=1.2, label="Equity")
ax1.axhline(equity[0], color="grey", linewidth=0.6, linestyle="--")
ax1.set_ylabel("Balance ($)")
ax1.set_title("Equity Curve")
ax1.legend(fontsize=8)
ax1.grid(True, alpha=0.3)
# ── drawdown ──────────────────────────────────────────────────────────────
ax2 = fig.add_subplot(gs[1, :])
ax2.fill_between(range(len(dd)), dd, 0, color="#F44336", alpha=0.6)
ax2.set_ylabel("Drawdown (%)")
ax2.set_title(f"Drawdown (max {metrics.max_drawdown*100:.1f}%)")
ax2.grid(True, alpha=0.3)
# ── R distribution ────────────────────────────────────────────────────────
ax3 = fig.add_subplot(gs[2, 0])
r_vals = [t.pnl_r for t in closed]
colors = ["#4CAF50" if r > 0 else "#F44336" for r in r_vals]
ax3.bar(range(len(r_vals)), r_vals, color=colors, width=0.8, alpha=0.8)
ax3.axhline(0, color="black", linewidth=0.6)
ax3.set_xlabel("Trade #")
ax3.set_ylabel("R")
ax3.set_title("R per Trade")
ax3.grid(True, alpha=0.3, axis="y")
# ── daily R histogram ─────────────────────────────────────────────────────
ax4 = fig.add_subplot(gs[2, 1])
if len(daily_r) > 0:
ax4.hist(daily_r, bins=30, color="#9C27B0", alpha=0.75, edgecolor="white")
ax4.axvline(0, color="black", linewidth=0.8)
mu, sigma = daily_r.mean(), daily_r.std()
ax4.set_title(f"Daily R μ={mu:.3f} σ={sigma:.3f}")
ax4.set_xlabel("Daily R")
ax4.set_ylabel("Days")
ax4.grid(True, alpha=0.3)
# stats annotation
stats = (
f"Trades: {metrics.n_trades} WR: {metrics.win_rate*100:.1f}%\n"
f"PF: {metrics.profit_factor:.2f} Sharpe: {metrics.sharpe:.2f}\n"
f"Return: {metrics.total_return*100:+.1f}% MDD: {metrics.max_drawdown*100:.1f}%"
)
fig.text(0.5, 0.01, stats, ha="center", fontsize=9,
bbox=dict(boxstyle="round", facecolor="wheat", alpha=0.4))
out = REPORTS_DIR / filename
plt.savefig(out, dpi=150, bbox_inches="tight")
plt.close(fig)
return out