From 3bbbdcfef8376d33f3d233c90726c690aa1d83f5 Mon Sep 17 00:00:00 2001 From: Mo Li <82895469+DseidLi@users.noreply.github.com> Date: Sun, 14 Apr 2024 22:24:25 +0800 Subject: [PATCH] Fix `overrides` training cfg bug (#10002) Co-authored-by: UltralyticsAssistant Co-authored-by: Glenn Jocher Co-authored-by: Ultralytics AI Assistant <135830346+UltralyticsAssistant@users.noreply.github.com> Co-authored-by: Laughing-q <1185102784@qq.com> --- ultralytics/engine/model.py | 7 ++++++- 1 file changed, 6 insertions(+), 1 deletion(-) diff --git a/ultralytics/engine/model.py b/ultralytics/engine/model.py index 09048bbf8e8..677361e2501 100644 --- a/ultralytics/engine/model.py +++ b/ultralytics/engine/model.py @@ -641,7 +641,12 @@ def train( checks.check_pip_update_available() overrides = yaml_load(checks.check_yaml(kwargs["cfg"])) if kwargs.get("cfg") else self.overrides - custom = {"data": DEFAULT_CFG_DICT["data"] or TASK2DATA[self.task]} # method defaults + custom = { + # NOTE: handle the case when 'cfg' includes 'data'. + "data": overrides.get("data") or DEFAULT_CFG_DICT["data"] or TASK2DATA[self.task], + "model": self.overrides["model"], + "task": self.task, + } # method defaults args = {**overrides, **custom, **kwargs, "mode": "train"} # highest priority args on the right if args.get("resume"): args["resume"] = self.ckpt_path