Search code examples
haskellcompiler-constructionllvmjitllvm-ir

Cannot access symbols across modules in LLVM OrcJIT


I'm writing a JIT compiler using haskell, LLVM-hs and OrcJIT. Here's my main file which compiles modules, adds them to the JIT and fetches and runs the internal main functions:

main :: IO ()
main =
    withContext $ \ctx ->
        withExecutionSession $ \es ->
            withHostTargetMachine Reloc.PIC CodeModel.Default CodeGenOpt.None $ \tm ->
                withSymbolResolver es myResolver $ \psr ->
                    withObjectLinkingLayer es (\_ -> return psr) $ \oll ->
                        withIRCompileLayer oll tm $ \ircl -> do
                            loadLibraryPermanently Nothing
                            repl ctx es tm ircl

    where
        myResolver :: SymbolResolver
        myResolver = SymbolResolver $ \mangled -> do
            ptr <- getSymbolAddressInProcess mangled
            return $ Right $ JITSymbol
                { jitSymbolAddress = ptr 
                , jitSymbolFlags   = defaultJITSymbolFlags { jitSymbolExported = True }
                }


repl :: Context -> ExecutionSession -> TargetMachine -> IRCompileLayer ObjectLinkingLayer ->  IO ()
repl ctx es tm cl = runInputT defaultSettings (loop C.initCmpState)
    where
        loop :: C.CmpState -> InputT IO ()
        loop state =
            getInputLine "% " >>= \minput -> case minput of
                Nothing    -> return ()
                Just "q"   -> return ()
                Just input -> liftIO (process state input) >>= loop

        process :: C.CmpState -> String -> IO C.CmpState
        process state source =
            case L.alexScanner source of
                Left  errStr -> putStrLn errStr >> return state
                Right tokens -> case (P.parseTokens tokens) 0 of
                    P.ParseOk ast ->
                        let (res, state') = C.codeGen state (head ast) in
                        case res of
                            Left err -> putStrLn (show err) >> return state
                            Right () -> runDefinition (state' { C.externs = C.externs state }) >> return state'
                                { C.globals      = Map.empty
                                , C.instructions = []
                                }

        runDefinition :: C.CmpState -> IO ()
        runDefinition state = do
            let globals = Map.elems (C.globals state)
            let externs = Map.elems (C.externs state)
            let instructions = reverse (C.instructions state)

            let mainName = mkBSS "main.0"
            let mainFn = GlobalDefinition $ functionDefaults
                { returnType  = void
                , name        = Name mainName
                , basicBlocks = [BasicBlock (mkName "entry") instructions (Do $ Ret Nothing [])]
                }

            case instructions of
                [] -> do
                    let astmod = defaultModule
                        { moduleDefinitions = externs ++ globals 
                        }
                    M.withModuleFromAST ctx astmod $ \mod -> do
                        BS.putStrLn =<< M.moduleLLVMAssembly mod
                        withModuleKey es $ \modKey ->
                            addModule cl modKey mod
                x -> do
                    let astmod = defaultModule
                        { moduleDefinitions = externs ++ globals ++ [mainFn]
                        }
                    M.withModuleFromAST ctx astmod $ \mod -> do
                        BS.putStrLn =<< M.moduleLLVMAssembly mod
                        withModuleKey es $ \modKey ->
                            withModule cl modKey mod $ do
                                res <- (\mangled -> findSymbol cl mangled False) =<< mangleSymbol cl mainName
                                case res of
                                    Left _ -> putStrLn ("Couldn't find: " ++ show mainName)
                                    Right (JITSymbol fn _)-> do
                                        run $ castPtrToFunPtr (wordPtrToPtr fn)

Isolated modules such as this print statement run correctly. Modules with a main function are removed from the JIT after being executed:

print(234);

; ModuleID = '<string>'
source_filename = "<string>"

@0 = constant [4 x i8] c"%d\0A\00"

declare i32 @printf(i8*, ...)

define void @main.0() {
entry:
  %0 = call i32 (i8*, ...) @printf(i8* getelementptr inbounds ([4 x i8], [4 x i8]* @0, i32 0, i32 0), i32 234)
  ret void
}

234

Assigning 4 to the symbol 'x' results in a module with a global variable, this module isn't deleted from the JIT:

x := 4;

; ModuleID = '<string>'
source_filename = "<string>"

@x = global i32 4

But attempting to print 'x' in the next statement results in a lookup failure for the main function:

print(x);

; ModuleID = '<string>'
source_filename = "<string>"

@x = external global i32
@0 = constant [4 x i8] c"%d\0A\00"

declare i32 @printf(i8*, ...)

define void @main.0() {
entry:
  %0 = load i32, i32* @x
  %1 = call i32 (i8*, ...) @printf(i8* getelementptr inbounds ([4 x i8], [4 x i8]* @0, i32 0, i32 0), i32 %0)
  ret void
}

Couldn't find: "main.0"

It appears there's a problem accessing symbols across modules.

Things I've tried:

  • Accessing functions instead of variables
  • Changing my symbol resolver to use 'findSymbol' instead of 'getSymbolAddressInProcess' as in the llvm-hs-examples repo. This prevented any modules from running.
  • Downloading the llvm-hs-examples repo and running the 'orc' example. This also resulted in a symbol error!
  • Re-downloading the haskell toolchain and llvm/llvm-hs (9.0.1) on a new linux install.

I'd be extremely grateful for any help!


Solution

  • Solved! I was confused by the symbol resolver. It isn't used to retrieve symbols when using 'findSymbol' but in the compiling and linking stage of the JIT. 'getSymbolAddressInProcess' will search only for symbols within the host process (such as printf), not symbols defined within the JIT (such as x).

    In order to use a module in the JIT which retrieves an external symbol 'x' from another module and 'printf' from the host process, a symbol resolver must be added which searches both the JIT compile layer and the host process for symbols:

    myResolver :: IRCompileLayer ObjectLinkingLayer -> SymbolResolver
    myResolver ircl = SymbolResolver $ \mangled -> do
        symbol <- findSymbol ircl mangled False
        case symbol of
            Right _ -> return symbol
            Left _ -> do
                ptr <- getSymbolAddressInProcess mangled
                return $ Right $ JITSymbol
                    { jitSymbolAddress = ptr 
                    , jitSymbolFlags   = defaultJITSymbolFlags { jitSymbolExported = True }
                    }