diff --git a/lualib/skynet.lua b/lualib/skynet.lua index 908701878..0e938e18d 100644 --- a/lualib/skynet.lua +++ b/lualib/skynet.lua @@ -126,10 +126,11 @@ local function co_create(f) end local function dispatch_wakeup() - local co = table.remove(wakeup_queue,1) - if co then - local session = sleep_session[co] + local token = table.remove(wakeup_queue,1) + if token then + local session = sleep_session[token] if session then + local co = session_id_coroutine[session] session_id_coroutine[session] = "BREAK" return suspend(co, coroutine_resume(co, false, "BREAK")) end @@ -149,7 +150,7 @@ local function release_watching(address) end -- suspend is local function -function suspend(co, result, command, param, size) +function suspend(co, result, command, param, param2) if not result then local session = session_coroutine_id[co] if session then -- coroutine may fork by others (session is nil) @@ -167,13 +168,16 @@ function suspend(co, result, command, param, size) session_id_coroutine[param] = co elseif command == "SLEEP" then session_id_coroutine[param] = co - sleep_session[co] = param + if sleep_session[param2] then + error(debug.traceback(co, "token duplicative")) + end + sleep_session[param2] = param elseif command == "RETURN" then local co_session = session_coroutine_id[co] session_coroutine_id[co] = nil if co_session == 0 then - if size ~= nil then - c.trash(param, size) + if param2 ~= nil then + c.trash(param, param2) end return suspend(co, coroutine_resume(co, false)) -- send don't need ret end @@ -183,13 +187,13 @@ function suspend(co, result, command, param, size) end local ret if not dead_service[co_address] then - ret = c.send(co_address, skynet.PTYPE_RESPONSE, co_session, param, size) ~= nil + ret = c.send(co_address, skynet.PTYPE_RESPONSE, co_session, param, param2) ~= nil if not ret then -- If the package is too large, returns nil. so we should report error back c.send(co_address, skynet.PTYPE_ERROR, co_session, "") end - elseif size ~= nil then - c.trash(param, size) + elseif param2 ~= nil then + c.trash(param, param2) ret = false end return suspend(co, coroutine_resume(co, ret)) @@ -279,11 +283,12 @@ function skynet.timeout(ti, func) session_id_coroutine[session] = co end -function skynet.sleep(ti) +function skynet.sleep(ti, token) local session = c.intcommand("TIMEOUT",ti) assert(session) - local succ, ret = coroutine_yield("SLEEP", session) - sleep_session[coroutine.running()] = nil + token = token or coroutine.running() + local succ, ret = coroutine_yield("SLEEP", session, token) + sleep_session[token] = nil if succ then return end @@ -298,11 +303,11 @@ function skynet.yield() return skynet.sleep(0) end -function skynet.wait(co) +function skynet.wait(token) local session = c.genid() - local ret, msg = coroutine_yield("SLEEP", session) - co = co or coroutine.running() - sleep_session[co] = nil + token = token or coroutine.running() + local ret, msg = coroutine_yield("SLEEP", session, token) + sleep_session[token] = nil session_id_coroutine[session] = nil end @@ -429,9 +434,9 @@ function skynet.retpack(...) return skynet.ret(skynet.pack(...)) end -function skynet.wakeup(co) - if sleep_session[co] then - table.insert(wakeup_queue, co) +function skynet.wakeup(token) + if sleep_session[token] then + table.insert(wakeup_queue, token) return true end end diff --git a/test/testtimeout.lua b/test/testtimeout.lua index 94355ab23..88c3deff2 100644 --- a/test/testtimeout.lua +++ b/test/testtimeout.lua @@ -14,18 +14,15 @@ local function test_service() end local function timeout_call(ti, ...) - local co = coroutine.running() + local token = {} local ret skynet.fork(function(...) ret = table.pack(pcall(skynet.call, ...)) - if co then - skynet.wakeup(co) - end + skynet.wakeup(token) end, ...) - skynet.sleep(ti) - co = nil -- prevent wakeup after call + skynet.sleep(ti, token) if ret then if ret[1] then return table.unpack(ret, 1, ret.n) @@ -45,6 +42,8 @@ skynet.start(function() skynet.error("2", skynet.now()) skynet.error(timeout_call(50, test, "lua")) skynet.error("3", skynet.now()) + skynet.error(timeout_call(150, test, "lua")) + skynet.error("4", skynet.now()) skynet.exit() end)