Files
2026-04-28 21:09:13 +02:00

107 lines
3.3 KiB
Python

"""
Central candle store — data/candles.db (SQLite, WAL mode).
Usage:
from shared.db import upsert_candles, get_candles
upsert_candles("EURUSD", "M15", df) # df: time,open,high,low,close,tick_volume
df = get_candles("EURUSD", "M15") # all candles, sorted ascending
df = get_candles("EURUSD", "M15", n=300) # last 300 candles
df = get_candles("EURUSD", "M15", start="2024-01-01")
"""
import sqlite3
from pathlib import Path
import pandas as pd
DB_PATH = Path(__file__).parent.parent / "data" / "candles.db"
_CREATE_TABLE = """
CREATE TABLE IF NOT EXISTS candles (
symbol TEXT NOT NULL,
timeframe TEXT NOT NULL,
time TEXT NOT NULL,
open REAL NOT NULL,
high REAL NOT NULL,
low REAL NOT NULL,
close REAL NOT NULL,
tick_volume INTEGER,
PRIMARY KEY (symbol, timeframe, time)
)
"""
_CREATE_INDEX = "CREATE INDEX IF NOT EXISTS idx_sym_tf_time ON candles(symbol, timeframe, time)"
def _connect(path: Path = DB_PATH) -> sqlite3.Connection:
path.parent.mkdir(parents=True, exist_ok=True)
conn = sqlite3.connect(path, check_same_thread=False)
conn.execute("PRAGMA journal_mode=WAL")
conn.execute(_CREATE_TABLE)
conn.execute(_CREATE_INDEX)
conn.commit()
return conn
def upsert_candles(symbol: str, timeframe: str, df: pd.DataFrame, path: Path = DB_PATH) -> int:
"""Insert or replace candles. Returns number of rows written."""
conn = _connect(path)
rows = [
(
symbol, timeframe,
str(row.time)[:19], # truncate to seconds
float(row.open), float(row.high), float(row.low), float(row.close),
int(row.tick_volume) if pd.notna(row.tick_volume) else None,
)
for row in df.itertuples(index=False)
]
conn.executemany("INSERT OR REPLACE INTO candles VALUES (?,?,?,?,?,?,?,?)", rows)
conn.commit()
conn.close()
return len(rows)
def get_candles(
symbol: str,
timeframe: str,
n: int = None,
start: str = None,
end: str = None,
path: Path = DB_PATH,
) -> pd.DataFrame:
"""Return candles as DataFrame sorted ascending by time."""
conn = _connect(path)
where = "symbol=? AND timeframe=?"
params: list = [symbol, timeframe]
if start:
where += " AND time >= ?"
params.append(start)
if end:
where += " AND time <= ?"
params.append(end)
if n and not start and not end:
sql = f"SELECT time,open,high,low,close,tick_volume FROM candles WHERE {where} ORDER BY time DESC LIMIT ?"
params.append(n)
df = pd.read_sql_query(sql, conn, params=params)
df = df.iloc[::-1].reset_index(drop=True)
else:
sql = f"SELECT time,open,high,low,close,tick_volume FROM candles WHERE {where} ORDER BY time"
df = pd.read_sql_query(sql, conn, params=params)
conn.close()
df["time"] = pd.to_datetime(df["time"])
return df
def list_available(path: Path = DB_PATH) -> pd.DataFrame:
"""Return a summary of all (symbol, timeframe) pairs with candle counts."""
conn = _connect(path)
df = pd.read_sql_query(
"SELECT symbol, timeframe, COUNT(*) as candles, MIN(time) as first, MAX(time) as last "
"FROM candles GROUP BY symbol, timeframe ORDER BY symbol, timeframe",
conn,
)
conn.close()
return df