Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add first version of a Comet integration #1

Open
wants to merge 8 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion .isort.cfg
Original file line number Diff line number Diff line change
@@ -1,3 +1,3 @@
[settings]
profile=black
known_third_party=wandb
known_third_party=wandb,comet_ml
18 changes: 17 additions & 1 deletion README.md
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@ Features:
- Integrated with xformer, flash attention, [liger kernel](https://github.com/linkedin/Liger-Kernel), rope scaling, and multipacking
- Works with single GPU or multiple GPUs via FSDP or Deepspeed
- Easily run with Docker locally or on the cloud
- Log results and optionally checkpoints to wandb or mlflow
- Log results and optionally checkpoints to wandb, mlflow or Comet
- And more!

<a href="https://www.phorm.ai/query?projectId=e315ba4a-4e14-421f-ab05-38a1f9076f25">
Expand Down Expand Up @@ -515,6 +515,22 @@ wandb_name:
wandb_log_model:
```

##### Comet Logging

Make sure your `COMET_API_KEY` environment variable is set (recommended) or you login to wandb with `comet login`.

- wandb options
```yaml
use_comet:
comet_api_key:
comet_workspace:
comet_project_name:
comet_experiment_key:
comet_mode:
comet_online:
comet_experiment_config:
```

##### Special Tokens

It is important to have special tokens like delimiters, end-of-sequence, beginning-of-sequence in your tokenizer's vocabulary. This will help you avoid tokenization issues and help your model train better. You can do this in axolotl like this:
Expand Down
12 changes: 12 additions & 0 deletions docs/config.qmd
Original file line number Diff line number Diff line change
Expand Up @@ -267,6 +267,18 @@ mlflow_tracking_uri: # URI to mlflow
mlflow_experiment_name: # Your experiment name
hf_mlflow_log_artifacts: # set to true to copy each saved checkpoint on each save to mlflow artifact registry

# Comet configuration if you're using it
# Make sure your `COMET_API_KEY` environment variable is set (recommended) or you login to Comet with `comet login`.
# Check out our documentation for more details https://www.comet.com/docs/v2/api-and-sdk/python-sdk/reference/Experiment-Creation/#comet_ml.start
use_comet: # Enable or disable Comet integration.
comet_api_key: # API key for Comet. Recommended to set via `comet login`.
comet_workspace: # Workspace name in Comet. Defaults to the user's default workspace.
comet_project_name: # Project name in Comet. Defaults to Uncategorized.
comet_experiment_key: # Identifier for the experiment. Used to append data to an existing experiment or control the key of new experiments. Default to a random key.
comet_mode: # Create a new experiment ("create") or log to an existing one ("get"). Default ("get_or_create") auto-selects based on configuration.
comet_online: # Set to True to log data to Comet server, or False for offline storage. Default is True.
comet_experiment_config: # Dictionary for additional configuration settings, see the doc for more details.

# Where to save the full-finetuned model to
output_dir: ./completed-model

Expand Down
3 changes: 3 additions & 0 deletions src/axolotl/cli/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,7 @@
from axolotl.logging_config import configure_logging
from axolotl.train import TrainDatasetMeta
from axolotl.utils.chat_templates import chat_templates
from axolotl.utils.comet_ import setup_comet_env_vars
from axolotl.utils.config import (
normalize_cfg_datasets,
normalize_config,
Expand Down Expand Up @@ -421,6 +422,8 @@ def load_cfg(config: Union[str, Path] = Path("examples/"), **kwargs):

setup_mlflow_env_vars(cfg)

setup_comet_env_vars(cfg)

return cfg


Expand Down
15 changes: 14 additions & 1 deletion src/axolotl/core/trainer_builder.py
Original file line number Diff line number Diff line change
Expand Up @@ -48,7 +48,7 @@

from axolotl.monkeypatch.multipack import SUPPORTED_MULTIPACK_MODEL_TYPES
from axolotl.monkeypatch.relora import ReLoRACallback, ReLoRAScheduler
from axolotl.utils import is_mlflow_available
from axolotl.utils import is_comet_available, is_mlflow_available
from axolotl.utils.callbacks import (
EvalFirstStepCallback,
GPUStatsCallback,
Expand Down Expand Up @@ -1104,6 +1104,12 @@ def get_callbacks(self) -> List[TrainerCallback]:
callbacks.append(
SaveAxolotlConfigtoMlflowCallback(self.cfg.axolotl_config_path)
)
if self.cfg.use_comet and is_comet_available():
from axolotl.utils.callbacks.comet_ import SaveAxolotlConfigtoCometCallback

callbacks.append(
SaveAxolotlConfigtoCometCallback(self.cfg.axolotl_config_path)
)

return callbacks

Expand Down Expand Up @@ -1172,6 +1178,11 @@ def get_post_trainer_create_callbacks(self, trainer):
trainer, self.tokenizer, "mlflow"
)
callbacks.append(LogPredictionCallback(self.cfg))
if self.cfg.use_comet and is_comet_available() and self.cfg.eval_table_size > 0:
LogPredictionCallback = log_prediction_callback_factory(
trainer, self.tokenizer, "comet_ml"
)
callbacks.append(LogPredictionCallback(self.cfg))

if self.cfg.do_bench_eval:
callbacks.append(bench_eval_callback_factory(trainer, self.tokenizer))
Expand Down Expand Up @@ -1423,6 +1434,8 @@ def build(self, total_num_steps):
report_to.append("mlflow")
if self.cfg.use_tensorboard:
report_to.append("tensorboard")
if self.cfg.use_comet:
report_to.append("comet_ml")

training_arguments_kwargs["report_to"] = report_to
training_arguments_kwargs["run_name"] = (
Expand Down
6 changes: 5 additions & 1 deletion src/axolotl/utils/__init__.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,12 @@
"""
Basic utils for Axolotl
"""
import importlib
import importlib.util


def is_mlflow_available():
return importlib.util.find_spec("mlflow") is not None


def is_comet_available():
return importlib.util.find_spec("comet_ml") is not None
11 changes: 10 additions & 1 deletion src/axolotl/utils/callbacks/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,7 +29,7 @@
)
from transformers.trainer_utils import PREFIX_CHECKPOINT_DIR, IntervalStrategy

from axolotl.utils import is_mlflow_available
from axolotl.utils import is_comet_available, is_mlflow_available
from axolotl.utils.bench import log_gpu_memory_usage
from axolotl.utils.callbacks.perplexity import Perplexity
from axolotl.utils.config.models.input.v0_4_1 import AxolotlInputConfig
Expand Down Expand Up @@ -747,6 +747,15 @@ def log_table_from_dataloader(name: str, table_dataloader):
artifact_file="PredictionsVsGroundTruth.json",
tracking_uri=tracking_uri,
)
elif logger == "comet_ml" and is_comet_available():
import comet_ml

experiment = comet_ml.get_running_experiment()
if experiment:
experiment.log_table(
f"{name} - Predictions vs Ground Truth.csv",
pd.DataFrame(table_data),
)

if is_main_process():
log_table_from_dataloader("Eval", eval_dataloader)
Expand Down
43 changes: 43 additions & 0 deletions src/axolotl/utils/callbacks/comet_.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,43 @@
"""Comet module for trainer callbacks"""

import logging
from typing import TYPE_CHECKING

import comet_ml
from transformers import TrainerCallback, TrainerControl, TrainerState

from axolotl.utils.distributed import is_main_process

if TYPE_CHECKING:
from axolotl.core.trainer_builder import AxolotlTrainingArguments

LOG = logging.getLogger("axolotl.callbacks")


class SaveAxolotlConfigtoCometCallback(TrainerCallback):
"""Callback to save axolotl config to comet"""

def __init__(self, axolotl_config_path):
self.axolotl_config_path = axolotl_config_path

def on_train_begin(
self,
args: "AxolotlTrainingArguments", # pylint: disable=unused-argument
state: TrainerState, # pylint: disable=unused-argument
control: TrainerControl,
**kwargs, # pylint: disable=unused-argument
):
if is_main_process():
try:
comet_experiment = comet_ml.start(source="axolotl")
comet_experiment.log_other("Created from", "axolotl")
comet_experiment.log_asset(
self.axolotl_config_path,
file_name="axolotl-config",
)
LOG.info(
"The Axolotl config has been saved to the Comet Experiment under assets."
)
except (FileNotFoundError, ConnectionError) as err:
LOG.warning(f"Error while saving Axolotl config to Comet: {err}")
return control
93 changes: 93 additions & 0 deletions src/axolotl/utils/comet_.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,93 @@
"""Module for wandb utilities"""

import logging
import os

from axolotl.utils.dict import DictDefault

LOG = logging.getLogger("axolotl.utils.comet_")

COMET_ENV_MAPPING_OVERRIDE = {
"comet_mode": "COMET_START_MODE",
"comet_online": "COMET_START_ONLINE",
}
COMET_EXPERIMENT_CONFIG_ENV_MAPPING_OVERRIDE = {
"auto_histogram_activation_logging": "COMET_AUTO_LOG_HISTOGRAM_ACTIVATIONS",
"auto_histogram_epoch_rate": "COMET_AUTO_LOG_HISTOGRAM_EPOCH_RATE",
"auto_histogram_gradient_logging": "COMET_AUTO_LOG_HISTOGRAM_GRADIENTS",
"auto_histogram_tensorboard_logging": "COMET_AUTO_LOG_HISTOGRAM_TENSORBOARD",
"auto_histogram_weight_logging": "COMET_AUTO_LOG_HISTOGRAM_WEIGHTS",
"auto_log_co2": "COMET_AUTO_LOG_CO2",
"auto_metric_logging": "COMET_AUTO_LOG_METRICS",
"auto_metric_step_rate": "COMET_AUTO_LOG_METRIC_STEP_RATE",
"auto_output_logging": "COMET_AUTO_LOG_OUTPUT_LOGGER",
"auto_param_logging": "COMET_AUTO_LOG_PARAMETERS",
"comet_disabled": "COMET_AUTO_LOG_DISABLE",
"display_summary_level": "COMET_DISPLAY_SUMMARY_LEVEL",
"distributed_node_identifier": "COMET_DISTRIBUTED_NODE_IDENTIFIER",
"log_code": "COMET_AUTO_LOG_CODE",
"log_env_cpu": "COMET_AUTO_LOG_ENV_CPU",
"log_env_details": "COMET_AUTO_LOG_ENV_DETAILS",
"log_env_disk": "COMET_AUTO_LOG_ENV_DISK",
"log_env_gpu": "COMET_AUTO_LOG_ENV_GPU",
"log_env_host": "COMET_AUTO_LOG_ENV_HOST",
"log_env_network": "COMET_AUTO_LOG_ENV_NETWORK",
"log_git_metadata": "COMET_AUTO_LOG_GIT_METADATA",
"log_git_patch": "COMET_AUTO_LOG_GIT_PATCH",
"log_graph": "COMET_AUTO_LOG_GRAPH",
"name": "COMET_START_EXPERIMENT_NAME",
"offline_directory": "COMET_OFFLINE_DIRECTORY",
"parse_args": "COMET_AUTO_LOG_CLI_ARGUMENTS",
"tags": "COMET_START_EXPERIMENT_TAGS",
}


def python_value_to_environ_value(python_value):
if isinstance(python_value, bool):
if python_value is True:
return "true"

return "false"

if isinstance(python_value, int):
return str(python_value)

if isinstance(python_value, list): # Comet only have one list of string parameter
return ",".join(map(str, python_value))

return python_value


def setup_comet_env_vars(cfg: DictDefault):
# TODO, we need to convert Axolotl configuration to environment variables
# as Transformers integration are call first and would create an
# Experiment first

for key in cfg.keys():
if key.startswith("comet_") and key != "comet_experiment_config":
value = cfg.get(key, "")

if value is not None and value != "":
env_variable_name = COMET_ENV_MAPPING_OVERRIDE.get(key, key.upper())
final_value = python_value_to_environ_value(value)
os.environ[env_variable_name] = final_value

if cfg.comet_experiment_config:
for key, value in cfg.comet_experiment_config.items():
if value is not None and value != "":
config_env_variable_name = (
COMET_EXPERIMENT_CONFIG_ENV_MAPPING_OVERRIDE.get(key)
)

if config_env_variable_name is None:
LOG.warning(
f"Unknown Comet Experiment Config name {key}, ignoring it"
)
continue

final_value = python_value_to_environ_value(value)
os.environ[config_env_variable_name] = final_value

# Enable comet if project name is present
if cfg.comet_project_name and len(cfg.comet_project_name) > 0:
cfg.use_comet = True
14 changes: 14 additions & 0 deletions src/axolotl/utils/config/models/input/v0_4_1/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -484,6 +484,19 @@ def check_wandb_run(cls, data):
return data


class CometConfig(BaseModel):
"""Comet configuration subset"""

use_comet: Optional[bool] = None
comet_api_key: Optional[str] = None
comet_workspace: Optional[str] = None
comet_project_name: Optional[str] = None
comet_experiment_key: Optional[str] = None
comet_mode: Optional[str] = None
comet_online: Optional[bool] = None
comet_experiment_config: Optional[Dict[str, Any]] = None


class GradioConfig(BaseModel):
"""Gradio configuration subset"""

Expand All @@ -504,6 +517,7 @@ class AxolotlInputConfig(
HyperparametersConfig,
WandbConfig,
MLFlowConfig,
CometConfig,
LISAConfig,
GradioConfig,
RemappedParameters,
Expand Down
Loading