diff --git a/src/eval.c b/src/eval.c index 5e6292d8bf..4bf0a2a05a 100644 --- a/src/eval.c +++ b/src/eval.c @@ -156,7 +156,7 @@ void freeEvalScriptsSync(dict *scripts, list *scripts_lru_list, list *engine_cal } } -static int resetEngineEvalEnvCallback(scriptingEngine *engine, void *context) { +static void resetEngineEvalEnvCallback(scriptingEngine *engine, void *context) { int async = context != NULL; callableLazyEvalReset *callback = scriptingEngineCallResetEvalEnvFunc(engine, async); @@ -164,8 +164,6 @@ static int resetEngineEvalEnvCallback(scriptingEngine *engine, void *context) { list *callbacks = context; listAddNodeTail(callbacks, callback); } - - return 1; } /* Release resources related to Lua scripting. @@ -173,12 +171,12 @@ static int resetEngineEvalEnvCallback(scriptingEngine *engine, void *context) { void evalRelease(int async) { if (async) { list *engine_callbacks = listCreate(); - engineManagerForEachEngine(resetEngineEvalEnvCallback, engine_callbacks); + scriptingEngineManagerForEachEngine(resetEngineEvalEnvCallback, engine_callbacks); freeEvalScriptsAsync(evalCtx.scripts, evalCtx.scripts_lru_list, engine_callbacks); } else { freeEvalScriptsSync(evalCtx.scripts, evalCtx.scripts_lru_list, NULL); - engineManagerForEachEngine(resetEngineEvalEnvCallback, NULL); + scriptingEngineManagerForEachEngine(resetEngineEvalEnvCallback, NULL); } } @@ -319,7 +317,7 @@ static void evalDeleteScript(client *c, sds sha) { scriptHolder *sh = dictGetVal(de); /* Delete the script from the engine. */ - engineCallFreeFunction(sh->engine, VMSE_EVAL, sh->script); + scriptingEngineCallFreeFunction(sh->engine, VMSE_EVAL, sh->script); evalCtx.scripts_mem -= sdsAllocSize(sha) + getStringObjectSdsUsedMemory(sh->body); dictFreeUnlinkedEntry(evalCtx.scripts, de); @@ -397,7 +395,7 @@ static int evalRegisterNewScript(client *c, robj *body, char **sha) { } serverAssert(engine_name != NULL); - scriptingEngine *engine = engineManagerFind(engine_name); + scriptingEngine *engine = scriptingEngineManagerFind(engine_name); if (!engine) { if (c != NULL) { addReplyErrorFormat(c, "Could not find scripting engine '%s'", engine_name); @@ -414,12 +412,12 @@ static int evalRegisterNewScript(client *c, robj *body, char **sha) { robj *_err = NULL; size_t num_compiled_functions = 0; compiledFunction **functions = - engineCallCompileCode(engine, - VMSE_EVAL, - (sds)body->ptr + shebang_len, - 0, - &num_compiled_functions, - &_err); + scriptingEngineCallCompileCode(engine, + VMSE_EVAL, + (sds)body->ptr + shebang_len, + 0, + &num_compiled_functions, + &_err); if (functions == NULL) { serverAssert(_err != NULL); if (c != NULL) { @@ -449,7 +447,7 @@ static int evalRegisterNewScript(client *c, robj *body, char **sha) { } sh->body = body; int retval = dictAdd(evalCtx.scripts, _sha, sh); - serverAssertWithInfo(c ? c : engineGetClient(engine), NULL, retval == DICT_OK); + serverAssertWithInfo(c ? c : scriptingEngineGetClient(engine), NULL, retval == DICT_OK); evalCtx.scripts_mem += sdsAllocSize(_sha) + getStringObjectSdsUsedMemory(body); incrRefCount(body); zfree(functions); @@ -501,21 +499,21 @@ static void evalGenericCommand(client *c, int evalsha) { int ro = c->cmd->proc == evalRoCommand || c->cmd->proc == evalShaRoCommand; scriptRunCtx rctx; - if (scriptPrepareForRun(&rctx, engineGetClient(sh->engine), c, sha, sh->flags, ro) != C_OK) { + if (scriptPrepareForRun(&rctx, scriptingEngineGetClient(sh->engine), c, sha, sh->flags, ro) != C_OK) { return; } rctx.flags |= SCRIPT_EVAL_MODE; /* mark the current run as EVAL (as opposed to FCALL) so we'll get appropriate error messages and logs */ - engineCallFunction(sh->engine, - &rctx, - c, - sh->script, - VMSE_EVAL, - c->argv + 3, - numkeys, - c->argv + 3 + numkeys, - c->argc - 3 - numkeys); + scriptingEngineCallFunction(sh->engine, + &rctx, + c, + sh->script, + VMSE_EVAL, + c->argv + 3, + numkeys, + c->argv + 3 + numkeys, + c->argc - 3 - numkeys); scriptResetRun(&rctx); if (sh->node) { @@ -654,16 +652,15 @@ void scriptCommand(client *c) { } } -static int getEngineUsedMemory(scriptingEngine *engine, void *context) { +static void getEngineUsedMemory(scriptingEngine *engine, void *context) { size_t *sum = (size_t *)context; - engineMemoryInfo mem_info = engineCallGetMemoryInfo(engine, VMSE_EVAL); + engineMemoryInfo mem_info = scriptingEngineCallGetMemoryInfo(engine, VMSE_EVAL); *sum += mem_info.used_memory; - return 1; } unsigned long evalMemory(void) { size_t memory = 0; - engineManagerForEachEngine(getEngineUsedMemory, &memory); + scriptingEngineManagerForEachEngine(getEngineUsedMemory, &memory); return memory; } diff --git a/src/lua/engine_lua.c b/src/lua/engine_lua.c index 5fa2a87bfa..046e301997 100644 --- a/src/lua/engine_lua.c +++ b/src/lua/engine_lua.c @@ -380,8 +380,8 @@ int luaEngineInitEngine(void) { .get_memory_info = luaEngineGetMemoryInfo, }; - return scriptingEngineManagerRegisterEngine(LUA_ENGINE_NAME, - NULL, - createEngineContext(), - &methods); + return scriptingEngineManagerRegister(LUA_ENGINE_NAME, + NULL, + createEngineContext(), + &methods); } diff --git a/src/module.c b/src/module.c index 0ca4855f07..01faeb6450 100644 --- a/src/module.c +++ b/src/module.c @@ -13191,6 +13191,21 @@ int VM_UnregisterScriptingEngine(ValkeyModuleCtx *ctx, const char *engine_name) return VALKEYMODULE_OK; } +/* Returns the state of the current function being executed by the scripting + * engine. + * + * `server_ctx` is the server runtime context. + * + * It will return VMSE_STATE_KILLED if the function was already killed either by + * a `SCRIPT KILL`, or `FUNCTION KILL`. + */ +ValkeyModuleScriptingEngineExecutionState VM_GetFunctionExecutionState( + ValkeyModuleScriptingEngineServerRuntimeCtx *server_ctx) { + int ret = scriptInterrupt(server_ctx); + serverAssert(ret == SCRIPT_CONTINUE || ret == SCRIPT_KILL); + return ret == SCRIPT_CONTINUE ? VMSE_STATE_EXECUTING : VMSE_STATE_KILLED; +} + /* MODULE command. * * MODULE LIST @@ -14063,4 +14078,5 @@ void moduleRegisterCoreAPI(void) { REGISTER_API(RdbSave); REGISTER_API(RegisterScriptingEngine); REGISTER_API(UnregisterScriptingEngine); + REGISTER_API(GetFunctionExecutionState); } diff --git a/src/scripting_engine.c b/src/scripting_engine.c index 763c8ddc75..e1bb14cbc7 100644 --- a/src/scripting_engine.c +++ b/src/scripting_engine.c @@ -142,7 +142,7 @@ int scriptingEngineManagerUnregister(const char *engine_name) { functionsRemoveLibFromEngine(e); - engineMemoryInfo mem_info = scriptingEngineCallGetMemoryInfo(e); + engineMemoryInfo mem_info = scriptingEngineCallGetMemoryInfo(e, VMSE_ALL); engineMgr.total_memory_overhead -= zmalloc_size(e) + sdsAllocSize(e->name) + mem_info.engine_memory_overhead; @@ -215,17 +215,17 @@ static void engineTeardownModuleCtx(scriptingEngine *e) { } } -compiledFunction **scriptinEngineCallCompileCode(scriptingEngine *engine, - subsystemType type, - const char *code, - size_t timeout, - size_t *out_num_compiled_functions, - robj **err) { +compiledFunction **scriptingEngineCallCompileCode(scriptingEngine *engine, + subsystemType type, + const char *code, + size_t timeout, + size_t *out_num_compiled_functions, + robj **err) { serverAssert(type == VMSE_EVAL || type == VMSE_FUNCTION); engineSetupModuleCtx(engine, NULL); - compiledFunction **functions = engine->impl->methods.compile_code( + compiledFunction **functions = engine->impl.methods.compile_code( engine->module_ctx, engine->impl.ctx, type, @@ -290,11 +290,11 @@ size_t scriptingEngineCallGetFunctionMemoryOverhead(scriptingEngine *engine, } callableLazyEvalReset *scriptingEngineCallResetEvalEnvFunc(scriptingEngine *engine, - int async) { + int async) { engineSetupModuleCtx(engine, NULL); callableLazyEvalReset *callback = engine->impl.methods.reset_eval_env( engine->module_ctx, - engine->impl->ctx, + engine->impl.ctx, async); engineTeardownModuleCtx(engine); return callback; diff --git a/src/scripting_engine.h b/src/scripting_engine.h index bd2f3076b3..75147100e2 100644 --- a/src/scripting_engine.h +++ b/src/scripting_engine.h @@ -31,14 +31,14 @@ typedef void (*engineIterCallback)(scriptingEngine *engine, void *context); * Engine manager API functions. */ int scriptingEngineManagerInit(void); -size_t scriptingEngineManagerGetCacheMemory(void); +size_t scriptingEngineManagerGetTotalMemoryOverhead(void); size_t scriptingEngineManagerGetNumEngines(void); size_t scriptingEngineManagerGetMemoryUsage(void); -int scriptingEngineManagerRegisterEngine(const char *engine_name, - ValkeyModule *engine_module, - engineCtx *engine_ctx, - engineMethods *engine_methods); -int scriptingEngineManagerUnregisterEngine(const char *engine_name); +int scriptingEngineManagerRegister(const char *engine_name, + ValkeyModule *engine_module, + engineCtx *engine_ctx, + engineMethods *engine_methods); +int scriptingEngineManagerUnregister(const char *engine_name); scriptingEngine *scriptingEngineManagerFind(const char *engine_name); void scriptingEngineManagerForEachEngine(engineIterCallback callback, void *context); diff --git a/src/server.c b/src/server.c index 66a9c1715a..3fae0750e5 100644 --- a/src/server.c +++ b/src/server.c @@ -1369,11 +1369,10 @@ void checkChildrenDone(void) { } } -static int sumEngineUsedMemory(scriptingEngine *engine, void *context) { +static void sumEngineUsedMemory(scriptingEngine *engine, void *context) { size_t *total_memory = (size_t *)context; - engineMemoryInfo mem_info = engineCallGetMemoryInfo(engine, VMSE_ALL); + engineMemoryInfo mem_info = scriptingEngineCallGetMemoryInfo(engine, VMSE_ALL); *total_memory += mem_info.used_memory; - return 1; } /* Called from serverCron and cronUpdateMemoryStats to update cached memory metrics. */ @@ -1402,7 +1401,7 @@ void cronUpdateMemoryStats(void) { * so we must deduct it in order to be able to calculate correct * "allocator fragmentation" ratio */ size_t engines_memory = 0; - engineManagerForEachEngine(sumEngineUsedMemory, &engines_memory); + scriptingEngineManagerForEachEngine(sumEngineUsedMemory, &engines_memory); server.cron_malloc_stats.allocator_resident = server.cron_malloc_stats.process_rss - engines_memory; } if (!server.cron_malloc_stats.allocator_active) diff --git a/src/valkeymodule.h b/src/valkeymodule.h index 25fa588178..84eaaff6f8 100644 --- a/src/valkeymodule.h +++ b/src/valkeymodule.h @@ -837,6 +837,11 @@ typedef enum ValkeyModuleScriptingEngineSubsystemType { VMSE_ALL } ValkeyModuleScriptingEngineSubsystemType; +typedef enum ValkeyModuleScriptingEngineExecutionState { + VMSE_STATE_EXECUTING, + VMSE_STATE_KILLED, +} ValkeyModuleScriptingEngineExecutionState; + typedef struct ValkeyModuleScriptingEngineCallableLazyEvalReset { void *context; @@ -1868,6 +1873,8 @@ VALKEYMODULE_API int (*ValkeyModule_RegisterScriptingEngine)(ValkeyModuleCtx *mo VALKEYMODULE_API int (*ValkeyModule_UnregisterScriptingEngine)(ValkeyModuleCtx *module_ctx, const char *engine_name) VALKEYMODULE_ATTR; +VALKEYMODULE_API ValkeyModuleScriptingEngineExecutionState (*ValkeyModule_GetFunctionExecutionState)(ValkeyModuleScriptingEngineServerRuntimeCtx *server_ctx) VALKEYMODULE_ATTR; + #define ValkeyModule_IsAOFClient(id) ((id) == UINT64_MAX) /* This is included inline inside each Valkey module. */ @@ -2237,6 +2244,7 @@ static int ValkeyModule_Init(ValkeyModuleCtx *ctx, const char *name, int ver, in VALKEYMODULE_GET_API(RdbSave); VALKEYMODULE_GET_API(RegisterScriptingEngine); VALKEYMODULE_GET_API(UnregisterScriptingEngine); + VALKEYMODULE_GET_API(GetFunctionExecutionState); if (ValkeyModule_IsModuleNameBusy && ValkeyModule_IsModuleNameBusy(name)) return VALKEYMODULE_ERR; ValkeyModule_SetModuleAttribs(ctx, name, ver, apiver); diff --git a/tests/modules/helloscripting.c b/tests/modules/helloscripting.c index 0a342d92e8..f0960b4a96 100644 --- a/tests/modules/helloscripting.c +++ b/tests/modules/helloscripting.c @@ -3,6 +3,7 @@ #include #include #include +#include /* * This module implements a very simple stack based scripting language. @@ -28,6 +29,15 @@ * CONSTI 432 # pushes the value 432 to the top of the stack * RETURN # returns the current value on the top of the stack and marks * # the end of the function declaration. + * + * FUNCTION sleep # declaration of function 'sleep' + * ARGS 0 # pushes the value in the first argument to the top of the + * # stack + * SLEEP # Pops the current value in the stack and sleeps for `value` + * # seconds + * CONSTI 0 # pushes the value 0 to the top of the stack + * RETURN # returns the current value on the top of the stack and marks + * # the end of the function declaration. * ``` */ @@ -38,6 +48,7 @@ typedef enum HelloInstKind { FUNCTION = 0, CONSTI, ARGS, + SLEEP, RETURN, _NUM_INSTRUCTIONS, // Not a real instruction. } HelloInstKind; @@ -49,6 +60,7 @@ const char *HelloInstKindStr[] = { "FUNCTION", "CONSTI", "ARGS", + "SLEEP", "RETURN", }; @@ -185,6 +197,10 @@ static int helloLangParseCode(const char *code, ValkeyModule_Assert(currentFunc != NULL); helloLangParseArgs(currentFunc); break; + case SLEEP: + ValkeyModule_Assert(currentFunc != NULL); + currentFunc->num_instructions++; + break; case RETURN: ValkeyModule_Assert(currentFunc != NULL); currentFunc->num_instructions++; @@ -204,13 +220,40 @@ static int helloLangParseCode(const char *code, return 0; } +static ValkeyModuleScriptingEngineExecutionState executeSleepInst(ValkeyModuleScriptingEngineServerRuntimeCtx *server_ctx, + uint32_t seconds) { + uint32_t elapsed_milliseconds = 0; + ValkeyModuleScriptingEngineExecutionState state = VMSE_STATE_EXECUTING; + while(1) { + state = ValkeyModule_GetFunctionExecutionState(server_ctx); + if (state != VMSE_STATE_EXECUTING) { + break; + } + + if (elapsed_milliseconds >= (seconds * 1000)) { + break; + } + + usleep(1000); + elapsed_milliseconds++; + } + + return state; +} + /* * Executes an HELLO function. */ -static uint32_t executeHelloLangFunction(HelloFunc *func, - ValkeyModuleString **args, int nargs) { +static ValkeyModuleScriptingEngineExecutionState executeHelloLangFunction(ValkeyModuleScriptingEngineServerRuntimeCtx *server_ctx, + HelloFunc *func, + ValkeyModuleString **args, + int nargs, + uint32_t *result) { + ValkeyModule_Assert(result != NULL); uint32_t stack[64]; + uint32_t val = 0; int sp = 0; + ValkeyModuleScriptingEngineExecutionState state = VMSE_STATE_EXECUTING; for (uint32_t pc = 0; pc < func->num_instructions; pc++) { HelloInst instr = func->instructions[pc]; @@ -226,21 +269,27 @@ static uint32_t executeHelloLangFunction(HelloFunc *func, uint32_t arg = str2int(argStr); stack[sp++] = arg; break; - } + } + case SLEEP: { + val = stack[--sp]; + state = executeSleepInst(server_ctx, val); + break; + } case RETURN: { ValkeyModule_Assert(sp > 0); - uint32_t val = stack[--sp]; + val = stack[--sp]; ValkeyModule_Assert(sp == 0); - return val; - } + *result = val; + return state; + } case FUNCTION: - default: + case _NUM_INSTRUCTIONS: ValkeyModule_Assert(0); } } ValkeyModule_Assert(0); - return 0; + return state; } static ValkeyModuleScriptingEngineMemoryInfo engineGetMemoryInfo(ValkeyModuleCtx *module_ctx, @@ -353,19 +402,30 @@ static ValkeyModuleScriptingEngineCompiledFunction **createHelloLangEngine(Valke static void callHelloLangFunction(ValkeyModuleCtx *module_ctx, ValkeyModuleScriptingEngineCtx *engine_ctx, - ValkeyModuleScriptingEngineFunctionCtx *func_ctx, + ValkeyModuleScriptingEngineServerRuntimeCtx *server_ctx, ValkeyModuleScriptingEngineCompiledFunction *compiled_function, ValkeyModuleScriptingEngineSubsystemType type, ValkeyModuleString **keys, size_t nkeys, ValkeyModuleString **args, size_t nargs) { VALKEYMODULE_NOT_USED(engine_ctx); - VALKEYMODULE_NOT_USED(func_ctx); - VALKEYMODULE_NOT_USED(type); VALKEYMODULE_NOT_USED(keys); VALKEYMODULE_NOT_USED(nkeys); + ValkeyModule_Assert(type == VMSE_EVAL || type == VMSE_FUNCTION); + HelloFunc *func = (HelloFunc *)compiled_function->function; - uint32_t result = executeHelloLangFunction(func, args, nargs); + uint32_t result; + ValkeyModuleScriptingEngineExecutionState state = executeHelloLangFunction(server_ctx, func, args, nargs, &result); + ValkeyModule_Assert(state == VMSE_STATE_KILLED || state == VMSE_STATE_EXECUTING); + + if (state == VMSE_STATE_KILLED) { + if (type == VMSE_EVAL) { + ValkeyModule_ReplyWithError(module_ctx, "ERR Script killed by user with SCRIPT KILL."); + } + if (type == VMSE_FUNCTION) { + ValkeyModule_ReplyWithError(module_ctx, "ERR Script killed by user with FUNCTION KILL"); + } + } ValkeyModule_ReplyWithLongLong(module_ctx, result); } diff --git a/tests/unit/moduleapi/scriptingengine.tcl b/tests/unit/moduleapi/scriptingengine.tcl index 3a37339ea8..78f6572904 100644 --- a/tests/unit/moduleapi/scriptingengine.tcl +++ b/tests/unit/moduleapi/scriptingengine.tcl @@ -123,6 +123,22 @@ start_server {tags {"modules"}} { assert_equal $result 432 } + test {Test function kill} { + set rd [valkey_deferring_client] + r config set busy-reply-threshold 10 + r function load REPLACE "#!hello name=mylib\nFUNCTION sleep\nARGS 0\nSLEEP\nARGS 0\nRETURN" + $rd fcall sleep 0 100 + after 1000 + catch {r ping} e + assert_match {BUSY*} $e + assert_match {running_script {name sleep command {fcall sleep 0 100} duration_ms *} engines {*}} [r FUNCTION STATS] + r function kill + after 1000 ; + assert_equal [r ping] "PONG" + assert_error {ERR Script killed by user with FUNCTION KILL*} {$rd read} + $rd close + } + test {Unload scripting engine module} { set result [r module unload helloengine] assert_equal $result "OK"