0

Is there a way to find out how often a coroutine has been suspended?

For example:

import trio

async def sleep(count: int):
   for _ in range(count):
      await trio.sleep(1)

async def test():
   with track_suspensions() as ctx:
      await sleep(1)
   assert ctx.suspension_count == 1

   with track_suspensions() as ctx:
      await sleep(3)
   assert ctx.suspension_count == 3

trio.run(test)

..except that (as far as I know), track_suspensions() does not exist. Is the a way to implement it, or some other way to get this information? (I'm looking for this information to write a unit test).

Nikratio
  • 2,338
  • 2
  • 29
  • 43
  • You can probably instrument (by monkeypatching) some of the trio internals. Or you could wrap `test` so that it returns an instrumented coroutine. – Bergi Jul 21 '23 at 23:47
  • Please don't use monkeypatches to instrument trio, we have a whole dedicated public API for adding instrumentation :-) https://trio.readthedocs.io/en/stable/reference-lowlevel.html#instrument-api – Nathaniel J. Smith Jul 23 '23 at 03:32

2 Answers2

1

You can't do it with a context manager.

However, the async protocol is just a beefed-up iterator, and you can intercept Python's iterator protocol. Thus the following Tracker object counts how often the wrapped function yields.

import trio

class Tracker:
    def __init__(self):
        self.n = 0

    def __call__(self, /, fn,*a,**k):
        self.fn = fn
        self.a = a
        self.k = k
        return self

    def __await__(self):
        it = self.fn(*self.a,**self.k).__await__()
        try:
            r = None
            while True:
                r = it.send(r)
                r = (yield r)
                self.n += 1
        except StopIteration as r:
            return r.value


async def sleep(n):
    for i in range(n):
        await trio.sleep(0.1)
    return n**2


async def main():
    t = Tracker()
    res = await t(sleep,10)
    assert t.n == 10
    assert res == 100

trio.run(main)

This code also works with asyncio (and thus anyio) mainloops.

NB: The single slash in the __call__ argument list prevents a name clash when your function happens to require a keyword argument named "fn".

Matthias Urlichs
  • 2,301
  • 19
  • 29
1

It's not a public API, but at least as of when I write this trio does track this information on the Task object, and uses it to implement trio.testing.assert_checkpoints and trio.testing.assert_no_checkpoints. If you can't use those, then you could look at the implementation to come up with a temporary hack to get your test working, and then hopefully send a pr upstream so we can cover your use case in a more maintainable way :-)

Nathaniel J. Smith
  • 11,613
  • 4
  • 41
  • 49