8

I want to specify a marshmallow schema. For one of my fields, I want it to be validated however it can be EITHER a string or a list of strings. I have tried the Raw field type however that is allows everything through. Is there a way to just validate the two types that I want?

Something like,

value = fields.Str() or fields.List()
Jimmy Jo
  • 121
  • 1
  • 1
  • 4

4 Answers4

9

I had the same issue today, and I came up with this solution:

class ValueField(fields.Field):
    def _deserialize(self, value, attr, data, **kwargs):
        if isinstance(value, str) or isinstance(value, list):
            return value
        else:
            raise ValidationError('Field should be str or list')


class Foo(Schema):
    value = ValueField()
    other_field = fields.Integer()

You can create a custom field and overload the _deserialize method so that it validates if the code isinstance of desired types. I hope it'll work for you.

foo.load({'value': 'asdf', 'other_field': 1})
>>> {'other_field': 1, 'value': 'asdf'}
foo.load({'value': ['asdf'], 'other_field': 1})
>>> {'other_field': 1, 'value': ['asdf']}
foo.load({'value': 1, 'other_field': 1})
Traceback (most recent call last):
  File "<input>", line 1, in <module>
  File "/Users/webinterpret/Envs/gl-gs-onboarding-api/lib/python3.7/site-packages/marshmallow/schema.py", line 723, in load
    data, many=many, partial=partial, unknown=unknown, postprocess=True
  File "/Users/webinterpret/Envs/gl-gs-onboarding-api/lib/python3.7/site-packages/marshmallow/schema.py", line 904, in _do_load
    raise exc
marshmallow.exceptions.ValidationError: {'value': ['Field should be str or list']}
Piotr BG
  • 333
  • 1
  • 6
5

solution for Mapping(s), similar to the above:

from typing import List, Mapping, Any
from marshmallow import Schema, fields
from marshmallow.exceptions import ValidationError


class UnionField(fields.Field):
    """Field that deserializes multi-type input data to app-level objects."""

    def __init__(self, val_types: List[fields.Field]):
        self.valid_types = val_types
        super().__init__()

    def _deserialize(
        self, value: Any, attr: str = None, data: Mapping[str, Any] = None, **kwargs
    ):
        """
        _deserialize defines a custom Marshmallow Schema Field that takes in mutli-type input data to
        app-level objects.
        
        Parameters
        ----------
        value : {Any}
            The value to be deserialized.
        
        Keyword Parameters
        ----------
        attr : {str} [Optional]
            The attribute/key in data to be deserialized. (default: {None})
        data : {Optional[Mapping[str, Any]]}
            The raw input data passed to the Schema.load. (default: {None})
        
        Raises
        ----------
        ValidationError : Exception
            Raised when the validation fails on a field or schema.
        """
        errors = []
        # iterate through the types being passed into UnionField via val_types
        for field in self.valid_types:
            try:
                # inherit deserialize method from Fields class
                return field.deserialize(value, attr, data, **kwargs)
            # if error, add error message to error list
            except ValidationError as error:
                errors.append(error.messages)
                raise ValidationError(errors)

Use:

class SampleSchema(Schema):
    ex_attr = fields.Dict(keys=fields.Str(), values=UnionField([fields.Str(), fields.Number()]))

Credit: Anna K

bwl1289
  • 1,655
  • 1
  • 12
  • 10
3

Solution

Based on @bwl1289 answer. In addition, this custom field is inspired by from typing import Union.

# encoding: utf-8
"""
Marshmallow fields
------------------

Extension on the already available marshmallow fields
"""
from marshmallow import ValidationError, fields


class UnionField(fields.Field):
    """Field that deserializes multi-type input data to app-level objects."""
    def __init__(self, types: list = [], *args, **kwargs) -> None:
        super().__init__(*args, **kwargs)
        if types:
            self.types = types
        else:
            raise AttributeError('No types provided on union field')

    def _deserialize(self, value, attr, data, **kwargs):
        if bool([isinstance(value, i) for i in self.types if isinstance(value, i)]):
            return value
        else:
            raise ValidationError(
                f'Field shoud be any of the following types: [{", ".join([str(i) for i in self.types])}]'
            )

__init__(self, types)

  • New parameter "types". Which accepts a list of default types within Python alongside the default parameters of a marshmallow field.
  • super copies default class in current class.
  • If this "types" parameter is empty we raises by default an AttributeError.

_deserialize()

  • Checks if current value is oneof the self.types provided in the __init__.
  • Raises ValidationError with a formatted error message based on self.types.

Example

# encoding: utf-8
"""
Example
-------

Example for utilization
"""
from marshmallow import Schema


class AllTypes(Schema):
    """
    Example schema
    """
    some_field = UnionField(
        types=[str, int, float, dict, list, bool, set, tuple],
        metadata={
            "description": "Multiple types.",
        },
    )

UnitTest

# encoding: utf-8
"""
Test custom marshmallow fields
"""
from marshmallow import Schema, ValidationError
import pytest


def test_union_field():
    class MultiType(Schema):
        test = UnionField(
            types=[str, int],
            metadata={
                "description": "String and Integer.",
            },
        )
        
    class AllTypes(Schema):
        test = UnionField(
            types=[str, int, float, dict, list, bool, set, tuple],
            metadata={
                "description": "Multiple types",
            },
        )

    with pytest.raises(AttributeError):    
        class NoTypes(Schema):
            test = UnionField(
                types=[],
                metadata={
                    "description": "No Type.",
                },
            )

    m = MultiType()
    m.load({'test': 'test'})
    m.load({'test': 123})
    with pytest.raises(ValidationError):
        m.load({'test': 123.123})
        m.load({'test': {'test': 'test'}})
        m.load({'test': ['test', 'test']})
        m.load({'test': False})
        m.load({'test': set([1, 1, 2, 3, 4])})
        m.load({'test': (1, 1, 2, 3, 4,)})
    
    a = AllTypes()
    a.load({'test': 'test'})
    a.load({'test': 123})
    a.load({'test': 123.123})
    a.load({'test': {'test': 'test'}})
    a.load({'test': ['test', 'test']})
    a.load({'test': False})
    a.load({'test': set([1, 1, 2, 3, 4])})
    a.load({'test': (1, 1, 2, 3, 4,)})
    assert 1 == 1
YetAnotherDuck
  • 294
  • 4
  • 13
1

The marshmallow-oneofschema project has a nice solution here.
https://github.com/marshmallow-code/marshmallow-oneofschema

From their sample code:

import marshmallow
import marshmallow.fields
from marshmallow_oneofschema import OneOfSchema


class Foo:
    def __init__(self, foo):
        self.foo = foo


class Bar:
    def __init__(self, bar):
        self.bar = bar


class FooSchema(marshmallow.Schema):
    foo = marshmallow.fields.String(required=True)

    @marshmallow.post_load
    def make_foo(self, data, **kwargs):
        return Foo(**data)


class BarSchema(marshmallow.Schema):
    bar = marshmallow.fields.Integer(required=True)

    @marshmallow.post_load
    def make_bar(self, data, **kwargs):
        return Bar(**data)


class MyUberSchema(OneOfSchema):
    type_schemas = {"foo": FooSchema, "bar": BarSchema}

    def get_obj_type(self, obj):
        if isinstance(obj, Foo):
            return "foo"
        elif isinstance(obj, Bar):
            return "bar"
        else:
            raise Exception("Unknown object type: {}".format(obj.__class__.__name__))


MyUberSchema().dump([Foo(foo="hello"), Bar(bar=123)], many=True)
# => [{'type': 'foo', 'foo': 'hello'}, {'type': 'bar', 'bar': 123}]

MyUberSchema().load(
    [{"type": "foo", "foo": "hello"}, {"type": "bar", "bar": 123}], many=True
)
# => [Foo('hello'), Bar(123)]
rcarlson
  • 21
  • 1