From fe75d2bcb7859993fffea290bed15e6d0e1da01a Mon Sep 17 00:00:00 2001 From: JulioJerez Date: Fri, 27 Sep 2024 18:43:30 -0700 Subject: [PATCH] do no kill robot if it hit a limit, jut penalize the reward. --- .../demos/ndAdvancedIndustrialRobot.cpp | 50 +++++++++---------- 1 file changed, 24 insertions(+), 26 deletions(-) diff --git a/newton-4.00/applications/ndSandbox/demos/ndAdvancedIndustrialRobot.cpp b/newton-4.00/applications/ndSandbox/demos/ndAdvancedIndustrialRobot.cpp index 8667d5966..003b8d800 100644 --- a/newton-4.00/applications/ndSandbox/demos/ndAdvancedIndustrialRobot.cpp +++ b/newton-4.00/applications/ndSandbox/demos/ndAdvancedIndustrialRobot.cpp @@ -251,14 +251,12 @@ namespace ndAdvancedRobot //#pragma optimize( "", off ) void SaveTrajectory() { - if (IsTerminal()) + ndInt32 index = m_trajectory.GetCount() - 1; + if (m_trajectory.GetReward(index) == ND_DEAD_PENALTY) { - ndInt32 index = m_trajectory.GetCount() - 1; - if (m_trajectory.GetReward(index) != ND_DEAD_PENALTY) - { - m_trajectory.SetReward(index, ND_DEAD_PENALTY); - } + m_trajectory.SetReward(index, ND_DEAD_PENALTY * 4.0f); } + ndBrainAgentContinuePolicyGradient_Trainer::SaveTrajectory(); } @@ -445,26 +443,6 @@ namespace ndAdvancedRobot return true; } - if (m_leftGripper->GetJointHitLimits()) - { - return true; - } - - if (m_rightGripper->GetJointHitLimits()) - { - return true; - } - - if (m_arm_0->GetJointHitLimits()) - { - return true; - } - - if (m_arm_1->GetJointHitLimits()) - { - return true; - } - const ndModelArticulation* const model = GetModel()->GetAsModelArticulation(); for (ndModelArticulation::ndNode* node = model->GetRoot()->GetFirstIterator(); node; node = node->GetNextIterator()) { @@ -520,6 +498,26 @@ namespace ndAdvancedRobot return ND_DEAD_PENALTY; } + if (m_leftGripper->GetJointHitLimits()) + { + return ND_DEAD_PENALTY; + } + + if (m_rightGripper->GetJointHitLimits()) + { + return ND_DEAD_PENALTY; + } + + if (m_arm_0->GetJointHitLimits()) + { + return ND_DEAD_PENALTY; + } + + if (m_arm_1->GetJointHitLimits()) + { + return ND_DEAD_PENALTY; + } + const ndMatrix effectorMatrix(m_effectorLocalTarget * m_arm_4->GetBody0()->GetMatrix()); const ndMatrix baseMatrix(m_effectorLocalBase * m_base_rotator->GetBody1()->GetMatrix()); const ndMatrix currentEffectorMatrix(effectorMatrix * baseMatrix.OrthoInverse());