Skip to content

Commit

Permalink
drqn.rst APIs(#1)
Browse files Browse the repository at this point in the history
  • Loading branch information
Ykizi committed Nov 30, 2023
1 parent b285574 commit c9a2368
Showing 1 changed file with 96 additions and 10 deletions.
106 changes: 96 additions & 10 deletions docs/source/documents/api/learners/drl/qrdqn.rst
Original file line number Diff line number Diff line change
@@ -1,20 +1,52 @@
QRDQN_Learner
=====================================
DDQN_Learner
=====================================

.. raw:: html

<br><hr>

**PyTorch:**

.. py:class::
xuance.torch.learners.qlearning_family.drqn_learner.DRQN_Learner(policy, optimizer, scheduler, device, model_dir, gamma, sync_frequency)

:param policy: xxxxxx.
:type policy: xxxxxx
:param optimizer: xxxxxx.
:type optimizer: xxxxxx
:param scheduler: xxxxxx.
:type scheduler: xxxxxx
:param device: xxxxxx.
:type device: xxxxxx
:param model_dir: xxxxxx.
:type model_dir: xxxxxx
:param gamma: xxxxxx.
:type gamma: xxxxxx
:param sync_frequency: xxxxxx.
:type sync_frequency: xxxxxx

.. py:function::
xuance.torch.learners.qlearning_family.drqn_learner.DRQN_Learner.update(obs_batch, act_batch, rew_batch, terminal_batch)

:param obs_batch: xxxxxx.
:type obs_batch: xxxxxx
:param act_batch: xxxxxx.
:type act_batch: xxxxxx
:param rew_batch: xxxxxx.
:type rew_batch: xxxxxx
:param terminal_batch: xxxxxx.
:type terminal_batch: xxxxxx
:return: xxxxxx.
:rtype: xxxxxx

.. raw:: html

<br><hr>

**TensorFlow:**


.. raw:: html

<br><hr>
Expand All @@ -29,21 +61,75 @@ Source Code
-----------------

.. tabs::

.. group-tab:: PyTorch

.. code-block:: python3

.. group-tab:: PyTorch

.. code-block:: python
from xuance.torch.learners import *
class DRQN_Learner(Learner):
def __init__(self,
policy: nn.Module,
optimizer: torch.optim.Optimizer,
scheduler: Optional[torch.optim.lr_scheduler._LRScheduler] = None,
device: Optional[Union[int, str, torch.device]] = None,
model_dir: str = "./",
gamma: float = 0.99,
sync_frequency: int = 100):
self.gamma = gamma
self.sync_frequency = sync_frequency
super(DRQN_Learner, self).__init__(policy, optimizer, scheduler, device, model_dir)
def update(self, obs_batch, act_batch, rew_batch, terminal_batch):
self.iterations += 1
act_batch = torch.as_tensor(act_batch, device=self.device)
rew_batch = torch.as_tensor(rew_batch, device=self.device)
ter_batch = torch.as_tensor(terminal_batch, device=self.device, dtype=torch.float)
batch_size = obs_batch.shape[0]
rnn_hidden = self.policy.init_hidden(batch_size)
_, _, evalQ, _ = self.policy(obs_batch[:, 0:-1], *rnn_hidden)
target_rnn_hidden = self.policy.init_hidden(batch_size)
_, targetA, targetQ, _ = self.policy.target(obs_batch[:, 1:], *target_rnn_hidden)
# targetQ = targetQ.max(dim=-1).values
targetA = F.one_hot(targetA, targetQ.shape[-1])
targetQ = (targetQ * targetA).sum(dim=-1)
targetQ = rew_batch + self.gamma * (1 - ter_batch) * targetQ
predictQ = (evalQ * F.one_hot(act_batch.long(), evalQ.shape[-1])).sum(dim=-1)
loss = F.mse_loss(predictQ, targetQ)
self.optimizer.zero_grad()
loss.backward()
self.optimizer.step()
if self.scheduler is not None:
self.scheduler.step()
# hard update for target network
if self.iterations % self.sync_frequency == 0:
self.policy.copy_target()
lr = self.optimizer.state_dict()['param_groups'][0]['lr']
info = {
"Qloss": loss.item(),
"learning_rate": lr,
"predictQ": predictQ.mean().item()
}
return info
.. group-tab:: TensorFlow

.. code-block:: python3
.. group-tab:: TensorFlow

.. code-block:: python
.. group-tab:: MindSpore
.. group-tab:: MindSpore

.. code-block:: python3
.. code-block:: python

0 comments on commit c9a2368

Please sign in to comment.