diff --git a/src/3rdParty/LuaMinify.h b/src/3rdParty/LuaMinify.h index d01a8fe..5512f5b 100644 --- a/src/3rdParty/LuaMinify.h +++ b/src/3rdParty/LuaMinify.h @@ -319,8 +319,6 @@ R"lua_codes( -- ParseLua returns an AST, internally relying on LexLua. -- -local WhiteChars = lookupify{' ', '\n', '\t', '\r'} -local EscapeLookup = {['\r'] = '\\r', ['\n'] = '\\n', ['\t'] = '\\t', ['"'] = '\\"', ["'"] = "\\'"} local LowerChars = lookupify{'a', 'b', 'c', 'd', 'e', 'f', 'g', 'h', 'i', 'j', 'k', 'l', 'm', 'n', 'o', 'p', 'q', 'r', 's', 't', 'u', 'v', 'w', 'x', 'y', 'z'} @@ -549,7 +547,7 @@ local function LexLua(src) --get the initial char local thisLine = line local thisChar = char - local errorAt = ":"..line..":"..char..":> " + --local errorAt = ":"..line..":"..char..":> " local c = peek() --symbol to emit @@ -822,9 +820,9 @@ local function ParseLua(src) return err end -- - local VarUid = 0 + -- local VarUid = 0 -- No longer needed: handled in Scopes now local GlobalVarGetMap = {} - local VarDigits = {'_', 'a', 'b', 'c', 'd'} + -- local VarDigits = {'_', 'a', 'b', 'c', 'd'} local function CreateScope(parent) --[[ local scope = {} @@ -1736,15 +1734,6 @@ R"lua_codes( -- - All local variables are renamed -- -local LowerChars = lookupify{'a', 'b', 'c', 'd', 'e', 'f', 'g', 'h', 'i', - 'j', 'k', 'l', 'm', 'n', 'o', 'p', 'q', 'r', - 's', 't', 'u', 'v', 'w', 'x', 'y', 'z'} -local UpperChars = lookupify{'A', 'B', 'C', 'D', 'E', 'F', 'G', 'H', 'I', - 'J', 'K', 'L', 'M', 'N', 'O', 'P', 'Q', 'R', - 'S', 'T', 'U', 'V', 'W', 'X', 'Y', 'Z'} -local Digits = lookupify{'0', '1', '2', '3', '4', '5', '6', '7', '8', '9'} -local Symbols = lookupify{'+', '-', '*', '/', '^', '%', ',', '{', '}', '[', ']', '(', ')', ';', '#'} - local function Format_Mini(ast) local formatStatlist, formatExpr; --local count = 0 @@ -2092,14 +2081,413 @@ local function Format_Mini(ast) ast.Scope:ObfuscateVariables() return formatStatlist(ast) end +)lua_codes" -return function(src) - local st, ast = ParseLua(src) - if st then - return Format_Mini(ast) - else - return nil, ast +R"lua_codes( +local function FormatYue(ast, lineMap) + local currentLine = 1 + local formatStatlist, formatExpr + + local function joinStatementsSafe(out, b, sep) + local aa = '' + for i = #out, 1, -1 do + local a = out[i] + aa = a:match("([^%s])%s*$") + if aa then + break + end + end + sep = sep or ' ' + if (out[#out] or ''):sub(-1,-1) == ' ' then + sep = '' + end + local bb = b:match("^%s*([^%s])") + if UpperChars[aa] or LowerChars[aa] or aa == '_' then + if not (UpperChars[bb] or LowerChars[bb] or bb == '_' or Digits[bb]) then + --bb is a symbol, can join without sep + out[#out + 1] = b + elseif bb == '(' then + --prevent ambiguous syntax + out[#out + 1] = sep + out[#out + 1] = b + else + out[#out + 1] = sep + out[#out + 1] = b + end + elseif Digits[aa] then + if bb == '(' then + --can join statements directly + out[#out + 1] = b + elseif Symbols[bb] then + out[#out + 1] = b + else + out[#out + 1] = sep + out[#out + 1] = b + end + elseif aa == '' then + out[#out + 1] = b + else + if bb == '(' then + --don't want to accidentally call last statement, can't join directly + out[#out + 1] = sep + out[#out + 1] = b + else + out[#out + 1] = b + end + end end + + formatExpr = function(expr) + local out = {string.rep('(', expr.ParenCount or 0)} + if expr.Tokens then + local line = expr.Tokens[1].Line + local targetLine = lineMap[line] + if targetLine and currentLine < targetLine then + out[#out + 1] = string.rep('\n', targetLine - currentLine) + currentLine = targetLine + end + elseif expr.Value then + local line = expr.Value.Line + local targetLine = lineMap[line] + if targetLine and currentLine < targetLine then + out[#out + 1] = string.rep('\n', targetLine - currentLine) + currentLine = targetLine + end + end + if expr.AstType == 'VarExpr' then + if expr.Variable then + out[#out + 1] = expr.Variable.Name + else + out[#out + 1] = expr.Name + end + + elseif expr.AstType == 'NumberExpr' then + out[#out + 1] = expr.Value.Data + + elseif expr.AstType == 'StringExpr' then + out[#out + 1] = expr.Value.Data + + elseif expr.AstType == 'BooleanExpr' then + out[#out + 1] = tostring(expr.Value) + + elseif expr.AstType == 'NilExpr' then + joinStatementsSafe(out, "nil", nil) + + elseif expr.AstType == 'BinopExpr' then + joinStatementsSafe(out, formatExpr(expr.Lhs), nil) + out[#out + 1] = " " + joinStatementsSafe(out, expr.Op, nil) + out[#out + 1] = " " + joinStatementsSafe(out, formatExpr(expr.Rhs), nil) + + elseif expr.AstType == 'UnopExpr' then + joinStatementsSafe(out, expr.Op, nil) + out[#out + 1] = (#expr.Op ~= 1 and " " or "") + joinStatementsSafe(out, formatExpr(expr.Rhs), nil) + + elseif expr.AstType == 'DotsExpr' then + out[#out + 1] = "..." + + elseif expr.AstType == 'CallExpr' then + out[#out + 1] = formatExpr(expr.Base) + out[#out + 1] = "(" + for i = 1, #expr.Arguments do + out[#out + 1] = formatExpr(expr.Arguments[i]) + if i ~= #expr.Arguments then + out[#out + 1] = ", " + end + end + out[#out + 1] = ")" + + elseif expr.AstType == 'TableCallExpr' then + out[#out + 1] = formatExpr(expr.Base) + out[#out + 1] = " " + out[#out + 1] = formatExpr(expr.Arguments[1]) + + elseif expr.AstType == 'StringCallExpr' then + out[#out + 1] = formatExpr(expr.Base) + out[#out + 1] = " " + out[#out + 1] = expr.Arguments[1].Data + + elseif expr.AstType == 'IndexExpr' then + out[#out + 1] = formatExpr(expr.Base) + out[#out + 1] = "[" + out[#out + 1] = formatExpr(expr.Index) + out[#out + 1] = "]" + + elseif expr.AstType == 'MemberExpr' then + out[#out + 1] = formatExpr(expr.Base) + out[#out + 1] = expr.Indexer + out[#out + 1] = expr.Ident.Data + + elseif expr.AstType == 'Function' then + -- anonymous function + out[#out + 1] = "function(" + if #expr.Arguments > 0 then + for i = 1, #expr.Arguments do + out[#out + 1] = expr.Arguments[i].Name + if i ~= #expr.Arguments then + out[#out + 1] = ", " + elseif expr.VarArg then + out[#out + 1] = ", ..." + end + end + elseif expr.VarArg then + out[#out + 1] = "..." + end + out[#out + 1] = ")" + joinStatementsSafe(out, formatStatlist(expr.Body), nil) + joinStatementsSafe(out, "end", nil) + elseif expr.AstType == 'ConstructorExpr' then + out[#out + 1] = "{ " + for i = 1, #expr.EntryList do + local entry = expr.EntryList[i] + if entry.Type == 'Key' then + out[#out + 1] = "[" + out[#out + 1] = formatExpr(entry.Key) + out[#out + 1] = "] = " + out[#out + 1] = formatExpr(entry.Value) + elseif entry.Type == 'Value' then + out[#out + 1] = formatExpr(entry.Value) + elseif entry.Type == 'KeyString' then + out[#out + 1] = entry.Key + out[#out + 1] = " = " + out[#out + 1] = formatExpr(entry.Value) + end + if i ~= #expr.EntryList then + out[#out + 1] = ", " + end + end + out[#out + 1] = " }" + + elseif expr.AstType == 'Parentheses' then + out[#out + 1] = "(" + out[#out + 1] = formatExpr(expr.Inner) + out[#out + 1] = ")" + + end + out[#out + 1] = string.rep(')', expr.ParenCount or 0) + return table.concat(out) + end + + local formatStatement = function(statement) + local out = {""} + if statement.Tokens and statement.Tokens[1] then + local line = statement.Tokens[1].Line + local targetLine = lineMap[line] + if targetLine and currentLine < targetLine then + out[#out + 1] = string.rep('\n', targetLine - currentLine) + currentLine = targetLine + end + end + if statement.AstType == 'AssignmentStatement' then + for i = 1, #statement.Lhs do + out[#out + 1] = formatExpr(statement.Lhs[i]) + if i ~= #statement.Lhs then + out[#out + 1] = ", " + end + end + if #statement.Rhs > 0 then + out[#out + 1] = " = " + for i = 1, #statement.Rhs do + out[#out + 1] = formatExpr(statement.Rhs[i]) + if i ~= #statement.Rhs then + out[#out + 1] = ", " + end + end + end + elseif statement.AstType == 'CallStatement' then + out[#out + 1] = formatExpr(statement.Expression) + elseif statement.AstType == 'LocalStatement' then + out[#out + 1] = "local " + for i = 1, #statement.LocalList do + out[#out + 1] = statement.LocalList[i].Name + if statement.AttrList[i] then + out[#out + 1] = " <" + out[#out + 1] = statement.AttrList[i] + out[#out + 1] = ">" + end + if i ~= #statement.LocalList then + out[#out + 1] = "," + end + end + if #statement.InitList > 0 then + out[#out + 1] = " = " + for i = 1, #statement.InitList do + out[#out + 1] = formatExpr(statement.InitList[i]) + if i ~= #statement.InitList then + out[#out + 1] = ", " + end + end + end + elseif statement.AstType == 'IfStatement' then + out[#out + 1] = "if " + joinStatementsSafe(out, formatExpr(statement.Clauses[1].Condition), nil) + joinStatementsSafe(out, " then", nil) + joinStatementsSafe(out, formatStatlist(statement.Clauses[1].Body), nil) + for i = 2, #statement.Clauses do + local st = statement.Clauses[i] + if st.Condition then + joinStatementsSafe(out, "elseif ", nil) + joinStatementsSafe(out, formatExpr(st.Condition), nil) + joinStatementsSafe(out, " then", nil) + else + joinStatementsSafe(out, "else", nil) + end + joinStatementsSafe(out, formatStatlist(st.Body), nil) + end + joinStatementsSafe(out, "end", nil) + elseif statement.AstType == 'WhileStatement' then + out[#out + 1] = "while " + joinStatementsSafe(out, formatExpr(statement.Condition), nil) + joinStatementsSafe(out, " do", nil) + joinStatementsSafe(out, formatStatlist(statement.Body), nil) + joinStatementsSafe(out, "end", nil) + elseif statement.AstType == 'DoStatement' then + joinStatementsSafe(out, "do", nil) + joinStatementsSafe(out, formatStatlist(statement.Body), nil) + joinStatementsSafe(out, "end", nil) + elseif statement.AstType == 'ReturnStatement' then + out[#out + 1] = "return " + for i = 1, #statement.Arguments do + joinStatementsSafe(out, formatExpr(statement.Arguments[i]), nil) + if i ~= #statement.Arguments then + out[#out + 1] = ", " + end + end + elseif statement.AstType == 'BreakStatement' then + out[#out + 1] = "break" + elseif statement.AstType == 'RepeatStatement' then + out[#out + 1] = "repeat" + joinStatementsSafe(out, formatStatlist(statement.Body), nil) + joinStatementsSafe(out, "until ", nil) + joinStatementsSafe(out, formatExpr(statement.Condition), nil) + elseif statement.AstType == 'Function' then + if statement.IsLocal then + out[#out + 1] = "local " + end + joinStatementsSafe(out, "function ", nil) + if statement.IsLocal then + out[#out + 1] = statement.Name.Name + else + out[#out + 1] = formatExpr(statement.Name) + end + out[#out + 1] = "(" + if #statement.Arguments > 0 then + for i = 1, #statement.Arguments do + out[#out + 1] = statement.Arguments[i].Name + if i ~= #statement.Arguments then + out[#out + 1] = ", " + elseif statement.VarArg then + out[#out + 1] = ",..." + end + end + elseif statement.VarArg then + out[#out + 1] = "..." + end + out[#out + 1] = ")" + joinStatementsSafe(out, formatStatlist(statement.Body), nil) + joinStatementsSafe(out, "end", nil) + elseif statement.AstType == 'GenericForStatement' then + out[#out + 1] = "for " + for i = 1, #statement.VariableList do + out[#out + 1] = statement.VariableList[i].Name + if i ~= #statement.VariableList then + out[#out + 1] = ", " + end + end + out[#out + 1] = " in " + for i = 1, #statement.Generators do + joinStatementsSafe(out, formatExpr(statement.Generators[i]), nil) + if i ~= #statement.Generators then + joinStatementsSafe(out, ', ', nil) + end + end + joinStatementsSafe(out, " do", nil) + joinStatementsSafe(out, formatStatlist(statement.Body), nil) + joinStatementsSafe(out, "end", nil) + elseif statement.AstType == 'NumericForStatement' then + out[#out + 1] = "for " + out[#out + 1] = statement.Variable.Name + out[#out + 1] = " = " + out[#out + 1] = formatExpr(statement.Start) + out[#out + 1] = ", " + out[#out + 1] = formatExpr(statement.End) + if statement.Step then + out[#out + 1] = ", " + out[#out + 1] = formatExpr(statement.Step) + end + joinStatementsSafe(out, " do", nil) + joinStatementsSafe(out, formatStatlist(statement.Body), nil) + joinStatementsSafe(out, "end", nil) + elseif statement.AstType == 'LabelStatement' then + out[#out + 1] = "::" + out[#out + 1] = statement.Label + out[#out + 1] = "::" + elseif statement.AstType == 'GotoStatement' then + out[#out + 1] = "goto " + out[#out + 1] = statement.Label + elseif statement.AstType == 'Comment' then + -- Ignore + elseif statement.AstType == 'Eof' then + -- Ignore + else + print("Unknown AST Type: ", statement.AstType) + end + return table.concat(out) + end + + formatStatlist = function(statList) + local out = {""} + for _, stat in pairs(statList.Body) do + joinStatementsSafe(out, formatStatement(stat), ';') + end + return table.concat(out) + end + + return formatStatlist(ast) +end + +local function GetYueLineMap(luaCodes) + local current = 1 + local lastLine = 1 + local lineMap = { } + for lineCode in luaCodes:gmatch("[^\n\r]*") do + local num = lineCode:match("--%s*(%d+)%s*$") + if num then + local line = tonumber(num) + if line > lastLine then + lastLine = line + end + end + lineMap[current] = lastLine + current = current + 1 + end + return lineMap end + +return { + FormatMini = function(src) + local st, ast = ParseLua(src) + if st then + return Format_Mini(ast) + else + return nil, ast + end + end, + + FormatYue = function(src) + local st, ast = ParseLua(src) + if st then + local lineMap = GetYueLineMap(src) + if #lineMap == 0 then + return src + end + return FormatYue(ast, lineMap) + else + return nil, ast + end + end +} )lua_codes"; diff --git a/src/yue.cpp b/src/yue.cpp index 98bf4b4..4dec0e4 100644 --- a/src/yue.cpp +++ b/src/yue.cpp @@ -246,6 +246,7 @@ int main(int narg, const char** args) { #ifndef YUE_COMPILER_ONLY " -e str Execute a file or raw codes\n" " -m Generate minified codes\n" + " -r Rewrite output to match original line numbers\n" #endif // YUE_COMPILER_ONLY " -t path Specify where to place compiled files\n" " -o file Write output to file\n" @@ -409,6 +410,7 @@ int main(int narg, const char** args) { return 0; } bool minify = false; + bool rewrite = false; #endif // YUE_COMPILER_ONLY yue::YueConfig config; config.implicitReturnRoot = true; @@ -522,6 +524,8 @@ int main(int narg, const char** args) { } } else if (arg == "-m"sv) { minify = true; + } else if (arg == "-r"sv) { + rewrite = true; #endif // YUE_COMPILER_ONLY } else if (arg == "-s"sv) { config.useSpaceOverTab = true; @@ -607,6 +611,16 @@ int main(int narg, const char** args) { std::cout << "Error: -o can not be used with multiple input files\n"sv; return 1; } +#ifndef YUE_COMPILER_ONLY + if (minify || rewrite) { + if (minify) { + rewrite = false; + } + if (rewrite) { + config.reserveLineNumber = true; + } + } +#endif // YUE_COMPILER_ONLY #ifndef YUE_NO_WATCHER if (watchFiles) { auto fullWorkPath = fs::absolute(fs::path(workPath)).string(); @@ -744,7 +758,7 @@ int main(int narg, const char** args) { DEFER({ if (L) lua_close(L); }); - if (minify) { + if (minify || rewrite) { L = luaL_newstate(); luaL_openlibs(L); pushLuaminify(L); @@ -766,7 +780,7 @@ int main(int narg, const char** args) { errs.push_back(msg); } else { #ifndef YUE_COMPILER_ONLY - if (minify) { + if (minify || rewrite) { std::ifstream input(file, std::ios::in); if (input) { std::string s; @@ -780,27 +794,24 @@ int main(int narg, const char** args) { input.close(); int top = lua_gettop(L); DEFER(lua_settop(L, top)); - lua_pushvalue(L, -1); + lua_getfield(L, -1, rewrite ? "FormatYue" : "FormatMini"); lua_pushlstring(L, s.c_str(), s.size()); if (lua_pcall(L, 1, 1, 0) != 0) { ret = 2; std::string err = lua_tostring(L, -1); - errs.push_back("Failed to minify: "s + file + '\n' + err + '\n'); + errs.push_back((rewrite ? "Failed to rewrite: "s : "Failed to minify: "s) + file + '\n' + err + '\n'); } else { size_t size = 0; - const char* minifiedCodes = lua_tolstring(L, -1, &size); + const char* transformedCodes = lua_tolstring(L, -1, &size); if (writeToFile) { std::ofstream output(file, std::ios::trunc | std::ios::out); - output.write(minifiedCodes, size); + output.write(transformedCodes, size); output.close(); - std::cout << "Minified built "sv << file << '\n'; + std::cout << (rewrite ? "Rewrited built "sv : "Minified built "sv) << file << '\n'; } else { - std::cout << minifiedCodes << '\n'; + std::cout << transformedCodes << '\n'; } } - } else { - ret = 2; - errs.push_back("Failed to minify: "s + file + '\n'); } } else { std::cout << msg; diff --git a/src/yuescript/yue_compiler.cpp b/src/yuescript/yue_compiler.cpp index 09da047..2a3cb8c 100644 --- a/src/yuescript/yue_compiler.cpp +++ b/src/yuescript/yue_compiler.cpp @@ -72,7 +72,7 @@ static std::unordered_set Metamethods = { "close"s // Lua 5.4 }; -const std::string_view version = "0.17.8"sv; +const std::string_view version = "0.17.9"sv; const std::string_view extension = "yue"sv; class CompileError : public std::logic_error { @@ -435,7 +435,7 @@ class YueCompilerImpl { struct ClassMember { std::string item; MemType type; - ast_node* node; + ast_ptr node; }; struct DestructItem {