From c07e353be4096e4bbd7c9cfe5b3ce36055ac5461 Mon Sep 17 00:00:00 2001 From: ypogorelova Date: Wed, 28 Jun 2023 15:35:14 -0700 Subject: [PATCH 1/4] Added live trading support --- livetrading/__init__.py | 0 livetrading/broker.py | 84 +++++++++++++++ livetrading/config.py | 4 + livetrading/converter.py | 17 ++++ livetrading/env | 5 + livetrading/event.py | 214 +++++++++++++++++++++++++++++++++++++++ setup.py | 5 +- 7 files changed, 328 insertions(+), 1 deletion(-) create mode 100644 livetrading/__init__.py create mode 100644 livetrading/broker.py create mode 100644 livetrading/config.py create mode 100644 livetrading/converter.py create mode 100644 livetrading/env create mode 100644 livetrading/event.py diff --git a/livetrading/__init__.py b/livetrading/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/livetrading/broker.py b/livetrading/broker.py new file mode 100644 index 00000000..b49acfc2 --- /dev/null +++ b/livetrading/broker.py @@ -0,0 +1,84 @@ +from decimal import Decimal +from typing import Any, Dict, Optional + +from livetrading.event import KLinesEventSource, Pair, PairInfo, TickersEventSource +from livetrading.rest_cli import RestClient +from livetrading.websocket_client import WSClient + + +class Broker: + """A client for crypto currency exchange. + + :param dispatcher: The event dispatcher. + :param config: Config settings for exchange. + """ + def __init__( + self, dispatcher, config + ): + self.dispatcher = dispatcher + self.config = config + self.api_cli = RestClient(self.config) + self.cli: Optional[Any] = None # external libs as ccxt + self.ws_cli = WSClient(config) + self._cached_pairs: Dict[Pair] = {} + + def subscribe_to_ticker_events( + self, pair: Pair, event_handler + ): + """Registers a callable that will be called every ticker. + + :param pair: The trading pair. + :param event_handler: A callable that receives an TickerEvent. + """ + event_source = TickersEventSource(pair, self.ws_cli) + channel = "ticker" + + self._subscribe_to_ws_channel_events( + channel, + event_handler, + event_source + ) + + def subscribe_to_bar_events( + self, pair: Pair, event_handler, interval + ): + """Registers a callable that will be called every bar. + + :param pair: The trading pair. + :param event_handler: A callable that receives an BarEvent. + """ + event_source = KLinesEventSource(pair, self.ws_cli) + channel = event_source.ws_channel(interval) + + self._subscribe_to_ws_channel_events( + channel, + event_handler, + event_source + ) + + def get_pair_info(self, pair: Pair) -> PairInfo: + """Returns information about a trading pair. + + :param pair: The trading pair. + """ + ret = self._cached_pairs.get(pair) + api_path = '/'.join(['products', pair]) + if not ret: + pair_info = self.api_cli.call(method='GET', apipath=api_path) + self._cached_pairs[pair] = PairInfo(Decimal(pair_info['base_increment']), + Decimal(pair_info['quote_increment'])) + return self._cached_pairs + + def get_data_df(self, event_source): + data_source = self.ws_cli.event_sources[event_source] + return list(data_source.events) + + def _subscribe_to_ws_channel_events( + self, channel: str, event_handler, event_source + ): + # Set the event source for the channel. + self.ws_cli.set_channel_event_source(channel, event_source) + # self.ws_cli.subscribe_to_channels() + + # Subscribe the event handler to the event source. + self.dispatcher.subscribe(event_source, event_handler) diff --git a/livetrading/config.py b/livetrading/config.py new file mode 100644 index 00000000..148f9033 --- /dev/null +++ b/livetrading/config.py @@ -0,0 +1,4 @@ +from configloader import ConfigLoader + +config = ConfigLoader() +config.update_from_json_file('path_to_json_file') diff --git a/livetrading/converter.py b/livetrading/converter.py new file mode 100644 index 00000000..666cf2fa --- /dev/null +++ b/livetrading/converter.py @@ -0,0 +1,17 @@ +import pandas as pd + +DEFAULT_DATAFRAME_COLUMNS = ['Date', 'Open', 'High', 'Low', 'Close', 'Volume'] + +def ohlcv_to_dataframe(historical_data: list) -> pd.DataFrame: + """ + Converts historical data to a Dataframe + :param historical_data: list with candle (OHLCV) data + :return: DataFrame + """ + df = pd.DataFrame( + [{fn: getattr(f, fn) for fn in DEFAULT_DATAFRAME_COLUMNS} for f in historical_data] + ) + df['Date'] = pd.to_datetime(df['Date'], unit='ms', utc=True, ) + df = df.set_index('Date') + df = df.sort_index(ascending=True) + return df.head() diff --git a/livetrading/env b/livetrading/env new file mode 100644 index 00000000..21f8be0a --- /dev/null +++ b/livetrading/env @@ -0,0 +1,5 @@ +{ +"ws_url": "wss://ws-feed.exchange.coinbase.com", +"api_url": "https://api.exchange.coinbase.com/", +"ws_timeout": 5} +} diff --git a/livetrading/event.py b/livetrading/event.py new file mode 100644 index 00000000..d7284d82 --- /dev/null +++ b/livetrading/event.py @@ -0,0 +1,214 @@ +import abc +import dataclasses +import datetime + +from collections import deque +from dateutil.parser import isoparse +from typing import Optional + + +@dataclasses.dataclass +class Bar: + """A Bar, aka candlestick, is the summary of the trading activity in a given period. + + :param date: The beginning of the period. It must have timezone information set. + :param pair: The trading pair. + :param open: The opening price. + :param high: The highest traded price. + :param low: The lowest traded price. + :param close: The closing price. + :param volume: The volume traded. + """ + date: datetime + pair: str + Open: float + High: float + Low: float + Close: float + Volume: float + + +@dataclasses.dataclass +class Pair: + """A trading pair. + + :param base_symbol: The base symbol. + :param quote_symbol: The quote symbol. + """ + base_symbol: str + quote_symbol: str + + def __str__(self): + # change format here to reflect corresponding exchange + return "{}-{}".format(self.base_symbol, self.quote_symbol) + + +@dataclasses.dataclass +class PairInfo: + """Information about a trading pair. + + :param base_increment: The increment for the base symbol. + :param quote_increment: The increment for the quote symbol. + """ + base_increment: float + quote_increment: float + + +class Ticker: + """A Ticker constantly updating stream of information about a stock. + :param datetime: The beginning of the period. It must have timezone information set. + :param pair: The trading pair. + :param open: The opening price. + :param high: The highest traded price. + :param low: The lowest traded price. + :param price: The price. + :param volume: The volume traded. + """ + def __init__(self, pair: Pair, json: dict): + self.pair: Pair = pair + self.json: dict = json + self.Date = isoparse(json['time']) + self.Volume = float(json["volume_24h"]) + self.Open = float(json["open_24h"]) + self.High = float(json["high_24h"]) + self.Low = float(json["low_24h"]) + self.Close = float(json["price"]) + + +class KlineBar(Bar): + """ + K-line, aka candlestick, is a chart marked with the opening price, closing price, + highest price, and lowest price to reflect price changes. + :param pair: The trading pair. + :param json: Message json. + """ + def __init__(self, pair: Pair, json: dict): + super().__init__( + datetime.utcfromtimestamp( + int(json["t"] / 1e3).replace(tzinfo=datetime.timezone.utc)), + pair, float(json["o"]), float(json["h"]), + float(json["l"]), float(json["c"]), float(json["v"]) + ) + self.pair: Pair = pair + self.json: dict = json + + +class EventProducer: + """Base class for event producers. + .. note:: + + Main method is for main functions that should be performed for an event producer. + Finalize method is called on error or stop. + """ + def main(self): + """Override to run the loop that produces events.""" + pass + + def finalize(self): + """Override to perform task and transaction cancellation.""" + pass + + +class Event: + """Base class for events. + + :param when: The datetime when the event occurred. + Used to calculate the datetime for the next event. + It must have timezone information set. + """ + + def __init__(self, when: datetime.datetime): + self.when: datetime.datetime = when + + +class EventSource(metaclass=abc.ABCMeta): + """Base class for events storage. + + :param producer: EventProducer. + """ + + def __init__(self, producer: Optional[EventProducer] = None): + self.producer = producer + self.events = deque() + + +class ChannelEventSource(EventSource): + """Base class for websockets channels. + + :param producer: EventProducer. + """ + def __init__(self, producer: EventProducer): + super().__init__(producer=producer) + + @abc.abstractmethod + async def push_to_queue(self, message: dict): + raise NotImplementedError() + + +class TickersEventSource(ChannelEventSource): + """An event source for :class:`Ticker` instances. + + :param pair: The trading pair. + """ + def __init__(self, pair: Pair, producer: EventProducer): + super().__init__(producer=producer) + self.pair: Pair = pair + + def push_to_queue(self, message: dict): + timestamp = message["time"] + self.events.append(TickerEvent( + isoparse(timestamp), + Ticker(self.pair, message))) + + +class KLinesEventSource(EventSource): + """An event source for :class:`KLineBar` instances. + + :param pair: The trading pair.. + """ + def __init__(self, pair: Pair, producer: EventProducer): + super().__init__(producer=producer) + self.pair: Pair = pair + + def push_to_queue(self, message: dict): + kline_event = message["data"] + kline = kline_event["k"] + # Wait for the last update to the kline. + if kline["x"] is False: + return + self.events.append(BarEvent( + datetime.utcfromtimestamp( + int(kline_event["E"] / 1e3).replace(tzinfo=datetime.timezone.utc)), + KlineBar(self.pair, kline))) + + def ws_channel(self, interval: str) -> str: + """ + Generate websocket channel + """ + return "{}@kline_{}".format( + "{}{}".format(self.pair.base_symbol.upper(), self.pair.quote_symbol.upper()).lower(), + interval) + + +class BarEvent(Event): + """An event for :class:`Bar` instances. + + :param when: The datetime when the event occurred. It must have timezone information set. + :param bar: The bar. + """ + def __init__(self, when, bar: Bar): + super().__init__(when) + + self.bar = bar + + +class TickerEvent(Event): + """An event for :class:`Ticker` instances. + + :param when: The datetime when the event occurred. It must have timezone information set. + :param ticker: The Ticker. + """ + def __init__(self, when, ticker: Ticker): + super().__init__(when) + + self.ticker = ticker diff --git a/setup.py b/setup.py index 60fa15ea..242c6f87 100644 --- a/setup.py +++ b/setup.py @@ -34,7 +34,10 @@ 'numpy >= 1.17.0', 'pandas >= 0.25.0, != 0.25.0', 'bokeh >= 1.4.0', - ], + 'configloader >= 1.0.1', + 'websocket-client >= 1.6.0', + 'urllib3 >= 2.0.3' + ], extras_require={ 'doc': [ 'pdoc3', From 568b966cf7d829be955276ddc1e3746c5992185a Mon Sep 17 00:00:00 2001 From: ypogorelova Date: Wed, 28 Jun 2023 15:36:40 -0700 Subject: [PATCH 2/4] Added live trading support --- livetrading/executor.py | 115 ++++++++++++++++++++++++++++++++ livetrading/live_trading.py | 64 ++++++++++++++++++ livetrading/rest_cli.py | 37 ++++++++++ livetrading/websocket_client.py | 64 ++++++++++++++++++ 4 files changed, 280 insertions(+) create mode 100644 livetrading/executor.py create mode 100644 livetrading/live_trading.py create mode 100644 livetrading/rest_cli.py create mode 100644 livetrading/websocket_client.py diff --git a/livetrading/executor.py b/livetrading/executor.py new file mode 100644 index 00000000..2c071cc3 --- /dev/null +++ b/livetrading/executor.py @@ -0,0 +1,115 @@ +import time +import datetime +import logging + +from typing import Any, Dict, List, Set, Optional + +from .event import Event, EventSource, EventProducer + +logger = logging.getLogger(__name__) + + +class EventDispatcher: + """Responsible for connecting event sources to event handlers and dispatching events + in the right order. + """ + def __init__(self): + self._event_handlers: Dict[EventSource, List[Any]] = {} + self._prefetched_events: Dict[EventSource, Optional[Event]] = {} + self._prev_events: Dict[EventSource, datetime.datetime] = {} + self._producers: Set[EventProducer] = set() + self._running = False + self._stopped = False + self._current_event_dt = None + + def set_strategy(self, strategy): + self.strategy = strategy + + def subscribe(self, source: EventSource, event_handler: Any): + """Registers an callable that will be called when an event source has new events. + + :param source: An event source. + :param event_handler: An callable that receives an event. + """ + assert not self._running + handlers = self._event_handlers.setdefault(source, []) + if event_handler not in handlers: + handlers.append(event_handler) + if source.producer: + self._producers.add(source.producer) + + def run(self): + assert not self._running, "Running or already ran" + + self._running = True + try: + # Run producers and dispatch loop. + for producer in self._producers: + producer.main() + self._dispatch_loop() + except Exception as error: + logger.error(error) + finally: + for producer in self._producers: + producer.finalize() + + def on_error(self, error: Any): + logger.error(error) + + def _dispatch_next(self, ge_or_assert: Optional[datetime.datetime]): + # Pre-fetch events from all sources. + sources_to_pop = [ + source for source in self._event_handlers.keys() if + self._prefetched_events.get(source) is None + ] + for source in sources_to_pop: + if source.events: + event = source.events.pop() + # Check that events from the same source are returned in order. + prev_event = self._prev_events.get(source) + if prev_event is not None and event.when < prev_event.when: + continue + + self._prev_events[source] = event + self._prefetched_events[source] = event + + # Calculate the datetime for the next event using the prefetched events. + next_dt = None + prefetched_events = [e for e in self._prefetched_events.values() if e] + if prefetched_events: + next_dt = min(map(lambda e: e.when, prefetched_events)) + assert ge_or_assert is None or next_dt is None or next_dt >= ge_or_assert, \ + f"{next_dt} can't be dispatched after {ge_or_assert}" + + # Dispatch events matching the desired datetime. + event_handlers = [] + for source, e in self._prefetched_events.items(): + if e is not None and e.when == next_dt: + # Collect event handlers for the event source. + event_handlers += [event_handler(e) for event_handler in + self._event_handlers.get(source, [])] + # Consume the event. + self._prefetched_events[source] = None + + self._current_event_dt = None + self.strategy.next() + self.strategy.init() + return next_dt + + def stop(self): + """Requests the event dispatcher to stop the event processing loop.""" + self._stopped = True + + for producer in self._producers: + producer.finalize() + + def _dispatch_loop(self): + last_dt = None + + while not self._stopped: + dispatched_dt = self._dispatch_next(last_dt) + print('dispatched_dt', dispatched_dt) + if dispatched_dt is None: + time.sleep(0.01) + else: + last_dt = dispatched_dt diff --git a/livetrading/live_trading.py b/livetrading/live_trading.py new file mode 100644 index 00000000..545413a8 --- /dev/null +++ b/livetrading/live_trading.py @@ -0,0 +1,64 @@ +import websocket + +from backtesting import Strategy +from backtesting._util import _Data +from livetrading import executor +from livetrading.broker import Broker, Pair +from livetrading.config import config +from livetrading.converter import ohlcv_to_dataframe + + +class LiveStrategy(Strategy): + + def __init__(self, broker): + super().__init__(broker=broker, data=[], params={}) + self.event_data = [] + + def init(self): + super().init() + self.set_atr_periods() + + def set_atr_periods(self): + if len(self.data) > 1: + print(self.data.High, self.data.Low) + + def on_bar_event(self, event): + self.event_data.append(event.ticker) + event_df = ohlcv_to_dataframe(self.event_data) + self._data = _Data(event_df.copy(deep=False)) + + def next(self): + print(self.data) + + +class PositionManager: + def __init__(self, exchange, position_amount): + assert position_amount > 0 + self.exchange = exchange + self.position_amount = position_amount + + def on_event(self, bar_event): + # react on event from websocket + pass + + +if __name__ == '__main__': + + websocket.enableTrace(False) + + event_dis = executor.EventDispatcher() + + exchange = Broker(event_dis, config=config) + + pair_info = exchange.get_pair_info('BTC-USD') + + position_mgr = PositionManager(exchange, 0.8) + + strategy = LiveStrategy(exchange) + + exchange.subscribe_to_ticker_events(Pair(base_symbol="UTC", quote_symbol="SDT"), + strategy.on_bar_event) + + event_dis.set_strategy(strategy) + + event_dis.run() diff --git a/livetrading/rest_cli.py b/livetrading/rest_cli.py new file mode 100644 index 00000000..2b9faf79 --- /dev/null +++ b/livetrading/rest_cli.py @@ -0,0 +1,37 @@ +import json +import logging +import requests + +from typing import Optional +from urllib.parse import urljoin + +logger = logging.getLogger(__name__) + + +class RestClient: + """"Class for REST API. + :param config: Config settings for exchange. + """ + def __init__(self, config): + self.url = config['api_url'] + self.session = requests.Session() + self.session.auth = (config.get('username'), config.get('password')) + + def call(self, method, apipath, params: Optional[dict] = None, data=None): + + if str(method).upper() not in ('GET', 'POST', 'PUT', 'DELETE'): + raise ValueError(f'invalid method <{method}>') + + headers = {"Accept": "application/json", + "Content-Type": "application/json" + } + url = urljoin(self.url, apipath) + + try: + resp = self.session.request(method, url, headers=headers, data=json.dumps(data), + params=params) + if resp.status_code == 200: + return resp.json() + return resp.text + except ConnectionError: + logger.warning("Connection error") diff --git a/livetrading/websocket_client.py b/livetrading/websocket_client.py new file mode 100644 index 00000000..6a50fcf8 --- /dev/null +++ b/livetrading/websocket_client.py @@ -0,0 +1,64 @@ +import logging +import websocket, json, _thread + +from typing import Dict, List, Set + +from livetrading.event import EventSource, EventProducer + +logger = logging.getLogger(__name__) + + +class WSClient(EventProducer, websocket.WebSocketApp): + """"Class for channel based web socket clients. + :param config: Config settings for exchange. + """ + def __init__(self, config): + super(WSClient, self).__init__(config['ws_url']) + self.event_sources: Dict[str, EventSource] = {} + self.pending_subscriptions: Set[str] = set() + self.timeout = config['ws_timeout'] + self.on_open = lambda ws: self.subscribe_msg() + self.on_message = lambda ws, msg: self.handle_message(json.loads(msg)) + self.on_error = lambda ws, e: logger.warning(f"Error: {e}") + self.on_close = self.on_close + self._running = False + self.thread = None + + def set_channel_event_source(self, channel: str, event_source: EventSource): + assert channel not in self.event_sources, "channel already registered" + self.event_sources[channel] = event_source + self.pending_subscriptions.add(channel) + + def subscribe_msg(self): + self.pending_subscriptions.update(self.event_sources.keys()) + channels = list(self.pending_subscriptions) + self.subscribe_to_channels(channels) + + def on_close(self): + self.pending_subscriptions = set() + + def main(self): + if not self._running: + self.thread = _thread.start_new_thread(self.run_forever, ()) + self._running = True + + def subscribe_to_channels( + self, channels: List[str] + ): + sub_msg = { + "type": "subscribe", + "product_ids": [ + "ETH-USD", + "BTC-USD" + ], + "channels": channels + } + self.send(json.dumps(sub_msg)) + logger.info(f"Subscribed to channels: {channels}") + + def handle_message(self, message: dict) -> None: + print(message) + channel = message.get("type") + event_source = self.event_sources.get(channel) + if event_source: + event_source.push_to_queue(message) \ No newline at end of file From e63d3a0a1f45afdd75ea842e649f9224d8dcca07 Mon Sep 17 00:00:00 2001 From: ypogorelova Date: Fri, 30 Jun 2023 17:01:16 -0700 Subject: [PATCH 3/4] Added live trading support --- livetrading/event.py | 2 +- livetrading/executor.py | 1 - livetrading/websocket_client.py | 1 - 3 files changed, 1 insertion(+), 3 deletions(-) diff --git a/livetrading/event.py b/livetrading/event.py index d7284d82..ab3c4acd 100644 --- a/livetrading/event.py +++ b/livetrading/event.py @@ -141,7 +141,7 @@ def __init__(self, producer: EventProducer): super().__init__(producer=producer) @abc.abstractmethod - async def push_to_queue(self, message: dict): + def push_to_queue(self, message: dict): raise NotImplementedError() diff --git a/livetrading/executor.py b/livetrading/executor.py index 2c071cc3..5c85be8b 100644 --- a/livetrading/executor.py +++ b/livetrading/executor.py @@ -108,7 +108,6 @@ def _dispatch_loop(self): while not self._stopped: dispatched_dt = self._dispatch_next(last_dt) - print('dispatched_dt', dispatched_dt) if dispatched_dt is None: time.sleep(0.01) else: diff --git a/livetrading/websocket_client.py b/livetrading/websocket_client.py index 6a50fcf8..6a6c4f8d 100644 --- a/livetrading/websocket_client.py +++ b/livetrading/websocket_client.py @@ -57,7 +57,6 @@ def subscribe_to_channels( logger.info(f"Subscribed to channels: {channels}") def handle_message(self, message: dict) -> None: - print(message) channel = message.get("type") event_source = self.event_sources.get(channel) if event_source: From 066a58ac170e1f1abfc121793e646e7b7bbad34c Mon Sep 17 00:00:00 2001 From: ypogorelova Date: Fri, 30 Jun 2023 21:29:29 -0700 Subject: [PATCH 4/4] Added live trading support --- livetrading/broker.py | 7 ++++--- livetrading/env | 2 +- livetrading/event.py | 30 ++++++++++++++++++++++++++---- livetrading/executor.py | 26 +++++++++++++++++++++----- livetrading/live_trading.py | 34 +++++++++++++++++++--------------- 5 files changed, 71 insertions(+), 28 deletions(-) diff --git a/livetrading/broker.py b/livetrading/broker.py index b49acfc2..65c13488 100644 --- a/livetrading/broker.py +++ b/livetrading/broker.py @@ -23,14 +23,16 @@ def __init__( self._cached_pairs: Dict[Pair] = {} def subscribe_to_ticker_events( - self, pair: Pair, event_handler + self, pair: Pair, interval: str, event_handler ): """Registers a callable that will be called every ticker. + :param bar_duration: The bar duration. One of 1s, 1m, 3m, 5m, 15m, 30m, 1h, 2h, 4h, 6h, 8h, 12h, 1d, 3d, 1w, 1M. :param pair: The trading pair. :param event_handler: A callable that receives an TickerEvent. """ - event_source = TickersEventSource(pair, self.ws_cli) + + event_source = TickersEventSource(pair, interval, self.ws_cli) channel = "ticker" self._subscribe_to_ws_channel_events( @@ -78,7 +80,6 @@ def _subscribe_to_ws_channel_events( ): # Set the event source for the channel. self.ws_cli.set_channel_event_source(channel, event_source) - # self.ws_cli.subscribe_to_channels() # Subscribe the event handler to the event source. self.dispatcher.subscribe(event_source, event_handler) diff --git a/livetrading/env b/livetrading/env index 21f8be0a..059df011 100644 --- a/livetrading/env +++ b/livetrading/env @@ -1,5 +1,5 @@ { "ws_url": "wss://ws-feed.exchange.coinbase.com", "api_url": "https://api.exchange.coinbase.com/", -"ws_timeout": 5} +"ws_timeout": 5 } diff --git a/livetrading/event.py b/livetrading/event.py index ab3c4acd..3c755b62 100644 --- a/livetrading/event.py +++ b/livetrading/event.py @@ -7,6 +7,26 @@ from typing import Optional +intervals = { + "1s": 1, + "1m": 60, + "3m": 3 * 60, + "5m": 5 * 60, + "15m": 15 * 60, + "30m": 30 * 60, + "1h": 3600, + "2h": 2 * 3600, + "4h": 4 * 3600, + "6h": 6 * 3600, + "8h": 8 * 3600, + "12h": 12 * 3600, + "1d": 86400, + "3d": 3 * 86400, + "1w": 7 * 86400, + "1M": 31 * 86400 +} + + @dataclasses.dataclass class Bar: """A Bar, aka candlestick, is the summary of the trading activity in a given period. @@ -150,14 +170,16 @@ class TickersEventSource(ChannelEventSource): :param pair: The trading pair. """ - def __init__(self, pair: Pair, producer: EventProducer): + def __init__(self, pair: Pair, when: datetime, producer: EventProducer): super().__init__(producer=producer) self.pair: Pair = pair + self.when = intervals.get(when) def push_to_queue(self, message: dict): timestamp = message["time"] + dt = isoparse(timestamp) + datetime.timedelta(seconds=self.when) self.events.append(TickerEvent( - isoparse(timestamp), + dt, Ticker(self.pair, message))) @@ -199,7 +221,7 @@ class BarEvent(Event): def __init__(self, when, bar: Bar): super().__init__(when) - self.bar = bar + self.data = bar class TickerEvent(Event): @@ -211,4 +233,4 @@ class TickerEvent(Event): def __init__(self, when, ticker: Ticker): super().__init__(when) - self.ticker = ticker + self.data = ticker diff --git a/livetrading/executor.py b/livetrading/executor.py index 5c85be8b..3fe4f5d0 100644 --- a/livetrading/executor.py +++ b/livetrading/executor.py @@ -1,9 +1,11 @@ import time import datetime import logging - +from functools import partial from typing import Any, Dict, List, Set, Optional +from backtesting import Backtest +from .converter import ohlcv_to_dataframe from .event import Event, EventSource, EventProducer logger = logging.getLogger(__name__) @@ -13,7 +15,7 @@ class EventDispatcher: """Responsible for connecting event sources to event handlers and dispatching events in the right order. """ - def __init__(self): + def __init__(self, strategy): self._event_handlers: Dict[EventSource, List[Any]] = {} self._prefetched_events: Dict[EventSource, Optional[Event]] = {} self._prev_events: Dict[EventSource, datetime.datetime] = {} @@ -21,9 +23,21 @@ def __init__(self): self._running = False self._stopped = False self._current_event_dt = None + self.strategy = strategy + self.backtesting = None def set_strategy(self, strategy): - self.strategy = strategy + self._strategy = strategy + + def set_backtesting_partial(self, cash: float = 10_000, + commission: float = .0, + margin: float = 1., + trade_on_close=False, + hedging=False, + exclusive_orders=False): + self.backtesting = partial(Backtest, strategy=self.strategy, cash=cash, commission=commission, + margin=margin, trade_on_close=trade_on_close, + hedging=hedging, exclusive_orders=exclusive_orders) def subscribe(self, source: EventSource, event_handler: Any): """Registers an callable that will be called when an event source has new events. @@ -64,6 +78,10 @@ def _dispatch_next(self, ge_or_assert: Optional[datetime.datetime]): ] for source in sources_to_pop: if source.events: + df = ohlcv_to_dataframe([event.data for event in source.events]) + bt = self.backtesting(data=df) + bt.run() + event = source.events.pop() # Check that events from the same source are returned in order. prev_event = self._prev_events.get(source) @@ -92,8 +110,6 @@ def _dispatch_next(self, ge_or_assert: Optional[datetime.datetime]): self._prefetched_events[source] = None self._current_event_dt = None - self.strategy.next() - self.strategy.init() return next_dt def stop(self): diff --git a/livetrading/live_trading.py b/livetrading/live_trading.py index 545413a8..b8a7c0a7 100644 --- a/livetrading/live_trading.py +++ b/livetrading/live_trading.py @@ -1,32 +1,34 @@ +import pandas as pd import websocket from backtesting import Strategy -from backtesting._util import _Data from livetrading import executor from livetrading.broker import Broker, Pair from livetrading.config import config -from livetrading.converter import ohlcv_to_dataframe + + +def SMA(arr: pd.Series, n: int) -> pd.Series: + """ + Returns `n`-period simple moving average of array `arr`. + """ + return pd.Series(arr).rolling(n).mean() class LiveStrategy(Strategy): + n1 = 10 + n2 = 20 - def __init__(self, broker): - super().__init__(broker=broker, data=[], params={}) - self.event_data = [] + def __init__(self, broker, data, params): + super().__init__(broker=broker, data=data, params=params) def init(self): - super().init() - self.set_atr_periods() + sma1 = self.I(SMA, self.data.Close, self.n1) + sma2 = self.I(SMA, self.data.Close, self.n2) def set_atr_periods(self): if len(self.data) > 1: print(self.data.High, self.data.Low) - def on_bar_event(self, event): - self.event_data.append(event.ticker) - event_df = ohlcv_to_dataframe(self.event_data) - self._data = _Data(event_df.copy(deep=False)) - def next(self): print(self.data) @@ -46,7 +48,7 @@ def on_event(self, bar_event): websocket.enableTrace(False) - event_dis = executor.EventDispatcher() + event_dis = executor.EventDispatcher(LiveStrategy) exchange = Broker(event_dis, config=config) @@ -54,11 +56,13 @@ def on_event(self, bar_event): position_mgr = PositionManager(exchange, 0.8) - strategy = LiveStrategy(exchange) + strategy = LiveStrategy(exchange, [], {}) exchange.subscribe_to_ticker_events(Pair(base_symbol="UTC", quote_symbol="SDT"), - strategy.on_bar_event) + '3m', position_mgr.on_event) event_dis.set_strategy(strategy) + event_dis.set_backtesting_partial(cash=100000) + event_dis.run()