Search code examples
javac++java-native-interface

Run jar file that only exists in memory, larger than what a java byte array can hold


To give a quick overview of what my goal is, I've built an application that injects a DLL into a java process, which then proceeds to attempt to load and run a jar file from a remote host. Note that the jar file is not and should never be present anywhere on the disk. To achieve this i wrote a custom class loader and an implementation of InputStream that reads a native buffer, passed in from my native c++ code using jni.

Since the buffer is larger than Integer.MAX_VALUE I couldn't simply pass in a byte[], so I came up with this UnsafeBufferInputStream, to read a native buffer, allocated from my c++ code that represents the jar file.

MemoryClassLoader:

public class MemoryClassLoader extends ClassLoader {
    private final Map<String, byte[]> classes = new HashMap<>();
    private final Map<String, byte[]> resources = new HashMap<>();

    public MemoryClassLoader() {

    }
    
    // Called from my native c++ code, providing the start address and length of the buffer
    public void load(long memAddress, long bufSize) {
        try {
            String tempFolderPath = System.getProperty("java.io.tmpdir"); // Get the system's temporary folder
            String className = "net.sxlver.UnsafeBufferInputStream"; // Adjust with your class name

            Class<?> unsafeStream = null;
            try {
                // Read the class file bytes from the temp folder
                Path filePath = Paths.get(tempFolderPath, "UnsafeBufferInputStream.class");
                byte[] classBytes = Files.readAllBytes(filePath);
                unsafeStream = defineClass(className, classBytes, 0, classBytes.length);
            } catch (IOException e) {
                e.printStackTrace();
            }


            //UnsafeBufferInputStream stream = (UnsafeBufferInputStream) unsafeStream.getDeclaredConstructor(long.class, long.class).newInstance(memAddress, bufSize);
            UnsafeBufferInputStream stream = new UnsafeBufferInputStream(memAddress, bufSize);

            loadClassesAndResourcesFromJar(stream);
        }catch(final Exception exception) {
            JOptionPane.showMessageDialog(null, "");
        }

    }

    private static Unsafe getUnsafe() {
        try {
            java.lang.reflect.Field field = Unsafe.class.getDeclaredField("theUnsafe");
            field.setAccessible(true);
            return (Unsafe) field.get(null);
        } catch (Exception e) {
            throw new RuntimeException("Unsafe access error: " + e.getMessage());
        }
    }

    private void loadClassesAndResourcesFromJar(UnsafeBufferInputStream stream) throws IOException {
        JarInputStream jarStream = new JarInputStream(stream);
        JarEntry entry;
        while ((entry = jarStream.getNextJarEntry()) != null) {
            if (!entry.isDirectory()) {
                byte[] buffer = new byte[1024];
                int bytesRead;
                ByteArrayOutputStream baos = new ByteArrayOutputStream();

                while ((bytesRead = jarStream.read(buffer)) != -1) {
                    baos.write(buffer, 0, bytesRead);
                }

                if (entry.getName().endsWith(".class")) {
                    classes.put(entry.getName().replace(".class", "").replace("/", "."), baos.toByteArray());

                } else {
                    resources.put(entry.getName(), baos.toByteArray());
                }
            }
        }
    }

    // Override findClass to load classes from memory
    @Override
    protected Class<?> findClass(String name) throws ClassNotFoundException {
        byte[] classBytes = classes.get(name);
        if (classBytes == null) {
            throw new ClassNotFoundException(name);
        }
        return defineClass(name, classBytes, 0, classBytes.length);
    }

    @Override
    public Class<?> loadClass(String name) throws ClassNotFoundException {
        synchronized (getClassLoadingLock(name)) {
            Class<?> loadedClass = findLoadedClass(name);
            if (loadedClass == null) {
                try {
                    loadedClass = findClass(name);
                } catch (ClassNotFoundException e) {
                    loadedClass = super.loadClass(name);
                }
            }
            return loadedClass;
        }
    }

    @Override
    public InputStream getResourceAsStream(String name) {
        byte[] resourceBytes = resources.get(name);
        if (resourceBytes != null) {
            return new ByteArrayInputStream(resourceBytes);
        }
        return super.getResourceAsStream(name);
    }
}

UnsafeBufferInputStream:

public class UnsafeBufferInputStream extends InputStream {
    private Unsafe unsafe;
    private long address;
    private long remainingBytes;
    private static final int BUFFER_SIZE = 1024; // Adjust buffer size as needed

    public UnsafeBufferInputStream(long bufferAddress, long size) {
        this.unsafe = getUnsafe();
        this.address = bufferAddress;
        this.remainingBytes = size;
    }

    private Unsafe getUnsafe() {
        try {
            java.lang.reflect.Field field = Unsafe.class.getDeclaredField("theUnsafe");
            field.setAccessible(true);
            return (Unsafe) field.get(null);
        } catch (Exception e) {
            throw new RuntimeException("Unsafe access error: " + e.getMessage());
        }
    }

    @Override
    public int read() throws IOException {
        if (remainingBytes <= 0) {
            return -1;
        }
        byte value = unsafe.getByte(address++);
        remainingBytes--;
        return value & 0xFF;
    }

    @Override
    public int read(byte[] b, int off, int len) throws IOException {
        if (remainingBytes <= 0) {
            return -1;
        }
        int bytesRead = (int) Math.min(len, remainingBytes);
        for (int i = 0; i < bytesRead; i++) {
            b[off + i] = unsafe.getByte(address++);
        }
        remainingBytes -= bytesRead;
        return bytesRead;
    }

    @Override
    public void close() {
    }
}

Now it kinda works, atleast loading the class that contains the entry point, however, i am encountering several issues whilst actually attempting to run the jar and i feel like this task is becoming way more complex than it might need to be. And I don't even wanna imagine the amount of hard to track down issues that approach might cause at runtime.

I was wondering whether there are any pre-made solutions to run jar files that solely exist in-memory that are well tested and stable.

Or maybe I should go with an different approach altogether.

I would be grateful if somebody smarter than me could point me in the right direction.


Solution

  • So I went with a modified version of the CacheClassLoader i found within the thread @g00se provided and the UnsafeBufferInputStream implementation I wrote and I've made some interesting observations.

    So first off, if attempting to load a single class file via JNI, make sure it doesn't include any inner classes. JNI really doesn't like them, at least not in this context.

    Spin up the class file in bytecodeviewer and you might see something like this:

    // The following inner classes couldn't be decompiled: java/lang/invoke/MethodHandles$Lookup

    In Java, when you use certain features introduced in Java 7 and later, such as the invokedynamic instruction or the MethodHandles API, the compiler may automatically include references to the java.lang.invoke.MethodHandles$Lookup class in the generated bytecode. This happens implicitly, and you don't need to explicitly include or import this class in your source code.

    This might happen if the class includes any lambdas. So wo watch out for that.

    For the sake of ease I ported a lot of class loading into an initializer java class that I load and call from my native code.

    But let's begin with my native code (which still contains some debug information and for the sake of ease just loads the files from the %TMP% directory)

    includes.h

    #pragma once
    #include <Windows.h>
    #include <string>
    #include <iostream>
    #include <sstream>
    #include <vector>
    #include "jvmti.h"
    

    main.cpp

    #include "includes.h"
    #include <fstream>
    #include <cstdlib>
    #include <string>
    
    int main();
    
    JNIEnv* g_jenv;
    JavaVM* g_jvm;
    
    jint InitializeVMPointers();
    
    
    BOOL WINAPI DllMain(
        HINSTANCE hinstDLL,
        DWORD fdwReason,
        LPVOID lpvReserved)
    {
        if (fdwReason == DLL_PROCESS_ATTACH)
        {
            main();
        }
    }
    
    void AllocDebugConsole()
    {
    #ifdef _DEBUG 
        AllocConsole();
        FILE* fIn;
        FILE* fOut;
        freopen_s(&fIn, "conin$", "r", stdin);
        freopen_s(&fOut, "conout$", "w", stdout);
        freopen_s(&fOut, "conout$", "w", stderr);
    #endif
    }
    
    HRESULT ReadLargeFileToMemory(const char* path, char** dest, PDWORD bufferSize)
    {
        // Test-wise and for the sake of ease just load the file from the %TMP% dir
        const char* localappdata = getenv("TMP");
        if (localappdata == nullptr)
        {
            MessageBox(NULL, L"%TMP% env could not be fetched.", L"Error", NULL);
            return 1;
        }
    
        std::string finalPath = std::string(localappdata).append(path);
    
        std::ifstream fStream = std::ifstream(finalPath.c_str(), std::ios::binary);
        if (fStream.is_open())
        {
            fStream.seekg(0, std::ios::end);
            std::streamoff fileSize = fStream.tellg();
            *bufferSize = static_cast<DWORD>(fileSize);
            fStream.seekg(0, std::ios::beg);
    
            *dest = new char[fileSize];
            if (fStream.read(*dest, fileSize))
            {
                return S_OK;
            }
        }
        return E_FAIL;
    }
    
    JNIEXPORT jclass JNICALL LoadClassFromBuffer(JNIEnv** env, jobject classLoader, const char* className, unsigned char* classBuffer, jint bufferSize)
    {
        jclass loadedClass = (*env)->DefineClass(className, classLoader, (jbyte*)classBuffer, bufferSize);
        if (loadedClass == NULL)
        {
            return (*env)->FindClass(className);
        }
        return loadedClass;
    }
    
    void printExceptionMessage(JNIEnv* env, jthrowable throwable) {
        jclass cls_Throwable = env->FindClass("java/lang/Throwable");
        jmethodID mid_getMessage = env->GetMethodID(cls_Throwable, "getMessage", "()Ljava/lang/String;");
    
        jstring messageObj = (jstring)env->CallObjectMethod(throwable, mid_getMessage);
    
        const char* message = env->GetStringUTFChars(messageObj, nullptr);
        std::cout << "Exception Message: " << message << std::endl;
        env->ReleaseStringUTFChars(messageObj, message);
    }
    
    jobject InvokeInitializerConstructor(JNIEnv** env, jclass targetClass, jlong startAddr, jlong bufLen)
    {
        jmethodID constructor = (*env)->GetMethodID(targetClass, "<init>", "(JJ)V");
        if (constructor == NULL) {
            std::cout << "Could not find constructor." << std::endl;
            return NULL;
        }
    
        std::cout << "found constructor" << std::endl;
    
        jobject instance = (*env)->NewObject(targetClass, constructor, startAddr, bufLen);
    
        std::cout << "created instance" << std::endl;
    
        if (instance == NULL) {
            printExceptionMessage(*env, (*env)->ExceptionOccurred());
            MessageBox(NULL, L"An error occurred whilst constructing initializer class.", L"", NULL);
            std::cout << "Instance null" << std::endl;
        }
    
        return instance;
    }
    
    jobject GetSystemClassLoader(JNIEnv* env) 
    {
        jclass classLoaderClass = env->FindClass("java/lang/ClassLoader");
        if (classLoaderClass == nullptr) {
            std::cout << "classLoaderClass not found" << std::endl;
            return nullptr;
        }
    
        jmethodID getSystemClassLoader = env->GetStaticMethodID(classLoaderClass, "getSystemClassLoader", "()Ljava/lang/ClassLoader;");
        if (getSystemClassLoader == nullptr) {
            std::cout << "getSystemClassLoader not found" << std::endl;
            return nullptr;
        }
    
        jobject systemClassLoader = env->CallStaticObjectMethod(classLoaderClass, getSystemClassLoader);
        if (systemClassLoader == nullptr) {
            std::cout << "getSystemClassLoader NULL" << std::endl;
            return nullptr;
        }
    
        return systemClassLoader;
    }
    
    jclass DefineClassFromByteArray(JNIEnv* env, const char* className, jbyte* classFileData, jsize size) 
    {
        jobject classLoader = GetSystemClassLoader(env);
        if (classLoader == nullptr) {
            MessageBox(NULL, L"Unable to get system class loader.", L"Error", NULL);
            return nullptr;
        }
    
        jclass newClass = env->DefineClass(className, classLoader, classFileData, size);
        if (newClass == nullptr) {
            MessageBox(NULL, L"Unable to load initializer class.", L"Error", NULL);
            return nullptr;
        }
        return newClass;
    }
    
    jsize convertToJSize(DWORD size) {
        return static_cast<jsize>(size);
    }
    
    jbyte* convertToJByte(char* rawData, jlong size)
    {
        jbyte* buffer = new jbyte[size];
        for (jlong i = 0; i < size; ++i) {
            buffer[i] = static_cast<jbyte>(rawData[i]);
        }
        return reinterpret_cast<jbyte*>(buffer);
    }
    
    void printClassInfo(JNIEnv* env, jclass clazz)
    {
        // Get the constructors
        jclass cls_Class = env->FindClass("java/lang/Class");
        jmethodID mid_getConstructors = env->GetMethodID(cls_Class, "getConstructors", "()[Ljava/lang/reflect/Constructor;");
        jobjectArray constructors = (jobjectArray)env->CallObjectMethod(clazz, mid_getConstructors);
    
        jsize constructorsCount = env->GetArrayLength(constructors);
        std::cout << "Number of Constructors: " << constructorsCount << std::endl;
    
        // Iterate through constructors and print their details
        for (jsize i = 0; i < constructorsCount; ++i) {
            jobject constructor = env->GetObjectArrayElement(constructors, i);
    
            jclass cls_Constructor = env->FindClass("java/lang/reflect/Constructor");
            jmethodID mid_toString = env->GetMethodID(cls_Constructor, "toString", "()Ljava/lang/String;");
            jstring constructorStr = (jstring)env->CallObjectMethod(constructor, mid_toString);
    
            const char* str = env->GetStringUTFChars(constructorStr, nullptr);
            std::cout << "Constructor " << (i + 1) << ": " << str << std::endl;
            env->ReleaseStringUTFChars(constructorStr, str);
        }
    }
    
    int main()
    {
        AllocDebugConsole();
        jint res = InitializeVMPointers();
        if (res != JNI_OK)
        {
            MessageBox(NULL, L"Unable to initialize VM pointers.", L"Error", NULL);
            return 1;
        }
    
        const char* localappdata = getenv("TMP");
        if (localappdata == nullptr)
        {
            MessageBox(NULL, L"%TMP% env could not be fetched.", L"Error", NULL);
            return 1;
        }
    
        std::string localAppdataPath = std::string(localappdata);
    
    
        char* buf;
        DWORD size;
        ReadLargeFileToMemory("\\Initializer.class", &buf, &size);
    
        std::cout << std::to_string(reinterpret_cast<__int64>(buf)) << std::endl;
        std::cout << std::to_string(size) << std::endl;
    
        jclass definedClass = DefineClassFromByteArray(g_jenv, "net/sxlver/Initializer", convertToJByte(buf, size), convertToJSize(size));
        if (definedClass == NULL)
        {
            std::cout << "definedClass is null" << std::endl;
        }
    
        printClassInfo(g_jenv, definedClass);
        char* jarFileBuf;
        DWORD jarFileSize = 0;
        ReadLargeFileToMemory("\\myjar.jar", &jarFileBuf, &jarFileSize);
    
        jobject result = InvokeInitializerConstructor(&g_jenv, definedClass, reinterpret_cast<__int64>(jarFileBuf), static_cast<jlong>(jarFileSize));
        return 0;
    }
    
    jint InitializeVMPointers() {
        jsize count;
        if (JNI_GetCreatedJavaVMs(&g_jvm, 1, &count) != JNI_OK || count == 0) {
            return JNI_ERR;
        }
    
        jint res = g_jvm->GetEnv((void**)&g_jenv, JNI_VERSION_1_6);
        if (res == JNI_EDETACHED) {
            res = g_jvm->AttachCurrentThread((void**)&g_jenv, nullptr);
        }
    
        if (res != JNI_OK) {
            return JNI_ERR;
        }
    
        return JNI_OK;
    }
    

    Now for the java code I have modified the CacheClassLoader a little. It's still a mess but it works.

    Note that all relevant .class files are currently located in the %TMP% dir and share the same package.

    Once again, the defineClass method doesn't like inner classes so I split the inner classes contained by the CacheClassLoader into their own source files.

    Initializer.java

    package net.sxlver;
    
    
    import java.lang.reflect.Constructor;
    import java.lang.reflect.Method;
    import java.nio.file.Files;
    import java.nio.file.Path;
    import java.nio.file.Paths;
    import java.util.HashMap;
    import java.util.Map;
    
    public class Initializer {
    
        private final long startAddr;
        private final long bufLen;
    
        private Map<String, Class<?>> classCache = new HashMap<>();
    
        public Initializer(long startAddress, long bufLen) throws Exception {
            this.startAddr = startAddress;
            this.bufLen = bufLen;
            try {
                load();
            } catch (final Exception e) {
                System.out.println(String.format("Exception thrown: %s", e.getMessage()));
            }
        }
    
        public void load() throws Exception {
            // load all necessary classes
            loadClasses();
    
            Class<?> classLoaderClass = classCache.get("net.sxlver.CacheClassLoader");
            Constructor<?> classLoaderConstructor = classLoaderClass.getDeclaredConstructor(ClassLoader.class);
            CacheClassLoader classLoader = (CacheClassLoader) classLoaderConstructor.newInstance(ClassLoader.getSystemClassLoader());
    
            System.out.println(String.format("start address: %d buf len: %d", startAddr, bufLen));
    
            classLoader.addJar("myjar.jar", new UnsafeBufferInputStream(startAddr, bufLen));
    
            // the entry point of the jar we're loading
            Class<?> cls = Class.forName("net.sxlver.myjar.Main", true, classLoader);
            System.out.println("class present");
        }
    
        private void loadClasses() {
            String tempFolderPath = System.getProperty("java.io.tmpdir");
            loadClass(tempFolderPath, "UnsafeBufferInputStream.class", "net.sxlver.UnsafeBufferInputStream");
            loadClass(tempFolderPath, "CacheURLConnection.class", "net.sxlver.CacheURLConnection");
            loadClass(tempFolderPath, "CacheURLStreamHandler.class", "net.sxlver.CacheURLStreamHandler");
            loadClass(tempFolderPath, "CacheClassLoader.class", "net.sxlver.CacheClassLoader");
        }
    
        private void loadClass(final String dir, final String fileName, final String className) {
            
            Class<?> clazz = null;
            try {
                // Read the class file bytes from the temp folder
                Path filePath = Paths.get(dir, fileName);
                byte[] classBytes = Files.readAllBytes(filePath);
    
                final Method method = ClassLoader.class.getDeclaredMethod("defineClass", String.class, byte[].class, int.class, int.class);
                method.setAccessible(true);
                clazz = (Class<?>) method.invoke(ClassLoader.getSystemClassLoader(), className, classBytes, 0, classBytes.length);
            } catch (Exception e) {
                System.out.println(String.format("Exception: %s", e.getMessage()));
            }
    
            System.out.println(String.format("loaded class: %s", clazz));
            classCache.put(className, clazz);
    
        }
    }
    

    CacheClassLoader:

    /*
     * Licensed to the Apache Software Foundation (ASF) under one
     * or more contributor license agreements.  See the NOTICE file
     * distributed with this work for additional information
     * regarding copyright ownership.  The ASF licenses this file
     * to you under the Apache License, Version 2.0 (the
     * "License") +  you may not use this file except in compliance
     * with the License.  You may obtain a copy of the License at
     *
     *   http://www.apache.org/licenses/LICENSE-2.0
     *
     * Unless required by applicable law or agreed to in writing,
     * software distributed under the License is distributed on an
     * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
     * KIND, either express or implied.  See the License for the
     * specific language governing permissions and limitations
     * under the License.
     */
    package net.sxlver;
    
    import java.net.*;
    import java.nio.file.Files;
    import java.nio.file.Path;
    import java.nio.file.Paths;
    import java.util.*;
    import java.util.jar.*;
    import java.io.*;
    
    import sun.misc.Unsafe;
    
    import javax.swing.*;
    import java.util.concurrent.ConcurrentHashMap;
    import java.util.concurrent.CopyOnWriteArrayList;
    
    public class CacheClassLoader extends URLClassLoader {
    
        public static final int BUFFER_SIZE = 8192;
        public static final String CLASS_SUFFIX = ".class";
        public static final String JAVA_SUFFIX = ".java";
        public final static String host = null;
        public final static int port = -1;
        public final static String baseURI = "/";
        public final static String DOT = ".";
        public final static String EMPTY = "";
        
        private final static String protocol = "x-mem-cache";
    
        public final Map<String, byte[]> CACHE;
        private URL cacheURL = null;
        private final List<String> jars = new CopyOnWriteArrayList<>();
    
        public boolean loadAllJar;
    
        public CacheClassLoader(URL[] urls, ClassLoader parent, boolean loadAllJar) throws MalformedURLException {
            super(urls, parent);
    
            CACHE = new ConcurrentHashMap<>();
            this.loadAllJar = loadAllJar;
            cacheURL = new URL(protocol, host, port, baseURI, new CacheURLStreamHandler(this));
    
            super.addURL(cacheURL);
        }
    
        public CacheClassLoader() throws MalformedURLException {
            this(new URL[]{}, ClassLoader.getSystemClassLoader(), true);
        }
    
        public CacheClassLoader(ClassLoader parent) throws MalformedURLException {
            this(new URL[]{}, parent, true);
        }
    
        public CacheClassLoader(ClassLoader parent, boolean loadAllJar) throws MalformedURLException {
            this(new URL[]{}, parent, loadAllJar);
        }
    
        @Override
        public void close() throws IOException {
    
            CACHE.clear();
            jars.clear();
            super.close();
    
        }
    
        public void addJAR(String jarFileName) throws IOException {
            if (loadAllJar == true) {
                addCode(jarFileName, "");
            } else {
                jars.add(jarFileName);
            }
        }
    
        public void addJar(String jarFileName, UnsafeBufferInputStream stream) throws IOException {
            if (loadAllJar == true) {
                addCode(jarFileName, stream);
            } else {
                jars.add(jarFileName);
            }
        }
    
        public boolean readJarFromNativeBuffer(String jarName, long startAddress, long bufSize) {
            String tempFolderPath = System.getProperty("java.io.tmpdir"); // Get the system's temporary folder
            String className = "net.sxlver.UnsafeBufferInputStream"; // Adjust with your class name
    
            Class<?> unsafeStream = null;
            try {
                Path filePath = Paths.get(tempFolderPath, "UnsafeBufferInputStream.class");
                byte[] classBytes = Files.readAllBytes(filePath);
                unsafeStream = defineClass(className, classBytes, 0, classBytes.length);
            } catch (IOException e) {
                e.printStackTrace();
            }
    
    
            //UnsafeBufferInputStream stream = (UnsafeBufferInputStream) unsafeStream.getDeclaredConstructor(long.class, long.class).newInstance(memAddress, bufSize);
            UnsafeBufferInputStream stream = new UnsafeBufferInputStream(startAddress, bufSize);
            try {
                addJar(jarName, stream);
                return true;
            }catch(final Exception e) {
                return false;
            }
        }
    
        public void addClass(String path, String packageName, String className) throws IOException {
            String class_name = path + java.io.File.separatorChar + className + CLASS_SUFFIX;
            try(FileInputStream inputStream = new FileInputStream(class_name)) {
                add_class(inputStream, packageName, className);
            }
        }
    
        public void addClass(InputStream inputStream, String packageName, String className) throws IOException {
            add_class(inputStream, packageName, className);
        }
    
        public void addDir(String directory) throws IOException {
            if (directory == null) {
                throw new FileNotFoundException("Directory name is empty.");
            }
    
            List<File> jars = Arrays.asList(new File(directory).listFiles());
            if (jars == null) {
                return;
            }
            for (File jar : jars) {
                if (jar == null) {
                    continue;
                }
                if (jar.isDirectory() == true) {
                    addDir(jar.getAbsolutePath());
                    return;
                }
    
                addJAR(jar.getAbsolutePath());
            }
    
        }
    
        public void addDir(String directory, String directorySeparator) throws IOException {
            if (directory == null) {
                throw new FileNotFoundException("Directory name is empty.");
            }
    
            String[] dirs = directory.split(directorySeparator);
            if (dirs == null) {
                throw new FileNotFoundException("Directories name are empty.");
            }
    
            for (String dir : dirs) {
    
                if (dir == null) {
                    continue;
                }
                addDir(dir);
            }
    
        }
    
    
        public boolean addCode(String fileName, String jar) throws IOException {
            try(final FileInputStream inputStream = new FileInputStream(jar)) {
                return addCode(fileName, inputStream);
            }
        }
    
        public boolean addCode(String fileName, InputStream inputStream) throws IOException {
            BufferedInputStream bis = null;
            JarInputStream jis = null;
            ByteArrayOutputStream out;
            String name;
            byte[] b = new byte[BUFFER_SIZE];
            int len = 0;
    
            try {
                bis = new BufferedInputStream(inputStream);
                jis = new JarInputStream(bis);
    
                JarEntry jarEntry;
                while ((jarEntry = jis.getNextJarEntry()) != null) {
                    name = baseURI + jarEntry.getName();
    
                    if (jarEntry.isDirectory()) {
                        continue;
                    }
    
                    if (CACHE.containsKey(name)) {
                        continue;
                    }
    
                    
                    
                    if (loadAllJar != true && fileName.equals(name) == false) {
                        continue;
                    }
    
                    out = new ByteArrayOutputStream();
    
                    while ((len = jis.read(b)) > 0) {
                        out.write(b, 0, len);
                    }
    
                    CACHE.put(name, out.toByteArray());
                    out.close();
                    if (loadAllJar != true) {
                        return true;
                    }
    
                }
            } finally {
                if (jis != null) {
                    jis.close();
                }
                if (bis != null) {
                    bis.close();
                }
            }
            return false;
        }
    
        public void addCode(String fileName) throws IOException {
    
            for (String jar : jars) {
    
                if (jar == null) {
                    continue;
                }
                if (addCode(jar, fileName) == true) {
                    return;
                }
            }
    
        }
    
        private void add_class(InputStream classStream, String packageName, String className) throws IOException {
    
            FileInputStream fis = null;
            BufferedInputStream bis = null;
            ByteArrayOutputStream out = null;
            byte[] b = new byte[BUFFER_SIZE];
            int len = 0;
            String name = null;
    
            if (classStream == null) {
                throw new FileNotFoundException("Class file is empty.");
            }
            if (packageName == null) {
                throw new FileNotFoundException("Package name is empty.");
            }
    
            name = baseURI + packageName.replace(DOT, baseURI) + baseURI + className + CLASS_SUFFIX;
            if (CACHE.containsKey(name)) {
                return;
            }
    
            try {
                bis = new BufferedInputStream(classStream);
                out = new ByteArrayOutputStream();
    
                while ((len = bis.read(b)) > 0) {
                    out.write(b, 0, len);
                }
                CACHE.put(name, out.toByteArray());
    
                out.close();
            } finally {
                if (bis != null) {
                    bis.close();
                }
                if (fis != null) {
                    fis.close();
                }
    
            }
        }
    
    }
    

    CacheURLConnection

    package net.sxlver;
    
    import java.io.ByteArrayInputStream;
    import java.io.FileNotFoundException;
    import java.io.IOException;
    import java.io.InputStream;
    import java.net.URL;
    import java.net.URLConnection;
    
    public class CacheURLConnection extends URLConnection {
    
        private final CacheClassLoader classLoader;
    
        public CacheURLConnection(CacheClassLoader classLoader, URL url) {
            super(url);
            this.classLoader = classLoader;
        }
    
        @Override
        public void connect() throws IOException {
        }
    
        @Override
        public InputStream getInputStream() throws IOException {
            String file_name = url.getFile();
    
            byte[] data = classLoader.CACHE.get(file_name);
    
            if (classLoader.loadAllJar != true) {
                if (data == null) {
                    classLoader.addCode(file_name);
                }
    
                data = classLoader.CACHE.get(file_name);
            }
    
            if (data == null) {
                throw new FileNotFoundException(file_name);
            }
    
            return new ByteArrayInputStream(data);
        }
    }
    

    CacheURLStreamHandler

    package net.sxlver;
    
    import java.io.IOException;
    import java.net.URL;
    import java.net.URLConnection;
    import java.net.URLStreamHandler;
    
    public class CacheURLStreamHandler extends URLStreamHandler {
    
        private final CacheClassLoader classLoader;
    
        public CacheURLStreamHandler(CacheClassLoader classLoader) {
            this.classLoader = classLoader;
        }
    
        @Override
        protected URLConnection openConnection(URL url) throws IOException {
            return new CacheURLConnection(classLoader, url);
        }
    
    }
    
    

    UnsafeBufferInputStream

    public class UnsafeBufferInputStream extends InputStream {
        private Unsafe unsafe;
        private long address;
        private long remainingBytes;
        private static final int BUFFER_SIZE = 1024; // Adjust buffer size as needed
    
        public UnsafeBufferInputStream(long startAddr, long bufLen) {
            this.unsafe = getUnsafe();
            this.address = startAddr;
            this.remainingBytes = bufLen;
        }
    
        private Unsafe getUnsafe() {
            try {
                java.lang.reflect.Field field = Unsafe.class.getDeclaredField("theUnsafe");
                field.setAccessible(true);
                return (Unsafe) field.get(null);
            } catch (Exception e) {
                throw new RuntimeException("Unsafe access error: " + e.getMessage());
            }
        }
    
        @Override
        public int read() throws IOException {
            if (remainingBytes <= 0) {
                return -1;
            }
            byte value = unsafe.getByte(address++);
            remainingBytes--;
            return value & 0xFF;
        }
    
        @Override
        public int read(byte[] b, int off, int len) throws IOException {
            if (remainingBytes <= 0) {
                return -1;
            }
            int bytesRead = (int) Math.min(len, remainingBytes);
            for (int i = 0; i < bytesRead; i++) {
                b[off + i] = unsafe.getByte(address++);
            }
            remainingBytes -= bytesRead;
            return bytesRead;
        }
    
        @Override
        public void close() {
        }
    }
    

    Also note that the main.cpp from my native code is actuall the entry point of a DLL that's been injected into a java process.

    This code will:

    • Fetch the JNIEnv and JavaVM pointers
    • Load a Initializer.class File into memory
    • Define the class using the system classloader from ClassLoader#getSystemClassLoader
    • Call the constructor of Initializer.class with the start address of our native buffer that represents a jar file and it's length
    • Load all contents from the buffer representing the .jar file into a cache
    • get the Class<?> object of the entry point of the .jar file.