forked from jsbaan/DPAC-DialogueGAN
-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathhelpers.py
40 lines (27 loc) · 976 Bytes
/
helpers.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
import torch
from torch.autograd import Variable
from math import ceil
import sys
def prepare_discriminator_data(pos_samples, neg_samples, gpu=False):
"""
Takes positive (target) samples, negative (generator) samples and prepares inp and target data for discriminator.
Inputs: pos_samples, neg_samples
- pos_samples: pos_size x seq_len
- neg_samples: neg_size x seq_len
Returns: inp, target
- inp: (pos_size + neg_size) x seq_len
- target: pos_size + neg_size (boolean 1/0)
"""
inp = torch.cat((pos_samples, neg_samples), 0).type(torch.LongTensor)
target = torch.ones(pos_samples.size()[0] + neg_samples.size()[0])
target[pos_samples.size()[0]:] = 0
# shuffle
perm = torch.randperm(target.size()[0])
target = target[perm]
inp = inp[perm]
inp = Variable(inp)
target = Variable(target)
if gpu:
inp = inp.cuda()
target = target.cuda()
return inp, target