シンプルな分散強化学習フレームワークを目指して作成しています。
以下の特徴があります。
- 分散強化学習のサポート
- 環境とアルゴリズム間のインタフェースの自動調整
- Gym/Gymnasiumの環境に対応
- カスタマイズ可能な環境クラスの提供
- カスタマイズ可能な強化学習アルゴリズムクラスの提供
- 有名な強化学習アルゴリズムの提供
- (新しいアルゴリズムへの対応)
ドキュメント
https://pocokhc.github.io/simple_distributed_rl/
アルゴリズムの解説記事(Qiita)
https://qiita.com/pocokhc/items/a2f1ba993c79fdbd4b4d
必須ライブラリ
pip install numpy
その他のライブラリ
使う機能によって以下のライブラリが必要になります。
- Tensorflow が必要なアルゴリズムを使用する場合
- tensorflow
- tensorflow-probability
- Torch が必要なアルゴリズムを使用する場合
- 画像関係の機能を使用する場合
- pillow
- opencv-python
- pygame
- historyによる統計情報を扱う場合
- pandas
- matplotlib
- OpenAI Gym の環境を使用する場合
- gym or gymnasium
- pygame
- ハードウェアの統計情報を表示する場合
- psutil
- pynvml
- クラウド/ネットワークによる分散学習を使用する場合
- redis
- pika
- paho-mqtt
- 学習を管理する場合
- mlflow
一括でインストールするコマンド例は以下です。(Tensorflow、Torch、クラウド分散学習用ライブラリ、mlflowを除く)
pip install matplotlib pillow opencv-python pygame pandas gymnasium psutil pynvml
本フレームワークはGitHubからインストールまたはダウンロードをして使う事ができます。
インストール
pip install git+https://github.com/pocokhc/simple_distributed_rl
or
git clone https://github.com/pocokhc/simple_distributed_rl.git
cd simple_distributed_rl
pip install .
ダウンロード
srlディレクトリに実行パスが通っていればダウンロードだけでも使えます。
# SRLのダウンロード
git clone https://github.com/pocokhc/simple_distributed_rl.git
# SRLのimport例
import os
import sys
assert os.path.isdir("./simple_distributed_rl/srl/") # SRLのダウンロード場所
sys.path.insert(0, "./simple_distributed_rl/")
import srl
print(srl.__version__)
簡単な使い方は以下です。
import srl
from srl.algorithms import ql # qlアルゴリズムのimport
def main():
# runnerの作成
runner = srl.Runner("Grid", ql.Config())
# 学習
runner.train(timeout=10)
# 学習結果の評価
rewards = runner.evaluate()
print(f"evaluate episodes: {rewards}")
# --- 可視化例
# (animation_save_gifの実行には "pip install opencv-python pillow pygame" が必要です)
runner.animation_save_gif("Grid.gif")
if __name__ == "__main__":
main()
その他の使い方は以下ドキュメントを見てください。
(試験導入)MLFlowによる学習管理
examples/sample_mlflow.py にサンプルがあります。
- 逐次処理フロー
- 分散処理フロー
- 疑似コード
※実装側の動作に関してはアルゴリズムの作成方法を参照
# 学習単位の初期化
env.setup()
worker.setup()
trainer.setup()
for episode in range(N)
# 1エピソードの初期化
env.reset()
worker.reset()
# 1エピソードのループ
while not env.done:
env.render() # 環境の描画
# アクションを取得
action = worker.policy()
worker.render() # workerの描画
# 環境の1stepを実行
env.step(action)
worker.on_step()
# 学習
trainer.train()
# 終了後の描画
env.render()
worker.render()
# 学習単位の終了
env.teardown()
worker.teardown()
trainer.teardown()
自作の環境とアルゴリズムの作成に関しては以下ドキュメントを参考にしてください。
Algorithm | Observation | Action | Tensorflow | Torch | ProgressRate | |
---|---|---|---|---|---|---|
QL | Discrete | Discrete | - | - | 100% | Basic Q Learning |
DQN | Continuous | Discrete | ✔ | ✔ | 100% | |
C51 | Continuous | Discrete | ✔ | - | 99% | CategoricalDQN |
Rainbow | Continuous | Discrete | ✔ | ✔ | 100% | |
R2D2 | Continuous | Discrete | ✔ | - | 100% | |
Agent57 | Continuous | Discrete | ✔ | ✔ | 100% | |
SND | Continuous | Discrete | ✔ | - | 100% | |
Go-Explore | Continuous | Discrete | ✔ | - | 100% | DQN base, R2D3 memory base |
Algorithm | Observation | Action | Tensorflow | Torch | ProgressRate | |
---|---|---|---|---|---|---|
VanillaPolicy | Discrete | Both | - | - | 100% | |
A3C/A2C | - | - | - | - | - | |
TRPO | - | - | - | - | - | |
PPO | Continuous | Both | ✔ | - | 100% | |
DDPG/TD3 | Continuous | Continuous | ✔ | - | 100% | |
SAC | Continuous | Both | ✔ | - | 100% |
Algorithm | Observation | Action | Tensorflow | Torch | ProgressRate | |
---|---|---|---|---|---|---|
MCTS | Discrete | Discrete | - | - | 100% | MDP base |
AlphaZero | Image | Discrete | ✔ | - | 100% | MDP base |
MuZero | Image | Discrete | ✔ | - | 100% | MDP base |
StochasticMuZero | Image | Discrete | ✔ | - | 100% | MDP base |
Algorithm | Observation | Action | Framework | ProgressRate |
---|---|---|---|---|
DynaQ | Discrete | Discrete | - | 100% |
Algorithm | Observation | Action | Tensorflow | Torch | ProgressRate | |
---|---|---|---|---|---|---|
WorldModels | Continuous | Discrete | ✔ | - | 100% | |
PlaNet | Continuous | Discrete | ✔(+tensorflow-probability) | - | 100% | |
Dreamer | Continuous | Both | - | - | merge DreamerV3 | |
DreamerV2 | Continuous | Both | - | - | merge DreamerV3 | |
DreamerV3 | Continuous | Both | ✔(+tensorflow-probability) | - | 100% |
Algorithm | Observation | Action | Framework | ProgressRate |
---|---|---|---|---|
CQL | Discrete | Discrete | 0% |
Algorithm | Observation | Action | Type | Tensorflow | Torch | ProgressRate | |
---|---|---|---|---|---|---|---|
QL_agent57 | Discrete | Discrete | ValueBase | - | - | 80% | QL + Agent57 |
Agent57_light | Continuous | Discrete | ValueBase | ✔ | ✔ | 100% | Agent57 - (LSTM,MultiStep) |
SearchDynaQ | Discrete | Discrete | ModelBase | - | - | 100% | original |
GoDynaQ | Discrete | Discrete | ModelBase | - | - | 99% | original |
GoDQN | Continuous | Discrete | ValueBase | ✔ | - | 90% | original |
ネットワーク経由での分散学習は以下のドキュメントを参照してください。
またクラウドサービスとの連携はQiita記事を参照
Look dockers folder
- PC1
- windows11
- CPUx1: Core i7-8700 3.2GHz
- GPUx1: NVIDIA GeForce GTX 1060 3GB
- memory 48GB
- PC2
- windows11
- CPUx1: Core i9-12900 2.4GHz
- GPUx1: NVIDIA GeForce RTX 3060 12GB
- memory 32GB