Initial commit: shared DB, fetcher, and CSV import utilities
Co-Authored-By: Claude Sonnet 4.6 <noreply@anthropic.com>
This commit is contained in:
@@ -0,0 +1,3 @@
|
||||
__pycache__/
|
||||
*.pyc
|
||||
*.pyo
|
||||
@@ -0,0 +1,106 @@
|
||||
"""
|
||||
Central candle store — data/candles.db (SQLite, WAL mode).
|
||||
|
||||
Usage:
|
||||
from shared.db import upsert_candles, get_candles
|
||||
|
||||
upsert_candles("EURUSD", "M15", df) # df: time,open,high,low,close,tick_volume
|
||||
df = get_candles("EURUSD", "M15") # all candles, sorted ascending
|
||||
df = get_candles("EURUSD", "M15", n=300) # last 300 candles
|
||||
df = get_candles("EURUSD", "M15", start="2024-01-01")
|
||||
"""
|
||||
import sqlite3
|
||||
from pathlib import Path
|
||||
|
||||
import pandas as pd
|
||||
|
||||
DB_PATH = Path(__file__).parent.parent / "data" / "candles.db"
|
||||
|
||||
_CREATE_TABLE = """
|
||||
CREATE TABLE IF NOT EXISTS candles (
|
||||
symbol TEXT NOT NULL,
|
||||
timeframe TEXT NOT NULL,
|
||||
time TEXT NOT NULL,
|
||||
open REAL NOT NULL,
|
||||
high REAL NOT NULL,
|
||||
low REAL NOT NULL,
|
||||
close REAL NOT NULL,
|
||||
tick_volume INTEGER,
|
||||
PRIMARY KEY (symbol, timeframe, time)
|
||||
)
|
||||
"""
|
||||
_CREATE_INDEX = "CREATE INDEX IF NOT EXISTS idx_sym_tf_time ON candles(symbol, timeframe, time)"
|
||||
|
||||
|
||||
def _connect(path: Path = DB_PATH) -> sqlite3.Connection:
|
||||
path.parent.mkdir(parents=True, exist_ok=True)
|
||||
conn = sqlite3.connect(path, check_same_thread=False)
|
||||
conn.execute("PRAGMA journal_mode=WAL")
|
||||
conn.execute(_CREATE_TABLE)
|
||||
conn.execute(_CREATE_INDEX)
|
||||
conn.commit()
|
||||
return conn
|
||||
|
||||
|
||||
def upsert_candles(symbol: str, timeframe: str, df: pd.DataFrame, path: Path = DB_PATH) -> int:
|
||||
"""Insert or replace candles. Returns number of rows written."""
|
||||
conn = _connect(path)
|
||||
rows = [
|
||||
(
|
||||
symbol, timeframe,
|
||||
str(row.time)[:19], # truncate to seconds
|
||||
float(row.open), float(row.high), float(row.low), float(row.close),
|
||||
int(row.tick_volume) if pd.notna(row.tick_volume) else None,
|
||||
)
|
||||
for row in df.itertuples(index=False)
|
||||
]
|
||||
conn.executemany("INSERT OR REPLACE INTO candles VALUES (?,?,?,?,?,?,?,?)", rows)
|
||||
conn.commit()
|
||||
conn.close()
|
||||
return len(rows)
|
||||
|
||||
|
||||
def get_candles(
|
||||
symbol: str,
|
||||
timeframe: str,
|
||||
n: int = None,
|
||||
start: str = None,
|
||||
end: str = None,
|
||||
path: Path = DB_PATH,
|
||||
) -> pd.DataFrame:
|
||||
"""Return candles as DataFrame sorted ascending by time."""
|
||||
conn = _connect(path)
|
||||
where = "symbol=? AND timeframe=?"
|
||||
params: list = [symbol, timeframe]
|
||||
|
||||
if start:
|
||||
where += " AND time >= ?"
|
||||
params.append(start)
|
||||
if end:
|
||||
where += " AND time <= ?"
|
||||
params.append(end)
|
||||
|
||||
if n and not start and not end:
|
||||
sql = f"SELECT time,open,high,low,close,tick_volume FROM candles WHERE {where} ORDER BY time DESC LIMIT ?"
|
||||
params.append(n)
|
||||
df = pd.read_sql_query(sql, conn, params=params)
|
||||
df = df.iloc[::-1].reset_index(drop=True)
|
||||
else:
|
||||
sql = f"SELECT time,open,high,low,close,tick_volume FROM candles WHERE {where} ORDER BY time"
|
||||
df = pd.read_sql_query(sql, conn, params=params)
|
||||
|
||||
conn.close()
|
||||
df["time"] = pd.to_datetime(df["time"])
|
||||
return df
|
||||
|
||||
|
||||
def list_available(path: Path = DB_PATH) -> pd.DataFrame:
|
||||
"""Return a summary of all (symbol, timeframe) pairs with candle counts."""
|
||||
conn = _connect(path)
|
||||
df = pd.read_sql_query(
|
||||
"SELECT symbol, timeframe, COUNT(*) as candles, MIN(time) as first, MAX(time) as last "
|
||||
"FROM candles GROUP BY symbol, timeframe ORDER BY symbol, timeframe",
|
||||
conn,
|
||||
)
|
||||
conn.close()
|
||||
return df
|
||||
+127
@@ -0,0 +1,127 @@
|
||||
"""
|
||||
Dukascopy public CDN fetcher — no account required.
|
||||
Downloads bi5 tick files, aggregates to OHLC at the requested timeframe.
|
||||
|
||||
Usage:
|
||||
from shared.fetcher import fetch
|
||||
from datetime import datetime
|
||||
|
||||
df = fetch("EURUSD", datetime(2024, 1, 1), datetime(2025, 1, 1), "M15")
|
||||
"""
|
||||
import io
|
||||
import lzma
|
||||
import struct
|
||||
import urllib.request
|
||||
from concurrent.futures import ThreadPoolExecutor, as_completed
|
||||
from datetime import datetime, timedelta, timezone
|
||||
|
||||
import pandas as pd
|
||||
|
||||
TIMEFRAME_RESAMPLE = {
|
||||
"M1": "1min", "M5": "5min", "M10": "10min", "M15": "15min", "M30": "30min",
|
||||
"H1": "1h", "H4": "4h", "D": "1D", "D1": "1D",
|
||||
}
|
||||
|
||||
_POINT_DIVISOR = {"JPY": 1000}
|
||||
_DEFAULT_DIVISOR = 100_000
|
||||
|
||||
_TICK_FMT = ">IIIff" # ms_offset, ask, bid, ask_vol, bid_vol
|
||||
_TICK_SIZE = struct.calcsize(_TICK_FMT)
|
||||
|
||||
|
||||
def _divisor(symbol: str) -> int:
|
||||
return _POINT_DIVISOR.get(symbol[-3:], _DEFAULT_DIVISOR)
|
||||
|
||||
|
||||
def _cdn_url(symbol: str, dt: datetime) -> str:
|
||||
return (
|
||||
f"https://datafeed.dukascopy.com/datafeed/{symbol}/"
|
||||
f"{dt.year}/{dt.month - 1:02d}/{dt.day:02d}/{dt.hour:02d}h_ticks.bi5"
|
||||
)
|
||||
|
||||
|
||||
def _download_hour(symbol: str, dt: datetime) -> list[tuple[int, float, float]]:
|
||||
"""Return list of (unix_ms, ask, bid) for one hour. Empty on any failure."""
|
||||
url = _cdn_url(symbol, dt)
|
||||
try:
|
||||
with urllib.request.urlopen(url, timeout=15) as resp:
|
||||
raw = resp.read()
|
||||
if not raw:
|
||||
return []
|
||||
data = lzma.decompress(raw, format=lzma.FORMAT_AUTO)
|
||||
divisor = _divisor(symbol)
|
||||
hour_ms = int(dt.replace(tzinfo=timezone.utc).timestamp() * 1000)
|
||||
n = len(data) // _TICK_SIZE
|
||||
ticks = []
|
||||
for i in range(n):
|
||||
ms_off, ask_raw, bid_raw, _, _ = struct.unpack_from(_TICK_FMT, data, i * _TICK_SIZE)
|
||||
ticks.append((hour_ms + ms_off, ask_raw / divisor, bid_raw / divisor))
|
||||
return ticks
|
||||
except Exception:
|
||||
return []
|
||||
|
||||
|
||||
def _trading_hours(start: datetime, end: datetime) -> list[datetime]:
|
||||
"""All hourly UTC timestamps Mon–Fri in the date range."""
|
||||
hours = []
|
||||
dt = start.replace(minute=0, second=0, microsecond=0)
|
||||
while dt < end:
|
||||
if dt.weekday() < 5:
|
||||
hours.append(dt)
|
||||
dt += timedelta(hours=1)
|
||||
return hours
|
||||
|
||||
|
||||
def fetch(
|
||||
symbol: str,
|
||||
start: datetime,
|
||||
end: datetime,
|
||||
timeframe: str,
|
||||
max_workers: int = 24,
|
||||
progress_cb=None,
|
||||
) -> pd.DataFrame:
|
||||
"""
|
||||
Download tick data from Dukascopy and resample to OHLC candles.
|
||||
|
||||
Args:
|
||||
symbol: e.g. "EURUSD"
|
||||
start/end: naive UTC datetimes
|
||||
timeframe: one of M1, M5, M15, M30, H1, H4, D1
|
||||
max_workers: parallel HTTP threads
|
||||
progress_cb: optional callable(completed, total) for progress reporting
|
||||
|
||||
Returns:
|
||||
DataFrame with columns: time, open, high, low, close, tick_volume
|
||||
"""
|
||||
if timeframe not in TIMEFRAME_RESAMPLE:
|
||||
raise ValueError(f"Unknown timeframe {timeframe!r}. Choose from {list(TIMEFRAME_RESAMPLE)}")
|
||||
|
||||
hours = _trading_hours(start, end)
|
||||
all_ticks: list[tuple[int, float, float]] = []
|
||||
completed = 0
|
||||
|
||||
with ThreadPoolExecutor(max_workers=max_workers) as pool:
|
||||
futures = {pool.submit(_download_hour, symbol, h): h for h in hours}
|
||||
for fut in as_completed(futures):
|
||||
all_ticks.extend(fut.result())
|
||||
completed += 1
|
||||
if progress_cb:
|
||||
progress_cb(completed, len(hours))
|
||||
|
||||
if not all_ticks:
|
||||
raise RuntimeError(f"No tick data returned for {symbol} ({start} – {end})")
|
||||
|
||||
all_ticks.sort(key=lambda t: t[0])
|
||||
ts, asks, bids = zip(*all_ticks)
|
||||
mids = [(a + b) / 2 for a, b in zip(asks, bids)]
|
||||
|
||||
idx = pd.to_datetime(ts, unit="ms", utc=True).tz_localize(None)
|
||||
s = pd.Series(mids, index=idx, name="mid")
|
||||
|
||||
freq = TIMEFRAME_RESAMPLE[timeframe]
|
||||
ohlc = s.resample(freq).ohlc().dropna()
|
||||
vol = s.resample(freq).count().rename("tick_volume")
|
||||
|
||||
df = ohlc.join(vol).reset_index()
|
||||
df.columns = ["time", "open", "high", "low", "close", "tick_volume"]
|
||||
return df
|
||||
@@ -0,0 +1,80 @@
|
||||
"""
|
||||
One-shot migration: import all NewBellCurve CSV candle files into data/candles.db.
|
||||
|
||||
Usage:
|
||||
python -m shared.import_csv
|
||||
python -m shared.import_csv --csv-dir /path/to/other/csvs
|
||||
python -m shared.import_csv --dry-run
|
||||
"""
|
||||
import argparse
|
||||
import logging
|
||||
import re
|
||||
import sys
|
||||
from pathlib import Path
|
||||
|
||||
import pandas as pd
|
||||
|
||||
# Allow running from repo root or any subdirectory
|
||||
sys.path.insert(0, str(Path(__file__).parent.parent))
|
||||
from shared.db import upsert_candles, list_available, DB_PATH
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
DEFAULT_CSV_DIR = Path(__file__).parent.parent / "NewBellCurve" / "simulation" / "data"
|
||||
_FILENAME_RE = re.compile(r"^([A-Z]+)_([A-Z0-9]+)\.csv$")
|
||||
|
||||
|
||||
def import_directory(csv_dir: Path, dry_run: bool = False) -> dict[str, int]:
|
||||
csv_files = sorted(csv_dir.glob("*.csv"))
|
||||
if not csv_files:
|
||||
logger.warning("No CSV files found in %s", csv_dir)
|
||||
return {}
|
||||
|
||||
results = {}
|
||||
for path in csv_files:
|
||||
m = _FILENAME_RE.match(path.name)
|
||||
if not m:
|
||||
logger.warning("Skipping unrecognised filename: %s", path.name)
|
||||
continue
|
||||
|
||||
symbol, timeframe = m.group(1), m.group(2)
|
||||
df = pd.read_csv(path, parse_dates=["time"])
|
||||
|
||||
if dry_run:
|
||||
logger.info("DRY RUN %-10s %-5s %d candles", symbol, timeframe, len(df))
|
||||
results[path.name] = len(df)
|
||||
continue
|
||||
|
||||
n = upsert_candles(symbol, timeframe, df)
|
||||
logger.info("%-10s %-5s %d candles written", symbol, timeframe, n)
|
||||
results[path.name] = n
|
||||
|
||||
return results
|
||||
|
||||
|
||||
def main():
|
||||
logging.basicConfig(level=logging.INFO, format="%(asctime)s %(message)s")
|
||||
parser = argparse.ArgumentParser(description="Import CSV candle files into shared candles.db")
|
||||
parser.add_argument("--csv-dir", default=str(DEFAULT_CSV_DIR), help="Directory of CSV files")
|
||||
parser.add_argument("--dry-run", action="store_true", help="Parse only, do not write to DB")
|
||||
args = parser.parse_args()
|
||||
|
||||
csv_dir = Path(args.csv_dir)
|
||||
if not csv_dir.exists():
|
||||
logger.error("CSV directory not found: %s", csv_dir)
|
||||
sys.exit(1)
|
||||
|
||||
logger.info("Importing from: %s", csv_dir)
|
||||
logger.info("Database: %s", DB_PATH)
|
||||
|
||||
results = import_directory(csv_dir, dry_run=args.dry_run)
|
||||
total = sum(results.values())
|
||||
logger.info("Done — %d files, %d candles total", len(results), total)
|
||||
|
||||
if not args.dry_run:
|
||||
print("\nAvailable data in candles.db:")
|
||||
print(list_available().to_string(index=False))
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
||||
Reference in New Issue
Block a user