-
Notifications
You must be signed in to change notification settings - Fork 272
/
Copy path07_custom.py
132 lines (109 loc) · 3.74 KB
/
07_custom.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
"""
Optional: Change where pretrained models from huggingface will be downloaded (cached) to:
export TRANSFORMERS_CACHE=/whatever/path/you/want
"""
# import os
# os.environ["TRANSFORMERS_CACHE"] = "/media/samuel/UDISK1/transformers_cache"
import os
import time
import torch
from dotenv import load_dotenv
from langchain.llms.base import LLM
from llama_index import (
GPTListIndex,
LLMPredictor,
PromptHelper,
ServiceContext,
SimpleDirectoryReader,
)
from transformers import pipeline
# load_dotenv()
os.environ["OPENAI_API_KEY"] = "random"
def timeit():
"""
a utility decoration to time running time
"""
def decorator(func):
def wrapper(*args, **kwargs):
start = time.time()
result = func(*args, **kwargs)
end = time.time()
args = [str(arg) for arg in args]
print(f"[{(end - start):.8f} seconds]: f({args}) -> {result}")
return result
return wrapper
return decorator
prompt_helper = PromptHelper(
# maximum input size
max_input_size=2048,
# number of output tokens
num_output=256,
# the maximum overlap between chunks.
max_chunk_overlap=20,
)
class LocalOPT(LLM):
# model_name = "facebook/opt-iml-max-30b" (this is a 60gb model)
model_name = "facebook/opt-iml-1.3b" # ~2.63gb model
# https://huggingface.co/docs/transformers/main_classes/pipelines
pipeline = pipeline(
"text-generation",
model=model_name,
device="cuda:0",
model_kwargs={"torch_dtype": torch.bfloat16},
)
def _call(self, prompt: str, stop=None) -> str:
response = self.pipeline(prompt, max_new_tokens=256)[0]["generated_text"]
# only return newly generated tokens
return response[len(prompt) :]
@property
def _identifying_params(self):
return {"name_of_model": self.model_name}
@property
def _llm_type(self):
return "custom"
@timeit()
def create_index():
print("Creating index")
# Wrapper around an LLMChain from Langchaim
llm = LLMPredictor(llm=LocalOPT())
# Service Context: a container for your llamaindex index and query
# https://gpt-index.readthedocs.io/en/latest/reference/service_context.html
service_context = ServiceContext.from_defaults(
llm_predictor=llm, prompt_helper=prompt_helper
)
docs = SimpleDirectoryReader("news").load_data()
index = GPTListIndex.from_documents(docs, service_context=service_context)
print("Done creating index", index)
return index
@timeit()
def execute_query():
response = index.query(
"Who does Indonesia export its coal to in 2023?",
# This will preemptively filter out nodes that do not contain required_keywords
# or contain exclude_keywords, reducing the search space and hence time/number of LLM calls/cost.
exclude_keywords=["petroleum"],
# required_keywords=["coal"],
# exclude_keywords=["oil", "gas", "petroleum"]
)
return response
if __name__ == "__main__":
"""
Check if a local cache of the model exists,
if not, it will download the model from huggingface
"""
if not os.path.exists("7_custom_opt.json"):
print("No local cache of model found, downloading from huggingface")
index = create_index()
index.save_to_disk("7_custom_opt.json")
else:
print("Loading local cache of model")
llm = LLMPredictor(llm=LocalOPT())
service_context = ServiceContext.from_defaults(
llm_predictor=llm, prompt_helper=prompt_helper
)
index = GPTListIndex.load_from_disk(
"7_custom_opt.json", service_context=service_context
)
response = execute_query()
print(response)
print(response.source_nodes)