diff --git a/Lua.Wrapper.pas b/Lua.Wrapper.pas index 0878c30..789d84b 100644 --- a/Lua.Wrapper.pas +++ b/Lua.Wrapper.pas @@ -245,6 +245,7 @@ type procedure SetAutoOpenLibraries(const Value: TLuaLibraries); virtual; protected procedure CheckState; virtual; + procedure CheckIsFunction; virtual; procedure AfterLoad; virtual; function GetRegisteredFunctionCookie: Integer; virtual; @@ -262,16 +263,17 @@ type procedure LoadFromFile(const AFileName: string; AAutoRun: Boolean = True; const AChunkName: string = ''); virtual; procedure LoadFromScript(AScript: TLuaScript; AOwnership: TStreamOwnership = soReference; AAutoRun: Boolean = True; const AChunkName: string = ''); virtual; - function GetGlobalVariable(const AName: string): ILuaVariable; - procedure SetGlobalVariable(const AName: string; AVariable: TLuaImplicitVariable); - procedure RegisterFunction(const AName: string; AFunction: TLuaCFunction); + function GetGlobalVariable(const AName: string): ILuaVariable; virtual; + procedure SetGlobalVariable(const AName: string; AVariable: TLuaImplicitVariable); virtual; + procedure RegisterFunction(const AName: string; AFunction: TLuaCFunction); virtual; procedure OpenLibraries(ALibraries: TLuaLibraries); virtual; - { Run or GetByteCode should only be called right after one of the + { These methods should only be called right after one of the LoadFrom methods, which must have AutoRun set to False. } procedure Run; virtual; procedure GetByteCode(AStream: TStream; APop: Boolean = False); virtual; + procedure Capture(const AName: string); virtual; function Call(const AFunctionName: string): ILuaReadParameters; overload; virtual; function Call(const AFunctionName: string; AParameters: array of const): ILuaReadParameters; overload; virtual; @@ -1178,42 +1180,41 @@ begin else begin if TLuaLibrary.Base in ALibraries then - luaopen_base(State); + luaL_requiref(State, 'base', luaopen_base, 1); if TLuaLibrary.Coroutine in ALibraries then - luaopen_coroutine(State); + luaL_requiref(State, 'coroutine', luaopen_coroutine, 1); if TLuaLibrary.Table in ALibraries then - luaopen_table(State); + luaL_requiref(State, 'table', luaopen_table, 1); if TLuaLibrary.IO in ALibraries then - luaopen_io(State); + luaL_requiref(State, 'io', luaopen_io, 1); if TLuaLibrary.OS in ALibraries then - luaopen_os(State); + luaL_requiref(State, 'os', luaopen_os, 1); if TLuaLibrary.StringLib in ALibraries then - luaopen_string(State); + luaL_requiref(State, 'string', luaopen_string, 1); if TLuaLibrary.Bit32 in ALibraries then - luaopen_bit32(State); + luaL_requiref(State, 'bit32', luaopen_bit32, 1); if TLuaLibrary.Math in ALibraries then - luaopen_math(State); + luaL_requiref(State, 'math', luaopen_math, 1); if TLuaLibrary.Debug in ALibraries then - luaopen_debug(State); + luaL_requiref(State, 'debug', luaopen_debug, 1); if TLuaLibrary.Package in ALibraries then - luaopen_package(State); + luaL_requiref(State, 'package', luaopen_package, 1); end; end; procedure TLua.Run; begin - if not lua_isfunction(State, -1) then - raise ELuaNoFunctionException.Create('No function on top of the stack, use the LoadFrom methods first'); + CheckIsFunction; if lua_pcall(State, 0, 0, 0) <> 0 then TLuaHelpers.RaiseLastLuaError(State); @@ -1225,9 +1226,9 @@ end; procedure TLua.GetByteCode(AStream: TStream; APop: Boolean); var returnCode: Integer; + begin - if not lua_isfunction(State, -1) then - raise ELuaNoFunctionException.Create('No function on top of the stack, use the LoadFrom methods first'); + CheckIsFunction; try returnCode := lua_dump(State, LuaWrapperWriter, @AStream); @@ -1240,6 +1241,42 @@ begin end; +procedure TLua.Capture(const AName: string); +var + name: PAnsiChar; + +begin + CheckIsFunction; + + // Create a new table to serve as the environment + lua_newtable(State); + + // Set the global AName to the new table + lua_pushvalue(State, -1); + name := TLuaHelpers.AllocLuaString(AName); + try + lua_setglobal(State, name); + finally + TLuaHelpers.FreeLuaString(name); + end; + + // Set the global environment as the table's metatable index, so calls to + // global functions and variables still work + lua_newtable(State); + TLuaHelpers.PushString(State, '__index'); + lua_pushglobaltable(State); + lua_settable(State, -3); + + lua_setmetatable(State, -2); + + // Set the new table as the environment (upvalue at index 1) + lua_setupvalue(State, -2, 1); + + if lua_pcall(State, 0, 0, 0) <> 0 then + TLuaHelpers.RaiseLastLuaError(State); +end; + + function TLua.Call(const AFunctionName: string): ILuaReadParameters; begin Result := Call(AFunctionName, nil); @@ -1283,6 +1320,13 @@ begin end; +procedure TLua.CheckIsFunction; +begin + if not lua_isfunction(State, -1) then + raise ELuaNoFunctionException.Create('No function on top of the stack, use the LoadFrom methods first'); +end; + + procedure TLua.AfterLoad; var cookie: Integer; diff --git a/Lua.pas b/Lua.pas index 20ab667..6a9c694 100644 --- a/Lua.pas +++ b/Lua.pas @@ -394,6 +394,8 @@ var { open all previous libraries } luaL_openlibs: procedure(L: lua_State); cdecl; + luaL_requiref: procedure(L: lua_State; modname: PAnsiChar; openf: lua_CFunction; glb: Integer); cdecl; + type @@ -582,6 +584,7 @@ begin Load(@luaopen_package, 'luaopen_package'); Load(@luaL_openlibs, 'luaL_openlibs'); + Load(@luaL_requiref, 'luaL_requiref'); Load(@luaL_setfuncs, 'luaL_setfuncs'); end; diff --git a/UnitTests/source/TestWrapper.pas b/UnitTests/source/TestWrapper.pas index 879c8f4..39e2b08 100644 --- a/UnitTests/source/TestWrapper.pas +++ b/UnitTests/source/TestWrapper.pas @@ -25,6 +25,7 @@ type procedure LoadAndRunFromString; procedure LoadAndRunFromStream; procedure LoadMultiple; + procedure LoadMultipleSharedVariable; procedure ChunkNameInException; procedure Input; @@ -45,6 +46,8 @@ type procedure VariableFunction; procedure ByteCode; + procedure Capture; + procedure DenyRequire; end; @@ -65,6 +68,7 @@ begin FPrinted := TStringBuilder.Create; FLua := TLua.Create; + FLua.AutoOpenLibraries := [StringLib]; FLua.RegisterFunction('print', procedure(AContext: ILuaContext) begin @@ -112,6 +116,14 @@ begin end; +procedure TTestWrapper.LoadMultipleSharedVariable; +begin + Lua.LoadFromString('message = "Hello world!"', True, 'Script1'); + Lua.LoadFromString('print(message)', True, 'Script2'); + + CheckEquals('Hello world!', Printed.ToString); +end; + procedure TTestWrapper.Input; begin Lua.SetGlobalVariable('thingy', 'world'); @@ -377,6 +389,41 @@ begin end; +procedure TTestWrapper.Capture; +begin + // Capture is a convenience method which puts a script's variables and + // functions in a global table variable. Useful for example when + // implementing a sandboxed API. + Lua.LoadFromString('message = "Hello world!"'#13#10 + + 'function outputMessage()'#13#10 + + ' print(message)'#13#10 + + 'end', False, 'Script1'); + Lua.Capture('Captured'); + + Lua.LoadFromString('print(Captured.message)'#13#10 + + 'Captured.message = "Goodbye world!"'#13#10 + + 'Captured.outputMessage()', True, 'Script2'); + + CheckEquals('Hello world!Goodbye world!', Printed.ToString); +end; + + +procedure TTestWrapper.DenyRequire; +begin + try + // This should fail, since we're not loading the Package library which + // adds the require function that can be considered a security risk. + Lua.LoadFromString('require("Test")'); + Fail('ELuaException expected'); + except + on E:Exception do + begin + CheckIs(E, ELuaException); + end; + end; +end; + + initialization RegisterTest(TTestWrapper.Suite);