forked from uber-archive/plato-research-dialogue-system
-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathrunDSTC2DataParser.py
114 lines (82 loc) · 3.5 KB
/
runDSTC2DataParser.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
"""
Copyright (c) 2019 Uber Technologies, Inc.
Licensed under the Uber Non-Commercial License (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at the root directory of this project.
See the License for the specific language governing permissions and
limitations under the License.
"""
__author__ = "Alexandros Papangelis"
import sys
from Data import Parse_DSTC2
from DialogueManagement.DialoguePolicy.DeepLearning.SupervisedPolicy import \
SupervisedPolicy
from Domain.DataBase import SQLDataBase
from Domain.Ontology import Ontology
from Utilities.DialogueEpisodeRecorder import DialogueEpisodeRecorder
"""
This script runs the DSTC2 Data Parser and trains a Supervised DialoguePolicy
for the user and the system respectively.
"""
if __name__ == '__main__':
"""
This script will create a DSCT2-specific data parser, and run it. It will
then load the parsed experience and train two supervised dialogue policies,
one for the system and one for the user.
"""
if len(sys.argv) < 3:
raise AttributeError('Please provide a path to the DSTC2 data. For'
'example: .../DSTC2/dstc2_traindev/data/')
if sys.argv[1] == '-data_path':
data_path = sys.argv[2]
else:
raise TypeError(f'Incorrect option: {sys.argv[1]}')
parser = Parse_DSTC2.Parser()
# Default values
ontology_path = 'Domain/Domains/CamRestaurants-rules.json'
database_path = 'Domain/Domains/CamRestaurants-dbase.db'
if len(sys.argv) > 2:
if sys.argv[1] == '-data':
data_path = sys.argv[2]
ontology = Ontology(ontology_path)
database = SQLDataBase(database_path)
args = {'path': data_path,
'ontology': ontology,
'database': database}
parser.initialize(**args)
print('Parsing {0}'.format(args['path']))
parser.parse_data()
print('Data parsing complete.')
# Save data
parser.save('Logs')
# Load data
recorder_sys = DialogueEpisodeRecorder(path='Logs/DSTC2_system')
recorder_usr = DialogueEpisodeRecorder(path='Logs/DSTC2_user')
# Train Supervised Models using the recorded data
system_policy_supervised = SupervisedPolicy(ontology,
database,
agent_role='system',
agent_id=0,
domain='CamRest')
user_policy_supervised = SupervisedPolicy(ontology,
database,
agent_role='user',
agent_id=1,
domain='CamRest')
# Set learning rate and number of epochs
learning_rate = 0.02
epochs = 100
system_policy_supervised.initialize(**{
'is_training': True,
'policy_path': 'Models/CamRestPolicy/Sys/sys_supervised_data',
'learning_rate': learning_rate})
user_policy_supervised.initialize(**{
'is_training': True,
'policy_path': 'Models/CamRestPolicy/Usr/usr_supervised_data',
'learning_rate': learning_rate})
for epoch in range(1, epochs):
print(f'\nTraining epoch {epoch}\n')
user_policy_supervised.train(recorder_usr.dialogues)
system_policy_supervised.train(recorder_sys.dialogues)
system_policy_supervised.save()
user_policy_supervised.save()