One improvement I would suggest is to compute the result from dataclasses.fields
and then cache the default values from the result. This will help performance because currently dataclasses
evaluates the fields
each time it is invoked.
Here's a simple example using a metaclass approach.
Note that I've also modified it slightly so it handles mutable-type fields that define a default_factory
for instance.
from __future__ import annotations
import dataclasses
# adapted from `dataclasses` module
def _create_fn(name, args, body, *, globals=None, locals=None):
if locals is None:
locals = {}
args = ','.join(args)
body = '\n'.join(f' {b}' for b in body)
# Compute the text of the entire function.
txt = f' def {name}({args}):\n{body}'
local_vars = ', '.join(locals.keys())
txt = f"def __create_fn__({local_vars}):\n{txt}\n return {name}"
ns = {}
exec(txt, globals, ns)
return ns['__create_fn__'](**locals)
def terse_str(cls_name, bases, cls_dict): # Metaclass for class
def __str__(self):
cls_fields: tuple[dataclasses.Field, ...] = dataclasses.fields(self)
_locals = {}
_body_lines = ['lines=[]']
for f in cls_fields:
name = f.name
dflt_name = f'_dflt_{name}'
dflt_factory = f.default_factory
if dflt_factory is not dataclasses.MISSING:
_locals[dflt_name] = dflt_factory()
else:
_locals[dflt_name] = f.default
_body_lines.append(f'value=self.{name}')
_body_lines.append(f'if value != _dflt_{name}:')
_body_lines.append(f' lines.append(f"{name}={{value!r}}")')
_body_lines.append(f'return f\'{cls_name}({{", ".join(lines)}})\'')
# noinspection PyShadowingNames
__str__ = _create_fn('__str__', ('self', ), _body_lines, locals=_locals)
# set the __str__ with the cached `dataclass.fields`
setattr(type(self), '__str__', __str__)
# on initial run, compute and return __str__()
return __str__(self)
cls_dict['__str__'] = __str__
return type(cls_name, bases, cls_dict)
@dataclasses.dataclass
class X(metaclass=terse_str):
a: int = 1
b: bool = False
c: float = 2.0
d: list[int] = dataclasses.field(default_factory=lambda: [1, 2, 3])
x1 = X(b=True)
x2 = X(b=False, c=3, d=[1, 2])
print(x1) # X(b=True)
print(x2) # X(c=3, d=[1, 2])
Finally, here's a quick and dirty test to confirm that caching is actually beneficial for repeated calls to str()
or print
:
import dataclasses
from timeit import timeit
def terse_str(cls): # Decorator for class.
def __str__(self):
"""Returns a string containing only the non-default field values."""
s = ', '.join(f'{field.name}={getattr(self, field.name)}'
for field in dataclasses.fields(self)
if getattr(self, field.name) != field.default)
return f'{type(self).__name__}({s})'
setattr(cls, '__str__', __str__)
return cls
# adapted from `dataclasses` module
def _create_fn(name, args, body, *, globals=None, locals=None):
if locals is None:
locals = {}
args = ','.join(args)
body = '\n'.join(f' {b}' for b in body)
# Compute the text of the entire function.
txt = f' def {name}({args}):\n{body}'
local_vars = ', '.join(locals.keys())
txt = f"def __create_fn__({local_vars}):\n{txt}\n return {name}"
ns = {}
exec(txt, globals, ns)
return ns['__create_fn__'](**locals)
def terse_str_meta(cls_name, bases, cls_dict): # Metaclass for class
def __str__(self):
cls_fields: tuple[dataclasses.Field, ...] = dataclasses.fields(self)
_locals = {}
_body_lines = ['lines=[]']
for f in cls_fields:
name = f.name
dflt_name = f'_dflt_{name}'
dflt_factory = f.default_factory
if dflt_factory is not dataclasses.MISSING:
_locals[dflt_name] = dflt_factory()
else:
_locals[dflt_name] = f.default
_body_lines.append(f'value=self.{name}')
_body_lines.append(f'if value != _dflt_{name}:')
_body_lines.append(f' lines.append(f"{name}={{value!r}}")')
_body_lines.append(f'return f\'{cls_name}({{", ".join(lines)}})\'')
# noinspection PyShadowingNames
__str__ = _create_fn('__str__', ('self', ), _body_lines, locals=_locals)
# set the __str__ with the cached `dataclass.fields`
setattr(type(self), '__str__', __str__)
# on initial run, compute and return __str__()
return __str__(self)
cls_dict['__str__'] = __str__
return type(cls_name, bases, cls_dict)
@dataclasses.dataclass
@terse_str
class X:
a: int = 1
b: bool = False
c: float = 2.0
@dataclasses.dataclass
class X_Cached(metaclass=terse_str_meta):
a: int = 1
b: bool = False
c: float = 2.0
print(f"Simple: {timeit('str(X(b=True))', globals=globals()):.3f}")
print(f"Cached: {timeit('str(X_Cached(b=True))', globals=globals()):.3f}")
print()
print(X(b=True))
print(X_Cached(b=True))
Results:
Simple: 1.038
Cached: 0.289