-
Notifications
You must be signed in to change notification settings - Fork 11
/
Copy pathbook_summary.py
157 lines (117 loc) · 5.02 KB
/
book_summary.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
import os
import sys
import argparse
from os.path import join
from tools import *
import logging
from core.api import set_api_logger, KEY_MANAGER
from core.book import SummaryBot, SummaryTurn, set_chat_logger
from utils.spliter import BookSpliter
from prompts.book import *
args: argparse.Namespace = None
bot: SummaryBot = None
def get_concat_input(user_str, pre_sre, hist_str=None):
lang2template = {
LANG_EN: en_agent_scm_prompt,
LANG_ZH: zh_agent_scm_prompt
}
template: str = choose_language_template(lang2template, user_str)
current_text = user_str
previous_content = pre_sre
if hist_str:
previous_content = f"{hist_str}\n\n{pre_sre}"
input_text = template.format(previous_content=previous_content, current_text=current_text)
return input_text
def check_key_file(key_file):
if not os.path.exists(key_file):
print(f'[{key_file}] not found! Please put your apikey in the txt file.')
sys.exit(-1)
def get_first_prompt(user_text):
lang2template = {
LANG_EN: en_start_prompt,
LANG_ZH: zh_start_prompt
}
tmp = choose_language_template(lang2template, user_text)
concat_input = tmp.format(text=user_text)
return concat_input
def get_paragragh_prompt(user_text):
lang2template = {
LANG_EN: en_agent_no_scm_prompt,
LANG_ZH: zh_agent_no_scm_prompt
}
tmp = choose_language_template(lang2template, user_text)
concat_input = tmp.format(text=user_text)
return concat_input
def summarize_book(book_file, model_name, scm=True):
global args
global bot
bot.clear_history()
spliter = BookSpliter(model_name)
paragraphs = spliter.split(book_file)
total = len(paragraphs)
for i, text in enumerate(paragraphs):
concat_input = ''
if scm:
if i == 0:
concat_input = get_first_prompt(text)
else:
pre_info = bot.get_turn_for_previous()
concat_input = get_concat_input(text, pre_info)
else:
concat_input = get_paragragh_prompt(text)
logger.info(f'\n--------------\n[第{i+1}/{total}轮] book_file: {book_file} model_name:{model_name}; USE SCM: {scm} \n\nconcat_input:\n\n{concat_input}\n--------------\n')
print(f'\n--------------\n[第{i+1}/{total}轮] book_file: {book_file} model_name:{model_name}; USE SCM: {scm}\n--------------\n')
summary: str = bot.ask(concat_input).strip()
logger.info(f"model_name:{model_name}; USE SCM: {scm}; Summary:\n\n{summary}\n\n")
# embedding = bot.vectorize(summary)
# just book summarization do not need embedding
embedding = None
cur_turn = SummaryTurn(paragraph=text, summary=summary, embedding=embedding)
bot.add_turn_history(cur_turn)
logger.info(f"model_name:{model_name}; USE SCM: {scm}; Processing: {i+1}/{total}; add_turn_history is done!")
logger.info(f"First Level Summarization Done!")
final_summary = bot.get_final_summary()
logger.info(f"final_summary:\n\n{final_summary}")
suffix = ''
if scm is False:
suffix = 'no_scm'
bot.export_history(book_file, suffix)
bot.clear_history()
if __name__ == '__main__':
parser = argparse.ArgumentParser()
model_choices = [ENGINE_TURBO, ENGINE_DAVINCI_003]
parser.add_argument("--apikey_file", type=str, default="./config/apikey.txt")
parser.add_argument("--model_name", type=str, default=ENGINE_DAVINCI_003, choices=model_choices)
parser.add_argument("--book_files", nargs='+', type=str, required=True)
parser.add_argument("--logfile", type=str, default="./logs/book.summary.log.txt")
parser.add_argument('--no_scm', action='store_true', help='do not use historical memory, default is False')
args = parser.parse_args()
check_key_file(args.apikey_file)
log_path = args.logfile
makedirs(log_path)
# 配置日志记录
logger = logging.getLogger('summary_logger')
logger.setLevel(logging.INFO)
formatter = logging.Formatter('【%(asctime)s - %(levelname)s】 - %(message)s', datefmt='%Y-%m-%d %H:%M:%S')
file_handler = logging.FileHandler(log_path, encoding='utf-8')
file_handler.setLevel(logging.INFO)
file_handler.setFormatter(formatter)
logger.addHandler(file_handler)
set_chat_logger(logger)
set_api_logger(logger)
logger.info('\n\n\n')
logger.info('#################################')
logger.info('#################################')
logger.info('#################################')
logger.info('\n\n\n')
logger.info(f"args: \n\n{args}\n")
book_list = args.book_files
# whether use scm for history memory
USE_SCM = False if args.no_scm else True
model_name = args.model_name
bot = SummaryBot(model_name=model_name)
for book_file in book_list:
book_name = os.path.basename(book_file)
logger.info(f'\n\n※※※ Begin Summarize Book : {book_name} ※※※\n\n')
summarize_book(book_file, model_name, scm=USE_SCM)
KEY_MANAGER.remove_deprecated_keys()