Search code examples
javaspringspring-bootmultipartform-data

Spring POST multipart/form-data, request parts always empty


I have a simple REST controller that I use for accepting a file being uploaded from a HTML form. The project is Spring Boot 2.6.1 and Java 17. But the problem was also to be found in Spring Boot 2.3.7 and Java 15.

@PostMapping(path = "/file", consumes = MediaType.MULTIPART_FORM_DATA_VALUE)
public void handleFileUpload(@RequestParam("file") MultipartFile file) {
    fileService.upload(file.getInputStream(), file.getOriginalFilename());
}

The problem is file is always NULL. I found a lot of different answers about setting a MultipartResolver bean or enabling spring.http.multipart.enabled = true but nothing helped. I have a logging filter as one of the first filters in the chain. After debugging in the filter chain I found out that making a call to request.getParts() made everything work. My filter look like this:

public class LoggingFilter extends GenericFilterBean {

    @Override
    public void doFilter(ServletRequest request, ServletResponse response, FilterChain filterChain) throws IOException, ServletException {
        HttpServletRequest httpServletRequest = (HttpServletRequest) request;
        BufferedRequestWrapper bufferedRequest = new BufferedRequestWrapper(httpServletRequest);
        BufferedResponseWrapper bufferedResponse = new BufferedResponseWrapper((HttpServletResponse) response);

        filterChain.doFilter(bufferedRequest, bufferedResponse);

        logRequest(httpServletRequest, bufferedRequest);
        logResponse(httpServletRequest, bufferedResponse);
    }

I changed the filter to:

public class LoggingFilter extends GenericFilterBean {

    @Override
    public void doFilter(ServletRequest request, ServletResponse response, FilterChain filterChain) throws IOException, ServletException {
        HttpServletRequest httpServletRequest = (HttpServletRequest) request;

        if (request.getContentType() != null && request.getContentType().startsWith("multipart/form-data")) {
            httpServletRequest.getParts(); // Trigger initialization of multi-part.
        }

        BufferedRequestWrapper bufferedRequest = new BufferedRequestWrapper(httpServletRequest);
        BufferedResponseWrapper bufferedResponse = new BufferedResponseWrapper((HttpServletResponse) response);

        filterChain.doFilter(bufferedRequest, bufferedResponse);

        logRequest(httpServletRequest, bufferedRequest);
        logResponse(httpServletRequest, bufferedResponse);
    }

and everything was working. My question is; why is this needed? And is there a better way of doing this?

Below is a complete example where only the actual logging is removed because we use a custom logging framework.

package com.unwire.ticketing.filter.logging;

import lombok.Getter;
import org.apache.commons.io.IOUtils;
import org.apache.commons.io.output.TeeOutputStream;
import org.springframework.web.filter.GenericFilterBean;

import javax.servlet.*;
import javax.servlet.http.*;
import java.io.*;
import java.nio.charset.StandardCharsets;
import java.util.Collection;
import java.util.Locale;
import java.util.stream.Collectors;

public class Log extends GenericFilterBean {

    @Override
    public void doFilter(ServletRequest request, ServletResponse response, FilterChain filterChain) throws IOException, ServletException {
        HttpServletRequest httpServletRequest = (HttpServletRequest) request;

        if (request.getContentType() != null && request.getContentType().startsWith("multipart/form-data")) {
            httpServletRequest.getParts(); // Trigger initialization of multi-part.
        }

        try {
            BufferedRequestWrapper bufferedRequest = new BufferedRequestWrapper(httpServletRequest);
            BufferedResponseWrapper bufferedResponse = new BufferedResponseWrapper((HttpServletResponse) response);

            filterChain.doFilter(bufferedRequest, bufferedResponse);

            logRequest(httpServletRequest, bufferedRequest);
            logResponse(httpServletRequest, bufferedResponse);
        } catch (Throwable t) {

        }
    }


    private void logRequest(HttpServletRequest request, BufferedRequestWrapper bufferedRequest) throws IOException {
        String body = bufferedRequest.getRequestBody();
        // Log request
    }

    private void logResponse(HttpServletRequest httpServletRequest, BufferedResponseWrapper bufferedResponse) {
        // Log response
    }

    private static final class BufferedRequestWrapper extends HttpServletRequestWrapper {

        private final byte[] buffer;

        BufferedRequestWrapper(HttpServletRequest req) throws IOException {
            super(req);

            if (req.getContentType() == null || (req.getContentType() != null && !req.getContentType().startsWith("application/x-www-form-urlencoded"))) {
                // Read InputStream and store its content in a buffer.
                InputStream is = req.getInputStream();
                ByteArrayOutputStream baos = new ByteArrayOutputStream();
                byte[] buf = new byte[1024];
                int read;
                while ((read = is.read(buf)) > 0) {
                    baos.write(buf, 0, read);
                }
                this.buffer = baos.toByteArray();
            } else {
                buffer = new byte[0];
            }
        }

        @Override
        public ServletInputStream getInputStream() {
            return new BufferedServletInputStream(new ByteArrayInputStream(this.buffer));
        }

        @Override
        public Collection<Part> getParts() throws IOException, ServletException {
            return super.getParts();
        }

        String getRequestBody() throws IOException {
            return IOUtils.readLines(this.getInputStream(), StandardCharsets.UTF_8.name()).stream()
                    .map(String::trim)
                    .collect(Collectors.joining());
        }
    }

    private static final class BufferedServletInputStream extends ServletInputStream {

        private final ByteArrayInputStream bais;

        BufferedServletInputStream(ByteArrayInputStream bais) {
            this.bais = bais;
        }

        @Override
        public int available() {
            return this.bais.available();
        }

        @Override
        public int read() {
            return this.bais.read();
        }

        @Override
        public int read(byte[] buf, int off, int len) {
            return this.bais.read(buf, off, len);
        }

        @Override
        public boolean isFinished() {
            return false;
        }

        @Override
        public boolean isReady() {
            return true;
        }

        @Override
        public void setReadListener(ReadListener readListener) {

        }
    }

    public static class TeeServletOutputStream extends ServletOutputStream {

        private final TeeOutputStream targetStream;

        TeeServletOutputStream(OutputStream one, OutputStream two) {
            targetStream = new TeeOutputStream(one, two);
        }

        @Override
        public void write(int arg0) throws IOException {
            this.targetStream.write(arg0);
        }

        public void flush() throws IOException {
            super.flush();
            this.targetStream.flush();
        }

        public void close() throws IOException {
            super.close();
            this.targetStream.close();
        }

        @Override
        public boolean isReady() {
            return false;
        }

        @Override
        public void setWriteListener(WriteListener writeListener) {

        }
    }

    public class BufferedResponseWrapper implements HttpServletResponse {

        HttpServletResponse original;
        TeeServletOutputStream tee;
        ByteArrayOutputStream bos;
        @Getter
        Long startTime;

        BufferedResponseWrapper(HttpServletResponse response) {
            this.original = response;
            this.startTime = System.currentTimeMillis();
        }

        String getContent() {
            if (bos != null) {
                return bos.toString();
            } else {
                return "";
            }
        }

        @Override
        public PrintWriter getWriter() throws IOException {
            return original.getWriter();
        }

        @Override
        public ServletOutputStream getOutputStream() throws IOException {
            if (tee == null) {
                bos = new ByteArrayOutputStream();
                tee = new TeeServletOutputStream(original.getOutputStream(), bos);
            }
            return tee;
        }

        @Override
        public String getCharacterEncoding() {
            return original.getCharacterEncoding();
        }

        @Override
        public void setCharacterEncoding(String charset) {
            original.setCharacterEncoding(charset);
        }

        @Override
        public String getContentType() {
            return original.getContentType();
        }

        @Override
        public void setContentType(String type) {
            original.setContentType(type);
        }

        @Override
        public void setContentLength(int len) {
            original.setContentLength(len);
        }

        @Override
        public void setContentLengthLong(long l) {
            original.setContentLengthLong(l);
        }

        @Override
        public int getBufferSize() {
            return original.getBufferSize();
        }

        @Override
        public void setBufferSize(int size) {
            original.setBufferSize(size);
        }

        @Override
        public void flushBuffer() throws IOException {
            if (tee != null) {
                tee.flush();
            }
        }

        @Override
        public void resetBuffer() {
            original.resetBuffer();
        }

        @Override
        public boolean isCommitted() {
            return original.isCommitted();
        }

        @Override
        public void reset() {
            original.reset();
        }

        @Override
        public Locale getLocale() {
            return original.getLocale();
        }

        @Override
        public void setLocale(Locale loc) {
            original.setLocale(loc);
        }

        @Override
        public void addCookie(Cookie cookie) {
            original.addCookie(cookie);
        }

        @Override
        public boolean containsHeader(String name) {
            return original.containsHeader(name);
        }

        @Override
        public String encodeURL(String url) {
            return original.encodeURL(url);
        }

        @Override
        public String encodeRedirectURL(String url) {
            return original.encodeRedirectURL(url);
        }

        @Override
        public void sendError(int sc, String msg) throws IOException {
            original.sendError(sc, msg);
        }

        @Override
        public void sendError(int sc) throws IOException {
            original.sendError(sc);
        }

        @Override
        public void sendRedirect(String location) throws IOException {
            original.sendRedirect(location);
        }

        @Override
        public void setDateHeader(String name, long date) {
            original.setDateHeader(name, date);
        }

        @Override
        public void addDateHeader(String name, long date) {
            original.addDateHeader(name, date);
        }

        @Override
        public void setHeader(String name, String value) {
            original.setHeader(name, value);
        }

        @Override
        public void addHeader(String name, String value) {
            original.addHeader(name, value);
        }

        @Override
        public void setIntHeader(String name, int value) {
            original.setIntHeader(name, value);
        }

        @Override
        public void addIntHeader(String name, int value) {
            original.addIntHeader(name, value);
        }

        @Override
        public String getHeader(String arg0) {
            return original.getHeader(arg0);
        }

        @Override
        public Collection<String> getHeaderNames() {
            return original.getHeaderNames();
        }

        @Override
        public Collection<String> getHeaders(String arg0) {
            return original.getHeaders(arg0);
        }

        @Override
        public int getStatus() {
            return original.getStatus();
        }

        @Override
        public void setStatus(int sc) {
            original.setStatus(sc);
        }
    }
}

Solution

  • Please consider using ContentCachingRequestWrapper.

    It's built-in of spring which help you can read caches all content read from the input stream and reader.

    Be aware, with multipart file, spring already have a wrapper ... MultipartHttpServletRequest

    Please refer: https://docs.spring.io/spring-framework/docs/current/javadoc-api/org/springframework/web/util/ContentCachingRequestWrapper.html