-
Notifications
You must be signed in to change notification settings - Fork 1
/
mwe_DSPy.py
62 lines (46 loc) · 1.84 KB
/
mwe_DSPy.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
# minimum working example for DSPy
import os
import sys
import json
import requests
from typing import AsyncGenerator, Literal, Dict, Generator, List, Optional, Union
from typing_extensions import Literal
import openai
from unify.clients import Unify as UnifyClient
import dsp
import dspy
from dspy.evaluate import Evaluate
from dspy.teleprompt import BootstrapFewShot
from dsp import Unify
from dspy.datasets.gsm8k import GSM8K, gsm8k_metric
# Set up the LM.
class Model_Unify(dsp.Unify):
def __call__(self, *args, **kwargs):
# Implement the method here
print("Called with", args, kwargs)
return "Result"
model = Model_Unify(endpoint='claude-3-haiku@antrophic', max_tokens=250, api_key="YOUR_API_KEY")
dspy.settings.configure(lm=model)
# Load math questions from the GSM8K dataset.
gsm8k = GSM8K()
gsm8k_trainset, gsm8k_devset = gsm8k.train[:10], gsm8k.dev[:10]
print(gsm8k_trainset)
print("--_"*20)
class CoT(dspy.Module):
def __init__(self):
super().__init__()
self.prog = dspy.ChainOfThought("question -> answer")
def forward(self, question):
return self.prog(question=question)
# Set up the optimizer: we want to "bootstrap" (i.e., self-generate) 4-shot examples of our CoT program.
config = dict(max_bootstrapped_demos=4, max_labeled_demos=4)
# Optimize! Use the `gsm8k_metric` here. In general, the metric is going to tell the optimizer how well it's doing.
teleprompter = BootstrapFewShot(metric=gsm8k_metric, **config)
optimized_cot = teleprompter.compile(CoT(), trainset=gsm8k_trainset)
print("--_"*20)
# Set up the evaluator, which can be used multiple times.
evaluate = Evaluate(devset=gsm8k_devset, metric=gsm8k_metric, num_threads=4, display_progress=True, display_table=0)
# Evaluate our `optimized_cot` program.
evaluate(optimized_cot)
print("--_"*20)
model.inspect_history(n=1)