Search code examples
javaspringencryptionrandom-accesschunking

How to effeciently read chunk of bytes of a given range from a large encrypted file in java?


I have a large encrypted file(10GB+) in server. I need to transfer the decrypted file to the client in small chunks. When a client make a request for a chunk of bytes (say 18 to 45) I have to random access the file, read the specific bytes, decrypt it and transfer it to the client using ServletResponseStream.

But since the file is encrypted I have to read the file as blocks of 16 bytes in order to decrypt correctly.

So if client requests to get from byte 18 to 45, in the server I have to read the file in multiples of 16 bytes block. So I have to random access the file from byte 16 to 48. Then decrypt it. After decryption I have to skip 2 bytes from the first and 3 bytes from the last to return the appropriate chunk of data client requested.

Here is what I am trying to do

Adjust start and end for encrypted files

long start = 15; // input from client
long end = 45; // input from client
long skipStart = 0; // need to skip for encrypted file
long skipEnd = 0;

// encrypted files, it must be access in blocks of 16 bytes
if(fileisEncrypted){
   skipStart = start % 16;  // skip 2 byte at start
   skipEnd = 16 - end % 16; // skip 3 byte at end
   start = start - skipStart; // start becomes 16
   end = end + skipEnd; // end becomes 48
}

Access the encrypted file data from start to end

try(final FileChannel channel = FileChannel.open(services.getPhysicalFile(datafile).toPath())){
    MappedByteBuffer mappedByteBuffer = channel.map(FileChannel.MapMode.READ_ONLY, start, end-start);

    // *** No idea how to convert MappedByteBuffer into input stream ***
    // InputStream is = (How do I get inputstream for byte 16 to 48 here?)

    // the medhod I used earlier to decrypt the all file atonce, now somehow I need the inputstream of specific range
    is = new FileEncryptionUtil().getCipherInputStream(is,
                        EncodeUtil.decodeSeedValue(encryptionKeyRef), AESCipher.DECRYPT_MODE);

    // transfering decrypted input stream to servlet response
    OutputStream outputStream = response.getOutputStream();
    // *** now for chunk transfer, here I also need to 
    //     skip 2 bytes at the start and 3 bytes from the end. 
    //     How to do it? ***/
    org.apache.commons.io.IOUtils.copy(is, outputStream)
}

I am missing few steps in the code given above. I know I could try to read byte by byte and the ignore 2byte from first and 3 byte from last. But I am not sure if it will be efficient enough. Moreover, the client could request a large chunk say from byte 18 to 2048 which would require to read and decrypt almost two gigabytes of data. I am afraid creating a large byte array will consume too much memory.

How can I efficiently do it without putting too much pressure on server processing or memory? Any ideas?


Solution

  • After researching for awhile. This is how I solved it. First I created a ByteBufferInputStream class. To read from MappedByteBuffer

    public class ByteBufferInputStream extends InputStream {
        private ByteBuffer byteBuffer;
    
        public ByteBufferInputStream () {
        }
    
        /** Creates a stream with a new non-direct buffer of the specified size. The position and limit of the buffer is zero. */
        public ByteBufferInputStream (int bufferSize) {
            this(ByteBuffer.allocate(bufferSize));
            byteBuffer.flip();
        }
    
        /** Creates an uninitialized stream that cannot be used until {@link #setByteBuffer(ByteBuffer)} is called. */
        public ByteBufferInputStream (ByteBuffer byteBuffer) {
            this.byteBuffer = byteBuffer;
        }
    
        public ByteBuffer getByteBuffer () {
            return byteBuffer;
        }
    
        public void setByteBuffer (ByteBuffer byteBuffer) {
            this.byteBuffer = byteBuffer;
        }
    
        public int read () throws IOException {
            if (!byteBuffer.hasRemaining()) return -1;
            return byteBuffer.get();
        }
    
        public int read (byte[] bytes, int offset, int length) throws IOException {
            int count = Math.min(byteBuffer.remaining(), length);
            if (count == 0) return -1;
            byteBuffer.get(bytes, offset, count);
            return count;
        }
    
        public int available () throws IOException {
            return byteBuffer.remaining();
        }
    }
    

    Then created BlockInputStream class by extending InputStream which will allow to skip the extra bytes and read internal input stream in multiples of 16 bytes block.

    public class BlockInputStream extends InputStream {
        private final BufferedInputStream inputStream;
        private final long totalLength;
        private final long skip;
        private long read = 0;
        private byte[] buff = new byte[16];
        private ByteArrayInputStream blockInputStream;
    
        public BlockInputStream(InputStream inputStream, long skip, long length) throws IOException {
            this.inputStream = new BufferedInputStream(inputStream);
            this.skip = skip;
            this.totalLength = length + skip;
            if(skip > 0) {
                byte[] b = new byte[(int)skip];
                read(b);
                b = null;
            }
        }
    
    
        private int readBlock() throws IOException {
            int count = inputStream.read(buff);
            blockInputStream = new ByteArrayInputStream(buff);
            return count;
        }
    
        @Override
        public int read () throws IOException {
            byte[] b = new byte[1];
            read(b);
            return (int)b[1];
        }
    
        @Override
        public int read(byte[] b) throws IOException {
            return read(b, 0, b.length);
        }
    
        @Override
        public int read (byte[] bytes, int offset, int length) throws IOException {
            long remaining = totalLength - read;
            if(remaining < 1){
                return -1;
            }
            int bytesToRead = (int)Math.min(length, remaining);
            int n = 0;
            while(bytesToRead > 0){
                if(read % 16 == 0 && bytesToRead % 16 == 0){
                    int count = inputStream.read(bytes, offset, bytesToRead);
                    read += count;
                    offset += count;
                    bytesToRead -= count;
                    n += count;
                } else {
                    if(blockInputStream != null && blockInputStream.available() > 0) {
                        int len = Math.min(bytesToRead, blockInputStream.available());
                        int count = blockInputStream.read(bytes, offset, len);
                        read += count;
                        offset += count;
                        bytesToRead -= count;
                        n += count;
                    } else {
                        readBlock();
                    }
                }
            }
            return n;
        }
    
        @Override
        public int available () throws IOException {
            long remaining = totalLength - read;
            if(remaining < 1){
                return -1;
            }
            return inputStream.available();
        }
    
        @Override
        public long skip(long n) throws IOException {
            return inputStream.skip(n);
        }
    
        @Override
        public void close() throws IOException {
            inputStream.close();
        }
    
        @Override
        public synchronized void mark(int readlimit) {
            inputStream.mark(readlimit);
        }
    
        @Override
        public synchronized void reset() throws IOException {
            inputStream.reset();
        }
    
        @Override
        public boolean markSupported() {
            return inputStream.markSupported();
        }
    }
    

    This is my final working implementation using this two classes

    private RangeData getRangeData(RangeInfo r) throws IOException, GeneralSecurityException, CryptoException {
    
        // used for encrypted files
        long blockStart = r.getStart();
        long blockEnd = r.getEnd();
        long blockLength = blockEnd - blockStart + 1;
    
        // encrypted files, it must be access in blocks of 16 bytes
        if(datafile.isEncrypted()){
            blockStart -= blockStart % 16;
            blockEnd = blockEnd | 15; // nearest multiple of 16 for length n = ((n−1)|15)+1
            blockLength = blockEnd - blockStart + 1;
        }
    
        try ( final FileChannel channel = FileChannel.open(services.getPhysicalFile(datafile).toPath()) )
        {
            MappedByteBuffer mappedByteBuffer = channel.map(FileChannel.MapMode.READ_ONLY, blockStart, blockLength);
            InputStream inputStream = new ByteBufferInputStream(mappedByteBuffer);
            if(datafile.isEncrypted()) {
                String encryptionKeyRef = (String) settingsManager.getSetting(AppSetting.DEFAULT_ENCRYPTION_KEY);
                inputStream = new FileEncryptionUtil().getCipherInputStream(inputStream,
                        EncodeUtil.decodeSeedValue(encryptionKeyRef), AESCipher.DECRYPT_MODE);
                long skipStart = r.getStart() - blockStart;
                inputStream = new BlockInputStream(inputStream, skipStart, r.getLength()); // this will trim the data to n bytes at last
            }
            return new RangeData(r, inputStream);
        }
    }