Search code examples
swaggerfastapi

How to display max file size and allowed extensions in swagger?


How to display max file size and allowed file extensions in swagger?

from fastapi import FastAPI, UploadFile, File

app = FastAPI()


@app.post('/upload/')
async def upload(file: UploadFile = File()):
    ...

Solution

  • If you only want to display the information, you can wrap the existing File() function and add a description to it (you can also use the existing description parameter of the File):

    from typing import Any, Optional
    from fastapi import FastAPI, UploadFile, File
    
    app = FastAPI()
    
    def EnhancedFile(
        *args: Any,
        max_size: Optional[int] = None,
        allowed_extensions: Optional[str] = None,
        **kwargs: Any
    ) -> Any:
        description_details = ""
        if max_size is not None:
            description_details += "\n- max size: {} bytes".format(max_size)
        if allowed_extensions is not None:
            description_details += "\n- allowed extensions: {}".format(allowed_extensions)
        kwargs["description"] = "{}{}".format(
            kwargs.get("description", ""), description_details
        )
    
        return File(*args, **kwargs)
    
    
    @app.post("/upload1/")
    async def upload1(
        file: UploadFile = EnhancedFile(max_size=100, allowed_extensions="jpeg|pdf")
    ):
        ...
    

    The above code only appends the specified arguments to the description parameter of the File. Which results in:

    enter image description here

    But if you want to implement the checking logics along with the descriptions, you should use Depends():

    from typing import Any, Optional, Callable
    from fastapi import FastAPI, UploadFile, File, Depends
    from fastapi.exceptions import HTTPException
    import re
    from inspect import signature
    from abc import ABC, abstractmethod
    
    app = FastAPI()
    
    
    class Validator(ABC):
        @property
        @abstractmethod
        def description(self) -> str:
            pass
    
        @abstractmethod
        def __call__(self, file: UploadFile) -> Any:
            pass
    
    
    class AllowedExtensions(Validator):
        def __init__(self, allowed_extnesions: str) -> None:
            self._allowed_extensions = allowed_extnesions
    
        @property
        def description(self) -> str:
            return "Allowed extensions: {}".format(self._allowed_extensions)
    
        def __call__(self, file: UploadFile) -> Any:
            if not re.match("^.*\.({})$".format(self._allowed_extensions), file.filename):
                raise HTTPException(status_code=400, detail="Invalid file extension.")
            return file
    
    
    class LimitedSize(Validator):
        def __init__(self, size_limit: int) -> None:
            self._size_limit = size_limit
    
        @property
        def description(self) -> str:
            return "Size limit: {}".format(self._size_limit)
    
        def __call__(self, file: UploadFile) -> Any:
            # Size limit logic.
            return file
    
    
    def checked(initial_file: UploadFile, *validators: Validator):
        initial_file.description = "{}\n- {}".format(
            initial_file.description or "",
            "\n- ".join(validator.description for validator in validators)
        )
    
        def _checked(file: UploadFile = initial_file):
            for validator in validators:
                validator(file)
            return file
    
        return _checked
    
    @app.post("/upload2/")
    async def upload2(
        file: UploadFile = Depends(
            checked(
                File(),
                AllowedExtensions("jpeg|pdf"),
                LimitedSize(100)
            )
        )
    ):
        ...
    

    The function checked is a dependency factory function that accepts a default value and some validators. At first, it adds the validators' descriptions to the file description. Also, in the actual dependency function, it runs validators over the specified file. The result of the execution of the above code is:

    enter image description here

    And as an example, returns an error when file extension is invalid:

    enter image description here

    You can also implement the file size limit using the Content-Length header but it doesn't prevent attackers to send invalid Content-Length value.