-
Notifications
You must be signed in to change notification settings - Fork 1
/
Copy pathmake_env.py
51 lines (42 loc) · 2.38 KB
/
make_env.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
from envs.wrappers import *
def make_train_env(env_name, hier=False, num_training_tasks = 100, rng_seed = 0):
env = gym.make(env_name)
if env_name in ['PointPush-v0', 'CarPush-v0', 'DoggoPush-v0', 'CarPush-v1',
'PointGoal-v0', 'CarGoal-v0']:
if hier:
return WaitWrapper(HierWrapper(FixedSeedsWrapper(env, min_seed=1, max_seed=num_training_tasks, rng_seed=rng_seed)))
else:
return FixedSeedsWrapper(env, min_seed=1, max_seed=num_training_tasks, rng_seed=rng_seed)
elif env_name in ['PointTSP-v0', 'PointTSP-v1', 'PointTSP-v2', 'PointTSP-v3', 'PointTTSP-v0', 'PointTTSP-v1', 'CarTSP-v0', 'DoggoTSP-v0', 'ColourMatch-v0']:
if hier:
return WaitWrapper(ZoneWrapper(FixedSeedsWrapper(env, min_seed=1, max_seed=num_training_tasks, rng_seed=rng_seed)))
else:
return ZoneWrapper(FixedSeedsWrapper(env, min_seed=1, max_seed=num_training_tasks, rng_seed=rng_seed))
else:
raise RuntimeError("Unknown environment")
def make_test_env(env_name, hier=False, seed=1000):
env = gym.make(env_name)
env.seed(seed)
if env_name in ['PointPush-v0', 'CarPush-v0', 'DoggoPush-v0', 'CarPush-v1',
'PointGoal-v0', 'CarGoal-v0']:
if hier:
return HierWrapper(env)
else:
return env
elif env_name in ['PointTSP-v0', 'PointTSP-v1', 'PointTSP-v2', 'PointTSP-v3', 'PointTSP-v4', 'PointTSP-v5', 'PointTTSP-v0', 'PointTTSP-v1', 'CarTSP-v0', 'DoggoTSP-v0', 'ColourMatch-v0']:
return ZoneWrapper(env)
else:
raise RuntimeError("Unknown environment")
def make_fixed_env(env_name, hier=False, seed=1000, env_seed=0):
env = gym.make(env_name)
env.seed(seed)
if env_name in ['PointPush-v0', 'CarPush-v0', 'DoggoPush-v0', 'CarPush-v1',
'PointGoal-v0', 'CarGoal-v0']:
if hier:
return HierWrapper(FixedSeedsWrapper(env, min_seed=env_seed, max_seed=env_seed, rng_seed=seed))
else:
return FixedSeedsWrapper(env, min_seed=env_seed, max_seed=env_seed, rng_seed=seed)
elif env_name in ['PointTSP-v0', 'PointTSP-v1', 'PointTSP-v2', 'PointTSP-v3', 'PointTSP-v4', 'PointTSP-v5', 'PointTTSP-v0', 'PointTTSP-v1', 'CarTSP-v0', 'DoggoTSP-v0', 'ColourMatch-v0']:
return ZoneWrapper(FixedSeedsWrapper(env, min_seed=env_seed, max_seed=env_seed, rng_seed=seed))
else:
raise RuntimeError("Unknown environment")