Skip to content

Commit

Permalink
Addresses Issue langchain-ai#245
Browse files Browse the repository at this point in the history
Added `logger_util` to enable package and class wide logging in `langchain-aws`.
Added logging for `invoke` and `ainvoke`.
  • Loading branch information
Vishal Patil committed Jan 15, 2025
1 parent 94cb00b commit 61356da
Show file tree
Hide file tree
Showing 3 changed files with 80 additions and 3 deletions.
7 changes: 5 additions & 2 deletions libs/aws/langchain_aws/chat_models/bedrock.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,7 +37,7 @@
from langchain_core.tools import BaseTool
from langchain_core.utils.pydantic import TypeBaseModel, is_basemodel_subclass
from pydantic import BaseModel, ConfigDict, model_validator

import logging
from langchain_aws.chat_models.bedrock_converse import ChatBedrockConverse
from langchain_aws.function_calling import (
ToolsOutputParser,
Expand All @@ -53,7 +53,9 @@
get_num_tokens_anthropic,
get_token_ids_anthropic,
)
from langchain_aws.logger_util import get_logger

logger = get_logger(__name__)

def _convert_one_message_to_text_llama(message: BaseMessage) -> str:
if isinstance(message, ChatMessage):
Expand Down Expand Up @@ -520,6 +522,7 @@ def _generate(
run_manager: Optional[CallbackManagerForLLMRun] = None,
**kwargs: Any,
) -> ChatResult:
logger.info(f"The input message sent by user: {messages}")
if self.beta_use_converse_api:
return self._as_converse._generate(
messages, stop=stop, run_manager=run_manager, **kwargs
Expand Down Expand Up @@ -595,7 +598,7 @@ def _generate(
tool_calls=cast(List[ToolCall], tool_calls),
usage_metadata=usage_metadata,
)

logger.info(f"The output message sent by user: {msg}")
return ChatResult(
generations=[
ChatGeneration(
Expand Down
6 changes: 5 additions & 1 deletion libs/aws/langchain_aws/chat_models/bedrock_converse.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@
cast,
)

import logging
import boto3
from langchain_core.callbacks import CallbackManagerForLLMRun
from langchain_core.language_models import BaseChatModel, LanguageModelInput
Expand Down Expand Up @@ -52,10 +53,11 @@
from typing_extensions import Self

from langchain_aws.function_calling import ToolsOutputParser

from langchain_aws.logger_util import get_logger
_BM = TypeVar("_BM", bound=BaseModel)
_DictOrPydanticClass = Union[Dict[str, Any], Type[_BM], Type]

logger = get_logger(__name__)

class ChatBedrockConverse(BaseChatModel):
"""Bedrock chat model integration built on the Bedrock converse API.
Expand Down Expand Up @@ -495,13 +497,15 @@ def _generate(
) -> ChatResult:
"""Top Level call"""
bedrock_messages, system = _messages_to_bedrock(messages)
logger.debug(f"input message: {bedrock_messages}")
params = self._converse_params(
stop=stop, **_snake_to_camel_keys(kwargs, excluded_keys={"inputSchema"})
)
response = self.client.converse(
messages=bedrock_messages, system=system, **params
)
response_message = _parse_response(response)
logger.info(f"output message: {response_message}")
return ChatResult(generations=[ChatGeneration(message=response_message)])

def _stream(
Expand Down
70 changes: 70 additions & 0 deletions libs/aws/langchain_aws/logger_util.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,70 @@
import logging
import os
import sys

__DEBUG = True if os.environ.get('LANGCHAIN_AWS_DEBUG') else False
__ROOT_DEBUG = __DEBUG if os.environ.get('LANGCHAIN_AWS_DEBUG_ROOT') else False

if __DEBUG:
DEFAULT_LOG_LEVEL: int = logging.DEBUG
else:
DEFAULT_LOG_LEVEL: int = logging.ERROR

DEFAULT_LOG_FILE = os.environ.get('LANGCHAIN_AWS_LOG_OUTPUT', '-')
if DEFAULT_LOG_FILE == '-':
DEFAULT_LOG_HANDLER: logging.Handler = None
else:
DEFAULT_LOG_HANDLER: logging.Handler = logging.FileHandler(DEFAULT_LOG_FILE)

if __DEBUG:
DEFAULT_LOG_FORMAT: str = '%(asctime)s %(levelname)s | [%(filename)s:%(lineno)s] | - %(name)s - %(message)s'
else:
DEFAULT_LOG_FORMAT: str = '%(asctime)s %(levelname)s | %(name)s - %(message)s'

try:
import colorama
import coloredlogs

colorama.init()

if DEFAULT_LOG_HANDLER:
DEFAULT_LOG_FORMATTER: logging.Formatter = logging.Formatter(DEFAULT_LOG_FORMAT)
else:
DEFAULT_LOG_FORMATTER: logging.Formatter = coloredlogs.ColoredFormatter(DEFAULT_LOG_FORMAT)

except ImportError:
colorama = None
coloredlogs = None

DEFAULT_LOG_FORMATTER: logging.Formatter = logging.Formatter(DEFAULT_LOG_FORMAT)

def get_logger(logger_name: str=None, log_handler: logging.Handler = DEFAULT_LOG_HANDLER,
log_formatter: logging.Formatter = DEFAULT_LOG_FORMATTER,
log_level: int = DEFAULT_LOG_LEVEL) -> logging.Logger:
'''
Define a logger with the passed module_name at module level or function level
'''
logger = logging.getLogger(logger_name)

# Do not customize root logger.
if logger_name or __ROOT_DEBUG:
if not log_handler:
log_handler = logging.StreamHandler()

# add formatter to handler
log_handler.setFormatter(log_formatter)

if log_handler:
try:
logger.removeHandler(log_handler)
except:
pass
logger.addHandler(log_handler)

if logger_name:
logger.propagate = False

logger.setLevel(log_level)
return logger

ROOT_LOGGER = get_logger()

0 comments on commit 61356da

Please sign in to comment.