-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathserver.lua
59 lines (45 loc) · 1.48 KB
/
server.lua
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
require 'torch'
require 'nn'
require 'LanguageModel'
local turbo = require("turbo")
local cmd = torch.CmdLine()
cmd:option('-checkpoint', '/nas/doc/nn/cp_3x128-0-1_0.01_400.t7.reset.t7')
cmd:option('-gpu', -1)
cmd:option('-gpu_backend', 'cuda')
cmd:option('-verbose', 0)
cmd:option('-port', 8888) -- http port to listen
local opt = cmd:parse(arg)
local checkpoint = torch.load(opt.checkpoint)
local model = checkpoint.model
local msg
if opt.gpu >= 0 and opt.gpu_backend == 'cuda' then
require 'cutorch'
require 'cunn'
cutorch.setDevice(opt.gpu + 1)
model:cuda()
msg = string.format('Running with CUDA on GPU %d', opt.gpu)
elseif opt.gpu >= 0 and opt.gpu_backend == 'opencl' then
require 'cltorch'
require 'clnn'
model:cl()
msg = string.format('Running with OpenCL on GPU %d', opt.gpu)
else
msg = 'Running in CPU mode'
end
if opt.verbose == 1 then print(msg) end
model:evaluate()
local SampleHandler = class("SampleHandler", turbo.web.RequestHandler)
function SampleHandler:get()
-- Get the 'name' argument, or use 'Santa Claus' if it does not exist
opt['length'] = self:get_argument("length", 2000)
opt['start_text'] = self:get_argument("start_text", "")
opt['sample'] = self:get_argument("sample", 1)
opt['temperature'] = self:get_argument("temperature", 1)
local sample = model:sample(opt)
self:write(sample)
end
local app = turbo.web.Application:new({
{"/sample", SampleHandler}
})
app:listen(opt.port)
turbo.ioloop.instance():start()