diff --git a/libs/aws/langchain_aws/chat_models/bedrock.py b/libs/aws/langchain_aws/chat_models/bedrock.py index 11cb01fb..093a76b8 100644 --- a/libs/aws/langchain_aws/chat_models/bedrock.py +++ b/libs/aws/langchain_aws/chat_models/bedrock.py @@ -37,7 +37,7 @@ from langchain_core.tools import BaseTool from langchain_core.utils.pydantic import TypeBaseModel, is_basemodel_subclass from pydantic import BaseModel, ConfigDict, model_validator - +import logging from langchain_aws.chat_models.bedrock_converse import ChatBedrockConverse from langchain_aws.function_calling import ( ToolsOutputParser, @@ -53,7 +53,9 @@ get_num_tokens_anthropic, get_token_ids_anthropic, ) +from langchain_aws.logger_util import get_logger +logger = get_logger(__name__) def _convert_one_message_to_text_llama(message: BaseMessage) -> str: if isinstance(message, ChatMessage): @@ -520,6 +522,7 @@ def _generate( run_manager: Optional[CallbackManagerForLLMRun] = None, **kwargs: Any, ) -> ChatResult: + logger.info(f"The input message sent by user: {messages}") if self.beta_use_converse_api: return self._as_converse._generate( messages, stop=stop, run_manager=run_manager, **kwargs @@ -595,7 +598,7 @@ def _generate( tool_calls=cast(List[ToolCall], tool_calls), usage_metadata=usage_metadata, ) - + logger.info(f"The output message sent by user: {msg}") return ChatResult( generations=[ ChatGeneration( diff --git a/libs/aws/langchain_aws/chat_models/bedrock_converse.py b/libs/aws/langchain_aws/chat_models/bedrock_converse.py index 4748d8b6..85111756 100644 --- a/libs/aws/langchain_aws/chat_models/bedrock_converse.py +++ b/libs/aws/langchain_aws/chat_models/bedrock_converse.py @@ -19,6 +19,7 @@ cast, ) +import logging import boto3 from langchain_core.callbacks import CallbackManagerForLLMRun from langchain_core.language_models import BaseChatModel, LanguageModelInput @@ -52,10 +53,11 @@ from typing_extensions import Self from langchain_aws.function_calling import ToolsOutputParser - +from langchain_aws.logger_util import get_logger _BM = TypeVar("_BM", bound=BaseModel) _DictOrPydanticClass = Union[Dict[str, Any], Type[_BM], Type] +logger = get_logger(__name__) class ChatBedrockConverse(BaseChatModel): """Bedrock chat model integration built on the Bedrock converse API. @@ -495,6 +497,7 @@ def _generate( ) -> ChatResult: """Top Level call""" bedrock_messages, system = _messages_to_bedrock(messages) + logger.debug(f"input message: {bedrock_messages}") params = self._converse_params( stop=stop, **_snake_to_camel_keys(kwargs, excluded_keys={"inputSchema"}) ) @@ -502,6 +505,7 @@ def _generate( messages=bedrock_messages, system=system, **params ) response_message = _parse_response(response) + logger.info(f"output message: {response_message}") return ChatResult(generations=[ChatGeneration(message=response_message)]) def _stream( diff --git a/libs/aws/langchain_aws/logger_util.py b/libs/aws/langchain_aws/logger_util.py new file mode 100644 index 00000000..3371d14f --- /dev/null +++ b/libs/aws/langchain_aws/logger_util.py @@ -0,0 +1,70 @@ +import logging +import os +import sys + +__DEBUG = True if os.environ.get('LANGCHAIN_AWS_DEBUG') else False +__ROOT_DEBUG = __DEBUG if os.environ.get('LANGCHAIN_AWS_DEBUG_ROOT') else False + +if __DEBUG: + DEFAULT_LOG_LEVEL: int = logging.DEBUG +else: + DEFAULT_LOG_LEVEL: int = logging.ERROR + +DEFAULT_LOG_FILE = os.environ.get('LANGCHAIN_AWS_LOG_OUTPUT', '-') +if DEFAULT_LOG_FILE == '-': + DEFAULT_LOG_HANDLER: logging.Handler = None +else: + DEFAULT_LOG_HANDLER: logging.Handler = logging.FileHandler(DEFAULT_LOG_FILE) + +if __DEBUG: + DEFAULT_LOG_FORMAT: str = '%(asctime)s %(levelname)s | [%(filename)s:%(lineno)s] | - %(name)s - %(message)s' +else: + DEFAULT_LOG_FORMAT: str = '%(asctime)s %(levelname)s | %(name)s - %(message)s' + +try: + import colorama + import coloredlogs + + colorama.init() + + if DEFAULT_LOG_HANDLER: + DEFAULT_LOG_FORMATTER: logging.Formatter = logging.Formatter(DEFAULT_LOG_FORMAT) + else: + DEFAULT_LOG_FORMATTER: logging.Formatter = coloredlogs.ColoredFormatter(DEFAULT_LOG_FORMAT) + +except ImportError: + colorama = None + coloredlogs = None + + DEFAULT_LOG_FORMATTER: logging.Formatter = logging.Formatter(DEFAULT_LOG_FORMAT) + +def get_logger(logger_name: str=None, log_handler: logging.Handler = DEFAULT_LOG_HANDLER, +log_formatter: logging.Formatter = DEFAULT_LOG_FORMATTER, +log_level: int = DEFAULT_LOG_LEVEL) -> logging.Logger: + ''' + Define a logger with the passed module_name at module level or function level + ''' + logger = logging.getLogger(logger_name) + + # Do not customize root logger. + if logger_name or __ROOT_DEBUG: + if not log_handler: + log_handler = logging.StreamHandler() + + # add formatter to handler + log_handler.setFormatter(log_formatter) + + if log_handler: + try: + logger.removeHandler(log_handler) + except: + pass + logger.addHandler(log_handler) + + if logger_name: + logger.propagate = False + + logger.setLevel(log_level) + return logger + +ROOT_LOGGER = get_logger()