diff --git a/lualib/skynet.lua b/lualib/skynet.lua index 11daa1e85..03699f953 100644 --- a/lualib/skynet.lua +++ b/lualib/skynet.lua @@ -95,6 +95,7 @@ end local coroutine_pool = {} local coroutine_yield = coroutine.yield +local coroutine_count = 0 local function co_create(f) local co = table.remove(coroutine_pool) @@ -109,6 +110,11 @@ local function co_create(f) f(coroutine_yield()) end end) + coroutine_count = coroutine_count + 1 + if coroutine_count > 1024 then + skynet.error("May overload, create 1024 task") + coroutine_count = 0 + end else coroutine.resume(co, f) end diff --git a/skynet-src/skynet_mq.c b/skynet-src/skynet_mq.c index 027c4dc01..8cc8f9daa 100644 --- a/skynet-src/skynet_mq.c +++ b/skynet-src/skynet_mq.c @@ -15,6 +15,7 @@ // 1 means mq is in global mq , or the message is dispatching. #define MQ_IN_GLOBAL 1 +#define MQ_OVERLOAD 1024 struct message_queue { uint32_t handle; @@ -24,6 +25,8 @@ struct message_queue { int lock; int release; int in_global; + int overload; + int overload_threshold; struct skynet_message *queue; struct message_queue *next; }; @@ -116,6 +119,8 @@ skynet_mq_create(uint32_t handle) { // If the service init success, skynet_context_new will call skynet_mq_force_push to push it to global queue. q->in_global = MQ_IN_GLOBAL; q->release = 0; + q->overload = 0; + q->overload_threshold = MQ_OVERLOAD; q->queue = skynet_malloc(sizeof(struct skynet_message) * q->cap); q->next = NULL; @@ -150,17 +155,42 @@ skynet_mq_length(struct message_queue *q) { return tail + cap - head; } +int +skynet_mq_overload(struct message_queue *q) { + if (q->overload) { + int overload = q->overload; + q->overload = 0; + return overload; + } + return 0; +} + int skynet_mq_pop(struct message_queue *q, struct skynet_message *message) { int ret = 1; LOCK(q) if (q->head != q->tail) { - *message = q->queue[q->head]; + *message = q->queue[q->head++]; ret = 0; - if ( ++ q->head >= q->cap) { - q->head = 0; + int head = q->head; + int tail = q->tail; + int cap = q->cap; + + if (head >= cap) { + q->head = head = 0; } + int length = tail - head; + if (length < 0) { + length += cap; + } + while (length > q->overload_threshold) { + q->overload = length; + q->overload_threshold *= 2; + } + } else { + // reset overload_threshold when queue is empty + q->overload_threshold = MQ_OVERLOAD; } if (ret) { diff --git a/skynet-src/skynet_mq.h b/skynet-src/skynet_mq.h index df1a99981..17178b4f5 100644 --- a/skynet-src/skynet_mq.h +++ b/skynet-src/skynet_mq.h @@ -30,6 +30,7 @@ void skynet_mq_push(struct message_queue *q, struct skynet_message *message); // return the length of message queue, for debug int skynet_mq_length(struct message_queue *q); +int skynet_mq_overload(struct message_queue *q); void skynet_mq_init(); diff --git a/skynet-src/skynet_server.c b/skynet-src/skynet_server.c index bdf0dcc1e..50978b10e 100644 --- a/skynet-src/skynet_server.c +++ b/skynet-src/skynet_server.c @@ -283,6 +283,10 @@ skynet_context_message_dispatch(struct skynet_monitor *sm, struct message_queue n = skynet_mq_length(q); n >>= weight; } + int overload = skynet_mq_overload(q); + if (overload) { + skynet_error(ctx, "May overload, message queue length = %d", overload); + } skynet_monitor_trigger(sm, msg.source , handle); diff --git a/test/testoverload.lua b/test/testoverload.lua new file mode 100644 index 000000000..10fe39a89 --- /dev/null +++ b/test/testoverload.lua @@ -0,0 +1,44 @@ +local skynet = require "skynet" + +local mode = ... + +if mode == "slave" then + +local CMD = {} + +function CMD.sum(n) + skynet.error("for loop begin") + local s = 0 + for i = 1, n do + s = s + i + end + skynet.error("for loop end") +end + +function CMD.blackhole() +end + +skynet.start(function() + skynet.dispatch("lua", function(_,_, cmd, ...) + local f = CMD[cmd] + f(...) + end) +end) + +else + +skynet.start(function() + local slave = skynet.newservice(SERVICE_NAME, "slave") + for step = 1, 20 do + skynet.error("overload test ".. step) + for i = 1, 512 * step do + skynet.send(slave, "lua", "blackhole") + end + skynet.sleep(step) + end + local n = 1000000000 + skynet.error(string.format("endless test n=%d", n)) + skynet.send(slave, "lua", "sum", n) +end) + +end