Search code examples
pythonpydanticpydantic-v2

How to create custom Pydantic type for Python's "ElementTree"


I want to create a custom Pydantic type for Python's ElementTree. It should accept string as input and type to parse the XML string via ElementTree.fromstring and raise appropriate error is invalid XML is found.

I tried creating a custom type using pydantic Annotated method but got an error saying ElementTree.Element does not define __get_pydantic_core_schema__.


Solution

  • I could not find much resources regarding this. So this is how I proceeded and it is working quite well.

    First create an _ElementTreeAnnotation:

    import xml.etree.ElementTree as _ET
    from typing import Annotated, Any, Optional
    
    from defusedxml.ElementTree import fromstring, tostring
    from pydantic import GetCoreSchemaHandler
    from pydantic_core import core_schema
    
    class _ElementTreeAnnotation:
        @classmethod
        def validate_value(cls, value: str) -> _ET.Element:
            """
            Validates input value; if it's a string, it tries to parse it as XML
            """
            try:
                tree = fromstring(value)
            except (_ET.ParseError, ValueError):
                raise ValueError("Invalid XML: cannot parse")
    
            return tree
    
        @classmethod
        def serialize(cls, instance: _ET.Element) -> str:
            """
            Serializes an ET.Element instance to an XML string.
            """
            return tostring(instance, encoding="unicode")
    
        @classmethod
        def _discriminator(cls, v: Any) -> Optional[str]:
            if isinstance(v, str):
                return "from_str"
    
            if isinstance(v, _ET.Element):
                return "from_instance"
    
            return None
    
        @classmethod
        def __get_pydantic_core_schema__(
            cls,
            _source_type: Any,
            _handler: GetCoreSchemaHandler,
        ) -> core_schema.CoreSchema:
            """
            Defines a Pydantic core schema for XML element validation and serialization.
            We return a pydantic_core.CoreSchema that behaves in the following ways:
    
            * `str` will be parsed as `etree.Element` instances
            * `etree.Element` instances will be parsed as `etree.Element` instances without any changes
            * Nothing else will pass validation
            * Serialization will always return just return string representation of `etree.Element`
            """
    
            # schema for validating a string and converting it to an ElementTree element
            from_str_schema = core_schema.chain_schema(
                [
                    core_schema.str_schema(),
                    core_schema.no_info_plain_validator_function(cls.validate_value),
                ]
            )
    
            # json schema: always use "from_str_schema"
            # python schema: use discriminated unions
            # tagged unions: ElementTree type is returned as it is while for string types "from_str_schema" is used
            # for any other data type custom error "invalid_union_tag" is raised
            # serialization: serialize ElementTree object to string
            return core_schema.json_or_python_schema(
                json_schema=from_str_schema,
                python_schema=core_schema.tagged_union_schema(
                    choices={
                        "from_instance": core_schema.is_instance_schema(_ET.Element),
                        "from_str": from_str_schema
                    },
                    discriminator=cls._discriminator,
                    custom_error_type="invalid_union_tag",
                    custom_error_message="Unable to extract union tag using discriminator",
                    custom_error_context={
                        "discriminator": "String or etree.Element"
                    }
                ),
                serialization=core_schema.plain_serializer_function_ser_schema(cls.serialize)
            )
    
    

    Then create a custom pydantic type:

    # create custom pydantic type for python's ElementTree.Element
    XmlElement = Annotated[
        _ET.Element, _ElementTreeAnnotation
    ]
    

    Now you can use it like this:

    class FooBar(BaseModel):
        response_xml: XmlElement = Field(description="Response XML")
    

    I have added comments to explain what each line is doing. The gist of it is:

    • str will be parsed as etree.Element instances.
    • etree.Element instances will be parsed as etree.Element instances without any changes.
    • Nothing else will pass validation. It will raise a pydantic custom error with type invalid_union_tag.
    • Serialization will always return string representation of etree.Element.

    Notes:

    • I am using defusedxml package to parse untrusted XML.

    References: