Skip to content

API Reference

app special

app

App

Represents the Mognet application.

You can use these objects to:

  • Create and abort tasks
  • Check the status of tasks
  • Configure the middleware that runs on key lifecycle events of the app and its tasks
Source code in mognet/app/app.py
class App:
    """
    Represents the Mognet application.

    You can use these objects to:

    - Create and abort tasks
    - Check the status of tasks
    - Configure the middleware that runs on key lifecycle events of the app and its tasks
    """

    # Where results are stored.
    result_backend: BaseResultBackend

    # Where task state is stored.
    # Task state is information that a task can save
    # during its execution, in case, for example, it gets
    # interrupted.
    state_backend: BaseStateBackend

    # Task broker.
    broker: BaseBroker

    # Mapping of [service name] -> dependency object,
    # should be accessed via Context#get_service.
    services: Dict[Any, Callable]

    # Holds references to all the tasks.
    task_registry: TaskRegistry

    _connected: bool

    # Configuration used to start this app.
    config: "AppConfig"

    # Worker running in this app instance.
    worker: Optional[Worker]

    # Background tasks spawned by this app.
    _consume_control_task: Optional[Future] = None
    _heartbeat_task: Optional[Future] = None

    _worker_task: Optional[Future]

    _middleware: List[Middleware]

    _loop: asyncio.AbstractEventLoop

    def __init__(
        self,
        *,
        name: str,
        config: "AppConfig",
    ) -> None:
        self.name = name

        self._connected = False

        self.config = config

        # Create the task registry and register it globally
        reg = task_registry.get(None)

        if reg is None:
            reg = TaskRegistry()
            reg.register_globally()

        self.task_registry = reg

        self._worker_task = None

        self.services = {}

        self._middleware = []

        self._load_modules()
        self.worker = None

        # Event that gets set when the app is closed
        self._run_result = None

    def add_middleware(self, mw_inst: Middleware):
        """
        Adds middleware to this app.

        Middleware is called in the order of in which it was added
        to the app.
        """
        if mw_inst in self._middleware:
            return

        self._middleware.append(mw_inst)

    async def start(self):
        """
        Starts the app.
        """
        _log.info("Starting app %r", self.config.node_id)

        self._loop = asyncio.get_event_loop()

        self._run_result = asyncio.Future()

        self._log_tasks_and_queues()

        await self._call_on_starting_middleware()

        await self.connect()

        self._heartbeat_task = asyncio.create_task(self._background_heartbeat())
        self._consume_control_task = asyncio.create_task(self._consume_control_queue())

        self.worker = Worker(app=self, middleware=self._middleware)
        self._worker_task = asyncio.create_task(self.worker.run())

        _log.info("Started")

        await self._call_on_started_middleware()

        return await self._run_result

    async def get_current_status_of_nodes(
        self,
    ) -> AsyncGenerator[StatusResponseMessage, None]:
        """
        Query all nodes of this App and get their status.
        """

        request = QueryRequestMessage(name="Status")

        responses = self.broker.send_query_message(
            payload=MessagePayload(
                id=str(request.id),
                kind="Query",
                payload=request.model_dump(),
            )
        )

        try:
            async for response in responses:
                try:
                    yield StatusResponseMessage.model_validate(response)
                except asyncio.CancelledError:
                    break
                except Exception as exc:  # pylint: disable=broad-except
                    _log.error(
                        "Could not parse status response %r", response, exc_info=exc
                    )
        finally:
            await responses.aclose()

    async def submit(self, req: "Request", context: Optional[Context] = None) -> Result:
        """
        Submits a request for execution.

        If a context is defined, it will be used to create a parent-child
        relationship between the new-to-submit request and the one existing
        in the context instance. This is later used to cancel the whole task tree.
        """
        if not self.result_backend:
            raise ImproperlyConfigured("Result backend not defined")

        if not self.broker:
            raise ImproperlyConfigured("Broker not connected")

        try:
            if req.kwargs_repr is None:
                req.kwargs_repr = format_kwargs_repr(req.args, req.kwargs)
                _log.debug("Set default kwargs_repr on Request %r", req)

            res = Result(
                self.result_backend,
                id=req.id,
                name=req.name,
                state=ResultState.PENDING,
                created=datetime.now(tz=timezone.utc),
                request_kwargs_repr=req.kwargs_repr,
            )

            if context is not None:
                # Set the parent-child relationship and update the request stack.
                parent_request = context.request
                res.parent_id = parent_request.id

                req.stack = [*parent_request.stack, parent_request.id]

                if res.parent_id is not None:
                    await self.result_backend.add_children(res.parent_id, req.id)

            await self.result_backend.set(req.id, res)

            # Store the metadata on the Result.
            if req.metadata:
                await res.set_metadata(**req.metadata)

            await self._on_submitting(req, context=context)

            payload = MessagePayload(
                id=str(req.id),
                kind="Request",
                payload=req.model_dump(),
                priority=req.priority,
            )

            _log.debug("Sending message %r", payload.id)

            await self.broker.send_task_message(self._get_task_route(req), payload)

            return res
        except Exception as exc:
            raise CouldNotSubmit(f"Could not submit {req!r}") from exc

    def get_task_queue_names(self) -> Set[str]:
        """
        Return the names of the queues that are going to be consumed,
        after applying defaults, inclusions, and exclusions.
        """
        all_queues = {*self.config.task_routes.values(), self.config.default_task_route}

        _log.debug("All queues: %r", all_queues)

        configured_queues = self.config.task_queues

        configured_queues.ensure_valid()

        if configured_queues.exclude:
            _log.debug("Applying queue exclusions: %r", configured_queues.exclude)
            return all_queues - configured_queues.exclude

        if configured_queues.include:
            _log.debug("Applying queue inclusions: %r", configured_queues.include)
            return all_queues & configured_queues.include

        _log.debug("No inclusions or exclusions applied")

        return all_queues

    @overload
    def create_request(
        self,
        func: Callable[Concatenate["Context", _P], Awaitable[_Return]],
        *args: _P.args,
        **kwargs: _P.kwargs,
    ) -> Request[_Return]:
        """
        Creates a Request object from the function that was decorated with @task,
        and the provided arguments.

        This overload is just to document async def function return values.
        """
        ...

    @overload
    def create_request(
        self,
        func: Callable[Concatenate["Context", _P], _Return],
        *args: _P.args,
        **kwargs: _P.kwargs,
    ) -> Request[_Return]:
        """
        Creates a Request object from the function that was decorated with @task,
        and the provided arguments.

        This overload is just to document non-async def function return values.
        """
        ...

    def create_request(
        self,
        func: Callable[Concatenate["Context", _P], Any],
        *args: _P.args,
        **kwargs: _P.kwargs,
    ) -> Request:
        """
        Creates a Request object from the function that was decorated with @task,
        and the provided arguments.
        """
        return Request(
            name=self.task_registry.get_task_name(cast(Any, func)),
            args=args,
            kwargs=kwargs,
        )

    @overload
    async def run(
        self,
        request: Callable[Concatenate["Context", _P], Awaitable[_Return]],
        *args: _P.args,
        **kwargs: _P.kwargs,
    ) -> _Return:
        """
        Short-hand method for creating a Request from a function decorated with `@task`,
        (see `create_request`), submitting it (see `submit`) and waiting for the result (see `run(Request)`).
        """
        ...

    @overload
    async def run(
        self, request: "Request[_Return]", context: Optional[Context] = None
    ) -> _Return:
        """
        Runs the request and waits for the result.

        Call `submit` if you just want to send a request
        without waiting for the result.
        """

        ...

    async def run(self, request, *args, **kwargs) -> Any:

        if not isinstance(request, Request):
            request = self.create_request(*args, **kwargs)

        res = await self.submit(request, *args, **kwargs)

        return await res

    async def revoke(self, request_id: uuid.UUID, *, force: bool = False) -> Result:
        """
        Revoke the execution of a request.

        If the request is already completed, this method returns
        the associated result as-is. Optionally, `force=True` may be set
        in order to ignore the state check.

        This will also revoke any request that's launched as a child of this one,
        recursively.

        Returns the cancelled result.
        """
        res = await self.result_backend.get_or_create(request_id)

        if not force and res.done:
            _log.warning(
                "Attempting to cancel result %r that's already done, this is a no-op",
                res.id,
            )
            return res

        _log.info("Revoking request id=%r", res)

        await res.revoke()

        payload = MessagePayload(
            id=str(uuid.uuid4()),
            kind=Revoke.MESSAGE_KIND,
            payload=Revoke(id=request_id).model_dump(),
        )

        await self.broker.send_control_message(payload)

        child_count = await res.children.count()
        if child_count:
            _log.info("Revoking %r children of id=%r", child_count, res.id)

            # Abort children.
            async for child_id in res.children.iter_ids():
                await self.revoke(child_id, force=force)

        return res

    async def connect(self):
        """Connect this app and its components to their respective backends."""
        if self._connected:
            return

        self.broker = self._create_broker()

        self.result_backend = self._create_result_backend()
        self.state_backend = self._create_state_backend()

        self._connected = True

        await self._setup_broker()

        _log.debug("Connecting to result backend %s", self.result_backend)

        await self.result_backend.connect()

        _log.debug("Connected to result backend %s", self.result_backend)

        _log.debug("Connecting to state backend %s", self.state_backend)

        await self.state_backend.connect()

        _log.debug("Connected to state backend %s", self.state_backend)

    async def __aenter__(self):
        await self.connect()

        return self

    async def __aexit__(self, *args, **kwargs):
        await self.close()

    @shield
    async def close(self):
        """Close this app and its components's backends."""

        _log.info("Closing app")

        await asyncio.shield(self._stop())

        if self._run_result and not self._run_result.done():
            self._run_result.set_result(None)

        _log.info("Closed app")

    async def _stop(self):
        await self._call_on_stopping_middleware()

        if self._heartbeat_task is not None:
            self._heartbeat_task.cancel()

            try:
                await self._heartbeat_task
            except BaseException:  # pylint: disable=broad-except
                pass

            self._heartbeat_task = None

        _log.debug("Closing queue listeners")

        if self._consume_control_task:
            self._consume_control_task.cancel()

            try:
                await self._consume_control_task
            except asyncio.CancelledError:
                pass
            except Exception as exc:  # pylint: disable=broad-except
                _log.debug("Error shutting down control consumption task", exc_info=exc)

            self._consume_control_task = None

        # Disconnect from the broker, this should NACK
        # all pending messages too.
        _log.debug("Closing broker connection")
        if self.broker:
            await self.broker.close()

        # Stop the worker
        await self._stop_worker()

        # Remove service instances
        for svc in self.services:
            if isinstance(svc, ClassService):
                try:
                    svc.close()
                    await svc.wait_closed()
                except Exception as exc:  # pylint: disable=broad-except
                    _log.error("Error closing service %r", svc, exc_info=exc)

        self.services.clear()

        # Finally, shut down the state and result backends.
        _log.debug("Closing backends")
        if self.result_backend:
            try:
                await self.result_backend.close()
            except Exception as exc:  # pylint: disable=broad-except
                _log.error("Error closing result backend", exc_info=exc)

        if self.state_backend:
            try:
                await self.state_backend.close()
            except Exception as exc:  # pylint: disable=broad-except
                _log.error("Error closing state backend", exc_info=exc)

        self._connected = False

        await self._call_on_stopped_middleware()

    async def purge_task_queues(self) -> Dict[str, int]:
        """
        Purge all known task queues.

        Returns a dict where the keys are the names of the queues,
        and the values are the number of messages purged.
        """
        deleted_per_queue = {}

        for queue in self.get_task_queue_names():
            _log.info("Purging task queue=%r", queue)
            deleted_per_queue[queue] = await self.broker.purge_task_queue(queue)

        return deleted_per_queue

    async def purge_control_queue(self) -> int:
        """
        Purges the control queue related to this app.

        Returns the number of messages purged.
        """
        return await self.broker.purge_control_queue()

    @property
    def loop(self) -> asyncio.AbstractEventLoop:
        return self._loop

    def _create_broker(self) -> BaseBroker:
        return AmqpBroker(config=self.config.broker, app=self)

    def _create_result_backend(self) -> BaseResultBackend:
        return RedisResultBackend(self.config.result_backend, app=self)

    def _create_state_backend(self) -> BaseStateBackend:
        return RedisStateBackend(self.config.state_backend, app=self)

    def _load_modules(self):
        for module in self.config.imports:
            importlib.import_module(module)

    def _log_tasks_and_queues(self):

        all_tasks = self.task_registry.registered_task_names

        tasks_msg = "\n".join(
            f"\t - {t!r} (queue={self._get_task_route(t)!r})" for t in all_tasks
        )

        _log.info("Registered %r tasks:\n%s", len(all_tasks), tasks_msg)

        all_queues = self.get_task_queue_names()

        queues_msg = "\n".join(f"\t - {q!r}" for q in all_queues)

        _log.info("Registered %r queues:\n%s", len(all_queues), queues_msg)

    async def _setup_broker(self):
        _log.debug("Connecting to broker %s", self.broker)

        await self.broker.connect()

        _log.debug("Connected to broker %r", self.broker)

        _log.debug("Setting up task queues")

        for queue_name in self.get_task_queue_names():
            await self.broker.setup_task_queue(TaskQueue(name=queue_name))

        _log.debug("Setup queues")

    async def _stop_worker(self):
        if self.worker is None or not self._worker_task:
            _log.debug("No worker running")
            return

        try:
            _log.debug("Closing worker")
            await self.worker.close()

            if self._worker_task is not None:
                self._worker_task.cancel()
                await self._worker_task

            _log.debug("Worker closed")
        except asyncio.CancelledError:
            pass
        except Exception as worker_close_exc:  # pylint: disable=broad-except
            _log.error(
                "Worker raised an exception while closing", exc_info=worker_close_exc
            )
        finally:
            self.worker = None
            self._worker_task = None

    async def _background_heartbeat(self):
        """
        Background task that checks if the event loop was blocked
        for too long.

        A crude check, it asyncio.sleep()s and checks if the time difference
        before and after sleeping is significantly higher. This could bring problems,
        for example, with task brokers, that may need to send periodic keep-alive messages
        to the broker in order to prevent connection drops.

        Error messages are logged in case the event loop got blocked for too long.
        """

        while True:
            current_ts = self.loop.time()

            await asyncio.sleep(5)

            next_ts = self.loop.time()

            diff = next_ts - current_ts

            if diff > 10:
                _log.error(
                    "Event loop seemed blocked for %.2fs (>10s), this could bring issues. Consider using asyncio.run_in_executor to run CPU-bound work",
                    diff,
                )
            else:
                _log.debug("Event loop heartbeat: %.2fs", diff)

    async def _consume_control_queue(self):
        """
        Reads messages from the control queue and dispatches them.
        """

        await self.broker.setup_control_queue()

        _log.debug("Listening on the control queue")

        async for msg in self.broker.consume_control_queue():
            try:
                await self._process_control_message(msg)
            except asyncio.CancelledError:
                break
            except Exception as exc:  # pylint: disable=broad-except
                _log.error(
                    "Could not process control queue message %r", msg, exc_info=exc
                )

    async def _process_control_message(self, msg: IncomingMessagePayload):
        _log.debug("Received control message id=%r", msg.id)

        try:
            if msg.kind == Revoke.MESSAGE_KIND:
                abort = Revoke.model_validate(msg.payload)

                _log.debug("Received request to revoke request id=%r", abort.id)

                if self.worker is None:
                    _log.debug("No worker running. Discarding revoke message.")
                    return

                try:
                    # Cancel the task's execution and ACK it on the broker
                    # to prevent it from re-running.
                    await self.worker.cancel(
                        abort.id, message_action=MessageCancellationAction.ACK
                    )
                except Exception as exc:  # pylint: disable=broad-except
                    _log.error(
                        "Error while cancelling request id=%r", abort.id, exc_info=exc
                    )

                return

            if msg.kind == "Query":
                query = QueryRequestMessage.model_validate(msg.payload)

                if query.name == "Status":
                    # Get the status of this worker and reply to the incoming message

                    if self.worker is None:
                        _log.debug("No worker running for Status query")
                        running_request_ids = []
                    else:
                        running_request_ids = list(self.worker.running_tasks.keys())

                    reply = StatusResponseMessage(
                        node_id=self.config.node_id,
                        payload=StatusResponseMessage.Status(
                            running_request_ids=running_request_ids,
                        ),
                    )

                    payload = MessagePayload(
                        id=str(reply.id), kind=reply.kind, payload=reply.model_dump()
                    )

                    return await self.broker.send_reply(msg, payload)

                _log.warning("Unknown query name=%r, discarding", query.name)
                return

            _log.warning("Unknown message kind=%r, discarding", msg.kind)
        finally:
            await msg.ack()

    async def _on_submitting(self, req: "Request", context: Optional["Context"]):
        for mw_inst in self._middleware:
            try:
                await mw_inst.on_request_submitting(req, context=context)
            except Exception as mw_exc:  # pylint: disable=broad-except
                _log.error("Middleware failed", exc_info=mw_exc)

    def _get_task_route(self, req: Union[str, Request]):
        if isinstance(req, Request):
            if req.queue_name is not None:
                _log.debug(
                    "Request %r has a queue override to route to queue=%r",
                    req,
                    req.queue_name,
                )
                return req.queue_name

            req = req.name

        route = self.config.task_routes.get(req)

        if route is not None:
            _log.debug(
                "Request %r has a config-set route to queue=%r",
                req,
                route,
            )
            return route

        default_queue = self.config.default_task_route

        _log.debug(
            "Request %r has no route set, falling back to default queue=%r",
            req,
            default_queue,
        )

        return default_queue

    async def _call_on_starting_middleware(self):
        for mw in self._middleware:
            try:
                await mw.on_app_starting(self)
            except Exception as exc:  # pylint: disable=broad-except
                _log.debug(
                    "Middleware %r failed on 'on_app_starting'", mw, exc_info=exc
                )

    async def _call_on_started_middleware(self):
        for mw in self._middleware:
            try:
                await mw.on_app_started(self)
            except Exception as exc:  # pylint: disable=broad-except
                _log.debug("Middleware %r failed on 'on_app_started'", mw, exc_info=exc)

    async def _call_on_stopping_middleware(self):
        for mw in self._middleware:
            try:
                await mw.on_app_stopping(self)
            except Exception as exc:  # pylint: disable=broad-except
                _log.debug(
                    "Middleware %r failed on 'on_app_stopping'", mw, exc_info=exc
                )

    async def _call_on_stopped_middleware(self):
        for mw in self._middleware:
            try:
                await mw.on_app_stopped(self)
            except Exception as exc:  # pylint: disable=broad-except
                _log.debug("Middleware %r failed on 'on_app_stopped'", mw, exc_info=exc)
add_middleware(self, mw_inst)

Adds middleware to this app.

Middleware is called in the order of in which it was added to the app.

Source code in mognet/app/app.py
def add_middleware(self, mw_inst: Middleware):
    """
    Adds middleware to this app.

    Middleware is called in the order of in which it was added
    to the app.
    """
    if mw_inst in self._middleware:
        return

    self._middleware.append(mw_inst)
close(self) async

Close this app and its components's backends.

Source code in mognet/app/app.py
@shield
async def close(self):
    """Close this app and its components's backends."""

    _log.info("Closing app")

    await asyncio.shield(self._stop())

    if self._run_result and not self._run_result.done():
        self._run_result.set_result(None)

    _log.info("Closed app")
connect(self) async

Connect this app and its components to their respective backends.

Source code in mognet/app/app.py
async def connect(self):
    """Connect this app and its components to their respective backends."""
    if self._connected:
        return

    self.broker = self._create_broker()

    self.result_backend = self._create_result_backend()
    self.state_backend = self._create_state_backend()

    self._connected = True

    await self._setup_broker()

    _log.debug("Connecting to result backend %s", self.result_backend)

    await self.result_backend.connect()

    _log.debug("Connected to result backend %s", self.result_backend)

    _log.debug("Connecting to state backend %s", self.state_backend)

    await self.state_backend.connect()

    _log.debug("Connected to state backend %s", self.state_backend)
create_request(self, func, *args, **kwargs)

Creates a Request object from the function that was decorated with @task, and the provided arguments.

Source code in mognet/app/app.py
def create_request(
    self,
    func: Callable[Concatenate["Context", _P], Any],
    *args: _P.args,
    **kwargs: _P.kwargs,
) -> Request:
    """
    Creates a Request object from the function that was decorated with @task,
    and the provided arguments.
    """
    return Request(
        name=self.task_registry.get_task_name(cast(Any, func)),
        args=args,
        kwargs=kwargs,
    )
get_current_status_of_nodes(self)

Query all nodes of this App and get their status.

Source code in mognet/app/app.py
async def get_current_status_of_nodes(
    self,
) -> AsyncGenerator[StatusResponseMessage, None]:
    """
    Query all nodes of this App and get their status.
    """

    request = QueryRequestMessage(name="Status")

    responses = self.broker.send_query_message(
        payload=MessagePayload(
            id=str(request.id),
            kind="Query",
            payload=request.model_dump(),
        )
    )

    try:
        async for response in responses:
            try:
                yield StatusResponseMessage.model_validate(response)
            except asyncio.CancelledError:
                break
            except Exception as exc:  # pylint: disable=broad-except
                _log.error(
                    "Could not parse status response %r", response, exc_info=exc
                )
    finally:
        await responses.aclose()
get_task_queue_names(self)

Return the names of the queues that are going to be consumed, after applying defaults, inclusions, and exclusions.

Source code in mognet/app/app.py
def get_task_queue_names(self) -> Set[str]:
    """
    Return the names of the queues that are going to be consumed,
    after applying defaults, inclusions, and exclusions.
    """
    all_queues = {*self.config.task_routes.values(), self.config.default_task_route}

    _log.debug("All queues: %r", all_queues)

    configured_queues = self.config.task_queues

    configured_queues.ensure_valid()

    if configured_queues.exclude:
        _log.debug("Applying queue exclusions: %r", configured_queues.exclude)
        return all_queues - configured_queues.exclude

    if configured_queues.include:
        _log.debug("Applying queue inclusions: %r", configured_queues.include)
        return all_queues & configured_queues.include

    _log.debug("No inclusions or exclusions applied")

    return all_queues
purge_control_queue(self) async

Purges the control queue related to this app.

Returns the number of messages purged.

Source code in mognet/app/app.py
async def purge_control_queue(self) -> int:
    """
    Purges the control queue related to this app.

    Returns the number of messages purged.
    """
    return await self.broker.purge_control_queue()
purge_task_queues(self) async

Purge all known task queues.

Returns a dict where the keys are the names of the queues, and the values are the number of messages purged.

Source code in mognet/app/app.py
async def purge_task_queues(self) -> Dict[str, int]:
    """
    Purge all known task queues.

    Returns a dict where the keys are the names of the queues,
    and the values are the number of messages purged.
    """
    deleted_per_queue = {}

    for queue in self.get_task_queue_names():
        _log.info("Purging task queue=%r", queue)
        deleted_per_queue[queue] = await self.broker.purge_task_queue(queue)

    return deleted_per_queue
revoke(self, request_id, *, force=False) async

Revoke the execution of a request.

If the request is already completed, this method returns the associated result as-is. Optionally, force=True may be set in order to ignore the state check.

This will also revoke any request that's launched as a child of this one, recursively.

Returns the cancelled result.

Source code in mognet/app/app.py
async def revoke(self, request_id: uuid.UUID, *, force: bool = False) -> Result:
    """
    Revoke the execution of a request.

    If the request is already completed, this method returns
    the associated result as-is. Optionally, `force=True` may be set
    in order to ignore the state check.

    This will also revoke any request that's launched as a child of this one,
    recursively.

    Returns the cancelled result.
    """
    res = await self.result_backend.get_or_create(request_id)

    if not force and res.done:
        _log.warning(
            "Attempting to cancel result %r that's already done, this is a no-op",
            res.id,
        )
        return res

    _log.info("Revoking request id=%r", res)

    await res.revoke()

    payload = MessagePayload(
        id=str(uuid.uuid4()),
        kind=Revoke.MESSAGE_KIND,
        payload=Revoke(id=request_id).model_dump(),
    )

    await self.broker.send_control_message(payload)

    child_count = await res.children.count()
    if child_count:
        _log.info("Revoking %r children of id=%r", child_count, res.id)

        # Abort children.
        async for child_id in res.children.iter_ids():
            await self.revoke(child_id, force=force)

    return res
start(self) async

Starts the app.

Source code in mognet/app/app.py
async def start(self):
    """
    Starts the app.
    """
    _log.info("Starting app %r", self.config.node_id)

    self._loop = asyncio.get_event_loop()

    self._run_result = asyncio.Future()

    self._log_tasks_and_queues()

    await self._call_on_starting_middleware()

    await self.connect()

    self._heartbeat_task = asyncio.create_task(self._background_heartbeat())
    self._consume_control_task = asyncio.create_task(self._consume_control_queue())

    self.worker = Worker(app=self, middleware=self._middleware)
    self._worker_task = asyncio.create_task(self.worker.run())

    _log.info("Started")

    await self._call_on_started_middleware()

    return await self._run_result
submit(self, req, context=None) async

Submits a request for execution.

If a context is defined, it will be used to create a parent-child relationship between the new-to-submit request and the one existing in the context instance. This is later used to cancel the whole task tree.

Source code in mognet/app/app.py
async def submit(self, req: "Request", context: Optional[Context] = None) -> Result:
    """
    Submits a request for execution.

    If a context is defined, it will be used to create a parent-child
    relationship between the new-to-submit request and the one existing
    in the context instance. This is later used to cancel the whole task tree.
    """
    if not self.result_backend:
        raise ImproperlyConfigured("Result backend not defined")

    if not self.broker:
        raise ImproperlyConfigured("Broker not connected")

    try:
        if req.kwargs_repr is None:
            req.kwargs_repr = format_kwargs_repr(req.args, req.kwargs)
            _log.debug("Set default kwargs_repr on Request %r", req)

        res = Result(
            self.result_backend,
            id=req.id,
            name=req.name,
            state=ResultState.PENDING,
            created=datetime.now(tz=timezone.utc),
            request_kwargs_repr=req.kwargs_repr,
        )

        if context is not None:
            # Set the parent-child relationship and update the request stack.
            parent_request = context.request
            res.parent_id = parent_request.id

            req.stack = [*parent_request.stack, parent_request.id]

            if res.parent_id is not None:
                await self.result_backend.add_children(res.parent_id, req.id)

        await self.result_backend.set(req.id, res)

        # Store the metadata on the Result.
        if req.metadata:
            await res.set_metadata(**req.metadata)

        await self._on_submitting(req, context=context)

        payload = MessagePayload(
            id=str(req.id),
            kind="Request",
            payload=req.model_dump(),
            priority=req.priority,
        )

        _log.debug("Sending message %r", payload.id)

        await self.broker.send_task_message(self._get_task_route(req), payload)

        return res
    except Exception as exc:
        raise CouldNotSubmit(f"Could not submit {req!r}") from exc

app_config

AppConfig (BaseModel)

Configuration for a Mognet application.

Source code in mognet/app/app_config.py
class AppConfig(BaseModel):
    """
    Configuration for a Mognet application.
    """

    # An ID for the node. Defaults to a string containing
    # the current PID and the hostname.
    node_id: str = Field(default_factory=_default_node_id)

    # Configuration for the result backend.
    result_backend: ResultBackendConfig

    # Configuration for the state backend.
    state_backend: StateBackendConfig

    # Configuration for the task broker.
    broker: BrokerConfig

    # List of modules to import
    imports: List[str] = Field(default_factory=list)

    # Maximum number of tasks that this app can handle.
    max_tasks: Optional[int] = None

    # Maximum recursion depth for tasks that call other tasks.
    max_recursion: int = 64

    # Defines the number of times a task that unexpectedly
    # failed (i.e., SIGKILL) can be retried.
    max_retries: int = 3

    # Default task route to send messages to.
    default_task_route: str = "tasks"

    # A mapping of [task name] -> [queue] that overrides the queue on which a task is listening.
    # If a task is not here, it will default to the queue set in [default_task_route].
    task_routes: Dict[str, str] = Field(default_factory=dict)

    # Specify which queues to listen, or not listen, on.
    task_queues: Queues = Field(default_factory=Queues)

    # The minimum prefetch count. Task consumption will start with
    # this value, and is then incremented based on the number of waiting
    # tasks that are running.
    # A higher value allows more tasks to run concurrently on this node.
    minimum_concurrency: int = 1

    # The minimum prefetch count. This helps ensure that not too many
    # recursive tasks run on this node.
    # Bear in mind that, if set, you can run into deadlocks if you have
    # overly recursive tasks.
    maximum_concurrency: Optional[int] = None

    # Settings that can be passed to instances retrieved via
    # Context#get_service()
    services_settings: Dict[str, Any] = Field(default_factory=dict)

    @classmethod
    def from_file(cls, file_path: str) -> "AppConfig":
        with open(file_path, "r", encoding="utf-8") as config_file:
            return cls.model_validate_json(config_file.read())

    # Maximum number of attempts to connect
    max_reconnect_retries: int = 5

    # Time to wait between reconnects
    reconnect_interval: float = 5

backend special

Result Backends are used to retrieve Task Results from a persistent storage backend.

backend_config

Encoding (str, Enum)

An enumeration.

Source code in mognet/backend/backend_config.py
class Encoding(str, Enum):
    GZIP = "gzip"

RedisResultBackendSettings (BaseModel)

Configuration for the Redis Result Backend

Source code in mognet/backend/backend_config.py
class RedisResultBackendSettings(BaseModel):
    """Configuration for the Redis Result Backend"""

    url: str = "redis://localhost:6379/"

    # TTL for the results.
    result_ttl: Optional[int] = int(timedelta(days=21).total_seconds())

    # TTL for the result values. This is set lower than `result_ttl` to keep
    # the results themselves available for longer.
    result_value_ttl: Optional[int] = int(timedelta(days=7).total_seconds())

    # Encoding for the result values.
    result_value_encoding: Optional[Encoding] = Encoding.GZIP

    retry_connect_attempts: int = 10
    retry_connect_timeout: float = 30

    # Set the limit of connections on the Redis connection pool.
    # DANGER! Setting this to too low a value WILL cause issues opening connections!
    max_connections: Optional[int] = None

base_result_backend

BaseResultBackend

Base interface to implemenent a Result Backend.

Source code in mognet/backend/base_result_backend.py
class BaseResultBackend(metaclass=ABCMeta):
    """Base interface to implemenent a Result Backend."""

    config: ResultBackendConfig
    app: AppParameters

    def __init__(self, config: ResultBackendConfig, app: AppParameters) -> None:
        super().__init__()

        self.config = config
        self.app = app

    @abstractmethod
    async def get(self, result_id: UUID) -> Optional[Result]:
        """
        Get a Result by it's ID.
        If it doesn't exist, this method returns None.
        """
        raise NotImplementedError

    async def get_many(self, *result_ids: UUID) -> List[Result]:
        """
        Get a list of Results by specifying their IDs.
        Results that don't exist will be removed from this list.
        """
        all_results = await asyncio.gather(*[self.get(r_id) for r_id in result_ids])

        return [r for r in all_results if r if r is not None]

    async def get_or_create(self, result_id: UUID) -> Result:
        """
        Get a Result by it's ID.
        If it doesn't exist, this method creates one.

        The returned Result will either be the existing one,
        or the newly-created one.
        """
        res = await self.get(result_id)

        if res is None:
            res = Result(self, id=result_id)
            await self.set(result_id, res)

        return res

    @abstractmethod
    async def set(self, result_id: UUID, result: Result) -> None:
        """
        Save a Result.
        """
        raise NotImplementedError

    async def wait(
        self, result_id: UUID, timeout: Optional[float] = None, poll: float = 0.1
    ) -> Result:
        """
        Wait until a result is ready.

        Raises `asyncio.TimeoutError` if a timeout is set and exceeded.
        """

        async def waiter():
            while True:
                result = await self.get(result_id)

                if result is not None and result.done:
                    return result

                await asyncio.sleep(poll)

        if timeout:
            return await asyncio.wait_for(waiter(), timeout)

        return await waiter()

    @abstractmethod
    async def get_children_count(self, parent_result_id: UUID) -> int:
        """
        Return the number of children of a Result.

        Returns 0 if the Result doesn't exist.
        """
        raise NotImplementedError

    @abstractmethod
    def iterate_children_ids(
        self, parent_result_id: UUID, *, count: Optional[int] = None
    ) -> AsyncGenerator[UUID, None]:
        """
        Get an AsyncGenerator for the IDs for the children of a Result.

        The AsyncGenerator will be empty if the Result doesn't exist.
        """
        raise NotImplementedError

    def iterate_children(
        self, parent_result_id: UUID, *, count: Optional[int] = None
    ) -> AsyncGenerator[Result, None]:
        """
        Get an AsyncGenerator for the children of a Result.

        The AsyncGenerator will be empty if the Result doesn't exist.
        """
        raise NotImplementedError

    async def __aenter__(self):
        return self

    async def __aexit__(self, *args, **kwargs):
        return None

    async def connect(self):
        """
        Explicit method to connect to the backend provided by
        this Result backend.
        """

    async def close(self):
        """
        Explicit method to close the backend provided by
        this Result backend.
        """

    @abstractmethod
    async def add_children(self, result_id: UUID, *children: UUID) -> None:
        """
        Add children to a parent Result.
        """
        raise NotImplementedError

    @abstractmethod
    async def get_value(self, result_id: UUID) -> ResultValueHolder:
        """
        Get the value of a Result.

        If the value is lost, ResultValueLost is raised.
        """
        raise NotImplementedError

    @abstractmethod
    async def set_value(self, result_id: UUID, value: ResultValueHolder) -> None:
        """
        Set the value of a Result.
        """
        raise NotImplementedError

    @abstractmethod
    async def get_metadata(self, result_id: UUID) -> Dict[str, Any]:
        """
        Get the metadata of a Result.

        Returns an empty Dict if the Result doesn't exist.
        """
        raise NotImplementedError

    @abstractmethod
    async def set_metadata(self, result_id: UUID, **kwargs: Any) -> None:
        """
        Set metadata on a Result.
        """
        raise NotImplementedError

    @abstractmethod
    async def delete(self, result_id: UUID, include_children: bool = True) -> None:
        """
        Delete a Result.
        """
        raise NotImplementedError

    @abstractmethod
    async def set_ttl(
        self, result_id: UUID, ttl: timedelta, include_children: bool = True
    ) -> None:
        """
        Set expiration on a Result.

        If include_children is True, children will have the same TTL set.
        """
        raise NotImplementedError
add_children(self, result_id, *children) async

Add children to a parent Result.

Source code in mognet/backend/base_result_backend.py
@abstractmethod
async def add_children(self, result_id: UUID, *children: UUID) -> None:
    """
    Add children to a parent Result.
    """
    raise NotImplementedError
close(self) async

Explicit method to close the backend provided by this Result backend.

Source code in mognet/backend/base_result_backend.py
async def close(self):
    """
    Explicit method to close the backend provided by
    this Result backend.
    """
connect(self) async

Explicit method to connect to the backend provided by this Result backend.

Source code in mognet/backend/base_result_backend.py
async def connect(self):
    """
    Explicit method to connect to the backend provided by
    this Result backend.
    """
delete(self, result_id, include_children=True) async

Delete a Result.

Source code in mognet/backend/base_result_backend.py
@abstractmethod
async def delete(self, result_id: UUID, include_children: bool = True) -> None:
    """
    Delete a Result.
    """
    raise NotImplementedError
get(self, result_id) async

Get a Result by it's ID. If it doesn't exist, this method returns None.

Source code in mognet/backend/base_result_backend.py
@abstractmethod
async def get(self, result_id: UUID) -> Optional[Result]:
    """
    Get a Result by it's ID.
    If it doesn't exist, this method returns None.
    """
    raise NotImplementedError
get_children_count(self, parent_result_id) async

Return the number of children of a Result.

Returns 0 if the Result doesn't exist.

Source code in mognet/backend/base_result_backend.py
@abstractmethod
async def get_children_count(self, parent_result_id: UUID) -> int:
    """
    Return the number of children of a Result.

    Returns 0 if the Result doesn't exist.
    """
    raise NotImplementedError
get_many(self, *result_ids) async

Get a list of Results by specifying their IDs. Results that don't exist will be removed from this list.

Source code in mognet/backend/base_result_backend.py
async def get_many(self, *result_ids: UUID) -> List[Result]:
    """
    Get a list of Results by specifying their IDs.
    Results that don't exist will be removed from this list.
    """
    all_results = await asyncio.gather(*[self.get(r_id) for r_id in result_ids])

    return [r for r in all_results if r if r is not None]
get_metadata(self, result_id) async

Get the metadata of a Result.

Returns an empty Dict if the Result doesn't exist.

Source code in mognet/backend/base_result_backend.py
@abstractmethod
async def get_metadata(self, result_id: UUID) -> Dict[str, Any]:
    """
    Get the metadata of a Result.

    Returns an empty Dict if the Result doesn't exist.
    """
    raise NotImplementedError
get_or_create(self, result_id) async

Get a Result by it's ID. If it doesn't exist, this method creates one.

The returned Result will either be the existing one, or the newly-created one.

Source code in mognet/backend/base_result_backend.py
async def get_or_create(self, result_id: UUID) -> Result:
    """
    Get a Result by it's ID.
    If it doesn't exist, this method creates one.

    The returned Result will either be the existing one,
    or the newly-created one.
    """
    res = await self.get(result_id)

    if res is None:
        res = Result(self, id=result_id)
        await self.set(result_id, res)

    return res
get_value(self, result_id) async

Get the value of a Result.

If the value is lost, ResultValueLost is raised.

Source code in mognet/backend/base_result_backend.py
@abstractmethod
async def get_value(self, result_id: UUID) -> ResultValueHolder:
    """
    Get the value of a Result.

    If the value is lost, ResultValueLost is raised.
    """
    raise NotImplementedError
iterate_children(self, parent_result_id, *, count=None)

Get an AsyncGenerator for the children of a Result.

The AsyncGenerator will be empty if the Result doesn't exist.

Source code in mognet/backend/base_result_backend.py
def iterate_children(
    self, parent_result_id: UUID, *, count: Optional[int] = None
) -> AsyncGenerator[Result, None]:
    """
    Get an AsyncGenerator for the children of a Result.

    The AsyncGenerator will be empty if the Result doesn't exist.
    """
    raise NotImplementedError
iterate_children_ids(self, parent_result_id, *, count=None)

Get an AsyncGenerator for the IDs for the children of a Result.

The AsyncGenerator will be empty if the Result doesn't exist.

Source code in mognet/backend/base_result_backend.py
@abstractmethod
def iterate_children_ids(
    self, parent_result_id: UUID, *, count: Optional[int] = None
) -> AsyncGenerator[UUID, None]:
    """
    Get an AsyncGenerator for the IDs for the children of a Result.

    The AsyncGenerator will be empty if the Result doesn't exist.
    """
    raise NotImplementedError
set(self, result_id, result) async

Save a Result.

Source code in mognet/backend/base_result_backend.py
@abstractmethod
async def set(self, result_id: UUID, result: Result) -> None:
    """
    Save a Result.
    """
    raise NotImplementedError
set_metadata(self, result_id, **kwargs) async

Set metadata on a Result.

Source code in mognet/backend/base_result_backend.py
@abstractmethod
async def set_metadata(self, result_id: UUID, **kwargs: Any) -> None:
    """
    Set metadata on a Result.
    """
    raise NotImplementedError
set_ttl(self, result_id, ttl, include_children=True) async

Set expiration on a Result.

If include_children is True, children will have the same TTL set.

Source code in mognet/backend/base_result_backend.py
@abstractmethod
async def set_ttl(
    self, result_id: UUID, ttl: timedelta, include_children: bool = True
) -> None:
    """
    Set expiration on a Result.

    If include_children is True, children will have the same TTL set.
    """
    raise NotImplementedError
set_value(self, result_id, value) async

Set the value of a Result.

Source code in mognet/backend/base_result_backend.py
@abstractmethod
async def set_value(self, result_id: UUID, value: ResultValueHolder) -> None:
    """
    Set the value of a Result.
    """
    raise NotImplementedError
wait(self, result_id, timeout=None, poll=0.1) async

Wait until a result is ready.

Raises asyncio.TimeoutError if a timeout is set and exceeded.

Source code in mognet/backend/base_result_backend.py
async def wait(
    self, result_id: UUID, timeout: Optional[float] = None, poll: float = 0.1
) -> Result:
    """
    Wait until a result is ready.

    Raises `asyncio.TimeoutError` if a timeout is set and exceeded.
    """

    async def waiter():
        while True:
            result = await self.get(result_id)

            if result is not None and result.done:
                return result

            await asyncio.sleep(poll)

    if timeout:
        return await asyncio.wait_for(waiter(), timeout)

    return await waiter()

memory_result_backend

MemoryResultBackend (BaseResultBackend)

Result backend that "persists" results in memory. Useful for testing, but this is not recommended for production setups.

Source code in mognet/backend/memory_result_backend.py
class MemoryResultBackend(BaseResultBackend):
    """
    Result backend that "persists" results in memory. Useful for testing,
    but this is not recommended for production setups.
    """

    def __init__(self, config: ResultBackendConfig, app: AppParameters) -> None:
        super().__init__(config, app)

        self._results: Dict[UUID, Result] = {}
        self._result_tree: Dict[UUID, Set[UUID]] = {}
        self._values: Dict[UUID, ResultValueHolder] = {}
        self._metadata: Dict[UUID, Dict[str, Any]] = {}

    async def get(self, result_id: UUID) -> Optional[Result]:
        return self._results.get(result_id, None)

    async def set(self, result_id: UUID, result: Result):
        self._results[result_id] = result

    async def get_children_count(self, parent_result_id: UUID) -> int:
        return len(self._result_tree.get(parent_result_id, set()))

    async def iterate_children_ids(
        self, parent_result_id: UUID, *, count: int = None
    ) -> AsyncGenerator[UUID, None]:
        children = self._result_tree[parent_result_id]

        for idx, child in enumerate(children):
            yield child

            if count is not None and idx > count:
                break

    async def iterate_children(
        self, parent_result_id: UUID, *, count: int = None
    ) -> AsyncGenerator[Result, None]:
        async for child_id in self.iterate_children_ids(parent_result_id, count=count):
            child = self._results.get(child_id, None)

            if child is not None:
                yield child

    async def add_children(self, result_id: UUID, *children: UUID) -> None:
        self._result_tree.setdefault(result_id, set()).update(children)

    async def get_value(self, result_id: UUID) -> ResultValueHolder:
        value = self._values.get(result_id, None)

        if value is None:
            raise ResultValueLost(result_id)

        return value

    async def set_value(self, result_id: UUID, value: ResultValueHolder):
        self._values[result_id] = value

    async def get_metadata(self, result_id: UUID) -> Dict[str, Any]:
        meta = self._metadata.get(result_id, {})
        return meta

    async def set_metadata(self, result_id: UUID, **kwargs: Any) -> None:
        self._metadata.setdefault(result_id, {}).update(kwargs)

    async def delete(self, result_id: UUID, include_children: bool = True):
        if include_children:
            for child_id in self._result_tree.get(result_id, set()):
                await self.delete(child_id, include_children=include_children)

        self._results.pop(result_id, None)
        self._metadata.pop(result_id, None)
        self._values.pop(result_id, None)

    async def set_ttl(
        self, result_id: UUID, ttl: timedelta, include_children: bool = True
    ):
        pass

    async def close(self):
        self._metadata = {}
        self._result_tree = {}
        self._results = {}
        self._values = {}

        return await super().close()
add_children(self, result_id, *children) async

Add children to a parent Result.

Source code in mognet/backend/memory_result_backend.py
async def add_children(self, result_id: UUID, *children: UUID) -> None:
    self._result_tree.setdefault(result_id, set()).update(children)
close(self) async

Explicit method to close the backend provided by this Result backend.

Source code in mognet/backend/memory_result_backend.py
async def close(self):
    self._metadata = {}
    self._result_tree = {}
    self._results = {}
    self._values = {}

    return await super().close()
delete(self, result_id, include_children=True) async

Delete a Result.

Source code in mognet/backend/memory_result_backend.py
async def delete(self, result_id: UUID, include_children: bool = True):
    if include_children:
        for child_id in self._result_tree.get(result_id, set()):
            await self.delete(child_id, include_children=include_children)

    self._results.pop(result_id, None)
    self._metadata.pop(result_id, None)
    self._values.pop(result_id, None)
get(self, result_id) async

Get a Result by it's ID. If it doesn't exist, this method returns None.

Source code in mognet/backend/memory_result_backend.py
async def get(self, result_id: UUID) -> Optional[Result]:
    return self._results.get(result_id, None)
get_children_count(self, parent_result_id) async

Return the number of children of a Result.

Returns 0 if the Result doesn't exist.

Source code in mognet/backend/memory_result_backend.py
async def get_children_count(self, parent_result_id: UUID) -> int:
    return len(self._result_tree.get(parent_result_id, set()))
get_metadata(self, result_id) async

Get the metadata of a Result.

Returns an empty Dict if the Result doesn't exist.

Source code in mognet/backend/memory_result_backend.py
async def get_metadata(self, result_id: UUID) -> Dict[str, Any]:
    meta = self._metadata.get(result_id, {})
    return meta
get_value(self, result_id) async

Get the value of a Result.

If the value is lost, ResultValueLost is raised.

Source code in mognet/backend/memory_result_backend.py
async def get_value(self, result_id: UUID) -> ResultValueHolder:
    value = self._values.get(result_id, None)

    if value is None:
        raise ResultValueLost(result_id)

    return value
iterate_children(self, parent_result_id, *, count=None)

Get an AsyncGenerator for the children of a Result.

The AsyncGenerator will be empty if the Result doesn't exist.

Source code in mognet/backend/memory_result_backend.py
async def iterate_children(
    self, parent_result_id: UUID, *, count: int = None
) -> AsyncGenerator[Result, None]:
    async for child_id in self.iterate_children_ids(parent_result_id, count=count):
        child = self._results.get(child_id, None)

        if child is not None:
            yield child
iterate_children_ids(self, parent_result_id, *, count=None)

Get an AsyncGenerator for the IDs for the children of a Result.

The AsyncGenerator will be empty if the Result doesn't exist.

Source code in mognet/backend/memory_result_backend.py
async def iterate_children_ids(
    self, parent_result_id: UUID, *, count: int = None
) -> AsyncGenerator[UUID, None]:
    children = self._result_tree[parent_result_id]

    for idx, child in enumerate(children):
        yield child

        if count is not None and idx > count:
            break
set(self, result_id, result) async

Save a Result.

Source code in mognet/backend/memory_result_backend.py
async def set(self, result_id: UUID, result: Result):
    self._results[result_id] = result
set_metadata(self, result_id, **kwargs) async

Set metadata on a Result.

Source code in mognet/backend/memory_result_backend.py
async def set_metadata(self, result_id: UUID, **kwargs: Any) -> None:
    self._metadata.setdefault(result_id, {}).update(kwargs)
set_ttl(self, result_id, ttl, include_children=True) async

Set expiration on a Result.

If include_children is True, children will have the same TTL set.

Source code in mognet/backend/memory_result_backend.py
async def set_ttl(
    self, result_id: UUID, ttl: timedelta, include_children: bool = True
):
    pass
set_value(self, result_id, value) async

Set the value of a Result.

Source code in mognet/backend/memory_result_backend.py
async def set_value(self, result_id: UUID, value: ResultValueHolder):
    self._values[result_id] = value

redis_result_backend

RedisResultBackend (BaseResultBackend)

Result backend that uses Redis for persistence.

Source code in mognet/backend/redis_result_backend.py
class RedisResultBackend(BaseResultBackend):
    """
    Result backend that uses Redis for persistence.
    """

    def __init__(self, config: ResultBackendConfig, app: AppParameters) -> None:
        super().__init__(config, app)

        self._url = config.redis.url
        self.__redis = None
        self._connected = False

        # Holds references to tasks which are spawned by .wait()
        self._waiters: List[asyncio.Future] = []

        # Attributes for @_retry
        self._retry_connect_attempts = self.config.redis.retry_connect_attempts
        self._retry_connect_timeout = self.config.redis.retry_connect_timeout
        self._retry_lock = asyncio.Lock()

    @property
    def _redis(self) -> Redis:
        if self.__redis is None:
            raise NotConnected

        return self.__redis

    @_retry
    async def get(self, result_id: UUID) -> Optional[Result]:
        obj_key = self._format_key(result_id)

        async with self._redis.pipeline(transaction=True) as pip:
            # Since HGETALL returns an empty HASH for keys that don't exist,
            # test if it exists at all and use that to check if we should return null.
            pip.exists(obj_key)
            pip.hgetall(obj_key)

            exists, value, *_ = await shield(pip.execute())

        if not exists:
            return None

        return self._decode_result(value)

    @_retry
    async def get_or_create(self, result_id: UUID) -> Result:
        """
        Gets a result, or creates one if it doesn't exist.
        """
        async with self._redis.pipeline(transaction=True) as pip:

            result_key = self._format_key(result_id)

            pip.hsetnx(result_key, "id", json.dumps(str(result_id)).encode())
            pip.hgetall(result_key)

            if self.config.redis.result_ttl is not None:
                pip.expire(result_key, self.config.redis.result_ttl)

            # Also set the value, to a default holding an absence of result.
            value_key = self._format_key(result_id, "value")

            default_not_ready = ResultValueHolder.not_ready()
            encoded = self._encode_result_value(default_not_ready)

            if self.config.redis.result_value_ttl is not None:
                pip.expire(value_key, self.config.redis.result_value_ttl)

            for encoded_k, encoded_v in encoded.items():
                pip.hsetnx(value_key, encoded_k, encoded_v)

            existed, value, *_ = await shield(pip.execute())

        if not existed:
            _log.debug("Created result %r on key %r", result_id, result_key)

        return self._decode_result(value)

    def _encode_result_value(self, value: ResultValueHolder) -> Dict[str, bytes]:
        contents = value.model_dump_json().encode()
        encoding = b"null"

        if self.config.redis.result_value_encoding == Encoding.GZIP:
            encoding = _json_bytes("gzip")
            contents = gzip.compress(contents)

        return {
            "contents": contents,
            "encoding": encoding,
            "content_type": _json_bytes("application/json"),
        }

    def _decode_result_value(self, encoded: Dict[bytes, bytes]) -> ResultValueHolder:
        if encoded.get(b"encoding") == _json_bytes("gzip"):
            contents = gzip.decompress(encoded[b"contents"])
        else:
            contents = encoded[b"contents"]

        if encoded.get(b"content_type") != _json_bytes("application/json"):
            raise ValueError(f"Unknown content_type={encoded.get(b'content_type')!r}")

        return ResultValueHolder.model_validate_json(contents)

    @_retry
    async def set(self, result_id: UUID, result: Result):
        key = self._format_key(result_id)

        async with self._redis.pipeline(transaction=True) as pip:

            encoded = _encode_result(result)

            pip.hset(key, None, None, encoded)

            if self.config.redis.result_ttl is not None:
                pip.expire(key, self.config.redis.result_ttl)

            await shield(pip.execute())

    def _format_key(self, result_id: UUID, subkey: str = None) -> str:
        key = f"{self.app.name}.mognet.result.{str(result_id)}"

        if subkey:
            key = f"{key}/{subkey}"

        _log.debug(
            "Formatted result key=%r for id=%r and subkey=%r", key, subkey, result_id
        )

        return key

    @_retry
    async def add_children(self, result_id: UUID, *children: UUID):
        if not children:
            return

        # If there are children to add, add them to the set
        # on Redis using SADD
        children_key = self._format_key(result_id, "children")

        async with self._redis.pipeline(transaction=True) as pip:

            pip.sadd(children_key, *_encode_children(children))

            if self.config.redis.result_ttl is not None:
                pip.expire(children_key, self.config.redis.result_ttl)

            await shield(pip.execute())

    async def get_value(self, result_id: UUID) -> ResultValueHolder:
        value_key = self._format_key(result_id, "value")

        async with self._redis.pipeline(transaction=True) as pip:

            pip.exists(value_key)
            pip.hgetall(value_key)

            exists, contents = await shield(pip.execute())

            if not exists:
                raise ResultValueLost(result_id)

            return self._decode_result_value(contents)

    async def set_value(self, result_id: UUID, value: ResultValueHolder):
        value_key = self._format_key(result_id, "value")

        encoded = self._encode_result_value(value)

        async with self._redis.pipeline(transaction=True) as pip:

            pip.hset(value_key, None, None, encoded)

            if self.config.redis.result_value_ttl is not None:
                pip.expire(value_key, self.config.redis.result_value_ttl)

            await shield(pip.execute())

    async def delete(self, result_id: UUID, include_children: bool = True):
        if include_children:
            async for child_id in self.iterate_children_ids(result_id):
                await self.delete(child_id, include_children=True)

        key = self._format_key(result_id)
        children_key = self._format_key(result_id, "children")
        value_key = self._format_key(result_id, "value")
        metadata_key = self._format_key(result_id, "metadata")

        await shield(self._redis.delete(key, children_key, value_key, metadata_key))

    async def set_ttl(
        self, result_id: UUID, ttl: timedelta, include_children: bool = True
    ):
        if include_children:
            async for child_id in self.iterate_children_ids(result_id):
                await self.set_ttl(child_id, ttl, include_children=True)

        key = self._format_key(result_id)
        children_key = self._format_key(result_id, "children")
        value_key = self._format_key(result_id, "value")
        metadata_key = self._format_key(result_id, "metadata")

        await shield(self._redis.expire(key, ttl))
        await shield(self._redis.expire(children_key, ttl))
        await shield(self._redis.expire(value_key, ttl))
        await shield(self._redis.expire(metadata_key, ttl))

    async def connect(self):
        if self._connected:
            return

        self._connected = True

        await self._connect()

    async def close(self):
        self._connected = False

        await self._close_waiters()

        await self._disconnect()

    async def get_children_count(self, parent_result_id: UUID) -> int:
        children_key = self._format_key(parent_result_id, "children")

        return await shield(self._redis.scard(children_key))

    async def iterate_children_ids(
        self, parent_result_id: UUID, *, count: Optional[float] = None
    ):
        children_key = self._format_key(parent_result_id, "children")

        raw_child_id: bytes
        async for raw_child_id in self._redis.sscan_iter(children_key, count=count):
            child_id = UUID(bytes=raw_child_id)
            yield child_id

    async def iterate_children(
        self, parent_result_id: UUID, *, count: Optional[float] = None
    ):
        async for child_id in self.iterate_children_ids(parent_result_id, count=count):
            child = await self.get(child_id)

            if child is not None:
                yield child

    @_retry
    async def wait(
        self, result_id: UUID, timeout: Optional[float] = None, poll: float = 1
    ) -> Result:
        async def waiter():
            key = self._format_key(result_id=result_id)

            # Type def for the state key. It can (but shouldn't)
            # be null.
            t = Optional[ResultState]

            while True:
                raw_state = await shield(self._redis.hget(key, "state")) or b"null"

                state = TypeAdapter(t).validate_json(raw_state)

                if state is None:
                    raise ResultValueLost(result_id)

                if state in READY_STATES:
                    final_result = await self.get(result_id)

                    if final_result is None:
                        raise RuntimeError(
                            f"Result id={result_id!r} that previously existed no longer does"
                        )

                    return final_result

                await asyncio.sleep(poll)

        waiter_task = asyncio.create_task(
            waiter(),
            name=f"RedisResultBackend:wait_for:{result_id}",
        )

        if timeout:
            waiter_task = asyncio.create_task(asyncio.wait_for(waiter_task, timeout))

        self._waiters.append(waiter_task)

        return await waiter_task

    async def get_metadata(self, result_id: UUID) -> Dict[str, Any]:
        key = self._format_key(result_id, "metadata")

        value = await shield(self._redis.hgetall(key))

        return _decode_json_dict(value)

    async def set_metadata(self, result_id: UUID, **kwargs: Any) -> None:
        key = self._format_key(result_id, "metadata")

        if not kwargs:
            return

        async with self._redis.pipeline(transaction=True) as pip:

            pip.hset(key, None, None, _dict_to_json_dict(kwargs))

            if self.config.redis.result_ttl is not None:
                pip.expire(key, self.config.redis.result_ttl)

            await shield(pip.execute())

    def __repr__(self):
        return f"RedisResultBackend(url={censor_credentials(self._url)!r})"

    async def __aenter__(self):
        await self.connect()

        return self

    async def __aexit__(self, *args, **kwargs):
        await self.close()

    async def _close_waiters(self):
        """
        Cancel any wait loop we have running.
        """
        while self._waiters:
            waiter_task = self._waiters.pop()

            try:
                _log.debug("Cancelling waiter %r", waiter_task)

                waiter_task.cancel()
                await waiter_task
            except asyncio.CancelledError:
                pass
            except Exception as exc:  # pylint: disable=broad-except
                _log.debug("Error on waiter task %r", waiter_task, exc_info=exc)

    async def _create_redis(self):
        _log.debug("Creating Redis connection")
        redis: Redis = await from_url(
            self._url,
            max_connections=self.config.redis.max_connections,
        )

        return redis

    @_retry
    async def _connect(self):
        if self.__redis is None:
            self.__redis = await self._create_redis()

        await shield(self._redis.ping())

    async def _disconnect(self):
        redis = self.__redis

        if redis is not None:
            self.__redis = None
            _log.debug("Closing Redis connection")
            await redis.close()

    def _decode_result(self, json_dict: Dict[bytes, bytes]) -> Result:
        # Load the dict of JSON values first; then update it with overrides.
        value = _decode_json_dict(json_dict)
        return Result(self, **value)
close(self) async

Explicit method to close the backend provided by this Result backend.

Source code in mognet/backend/redis_result_backend.py
async def close(self):
    self._connected = False

    await self._close_waiters()

    await self._disconnect()
connect(self) async

Explicit method to connect to the backend provided by this Result backend.

Source code in mognet/backend/redis_result_backend.py
async def connect(self):
    if self._connected:
        return

    self._connected = True

    await self._connect()
delete(self, result_id, include_children=True) async

Delete a Result.

Source code in mognet/backend/redis_result_backend.py
async def delete(self, result_id: UUID, include_children: bool = True):
    if include_children:
        async for child_id in self.iterate_children_ids(result_id):
            await self.delete(child_id, include_children=True)

    key = self._format_key(result_id)
    children_key = self._format_key(result_id, "children")
    value_key = self._format_key(result_id, "value")
    metadata_key = self._format_key(result_id, "metadata")

    await shield(self._redis.delete(key, children_key, value_key, metadata_key))
get_children_count(self, parent_result_id) async

Return the number of children of a Result.

Returns 0 if the Result doesn't exist.

Source code in mognet/backend/redis_result_backend.py
async def get_children_count(self, parent_result_id: UUID) -> int:
    children_key = self._format_key(parent_result_id, "children")

    return await shield(self._redis.scard(children_key))
get_metadata(self, result_id) async

Get the metadata of a Result.

Returns an empty Dict if the Result doesn't exist.

Source code in mognet/backend/redis_result_backend.py
async def get_metadata(self, result_id: UUID) -> Dict[str, Any]:
    key = self._format_key(result_id, "metadata")

    value = await shield(self._redis.hgetall(key))

    return _decode_json_dict(value)
get_or_create(self, result_id) async

Gets a result, or creates one if it doesn't exist.

Source code in mognet/backend/redis_result_backend.py
@_retry
async def get_or_create(self, result_id: UUID) -> Result:
    """
    Gets a result, or creates one if it doesn't exist.
    """
    async with self._redis.pipeline(transaction=True) as pip:

        result_key = self._format_key(result_id)

        pip.hsetnx(result_key, "id", json.dumps(str(result_id)).encode())
        pip.hgetall(result_key)

        if self.config.redis.result_ttl is not None:
            pip.expire(result_key, self.config.redis.result_ttl)

        # Also set the value, to a default holding an absence of result.
        value_key = self._format_key(result_id, "value")

        default_not_ready = ResultValueHolder.not_ready()
        encoded = self._encode_result_value(default_not_ready)

        if self.config.redis.result_value_ttl is not None:
            pip.expire(value_key, self.config.redis.result_value_ttl)

        for encoded_k, encoded_v in encoded.items():
            pip.hsetnx(value_key, encoded_k, encoded_v)

        existed, value, *_ = await shield(pip.execute())

    if not existed:
        _log.debug("Created result %r on key %r", result_id, result_key)

    return self._decode_result(value)
get_value(self, result_id) async

Get the value of a Result.

If the value is lost, ResultValueLost is raised.

Source code in mognet/backend/redis_result_backend.py
async def get_value(self, result_id: UUID) -> ResultValueHolder:
    value_key = self._format_key(result_id, "value")

    async with self._redis.pipeline(transaction=True) as pip:

        pip.exists(value_key)
        pip.hgetall(value_key)

        exists, contents = await shield(pip.execute())

        if not exists:
            raise ResultValueLost(result_id)

        return self._decode_result_value(contents)
iterate_children(self, parent_result_id, *, count=None)

Get an AsyncGenerator for the children of a Result.

The AsyncGenerator will be empty if the Result doesn't exist.

Source code in mognet/backend/redis_result_backend.py
async def iterate_children(
    self, parent_result_id: UUID, *, count: Optional[float] = None
):
    async for child_id in self.iterate_children_ids(parent_result_id, count=count):
        child = await self.get(child_id)

        if child is not None:
            yield child
iterate_children_ids(self, parent_result_id, *, count=None)

Get an AsyncGenerator for the IDs for the children of a Result.

The AsyncGenerator will be empty if the Result doesn't exist.

Source code in mognet/backend/redis_result_backend.py
async def iterate_children_ids(
    self, parent_result_id: UUID, *, count: Optional[float] = None
):
    children_key = self._format_key(parent_result_id, "children")

    raw_child_id: bytes
    async for raw_child_id in self._redis.sscan_iter(children_key, count=count):
        child_id = UUID(bytes=raw_child_id)
        yield child_id
set_metadata(self, result_id, **kwargs) async

Set metadata on a Result.

Source code in mognet/backend/redis_result_backend.py
async def set_metadata(self, result_id: UUID, **kwargs: Any) -> None:
    key = self._format_key(result_id, "metadata")

    if not kwargs:
        return

    async with self._redis.pipeline(transaction=True) as pip:

        pip.hset(key, None, None, _dict_to_json_dict(kwargs))

        if self.config.redis.result_ttl is not None:
            pip.expire(key, self.config.redis.result_ttl)

        await shield(pip.execute())
set_ttl(self, result_id, ttl, include_children=True) async

Set expiration on a Result.

If include_children is True, children will have the same TTL set.

Source code in mognet/backend/redis_result_backend.py
async def set_ttl(
    self, result_id: UUID, ttl: timedelta, include_children: bool = True
):
    if include_children:
        async for child_id in self.iterate_children_ids(result_id):
            await self.set_ttl(child_id, ttl, include_children=True)

    key = self._format_key(result_id)
    children_key = self._format_key(result_id, "children")
    value_key = self._format_key(result_id, "value")
    metadata_key = self._format_key(result_id, "metadata")

    await shield(self._redis.expire(key, ttl))
    await shield(self._redis.expire(children_key, ttl))
    await shield(self._redis.expire(value_key, ttl))
    await shield(self._redis.expire(metadata_key, ttl))
set_value(self, result_id, value) async

Set the value of a Result.

Source code in mognet/backend/redis_result_backend.py
async def set_value(self, result_id: UUID, value: ResultValueHolder):
    value_key = self._format_key(result_id, "value")

    encoded = self._encode_result_value(value)

    async with self._redis.pipeline(transaction=True) as pip:

        pip.hset(value_key, None, None, encoded)

        if self.config.redis.result_value_ttl is not None:
            pip.expire(value_key, self.config.redis.result_value_ttl)

        await shield(pip.execute())

broker special

amqp_broker

AmqpBroker (BaseBroker)

Source code in mognet/broker/amqp_broker.py
class AmqpBroker(BaseBroker):

    _task_channel: Channel
    _control_channel: Channel

    _task_queues: Dict[str, Queue]

    _direct_exchange: Exchange
    _control_exchange: Exchange

    _retry = retryableasyncmethod(
        _RETRYABLE_ERRORS,
        max_attempts="_retry_connect_attempts",
        wait_timeout="_retry_connect_timeout",
    )

    def __init__(self, app: "App", config: BrokerConfig) -> None:
        super().__init__()

        self._connected = False
        self.__connection = None

        self.config = config

        self._task_queues = {}
        self._control_queue = None

        # Lock to prevent duplicate queue declaration
        self._lock = Lock()

        self.app = app

        # Attributes for @retryableasyncmethod
        self._retry_connect_attempts = self.config.amqp.retry_connect_attempts
        self._retry_connect_timeout = self.config.amqp.retry_connect_timeout

        # List of callbacks for when connection drops
        self._on_connection_failed_callbacks: List[
            Callable[[Optional[BaseException]], Awaitable]
        ] = []

    @property
    def _connection(self) -> Connection:
        if self.__connection is None:
            raise NotConnected

        return self.__connection

    async def ack(self, delivery_tag: str):
        await self._task_channel.channel.basic_ack(delivery_tag)

    async def nack(self, delivery_tag: str):
        await self._task_channel.channel.basic_nack(delivery_tag)

    @_retry
    async def set_task_prefetch(self, prefetch: int):
        await self._task_channel.set_qos(prefetch_count=prefetch, global_=True)

    @_retry
    async def send_task_message(self, queue: str, payload: MessagePayload):
        amqp_queue = self._task_queue_name(queue)

        msg = Message(
            body=payload.model_dump_json().encode(),
            content_type="application/json",
            content_encoding="utf-8",
            priority=payload.priority,
            message_id=payload.id,
        )

        await self._direct_exchange.publish(msg, amqp_queue)

        _log.debug(
            "Message %r sent to queue=%r (amqp queue=%r)", payload.id, queue, amqp_queue
        )

    async def consume_tasks(
        self, queue: str
    ) -> AsyncGenerator[IncomingMessagePayload, None]:

        amqp_queue = await self._get_or_create_task_queue(TaskQueue(name=queue))

        async for message in self._consume(amqp_queue):
            yield message

    async def consume_control_queue(
        self,
    ) -> AsyncGenerator[IncomingMessagePayload, None]:

        amqp_queue = await self._get_or_create_control_queue()

        async for message in self._consume(amqp_queue):
            yield message

    @_retry
    async def send_control_message(self, payload: MessagePayload):
        msg = Message(
            body=payload.model_dump_json().encode(),
            content_type="application/json",
            content_encoding="utf-8",
            message_id=payload.id,
            expiration=timedelta(seconds=300),
        )

        # No queue name set because this is a fanout exchange.
        await self._control_exchange.publish(msg, "")

    @_retry
    async def _send_query_message(self, payload: MessagePayload):
        callback_queue = await self._task_channel.declare_queue(
            name=self._callback_queue_name,
            durable=False,
            exclusive=False,
            auto_delete=True,
            arguments={
                "x-expires": 30000,
                "x-message-ttl": 30000,
            },
        )
        await callback_queue.bind(self._direct_exchange)

        msg = Message(
            body=payload.model_dump_json().encode(),
            content_type="application/json",
            content_encoding="utf-8",
            message_id=payload.id,
            expiration=timedelta(seconds=300),
            reply_to=callback_queue.name,
        )

        await self._control_exchange.publish(msg, "")

        return callback_queue

    async def send_query_message(
        self, payload: MessagePayload
    ) -> AsyncGenerator[QueryResponseMessage, None]:

        # Create a callback queue for getting the replies,
        # then send the message to the control exchange (fanout).
        # When done, delete the callback queue.

        callback_queue = None
        try:
            callback_queue = await self._send_query_message(payload)

            async with callback_queue.iterator() as iterator:
                async for message in iterator:
                    async with message.process():
                        contents: dict = json.loads(message.body)
                        msg = _AmqpIncomingMessagePayload(
                            broker=self, incoming_message=message, **contents
                        )
                        yield QueryResponseMessage.model_validate(msg.payload)
        finally:
            if callback_queue is not None:
                await callback_queue.delete()

    async def setup_control_queue(self):
        await self._get_or_create_control_queue()

    async def setup_task_queue(self, queue: TaskQueue):
        await self._get_or_create_task_queue(queue)

    @_retry
    async def _create_connection(self):
        connection = await aio_pika.connect_robust(
            self.config.amqp.url,
            reconnect_interval=self.app.config.reconnect_interval,
            client_properties={
                "connection_name": self.app.config.node_id,
            },
        )

        # All callback for broadcasting unexpected connection drops
        connection.add_close_callback(self._send_connection_failed_events)

        return connection

    def add_connection_failed_callback(
        self, cb: Callable[[Optional[BaseException]], Awaitable]
    ):
        self._on_connection_failed_callbacks.append(cb)

    def _send_connection_failed_events(self, connection, exc=None):
        if not self._connected:
            _log.debug(
                "Not sending connection closed events because we are disconnected"
            )
            return

        _log.error("AMQP connection %r failed", connection, exc_info=exc)

        tasks = [cb(exc) for cb in self._on_connection_failed_callbacks]

        _log.info(
            "Notifying %r listeners of a disconnect",
            len(tasks),
        )

        def notify_task_completion_callback(fut: asyncio.Future):
            exc = fut.exception()

            if exc and not fut.cancelled():
                _log.error("Error notifying connection dropped", exc_info=exc)

        for task in tasks:
            notify_task = asyncio.create_task(task)
            notify_task.add_done_callback(notify_task_completion_callback)

    async def connect(self):
        if self._connected:
            return

        self._connected = True

        self.__connection = await self._create_connection()

        # Use two separate channels with separate prefetch counts.
        # This allows the task channel to increase the prefetch count
        # without affecting the control channel,
        # and allows the control channel to still receive messages, even if
        # the task channel has reached the full prefetch count.
        self._task_channel = await self._connection.channel()
        await self.set_task_prefetch(1)

        self._control_channel = await self._connection.channel()
        await self.set_control_prefetch(4)

        await self._create_exchanges()

        _log.debug("Connected")

    async def set_control_prefetch(self, prefetch: int):
        await self._control_channel.set_qos(prefetch_count=prefetch, global_=False)

    async def close(self):
        self._connected = False

        connection = self.__connection

        if connection is not None:
            self.__connection = None

            _log.debug("Closing connections")
            await connection.close()
            _log.debug("Connection closed")

    @_retry
    async def send_reply(self, message: IncomingMessagePayload, reply: MessagePayload):
        if not message.reply_to:
            raise ValueError("Message has no reply_to set")

        msg = Message(
            body=reply.model_dump_json().encode(),
            content_type="application/json",
            content_encoding="utf-8",
            message_id=reply.id,
        )

        await self._direct_exchange.publish(msg, message.reply_to)

    async def purge_task_queue(self, queue: str) -> int:
        amqp_queue = self._task_queue_name(queue)

        if amqp_queue not in self._task_queues:
            _log.warning(
                "Queue %r (amqp=%r) does not exist in this broker", queue, amqp_queue
            )
            return 0

        result = await self._task_queues[amqp_queue].purge()

        deleted_count: int = result.message_count

        _log.info(
            "Deleted %r messages from queue=%r (amqp=%r)",
            deleted_count,
            queue,
            amqp_queue,
        )

        return deleted_count

    async def purge_control_queue(self) -> int:
        if not self._control_queue:
            _log.debug("Not listening on any control queue, not purging it")
            return 0

        result = await self._control_queue.purge()

        return result.message_count

    def __repr__(self):
        return f"AmqpBroker(url={censor_credentials(self.config.amqp.url)!r})"

    async def __aenter__(self):
        await self.connect()

        return self

    async def __aexit__(self, *args, **kwargs):
        await self.close()

        return None

    async def _create_exchanges(self):
        self._direct_exchange = await self._task_channel.declare_exchange(
            self._direct_exchange_name,
            type=ExchangeType.DIRECT,
            durable=True,
        )
        self._control_exchange = await self._control_channel.declare_exchange(
            self._control_exchange_name,
            type=ExchangeType.FANOUT,
        )

    async def _consume(
        self, amqp_queue: Queue
    ) -> AsyncGenerator[IncomingMessagePayload, None]:

        async with amqp_queue.iterator() as queue_iterator:
            msg: IncomingMessage
            async for msg in queue_iterator:

                try:
                    contents: dict = json.loads(msg.body)

                    payload = _AmqpIncomingMessagePayload(
                        broker=self, incoming_message=msg, **contents
                    )

                    _log.debug("Successfully parsed message %r", payload.id)

                    yield payload
                except asyncio.CancelledError:
                    raise
                except Exception as exc:  # pylint: disable=broad-except
                    _log.error(
                        "Error parsing contents of message %r, discarding",
                        msg.correlation_id,
                        exc_info=exc,
                    )

                    try:
                        await asyncio.shield(msg.ack())
                    except Exception as ack_err:
                        _log.error(
                            "Could not ACK message %r for discarding",
                            msg.correlation_id,
                            exc_info=ack_err,
                        )

    def _task_queue_name(self, name: str) -> str:
        return f"{self.app.name}.{name}"

    @property
    def _control_queue_name(self) -> str:
        return f"{self.app.name}.mognet.control.{self.app.config.node_id}"

    @property
    def _callback_queue_name(self) -> str:
        return f"{self.app.name}.mognet.callback.{self.app.config.node_id}"

    @property
    def _control_exchange_name(self) -> str:
        return f"{self.app.name}.mognet.control"

    @property
    def _direct_exchange_name(self) -> str:
        return f"{self.app.name}.mognet.direct"

    @_retry
    async def _get_or_create_control_queue(self) -> Queue:
        if self._control_queue is None:
            async with self._lock:
                if self._control_queue is None:
                    self._control_queue = await self._control_channel.declare_queue(
                        name=self._control_queue_name,
                        durable=False,
                        auto_delete=True,
                        arguments={
                            "x-expires": 30000,
                            "x-message-ttl": 30000,
                        },
                    )
                    await self._control_queue.bind(self._control_exchange)

                    _log.debug("Prepared control queue=%r", self._control_queue.name)

        return self._control_queue

    @_retry
    async def _get_or_create_task_queue(self, queue: TaskQueue) -> Queue:
        name = self._task_queue_name(queue.name)

        if name not in self._task_queues:
            async with self._lock:

                if name not in self._task_queues:
                    _log.debug("Preparing queue %r as AMQP queue=%r", queue, name)

                    self._task_queues[name] = await self._task_channel.declare_queue(
                        name,
                        durable=True,
                        arguments={"x-max-priority": queue.max_priority},
                    )

                    await self._task_queues[name].bind(self._direct_exchange)

                    _log.debug(
                        "Prepared task queue=%r as AMQP queue=%r", queue.name, name
                    )

        return self._task_queues[name]

    async def task_queue_stats(self, task_queue_name: str) -> QueueStats:
        """
        Get the stats of a task queue.
        """

        name = self._task_queue_name(task_queue_name)

        # AMQP can close the Channel on us if we try accessing an object
        # that doesn't exist. So, to avoid trouble when the same Channel
        # is being used to consume a queue, use an ephemeral channel
        # for this operation alone.
        async with self._connection.channel() as channel:
            try:
                queue = await channel.get_queue(name, ensure=False)

                declare_result = await queue.declare()
            except (
                aiormq.exceptions.ChannelNotFoundEntity,
                aio_pika.exceptions.ChannelClosed,
            ) as query_err:
                raise QueueNotFound(task_queue_name) from query_err

        return QueueStats(
            queue_name=task_queue_name,
            message_count=declare_result.message_count,
            consumer_count=declare_result.consumer_count,
        )
task_queue_stats(self, task_queue_name) async

Get the stats of a task queue.

Source code in mognet/broker/amqp_broker.py
async def task_queue_stats(self, task_queue_name: str) -> QueueStats:
    """
    Get the stats of a task queue.
    """

    name = self._task_queue_name(task_queue_name)

    # AMQP can close the Channel on us if we try accessing an object
    # that doesn't exist. So, to avoid trouble when the same Channel
    # is being used to consume a queue, use an ephemeral channel
    # for this operation alone.
    async with self._connection.channel() as channel:
        try:
            queue = await channel.get_queue(name, ensure=False)

            declare_result = await queue.declare()
        except (
            aiormq.exceptions.ChannelNotFoundEntity,
            aio_pika.exceptions.ChannelClosed,
        ) as query_err:
            raise QueueNotFound(task_queue_name) from query_err

    return QueueStats(
        queue_name=task_queue_name,
        message_count=declare_result.message_count,
        consumer_count=declare_result.consumer_count,
    )

cli special

exceptions

GracefulShutdown (BaseException)

If this exception is raised from a coroutine, the Mognet app running from the CLI will be gracefully closed

Source code in mognet/cli/exceptions.py
class GracefulShutdown(BaseException):
    """If this exception is raised from a coroutine, the Mognet app running from the CLI will be gracefully closed"""

main

callback(app=<typer.models.ArgumentInfo object at 0x7fd935785e70>, log_level=<typer.models.OptionInfo object at 0x7fd9357863b0>, log_format=<typer.models.OptionInfo object at 0x7fd9357863e0>)

Mognet CLI

Source code in mognet/cli/main.py
@main.callback()
def callback(
    app: str = typer.Argument(..., help="App module to import"),
    log_level: LogLevel = typer.Option("INFO", metavar="log-level"),
    log_format: str = typer.Option(
        "%(asctime)s:%(name)s:%(levelname)s:%(message)s", metavar="log-format"
    ),
):
    """Mognet CLI"""

    logging.basicConfig(
        level=getattr(logging, log_level.value),
        format=log_format,
    )

    app_instance = _get_app(app)
    state["app_instance"] = app_instance

models

LogLevel (Enum)

Log Level

Source code in mognet/cli/models.py
class LogLevel(Enum):
    """Log Level"""

    DEBUG = "DEBUG"
    INFO = "INFO"
    WARNING = "WARNING"
    ERROR = "ERROR"

OutputFormat (str, Enum)

CLI output format

Source code in mognet/cli/models.py
class OutputFormat(str, Enum):
    """CLI output format"""

    TEXT = "text"
    JSON = "json"

nodes

status(format=<typer.models.OptionInfo object at 0x7fd93599f130>, text_label_format=<typer.models.OptionInfo object at 0x7fd93599c9a0>, json_indent=<typer.models.OptionInfo object at 0x7fd93599c970>, poll=<typer.models.OptionInfo object at 0x7fd93599f4f0>, timeout=<typer.models.OptionInfo object at 0x7fd93599f8e0>) async

Query each node for their status

Source code in mognet/cli/nodes.py
@group.command("status")
@run_in_loop
async def status(
    format: OutputFormat = typer.Option(OutputFormat.TEXT, metavar="format"),
    text_label_format: str = typer.Option(
        "{name}(id={id!r}, state={state!r})",
        metavar="text-label-format",
        help="Label format for text format",
    ),
    json_indent: int = typer.Option(2, metavar="json-indent"),
    poll: Optional[int] = typer.Option(
        None,
        metavar="poll",
        help="Polling interval, in seconds (default=None)",
    ),
    timeout: int = typer.Option(
        30,
        help="Timeout for querying nodes",
    ),
):
    """Query each node for their status"""

    async with state["app_instance"] as app:
        while True:
            each_node_status: List[StatusResponseMessage] = []

            async def read_status():
                async for node_status in app.get_current_status_of_nodes():
                    each_node_status.append(node_status)

            try:
                await asyncio.wait_for(read_status(), timeout=timeout)
            except asyncio.TimeoutError:
                pass

            all_result_ids = set()

            for node_status in each_node_status:
                all_result_ids.update(node_status.payload.running_request_ids)

            all_results_by_id = {
                r.id: r
                for r in await app.result_backend.get_many(
                    *all_result_ids,
                )
                if r is not None
            }

            report = _CliStatusReport()

            for node_status in each_node_status:
                running_requests = [
                    all_results_by_id[r]
                    for r in node_status.payload.running_request_ids
                    if r in all_results_by_id
                ]
                running_requests.sort(key=lambda r: r.created or now_utc())

                report.node_status.append(
                    _CliStatusReport.NodeStatus(
                        node_id=node_status.node_id, running_requests=running_requests
                    )
                )

            if poll:
                typer.clear()

            if format == "text":
                table_headers = ("Node name", "Running requests")

                table_data = [
                    (
                        n.node_id,
                        "\n".join(
                            text_label_format.format(**r.dict())
                            for r in n.running_requests
                        )
                        or "(Empty)",
                    )
                    for n in report.node_status
                ]

                typer.echo(
                    f"{len(report.node_status)} nodes replied as of {datetime.now()}:"
                )

                typer.echo(tabulate.tabulate(table_data, headers=table_headers))

            elif format == "json":
                typer.echo(report.model_dump_json(indent=json_indent))

            if not poll:
                break

            await asyncio.sleep(poll)

queues

purge(force=<typer.models.OptionInfo object at 0x7fd9364a1570>) async

Purge task and control queues

Source code in mognet/cli/queues.py
@group.command("purge")
@run_in_loop
async def purge(force: bool = typer.Option(False)):
    """Purge task and control queues"""

    if not force:
        typer.echo("Must pass --force")
        raise typer.Exit(1)

    async with state["app_instance"] as app:
        await app.connect()

        purged_task_counts = await app.purge_task_queues()
        purged_control_count = await app.purge_control_queue()

        typer.echo("Purged the following queues:")

        for queue_name, count in purged_task_counts.items():
            typer.echo(f"\t- {queue_name!r}: {count!r}")

        typer.echo(f"Purged {purged_control_count!r} control messages")

run

run(include_queues=<typer.models.OptionInfo object at 0x7fd935831d80>, exclude_queues=<typer.models.OptionInfo object at 0x7fd935831d50>)

Run the app

Source code in mognet/cli/run.py
@group.callback()
def run(
    include_queues: Optional[str] = typer.Option(
        None,
        metavar="include-queues",
        help="Comma-separated list of the ONLY queues to listen on.",
    ),
    exclude_queues: Optional[str] = typer.Option(
        None,
        metavar="exclude-queues",
        help="Comma-separated list of the ONLY queues to NOT listen on.",
    ),
):
    """Run the app"""

    app = state["app_instance"]

    # Allow overriding the queues this app listens on.
    queues = app.config.task_queues

    if include_queues is not None:
        queues.include = set(q.strip() for q in include_queues.split(","))

    if exclude_queues is not None:
        queues.exclude = set(q.strip() for q in exclude_queues.split(","))

    queues.ensure_valid()

    async def start():
        async with app:
            await app.start()

    async def stop(_: AbstractEventLoop):
        _log.info("Going to close app as part of a shut down")
        await app.close()

    pending_exception_to_raise = SystemExit(0)

    def custom_exception_handler(loop: AbstractEventLoop, context: dict):
        """See: https://docs.python.org/3/library/asyncio-eventloop.html#error-handling-api"""

        nonlocal pending_exception_to_raise

        exc = context.get("exception")

        if isinstance(exc, GracefulShutdown):
            _log.debug("Got GracefulShutdown")
        elif isinstance(exc, BaseException):
            pending_exception_to_raise = exc

            _log.error(
                "Unhandled exception; stopping loop: %r %r",
                context.get("message"),
                context,
                exc_info=pending_exception_to_raise,
            )

        loop.stop()

    loop = asyncio.get_event_loop()
    loop.set_exception_handler(custom_exception_handler)

    aiorun.run(
        start(), loop=loop, stop_on_unhandled_errors=False, shutdown_callback=stop
    )

    if pending_exception_to_raise is not None:
        raise pending_exception_to_raise

    return 0

run_in_loop

run_in_loop(f)

Utility to run a click/typer command function in an event loop (because they don't support it out of the box)

Source code in mognet/cli/run_in_loop.py
def run_in_loop(f):
    """
    Utility to run a click/typer command function in an event loop
    (because they don't support it out of the box)
    """

    @wraps(f)
    def in_loop(*args, **kwargs):
        loop = asyncio.get_event_loop()
        return loop.run_until_complete(f(*args, **kwargs))

    return in_loop

tasks

get(task_id=<typer.models.ArgumentInfo object at 0x7fd9357841c0>, include_value=<typer.models.OptionInfo object at 0x7fd935787dc0>) async

Get a task's details

Source code in mognet/cli/tasks.py
@group.command("get")
@run_in_loop
async def get(
    task_id: UUID = typer.Argument(
        ...,
        metavar="id",
        help="Task ID to get",
    ),
    include_value: bool = typer.Option(
        False,
        metavar="include-value",
        help="If passed, the task's result (or exception) will be printed",
    ),
):
    """Get a task's details"""

    async with state["app_instance"] as app:
        res = await app.result_backend.get(task_id)

        if res is None:
            _log.warning("Request %r does not exist", task_id)
            raise typer.Exit(1)

        table_data = [
            ("ID", res.id),
            ("Name", res.name),
            ("Arguments", res.request_kwargs_repr),
            ("State", res.state),
            ("Number of starts", res.number_of_starts),
            ("Number of stops", res.number_of_stops),
            ("Unexpected retry count", res.unexpected_retry_count),
            ("Parent", res.parent_id),
            ("Created at", res.created),
            ("Started at", res.started),
            ("Time in queue", res.queue_time),
            ("Finished at", res.finished),
            ("Runtime duration", res.duration),
            ("Node ID", res.node_id),
            ("Metadata", await res.get_metadata()),
        ]

        if include_value:
            try:
                value = await res.value.get_raw_value()

                if isinstance(value, _ExceptionInfo):
                    table_data.append(("Error raised", value.traceback))
                else:
                    table_data.append(("Result value", repr(value)))

            except ResultValueLost:
                table_data.append(("Result value", "<Lost>"))

        print(tabulate.tabulate(table_data))

revoke(task_id=<typer.models.ArgumentInfo object at 0x7fd935787160>, force=<typer.models.OptionInfo object at 0x7fd9357875e0>) async

Revoke a task

Source code in mognet/cli/tasks.py
@group.command("revoke")
@run_in_loop
async def revoke(
    task_id: UUID = typer.Argument(
        ...,
        metavar="id",
        help="Task ID to revoke",
    ),
    force: bool = typer.Option(
        False,
        metavar="force",
        help="Attempt revoking anyway if the result is complete. Helps cleaning up cases where subtasks may have been spawned.",
    ),
):
    """Revoke a task"""

    async with state["app_instance"] as app:

        res = await app.result_backend.get(task_id)

        ret_code = 0

        if res is None:
            _log.warning("Request %r does not exist", task_id)
            ret_code = 1

        await app.revoke(task_id, force=force)

        raise typer.Exit(ret_code)

tree(task_id=<typer.models.ArgumentInfo object at 0x7fd935787490>, format=<typer.models.OptionInfo object at 0x7fd935787d30>, json_indent=<typer.models.OptionInfo object at 0x7fd935787550>, text_label_format=<typer.models.OptionInfo object at 0x7fd935787520>, max_depth=<typer.models.OptionInfo object at 0x7fd935787640>, max_width=<typer.models.OptionInfo object at 0x7fd935787ca0>, poll=<typer.models.OptionInfo object at 0x7fd935787c40>) async

Get the tree (descendants) of a task

Source code in mognet/cli/tasks.py
@group.command("tree")
@run_in_loop
async def tree(
    task_id: UUID = typer.Argument(
        ...,
        metavar="id",
        help="Task ID to get tree from",
    ),
    format: OutputFormat = typer.Option(OutputFormat.TEXT, metavar="format"),
    json_indent: int = typer.Option(2, metavar="json-indent"),
    text_label_format: str = typer.Option(
        "{name}(id={id!r}, state={state!r})",
        metavar="text-label-format",
        help="Label format for text format",
    ),
    max_depth: int = typer.Option(3, metavar="max-depth"),
    max_width: int = typer.Option(16, metavar="max-width"),
    poll: Optional[int] = typer.Option(None, metavar="poll"),
):
    """Get the tree (descendants) of a task"""

    async with state["app_instance"] as app:
        while True:
            result = await app.result_backend.get(task_id)

            if result is None:
                raise RuntimeError(f"Result for request id={task_id!r} does not exist")

            _log.info("Building tree for result id=%r", result.id)

            tree = await result.tree(max_depth=max_depth, max_width=max_width)

            if poll:
                typer.clear()

            if format == "text":
                t = treelib.Tree()

                def build_tree(n: ResultTree, parent: Optional[ResultTree] = None):
                    t.create_node(
                        tag=text_label_format.format(**n.dict()),
                        identifier=n.result.id,
                        parent=None if parent is None else parent.result.id,
                    )

                    for c in n.children:
                        build_tree(c, parent=n)

                build_tree(tree)

                t.show()

            if format == "json":
                print(tree.model_dump_json(indent=json_indent))

            if not poll:
                break

            await asyncio.sleep(poll)

context special

context

Context

Context for a request.

Allows access to the App instance, task state, and the request that is part of this task execution.

Source code in mognet/context/context.py
class Context:
    """
    Context for a request.

    Allows access to the App instance, task state,
    and the request that is part of this task execution.
    """

    app: "App"

    state: "State"

    request: "Request"

    _dependencies: Set[UUID]

    def __init__(
        self,
        app: "App",
        request: "Request",
        state: "State",
        worker: "Worker",
    ):
        self.app = app
        self.state = state
        self.request = request
        self._worker = worker

        self._dependencies = set()

        self.create_request = self.app.create_request

    async def submit(self, request: "Request"):
        """
        Submits a new request as part of this one.

        The difference from this method to the one defined in the `App` class
        is that this one will submit the new request as a child request of
        the one that's a part of this `Context` instance. This allows
        the subrequests to be cancelled if the parent is also cancelled.
        """
        return await self.app.submit(request, self)

    @overload
    async def run(self, request: Request[_Return]) -> _Return:
        """
        Submits a Request to be run as part of this one (see `submit`), and waits for the result
        """
        ...

    @overload
    async def run(
        self,
        request: Callable[Concatenate["Context", _P], _Return],
        *args: _P.args,
        **kwargs: _P.kwargs
    ) -> _Return:
        """
        Short-hand method for creating a Request from a function decorated with `@task`,
        (see `create_request`), submitting it (see `submit`) and waiting for the result (see `run(Request)`).

        This overload is for documenting non-async def functions.
        """
        ...

    # This overload unwraps the Awaitable object
    @overload
    async def run(
        self,
        request: Callable[Concatenate["Context", _P], Awaitable[_Return]],
        *args: _P.args,
        **kwargs: _P.kwargs
    ) -> _Return:
        """
        Short-hand method for creating a Request from a function decorated with `@task`,
        (see `create_request`), submitting it (see `submit`) and waiting for the result (see `run(Request)`).

        This overload is for documenting async def functions.
        """
        ...

    async def run(self, request, *args, **kwargs):
        """
        Submits and runs a new request as part of this one.

        See `submit` for the difference between this and the equivalent
        `run` method on the `App` class.
        """

        if not isinstance(request, Request):
            request = self.create_request(request, *args, **kwargs)

        cancelled = False
        try:
            had_dependencies = bool(self._dependencies)

            self._dependencies.add(request.id)

            # If we transition from having no dependencies
            # to having some, then we should suspend.
            if not had_dependencies and self._dependencies:
                await asyncio.shield(self._suspend())

            self._log_dependencies()

            return await self.app.run(request, self)
        except asyncio.CancelledError:
            cancelled = True
        finally:
            self._dependencies.remove(request.id)

            self._log_dependencies()

            if not self._dependencies and not cancelled:
                await asyncio.shield(self._resume())

    def _log_dependencies(self):
        _log.debug(
            "Task %r is waiting on %r dependencies",
            self.request,
            len(self._dependencies),
        )

    async def gather(
        self, *results_or_ids: Union["Result", UUID], return_exceptions: bool = False
    ) -> List[Any]:
        results = []
        cancelled = False
        try:
            for result in results_or_ids:
                if isinstance(result, UUID):
                    result = await self.app.result_backend.get(result)

                results.append(result)

            # If we transition from having no dependencies
            # to having some, then we should suspend.
            had_dependencies = bool(self._dependencies)
            self._dependencies.update(r.id for r in results)

            self._log_dependencies()

            if not had_dependencies and self._dependencies:
                await asyncio.shield(self._suspend())

            return await asyncio.gather(*results, return_exceptions=return_exceptions)
        except asyncio.CancelledError:
            cancelled = True
            raise
        finally:
            self._dependencies.difference_update(r.id for r in results)

            self._log_dependencies()

            if not self._dependencies and not cancelled:
                await asyncio.shield(self._resume())

    @overload
    def get_service(
        self, func: Type[ClassService[_Return]], *args, **kwargs
    ) -> _Return:
        ...

    @overload
    def get_service(
        self,
        func: Callable[Concatenate["Context", _P], _Return],
        *args: _P.args,
        **kwargs: _P.kwargs
    ) -> _Return:
        ...

    def get_service(self, func, *args, **kwargs):
        """
        Get a service to use in the task function.
        This can be used for dependency injection purposes.
        """

        if inspect.isclass(func) and issubclass(func, ClassService):
            if func not in self.app.services:
                # This cast() is only here to silence Pylance (because it thinks the class is abstract)
                instance: ClassService = cast(Any, func)(self.app.config)
                self.app.services[func] = instance.__enter__()

            svc = self.app.services[func]
        else:
            svc = self.app.services.setdefault(func, func)

        return svc(self, *args, **kwargs)

    async def _suspend(self):
        _log.debug("Suspending %r", self.request)

        result = await self.get_result()

        if result.state == ResultState.RUNNING:
            await result.suspend()

        await self._worker.add_waiting_task(self.request.id)

    async def get_result(self):
        """
        Gets the Result associated with this task.

        WARNING: Do not `await` the returned Result instance! You will run
        into a deadlock (you will be awaiting yourself)
        """
        result = await self.app.result_backend.get(self.request.id)

        if result is None:
            raise ResultLost(self.request.id)

        return result

    def call_threadsafe(self, coro: Awaitable[_Return]) -> _Return:
        """
        NOTE: ONLY TO BE USED WITH SYNC TASKS!

        Utility function that will run the coroutine in the app's event loop
        in a thread-safe way.

        In reality this is a wrapper for `asyncio.run_coroutine_threadsafe(...)`

        Use as follows:

        ```
        context.call_sync(context.submit(...))
        ```
        """
        return asyncio.run_coroutine_threadsafe(coro, loop=self.app.loop).result()

    async def set_metadata(self, **kwargs: Any):
        """
        Update metadata on the Result associated with the current task.
        """

        result = await self.get_result()
        return await result.set_metadata(**kwargs)

    async def _resume(self):
        _log.debug("Resuming %r", self.request)

        result = await self.get_result()

        if result.state == ResultState.SUSPENDED:
            await result.resume()

        await self._worker.remove_suspended_task(self.request.id)
call_threadsafe(self, coro)

NOTE: ONLY TO BE USED WITH SYNC TASKS!

Utility function that will run the coroutine in the app's event loop in a thread-safe way.

In reality this is a wrapper for asyncio.run_coroutine_threadsafe(...)

Use as follows:

context.call_sync(context.submit(...))
Source code in mognet/context/context.py
def call_threadsafe(self, coro: Awaitable[_Return]) -> _Return:
    """
    NOTE: ONLY TO BE USED WITH SYNC TASKS!

    Utility function that will run the coroutine in the app's event loop
    in a thread-safe way.

    In reality this is a wrapper for `asyncio.run_coroutine_threadsafe(...)`

    Use as follows:

    ```
    context.call_sync(context.submit(...))
    ```
    """
    return asyncio.run_coroutine_threadsafe(coro, loop=self.app.loop).result()
get_result(self) async

Gets the Result associated with this task.

WARNING: Do not await the returned Result instance! You will run into a deadlock (you will be awaiting yourself)

Source code in mognet/context/context.py
async def get_result(self):
    """
    Gets the Result associated with this task.

    WARNING: Do not `await` the returned Result instance! You will run
    into a deadlock (you will be awaiting yourself)
    """
    result = await self.app.result_backend.get(self.request.id)

    if result is None:
        raise ResultLost(self.request.id)

    return result
get_service(self, func, *args, **kwargs)

Get a service to use in the task function. This can be used for dependency injection purposes.

Source code in mognet/context/context.py
def get_service(self, func, *args, **kwargs):
    """
    Get a service to use in the task function.
    This can be used for dependency injection purposes.
    """

    if inspect.isclass(func) and issubclass(func, ClassService):
        if func not in self.app.services:
            # This cast() is only here to silence Pylance (because it thinks the class is abstract)
            instance: ClassService = cast(Any, func)(self.app.config)
            self.app.services[func] = instance.__enter__()

        svc = self.app.services[func]
    else:
        svc = self.app.services.setdefault(func, func)

    return svc(self, *args, **kwargs)
run(self, request, *args, **kwargs) async

Submits and runs a new request as part of this one.

See submit for the difference between this and the equivalent run method on the App class.

Source code in mognet/context/context.py
async def run(self, request, *args, **kwargs):
    """
    Submits and runs a new request as part of this one.

    See `submit` for the difference between this and the equivalent
    `run` method on the `App` class.
    """

    if not isinstance(request, Request):
        request = self.create_request(request, *args, **kwargs)

    cancelled = False
    try:
        had_dependencies = bool(self._dependencies)

        self._dependencies.add(request.id)

        # If we transition from having no dependencies
        # to having some, then we should suspend.
        if not had_dependencies and self._dependencies:
            await asyncio.shield(self._suspend())

        self._log_dependencies()

        return await self.app.run(request, self)
    except asyncio.CancelledError:
        cancelled = True
    finally:
        self._dependencies.remove(request.id)

        self._log_dependencies()

        if not self._dependencies and not cancelled:
            await asyncio.shield(self._resume())
set_metadata(self, **kwargs) async

Update metadata on the Result associated with the current task.

Source code in mognet/context/context.py
async def set_metadata(self, **kwargs: Any):
    """
    Update metadata on the Result associated with the current task.
    """

    result = await self.get_result()
    return await result.set_metadata(**kwargs)
submit(self, request) async

Submits a new request as part of this one.

The difference from this method to the one defined in the App class is that this one will submit the new request as a child request of the one that's a part of this Context instance. This allows the subrequests to be cancelled if the parent is also cancelled.

Source code in mognet/context/context.py
async def submit(self, request: "Request"):
    """
    Submits a new request as part of this one.

    The difference from this method to the one defined in the `App` class
    is that this one will submit the new request as a child request of
    the one that's a part of this `Context` instance. This allows
    the subrequests to be cancelled if the parent is also cancelled.
    """
    return await self.app.submit(request, self)

decorators special

task_decorator

task(*, name=None)

Register a function as a task that can be run.

The name argument is recommended, but not required. It is used as an identifier for which task to run when creating Request objects.

If the name is not provided, the function's full name (module + name) is used instead. Bear in mind that this means that if you rename the module or the function, things may break during rolling upgrades.

Source code in mognet/decorators/task_decorator.py
def task(*, name: Optional[str] = None):
    """
    Register a function as a task that can be run.

    The name argument is recommended, but not required. It is used as an identifier
    for which task to run when creating Request objects.

    If the name is not provided, the function's full name (module + name) is used instead.
    Bear in mind that this means that if you rename the module or the function, things may break
    during rolling upgrades.
    """

    def task_decorator(t: _T) -> _T:
        reg = task_registry.get(None)

        if reg is None:
            _log.debug("No global task registry set. Creating one")

            reg = TaskRegistry()
            reg.register_globally()

        reg.add_task_function(cast(Callable, t), name=name)

        return t

    return task_decorator

exceptions special

base_exceptions

ConnectionError (MognetError)

Base class for connection errors

Source code in mognet/exceptions/base_exceptions.py
class ConnectionError(MognetError):
    """Base class for connection errors"""

CouldNotSubmit (MognetError)

The Request could not be submitted

Source code in mognet/exceptions/base_exceptions.py
class CouldNotSubmit(MognetError):
    """The Request could not be submitted"""

ImproperlyConfigured (MognetError)

Base class for configuration-based errors

Source code in mognet/exceptions/base_exceptions.py
class ImproperlyConfigured(MognetError):
    """Base class for configuration-based errors"""

MognetError (Exception)

Base class for all Mognet errors

Source code in mognet/exceptions/base_exceptions.py
class MognetError(Exception):
    """Base class for all Mognet errors"""

NotConnected (ConnectionError)

Not connected. Either call connect(), or use a context manager

Source code in mognet/exceptions/base_exceptions.py
class NotConnected(ConnectionError):
    """Not connected. Either call connect(), or use a context manager"""

result_exceptions

ResultLost (ResultError)

Raised when the result itself was lost (potentially due to key eviction)

Source code in mognet/exceptions/result_exceptions.py
class ResultLost(ResultError):
    """
    Raised when the result itself was lost
    (potentially due to key eviction)
    """

    def __init__(self, result_id: UUID) -> None:
        super().__init__(result_id)
        self.result_id = result_id

    def __str__(self) -> str:
        return f"Result id={self.result_id!r} lost"

ResultValueLost (ResultError)

Raised when the value for a result was lost (potentially due to key eviction)

Source code in mognet/exceptions/result_exceptions.py
class ResultValueLost(ResultError):
    """
    Raised when the value for a result was lost
    (potentially due to key eviction)
    """

    def __init__(self, result_id: UUID) -> None:
        super().__init__(result_id)
        self.result_id = result_id

    def __str__(self) -> str:
        return f"Value for result id={self.result_id!r} lost"

Revoked (ResultFailed)

Raised when a task is revoked, either by timing out, or manual revoking.

Source code in mognet/exceptions/result_exceptions.py
class Revoked(ResultFailed):
    """Raised when a task is revoked, either by timing out, or manual revoking."""

    def __str__(self) -> str:
        return f"Result {self.result!r} was revoked"

task_exceptions

InvalidErrorInfo (BaseModel)

Information about a validation error

Source code in mognet/exceptions/task_exceptions.py
class InvalidErrorInfo(BaseModel):
    """Information about a validation error"""

    loc: Loc
    msg: str
    type: str

InvalidTaskArguments (Exception)

Raised when the arguments to a task could not be validated.

Source code in mognet/exceptions/task_exceptions.py
class InvalidTaskArguments(Exception):
    """
    Raised when the arguments to a task could not be validated.
    """

    def __init__(self, errors: List[InvalidErrorInfo]) -> None:
        super().__init__(errors)
        self.errors = errors

    @classmethod
    def from_validation_error(cls, validation_error: ValidationError):
        return cls([InvalidErrorInfo.model_validate(e) for e in validation_error.errors()])

Pause (Exception)

Tasks may raise this when they want to stop execution and have their message return to the Task Broker.

Once the message is retrieved again, task execution will resume.

Source code in mognet/exceptions/task_exceptions.py
class Pause(Exception):
    """
    Tasks may raise this when they want to stop
    execution and have their message return to the Task Broker.

    Once the message is retrieved again, task execution will resume.
    """

too_many_retries

TooManyRetries (MognetError)

Raised when a task is retried too many times due to unforeseen errors (e.g., SIGKILL).

The number of retries for any particular task can be configured through the App's configuration, in max_retries.

Source code in mognet/exceptions/too_many_retries.py
class TooManyRetries(MognetError):
    """
    Raised when a task is retried too many times due to unforeseen errors (e.g., SIGKILL).

    The number of retries for any particular task can be configured through the App's
    configuration, in `max_retries`.
    """

    def __init__(
        self,
        request_id: UUID,
        actual_retries: int,
        max_retries: int,
    ) -> None:
        super().__init__(request_id, actual_retries, max_retries)

        self.request_id = request_id
        self.max_retries = max_retries
        self.actual_retries = actual_retries

    def __str__(self) -> str:
        return f"Task id={self.request_id!r} has been retried {self.actual_retries!r} times, which is more than the limit of {self.max_retries!r}"

middleware special

middleware

Middleware (Protocol)

Defines middleware that can hook into different parts of a Mognet App's lifecycle.

Source code in mognet/middleware/middleware.py
class Middleware(Protocol):
    """
    Defines middleware that can hook into different parts of a Mognet App's lifecycle.
    """

    async def on_app_starting(self, app: "App") -> None:
        """
        Called when the app is starting, but before it starts connecting to the backends.

        For example, you can use this for some early initialization of singleton-type objects in your app.
        """

    async def on_app_started(self, app: "App") -> None:
        """
        Called when the app has started.

        For example, you can use this for some early initialization of singleton-type objects in your app.
        """

    async def on_app_stopping(self, app: "App") -> None:
        """
        Called when the app is preparing to stop, but before it starts disconnecting.

        For example, you can use this for cleaning up objects that were previously set up.
        """

    async def on_app_stopped(self, app: "App") -> None:
        """
        Called when the app has stopped.

        For example, you can use this for cleaning up objects that were previously set up.
        """

    async def on_task_starting(self, context: "Context"):
        """
        Called when a task is starting.

        You can use this, for example, to track a task on a database.
        """

    async def on_task_completed(
        self, result: "Result", context: Optional["Context"] = None
    ):
        """
        Called when a task has completed it's execution.

        You can use this, for example, to track a task on a database.
        """

    async def on_request_submitting(
        self, request: "Request", context: Optional["Context"] = None
    ):
        """
        Called when a Request object is going to be submitted to the Broker.

        You can use this, for example, both to track the task on a database, or to modify
        the Request object (e.g., to modify arguments, or set metadata).
        """

    async def on_running_task_count_changed(self, running_task_count: int):
        """
        Called when the Worker's task count changes.

        This can be used to determine when the Worker has nothing to do.
        """
on_app_started(self, app) async

Called when the app has started.

For example, you can use this for some early initialization of singleton-type objects in your app.

Source code in mognet/middleware/middleware.py
async def on_app_started(self, app: "App") -> None:
    """
    Called when the app has started.

    For example, you can use this for some early initialization of singleton-type objects in your app.
    """
on_app_starting(self, app) async

Called when the app is starting, but before it starts connecting to the backends.

For example, you can use this for some early initialization of singleton-type objects in your app.

Source code in mognet/middleware/middleware.py
async def on_app_starting(self, app: "App") -> None:
    """
    Called when the app is starting, but before it starts connecting to the backends.

    For example, you can use this for some early initialization of singleton-type objects in your app.
    """
on_app_stopped(self, app) async

Called when the app has stopped.

For example, you can use this for cleaning up objects that were previously set up.

Source code in mognet/middleware/middleware.py
async def on_app_stopped(self, app: "App") -> None:
    """
    Called when the app has stopped.

    For example, you can use this for cleaning up objects that were previously set up.
    """
on_app_stopping(self, app) async

Called when the app is preparing to stop, but before it starts disconnecting.

For example, you can use this for cleaning up objects that were previously set up.

Source code in mognet/middleware/middleware.py
async def on_app_stopping(self, app: "App") -> None:
    """
    Called when the app is preparing to stop, but before it starts disconnecting.

    For example, you can use this for cleaning up objects that were previously set up.
    """
on_request_submitting(self, request, context=None) async

Called when a Request object is going to be submitted to the Broker.

You can use this, for example, both to track the task on a database, or to modify the Request object (e.g., to modify arguments, or set metadata).

Source code in mognet/middleware/middleware.py
async def on_request_submitting(
    self, request: "Request", context: Optional["Context"] = None
):
    """
    Called when a Request object is going to be submitted to the Broker.

    You can use this, for example, both to track the task on a database, or to modify
    the Request object (e.g., to modify arguments, or set metadata).
    """
on_running_task_count_changed(self, running_task_count) async

Called when the Worker's task count changes.

This can be used to determine when the Worker has nothing to do.

Source code in mognet/middleware/middleware.py
async def on_running_task_count_changed(self, running_task_count: int):
    """
    Called when the Worker's task count changes.

    This can be used to determine when the Worker has nothing to do.
    """
on_task_completed(self, result, context=None) async

Called when a task has completed it's execution.

You can use this, for example, to track a task on a database.

Source code in mognet/middleware/middleware.py
async def on_task_completed(
    self, result: "Result", context: Optional["Context"] = None
):
    """
    Called when a task has completed it's execution.

    You can use this, for example, to track a task on a database.
    """
on_task_starting(self, context) async

Called when a task is starting.

You can use this, for example, to track a task on a database.

Source code in mognet/middleware/middleware.py
async def on_task_starting(self, context: "Context"):
    """
    Called when a task is starting.

    You can use this, for example, to track a task on a database.
    """

model special

result

Result (BaseModel)

Represents the result of executing a Request.

It contains, along the return value (or raised exception), information on the resulting state, how many times it started, and timing information.

Source code in mognet/model/result.py
class Result(BaseModel):
    """
    Represents the result of executing a [`Request`][mognet.Request].

    It contains, along the return value (or raised exception),
    information on the resulting state, how many times it started,
    and timing information.
    """

    id: UUID
    name: Optional[str] = None
    state: ResultState = ResultState.PENDING

    number_of_starts: int = 0
    number_of_stops: int = 0

    parent_id: Optional[UUID] = None

    created: Optional[datetime] = None
    started: Optional[datetime] = None
    finished: Optional[datetime] = None

    node_id: Optional[str] = None

    request_kwargs_repr: Optional[str] = None

    _backend: "BaseResultBackend" = PrivateAttr()
    _children: Optional[ResultChildren] = PrivateAttr()
    _value: Optional[ResultValue] = PrivateAttr()

    def __init__(self, backend: "BaseResultBackend", **data) -> None:
        super().__init__(**data)
        self._backend = backend
        self._children = None
        self._value = None

    @property
    def children(self) -> ResultChildren:
        """Get an iterator on the children of this Result. Non-recursive."""

        if self._children is None:
            self._children = ResultChildren(self, self._backend)

        return self._children

    @property
    def value(self) -> ResultValue:
        """Get information about the value of this Result"""
        if self._value is None:
            self._value = ResultValue(self, self._backend)

        return self._value

    @property
    def duration(self) -> Optional[timedelta]:
        """
        Returns the time it took to complete this result.

        Returns None if the result did not start or finish.
        """
        if not self.started or not self.finished:
            return None

        return self.finished - self.started

    @property
    def queue_time(self) -> Optional[timedelta]:
        """
        Returns the time it took to start the task associated to this result.

        Returns None if the task did not start.
        """
        if not self.created or not self.started:
            return None

        return self.started - self.created

    @property
    def done(self):
        """
        True if the result is in a terminal state (e.g., SUCCESS, FAILURE).
        See `READY_STATES`.
        """
        return self.state in READY_STATES

    @property
    def successful(self):
        """True if the result was successful."""
        return self.state in SUCCESS_STATES

    @property
    def failed(self):
        """True if the result failed or was revoked."""
        return self.state in ERROR_STATES

    @property
    def revoked(self):
        """True if the result was revoked."""
        return self.state == ResultState.REVOKED

    @property
    def unexpected_retry_count(self) -> int:
        """
        Return the number of times the task associated with this result was retried
        as a result of an unexpected error, such as a SIGKILL.
        """
        return max(0, self.number_of_starts - self.number_of_stops)

    async def wait(self, *, timeout: Optional[float] = None, poll: float = 0.1) -> None:
        """Wait for the task associated with this result to finish."""
        updated_result = await self._backend.wait(self.id, timeout=timeout, poll=poll)

        await self._refresh(updated_result)

    async def revoke(self) -> "Result":
        """
        Revoke this Result.

        This shouldn't be called directly, use the method on the App class instead,
        as that will also revoke the children, recursively.
        """
        self.state = ResultState.REVOKED
        self.number_of_stops += 1
        self.finished = now_utc()
        await self._backend.set(self.id, self)
        return self

    async def get(self) -> Any:
        """
        Gets the value of this `Result` instance.

        Raises `ResultNotReady` if it's not ready yet.
        Raises any stored exception if the result failed

        Returns the stored value otherwise.

        Use `value.get_raw_value()` if you want access to the raw value.
        Call `wait` to wait for the value to be available.

        Optionally, `await` the result instance.
        """

        if not self.done:
            raise ResultNotReady()

        value = await self.value.get_raw_value()

        if self.state == ResultState.REVOKED:
            raise Revoked(self)

        if self.failed:
            if value is None:
                value = ResultFailed(self)

            # Re-hydrate exceptions.
            if isinstance(value, _ExceptionInfo):
                raise value.exception

            if not isinstance(value, BaseException):
                value = Exception(value)

            raise value

        return value

    async def set_result(
        self,
        value: Any,
        state: ResultState = ResultState.SUCCESS,
    ) -> "Result":
        """
        Set this Result to a success state, and store the value
        which will be return when one `get()`s this Result's value.
        """
        await self.value.set_raw_value(value)

        self.finished = now_utc()

        self.state = state
        self.number_of_stops += 1

        await self._update()

        return self

    async def set_error(
        self,
        exc: BaseException,
        state: ResultState = ResultState.FAILURE,
    ) -> "Result":
        """
        Set this Result to an error state, and store the exception
        which will be raised if one attempts to `get()` this Result's
        value.
        """

        _log.debug("Setting result id=%r to %r", self.id, state)

        await self.value.set_raw_value(exc)

        self.finished = now_utc()

        self.state = state
        self.number_of_stops += 1

        await self._update()

        return self

    async def start(self, *, node_id: Optional[str] = None) -> "Result":
        """
        Sets this `Result` as RUNNING, and logs the event.
        """
        self.started = now_utc()
        self.node_id = node_id

        self.state = ResultState.RUNNING
        self.number_of_starts += 1

        await self._update()

        return self

    async def resume(self, *, node_id: Optional[str] = None) -> "Result":
        if node_id is not None:
            self.node_id = node_id

        self.state = ResultState.RUNNING
        self.number_of_starts += 1

        await self._update()

        return self

    async def suspend(self) -> "Result":
        """
        Sets this `Result` as SUSPENDED, and logs the event.
        """

        self.state = ResultState.SUSPENDED
        self.number_of_stops += 1

        await self._update()

        return self

    async def tree(self, max_depth: int = 3, max_width: int = 500) -> "ResultTree":
        """
        Gets the tree of this result.

        :param max_depth: The maximum depth of the tree that's to be generated.
            This filters out results whose recursion levels are greater than it.
        """
        from .result_tree import ResultTree

        async def get_tree(result: Result, depth=1):
            _log.debug(
                "Getting tree of result id=%r, depth=%r max_depth=%r",
                result.id,
                depth,
                max_depth,
            )

            node = ResultTree(result=result, children=[])

            if depth >= max_depth and (await result.children.count()):
                _log.warning(
                    "Result id=%r has %r or more levels of children, which is more than the limit of %r. Results will be truncated",
                    result.id,
                    depth,
                    max_depth,
                )
                return node

            children_count = await result.children.count()
            if children_count > max_width:
                _log.warning(
                    "Result id=%r has %r children, which is more than the limit of %r. Results will be truncated",
                    result.id,
                    children_count,
                    max_width,
                )

            async for child in result.children.iter_instances(count=max_width):
                node.children.append(await get_tree(child, depth=depth + 1))

            node.children.sort(key=lambda r: r.result.created or now_utc())

            return node

        return await get_tree(self, depth=1)

    async def get_metadata(self) -> Dict[str, Any]:
        """Get the metadata associated with this Result."""
        return await self._backend.get_metadata(self.id)

    async def set_metadata(self, **kwargs: Any) -> None:
        """Set metadata on this Result."""
        await self._backend.set_metadata(self.id, **kwargs)

    async def _refresh(self, updated_result: Optional["Result"] = None):
        updated_result = updated_result or await self._backend.get(self.id)

        if updated_result is None:
            raise RuntimeError("Result no longer present")

        for k, v in updated_result.__dict__.items():
            if k == "id":
                continue

            setattr(self, k, v)

    async def _update(self):
        await self._backend.set(self.id, self)

    def __repr__(self):
        v = f"Result[{self.name or 'unknown'}, id={self.id!r}, state={self.state!r}]"

        if self.request_kwargs_repr is not None:
            v += f"({self.request_kwargs_repr})"

        return v

    # Implemented for asyncio's `await` functionality.
    def __hash__(self) -> int:
        return hash(f"Result_{self.id}")

    def __await__(self):
        yield from self.wait().__await__()
        value = yield from self.get().__await__()
        return value

    async def delete(self, include_children: bool = True):
        """
        Delete this Result from the backend.

        By default, this will delete children too.
        """
        await self._backend.delete(self.id, include_children=include_children)

    async def set_ttl(self, ttl: timedelta, include_children: bool = True):
        """
        Set TTL on this Result.

        By default, this will set it on the children too.
        """
        await self._backend.set_ttl(self.id, ttl, include_children=include_children)
children: ResultChildren property readonly

Get an iterator on the children of this Result. Non-recursive.

done property readonly

True if the result is in a terminal state (e.g., SUCCESS, FAILURE). See READY_STATES.

duration: Optional[datetime.timedelta] property readonly

Returns the time it took to complete this result.

Returns None if the result did not start or finish.

failed property readonly

True if the result failed or was revoked.

queue_time: Optional[datetime.timedelta] property readonly

Returns the time it took to start the task associated to this result.

Returns None if the task did not start.

revoked property readonly

True if the result was revoked.

successful property readonly

True if the result was successful.

unexpected_retry_count: int property readonly

Return the number of times the task associated with this result was retried as a result of an unexpected error, such as a SIGKILL.

value: ResultValue property readonly

Get information about the value of this Result

__hash__(self) special

Return hash(self).

Source code in mognet/model/result.py
def __hash__(self) -> int:
    return hash(f"Result_{self.id}")
delete(self, include_children=True) async

Delete this Result from the backend.

By default, this will delete children too.

Source code in mognet/model/result.py
async def delete(self, include_children: bool = True):
    """
    Delete this Result from the backend.

    By default, this will delete children too.
    """
    await self._backend.delete(self.id, include_children=include_children)
get(self) async

Gets the value of this Result instance.

Raises ResultNotReady if it's not ready yet. Raises any stored exception if the result failed

Returns the stored value otherwise.

Use value.get_raw_value() if you want access to the raw value. Call wait to wait for the value to be available.

Optionally, await the result instance.

Source code in mognet/model/result.py
async def get(self) -> Any:
    """
    Gets the value of this `Result` instance.

    Raises `ResultNotReady` if it's not ready yet.
    Raises any stored exception if the result failed

    Returns the stored value otherwise.

    Use `value.get_raw_value()` if you want access to the raw value.
    Call `wait` to wait for the value to be available.

    Optionally, `await` the result instance.
    """

    if not self.done:
        raise ResultNotReady()

    value = await self.value.get_raw_value()

    if self.state == ResultState.REVOKED:
        raise Revoked(self)

    if self.failed:
        if value is None:
            value = ResultFailed(self)

        # Re-hydrate exceptions.
        if isinstance(value, _ExceptionInfo):
            raise value.exception

        if not isinstance(value, BaseException):
            value = Exception(value)

        raise value

    return value
get_metadata(self) async

Get the metadata associated with this Result.

Source code in mognet/model/result.py
async def get_metadata(self) -> Dict[str, Any]:
    """Get the metadata associated with this Result."""
    return await self._backend.get_metadata(self.id)
model_post_init(/, self, context)

This function is meant to behave like a BaseModel method to initialise private attributes.

It takes context as an argument since that's what pydantic-core passes when calling it.

Parameters:

Name Type Description Default
self BaseModel

The BaseModel instance.

required
context Any

The context.

required
Source code in mognet/model/result.py
def init_private_attributes(self: BaseModel, context: Any, /) -> None:
    """This function is meant to behave like a BaseModel method to initialise private attributes.

    It takes context as an argument since that's what pydantic-core passes when calling it.

    Args:
        self: The BaseModel instance.
        context: The context.
    """
    if getattr(self, '__pydantic_private__', None) is None:
        pydantic_private = {}
        for name, private_attr in self.__private_attributes__.items():
            default = private_attr.get_default()
            if default is not PydanticUndefined:
                pydantic_private[name] = default
        object_setattr(self, '__pydantic_private__', pydantic_private)
revoke(self) async

Revoke this Result.

This shouldn't be called directly, use the method on the App class instead, as that will also revoke the children, recursively.

Source code in mognet/model/result.py
async def revoke(self) -> "Result":
    """
    Revoke this Result.

    This shouldn't be called directly, use the method on the App class instead,
    as that will also revoke the children, recursively.
    """
    self.state = ResultState.REVOKED
    self.number_of_stops += 1
    self.finished = now_utc()
    await self._backend.set(self.id, self)
    return self
set_error(self, exc, state='FAILURE') async

Set this Result to an error state, and store the exception which will be raised if one attempts to get() this Result's value.

Source code in mognet/model/result.py
async def set_error(
    self,
    exc: BaseException,
    state: ResultState = ResultState.FAILURE,
) -> "Result":
    """
    Set this Result to an error state, and store the exception
    which will be raised if one attempts to `get()` this Result's
    value.
    """

    _log.debug("Setting result id=%r to %r", self.id, state)

    await self.value.set_raw_value(exc)

    self.finished = now_utc()

    self.state = state
    self.number_of_stops += 1

    await self._update()

    return self
set_metadata(self, **kwargs) async

Set metadata on this Result.

Source code in mognet/model/result.py
async def set_metadata(self, **kwargs: Any) -> None:
    """Set metadata on this Result."""
    await self._backend.set_metadata(self.id, **kwargs)
set_result(self, value, state='SUCCESS') async

Set this Result to a success state, and store the value which will be return when one get()s this Result's value.

Source code in mognet/model/result.py
async def set_result(
    self,
    value: Any,
    state: ResultState = ResultState.SUCCESS,
) -> "Result":
    """
    Set this Result to a success state, and store the value
    which will be return when one `get()`s this Result's value.
    """
    await self.value.set_raw_value(value)

    self.finished = now_utc()

    self.state = state
    self.number_of_stops += 1

    await self._update()

    return self
set_ttl(self, ttl, include_children=True) async

Set TTL on this Result.

By default, this will set it on the children too.

Source code in mognet/model/result.py
async def set_ttl(self, ttl: timedelta, include_children: bool = True):
    """
    Set TTL on this Result.

    By default, this will set it on the children too.
    """
    await self._backend.set_ttl(self.id, ttl, include_children=include_children)
start(self, *, node_id=None) async

Sets this Result as RUNNING, and logs the event.

Source code in mognet/model/result.py
async def start(self, *, node_id: Optional[str] = None) -> "Result":
    """
    Sets this `Result` as RUNNING, and logs the event.
    """
    self.started = now_utc()
    self.node_id = node_id

    self.state = ResultState.RUNNING
    self.number_of_starts += 1

    await self._update()

    return self
suspend(self) async

Sets this Result as SUSPENDED, and logs the event.

Source code in mognet/model/result.py
async def suspend(self) -> "Result":
    """
    Sets this `Result` as SUSPENDED, and logs the event.
    """

    self.state = ResultState.SUSPENDED
    self.number_of_stops += 1

    await self._update()

    return self
tree(self, max_depth=3, max_width=500) async

Gets the tree of this result.

:param max_depth: The maximum depth of the tree that's to be generated. This filters out results whose recursion levels are greater than it.

Source code in mognet/model/result.py
async def tree(self, max_depth: int = 3, max_width: int = 500) -> "ResultTree":
    """
    Gets the tree of this result.

    :param max_depth: The maximum depth of the tree that's to be generated.
        This filters out results whose recursion levels are greater than it.
    """
    from .result_tree import ResultTree

    async def get_tree(result: Result, depth=1):
        _log.debug(
            "Getting tree of result id=%r, depth=%r max_depth=%r",
            result.id,
            depth,
            max_depth,
        )

        node = ResultTree(result=result, children=[])

        if depth >= max_depth and (await result.children.count()):
            _log.warning(
                "Result id=%r has %r or more levels of children, which is more than the limit of %r. Results will be truncated",
                result.id,
                depth,
                max_depth,
            )
            return node

        children_count = await result.children.count()
        if children_count > max_width:
            _log.warning(
                "Result id=%r has %r children, which is more than the limit of %r. Results will be truncated",
                result.id,
                children_count,
                max_width,
            )

        async for child in result.children.iter_instances(count=max_width):
            node.children.append(await get_tree(child, depth=depth + 1))

        node.children.sort(key=lambda r: r.result.created or now_utc())

        return node

    return await get_tree(self, depth=1)
wait(self, *, timeout=None, poll=0.1) async

Wait for the task associated with this result to finish.

Source code in mognet/model/result.py
async def wait(self, *, timeout: Optional[float] = None, poll: float = 0.1) -> None:
    """Wait for the task associated with this result to finish."""
    updated_result = await self._backend.wait(self.id, timeout=timeout, poll=poll)

    await self._refresh(updated_result)

ResultChildren

The children of a Result.

Source code in mognet/model/result.py
class ResultChildren:
    """The children of a Result."""

    def __init__(self, result: "Result", backend: "BaseResultBackend") -> None:
        self._result = result
        self._backend = backend

    async def count(self) -> int:
        """The number of children."""
        return await self._backend.get_children_count(self._result.id)

    def iter_ids(self, *, count: Optional[int] = None) -> AsyncGenerator[UUID, None]:
        """Iterate the IDs of the children, optionally limited to a set count."""
        return self._backend.iterate_children_ids(self._result.id, count=count)

    def iter_instances(
        self, *, count: Optional[int] = None
    ) -> AsyncGenerator["Result", None]:
        """Iterate the instances of the children, optionally limited to a set count."""
        return self._backend.iterate_children(self._result.id, count=count)

    async def add(self, *children_ids: UUID):
        """For internal use."""
        await self._backend.add_children(self._result.id, *children_ids)
add(self, *children_ids) async

For internal use.

Source code in mognet/model/result.py
async def add(self, *children_ids: UUID):
    """For internal use."""
    await self._backend.add_children(self._result.id, *children_ids)
count(self) async

The number of children.

Source code in mognet/model/result.py
async def count(self) -> int:
    """The number of children."""
    return await self._backend.get_children_count(self._result.id)
iter_ids(self, *, count=None)

Iterate the IDs of the children, optionally limited to a set count.

Source code in mognet/model/result.py
def iter_ids(self, *, count: Optional[int] = None) -> AsyncGenerator[UUID, None]:
    """Iterate the IDs of the children, optionally limited to a set count."""
    return self._backend.iterate_children_ids(self._result.id, count=count)
iter_instances(self, *, count=None)

Iterate the instances of the children, optionally limited to a set count.

Source code in mognet/model/result.py
def iter_instances(
    self, *, count: Optional[int] = None
) -> AsyncGenerator["Result", None]:
    """Iterate the instances of the children, optionally limited to a set count."""
    return self._backend.iterate_children(self._result.id, count=count)

ResultValue

Represents information about the value of a Result.

Source code in mognet/model/result.py
class ResultValue:
    """
    Represents information about the value of a Result.
    """

    def __init__(self, result: "Result", backend: "BaseResultBackend") -> None:
        self._result = result
        self._backend = backend

        self._value_holder: Optional[ResultValueHolder] = None

    async def get_value_holder(self) -> ResultValueHolder:
        if self._value_holder is None:
            self._value_holder = await self._backend.get_value(self._result.id)

        return self._value_holder

    async def get_raw_value(self) -> Any:
        """Get the value. In case this is an exception, it won't be raised."""
        holder = await self.get_value_holder()
        return holder.deserialize()

    async def set_raw_value(self, value: Any):
        if isinstance(value, BaseException):
            value = _ExceptionInfo.from_exception(value)

        holder = ResultValueHolder(raw_value=value, value_type=_serialize_name(value))
        await self._backend.set_value(self._result.id, holder)
get_raw_value(self) async

Get the value. In case this is an exception, it won't be raised.

Source code in mognet/model/result.py
async def get_raw_value(self) -> Any:
    """Get the value. In case this is an exception, it won't be raised."""
    holder = await self.get_value_holder()
    return holder.deserialize()

ResultValueHolder (BaseModel)

Holds information about the type of the Result's value, and the raw value itself.

Use deserialize() to parse the value according to the type.

Source code in mognet/model/result.py
class ResultValueHolder(BaseModel):
    """
    Holds information about the type of the Result's value, and the raw value itself.

    Use `deserialize()` to parse the value according to the type.
    """

    value_type: str
    raw_value: Any

    def deserialize(self) -> Any:
        if self.raw_value is None:
            return None

        if self.value_type is not None:
            cls = _get_attr(self.value_type)

            value = TypeAdapter(cls).validate_python(self.raw_value)
        else:
            value = self.raw_value

        return value

    @classmethod
    def not_ready(cls):
        """
        Creates a value holder which is not ready yet.
        """
        value = _ExceptionInfo.from_exception(ResultNotReady())
        return cls(value_type=_serialize_name(value), raw_value=value)
not_ready() classmethod

Creates a value holder which is not ready yet.

Source code in mognet/model/result.py
@classmethod
def not_ready(cls):
    """
    Creates a value holder which is not ready yet.
    """
    value = _ExceptionInfo.from_exception(ResultNotReady())
    return cls(value_type=_serialize_name(value), raw_value=value)

result_state

ResultState (str, Enum)

States that a task execution, and its result, can be in.

Source code in mognet/model/result_state.py
class ResultState(str, Enum):
    """
    States that a task execution, and its result, can be in.
    """

    # The task associated with this result has not yet started.
    PENDING = "PENDING"

    # The task associated with this result is currently running.
    RUNNING = "RUNNING"

    # The task associated with this result was suspended, either because
    # it yielded subtasks, or because the worker it was on was shut down
    # gracefully.
    SUSPENDED = "SUSPENDED"

    # The task associated with this result finished successfully.
    SUCCESS = "SUCCESS"

    # The task associated with this result failed.
    FAILURE = "FAILURE"

    # The task associated with this result was aborted.
    REVOKED = "REVOKED"

    # Invalid task
    INVALID = "INVALID"

    def __repr__(self):
        return f"{self.name!r}"

primitives special

request

Request (BaseModel, Generic)

Represents the Mognet request.

Source code in mognet/primitives/request.py
class Request(BaseModel, Generic[TReturn]):
    """
    Represents the Mognet request.

    """
    id: UUID = Field(default_factory=uuid4)
    name: str

    args: tuple = ()
    kwargs: Dict[str, Any] = Field(default_factory=dict)

    stack: List[UUID] = Field(default_factory=list)

    # Metadata that's going to be put in the Result associated
    # with this Request.
    metadata: Dict[str, Any] = Field(default_factory=dict)

    # Deadline to run this request.
    # If it's a datetime, the deadline will be computed based on the difference
    # to `datetime.now(tz=timezone.utc)`. If the deadline is already passed, the task
    # is discarded and marked as `REVOKED`.
    # If it's a timedelta, the task's coroutine will be given that long to run, after which
    # it will be cancelled and marked as `REVOKED`.
    # Note that, like with manual revoking, there is no guarantee that the timed out task will
    # actually stop running.
    deadline: Optional[Union[timedelta, datetime]] = None

    # Overrides the queue the message will be sent to.
    queue_name: Optional[str] = None

    # Allow setting a kwargs representation for debugging purposes.
    # This is stored on the corresponding Result.
    # If not set, it's set when the request is submitted.
    # Note that if there are arguments which contain sensitive data, this will leak their values,
    # so you are responsible for ensuring such values are censored.
    kwargs_repr: Optional[str] = None

    # Task priority. The higher the value, the higher the priority.
    priority: Priority = 5

    def __repr__(self):
        msg = f"{self.name}[id={self.id!r}]"

        if self.kwargs_repr is not None:
            msg += f"({self.kwargs_repr})"

        return msg

service special

class_service

ClassService (Generic)

Base class for object-based services retrieved through Context#get_service()

To get instances of a class-based service using Context#get_service(), pass the class itself, and an instance of the class will be returned.

Note that the instances are singletons.

The instances, when created, get access to the app's configuration.

Source code in mognet/service/class_service.py
class ClassService(Generic[_TReturn], metaclass=ABCMeta):
    """
    Base class for object-based services retrieved
    through Context#get_service()

    To get instances of a class-based service using
    Context#get_service(), pass the class itself,
    and an instance of the class will be returned.

    Note that the instances are *singletons*.

    The instances, when created, get access to the app's configuration.
    """

    def __init__(self, config: "AppConfig") -> None:
        self.config = config

    @abstractmethod
    def __call__(self, context: "Context", *args, **kwds) -> _TReturn:
        raise NotImplementedError

    def __enter__(self):
        return self

    def __exit__(self, exc_type, exc_val, exc_tb):
        self.close()

    def close(self):
        pass

    async def wait_closed(self):
        pass

state special

state

State

Represents state that can persist across task restarts.

Has facilities for getting, setting, and removing values.

The task's state is deleted when the task finishes.

Source code in mognet/state/state.py
class State:
    """
    Represents state that can persist across task restarts.

    Has facilities for getting, setting, and removing values.

    The task's state is deleted when the task finishes.
    """

    request_id: UUID

    def __init__(self, app: "App", request_id: UUID) -> None:
        self._app = app
        self.request_id = request_id

    @property
    def _backend(self):
        return self._app.state_backend

    async def get(self, key: str, default: Any = None) -> Any:
        """Get a value."""
        return await self._backend.get(self.request_id, key, default)

    async def set(self, key: str, value: Any):
        """Set a value."""
        return await self._backend.set(self.request_id, key, value)

    async def pop(self, key: str, default: Any = None) -> Any:
        """Delete a value from the state and return it's value."""
        return await self._backend.pop(self.request_id, key, default)

    async def clear(self):
        """Clear all values."""
        return await self._backend.clear(self.request_id)
clear(self) async

Clear all values.

Source code in mognet/state/state.py
async def clear(self):
    """Clear all values."""
    return await self._backend.clear(self.request_id)
get(self, key, default=None) async

Get a value.

Source code in mognet/state/state.py
async def get(self, key: str, default: Any = None) -> Any:
    """Get a value."""
    return await self._backend.get(self.request_id, key, default)
pop(self, key, default=None) async

Delete a value from the state and return it's value.

Source code in mognet/state/state.py
async def pop(self, key: str, default: Any = None) -> Any:
    """Delete a value from the state and return it's value."""
    return await self._backend.pop(self.request_id, key, default)
set(self, key, value) async

Set a value.

Source code in mognet/state/state.py
async def set(self, key: str, value: Any):
    """Set a value."""
    return await self._backend.set(self.request_id, key, value)

state_backend_config

RedisStateBackendSettings (BaseModel)

Configuration for the Redis State Backend

Source code in mognet/state/state_backend_config.py
class RedisStateBackendSettings(BaseModel):
    """Configuration for the Redis State Backend"""

    url: str = "redis://localhost:6379/"

    # How long each task's state should live for.
    state_ttl: int = 7200

    # Set the limit of connections on the Redis connection pool.
    # DANGER! Setting this to too low a value WILL cause issues opening connections!
    max_connections: Optional[int] = None

testing special

pytest_integration

create_app_fixture(app)

Create a Pytest fixture for a Mognet application.

Source code in mognet/testing/pytest_integration.py
def create_app_fixture(app: App):
    """Create a Pytest fixture for a Mognet application."""

    @pytest_asyncio.fixture
    async def app_fixture():
        async with app:
            start_task = asyncio.create_task(app.start())
            yield app
            await app.close()

            try:
                start_task.cancel()
                await start_task
            except BaseException:  # pylint: disable=broad-except
                pass

    return app_fixture

tools special

backports special

aioitertools

Backport of https://github.com/RedRoserade/aioitertools/blob/f86552753e626cb71a3a305b9ec890f97d771e6b/aioitertools/asyncio.py#L93

Should be upstreamed here: https://github.com/omnilib/aioitertools/pull/103

as_generated(iterables, *, return_exceptions=False)

Yield results from one or more async iterables, in the order they are produced. Like :func:as_completed, but for async iterators or generators instead of futures. Creates a separate task to drain each iterable, and a single queue for results. If return_exceptions is False, then any exception will be raised, and pending iterables and tasks will be cancelled, and async generators will be closed. If return_exceptions is True, any exceptions will be yielded as results, and execution will continue until all iterables have been fully consumed. Example:: async def generator(x): for i in range(x): yield i gen1 = generator(10) gen2 = generator(12) async for value in as_generated([gen1, gen2]): ... # intermixed values yielded from gen1 and gen2

Source code in mognet/tools/backports/aioitertools.py
async def as_generated(
    iterables: Iterable[AsyncIterable[T]],
    *,
    return_exceptions: bool = False,
) -> AsyncIterable[T]:
    """
    Yield results from one or more async iterables, in the order they are produced.
    Like :func:`as_completed`, but for async iterators or generators instead of futures.
    Creates a separate task to drain each iterable, and a single queue for results.
    If ``return_exceptions`` is ``False``, then any exception will be raised, and
    pending iterables and tasks will be cancelled, and async generators will be closed.
    If ``return_exceptions`` is ``True``, any exceptions will be yielded as results,
    and execution will continue until all iterables have been fully consumed.
    Example::
        async def generator(x):
            for i in range(x):
                yield i
        gen1 = generator(10)
        gen2 = generator(12)
        async for value in as_generated([gen1, gen2]):
            ...  # intermixed values yielded from gen1 and gen2
    """

    queue: asyncio.Queue[dict] = asyncio.Queue()

    tailer_count: int = 0

    async def tailer(iterable: AsyncIterable[T]) -> None:
        nonlocal tailer_count

        try:
            async for item in iterable:
                await queue.put({"value": item})
        except asyncio.CancelledError:
            if isinstance(iterable, AsyncGenerator):  # pragma:nocover
                with suppress(Exception):
                    await iterable.aclose()
            raise
        except Exception as exc:  # pylint: disable=broad-except
            await queue.put({"exception": exc})
        finally:
            tailer_count -= 1

            if tailer_count == 0:
                await queue.put({"done": True})

    tasks = [asyncio.ensure_future(tailer(iter)) for iter in iterables]

    if not tasks:
        # Nothing to do
        return

    tailer_count = len(tasks)

    try:
        while True:
            i = await queue.get()

            if "value" in i:
                yield i["value"]
            elif "exception" in i:
                if return_exceptions:
                    yield i["exception"]
                else:
                    raise i["exception"]
            elif "done" in i:
                break
    except (asyncio.CancelledError, GeneratorExit):
        pass
    finally:
        for task in tasks:
            if not task.done():
                task.cancel()

        for task in tasks:
            with suppress(asyncio.CancelledError):
                await task

kwargs_repr

format_kwargs_repr(args, kwargs, *, value_max_length=64)

Utility function to create an args + kwargs representation.

Source code in mognet/tools/kwargs_repr.py
def format_kwargs_repr(
    args: tuple,
    kwargs: dict,
    *,
    value_max_length: Optional[int] = 64,
) -> str:
    """Utility function to create an args + kwargs representation."""

    parts = []

    for arg in args:
        parts.append(_format_value(arg, max_length=value_max_length))

    for arg_name, arg_value in kwargs.items():
        parts.append(
            f"{arg_name}={_format_value(arg_value, max_length=value_max_length)}"
        )

    return ", ".join(parts)

retries

retryableasyncmethod(types, *, max_attempts, wait_timeout, lock=None, on_retry=None)

Decorator to wrap an async method and make it retryable.

Source code in mognet/tools/retries.py
def retryableasyncmethod(
    types: Tuple[Type[BaseException], ...],
    *,
    max_attempts: Union[int, str],
    wait_timeout: Union[float, str],
    lock: Union[asyncio.Lock, str] = None,
    on_retry: Union[Callable[[BaseException], Awaitable], str] = None,
):
    """
    Decorator to wrap an async method and make it retryable.
    """

    def make_retryable(func: _T) -> _T:
        if inspect.isasyncgenfunction(func):
            raise TypeError("Async generator functions are not supported")

        f: Any = cast(Any, func)

        @wraps(f)
        async def async_retryable_decorator(self, *args, **kwargs):
            last_exc = None

            retry = _noop
            if isinstance(on_retry, str):
                retry = getattr(self, on_retry)
            elif callable(on_retry):
                retry = on_retry

            attempts: int
            if isinstance(max_attempts, str):
                attempts = getattr(self, max_attempts)
            else:
                attempts = max_attempts

            timeout: float
            if isinstance(wait_timeout, str):
                timeout = getattr(self, wait_timeout)
            else:
                timeout = wait_timeout

            retry_lock = None
            if isinstance(lock, str):
                retry_lock = getattr(self, lock)
            elif lock is not None:
                retry_lock = lock

            # Use an exponential backoff, starting with 1s
            # and with a maximum of whatever was configured
            current_wait_timeout = min(1, timeout)

            for attempt in range(1, attempts + 1):
                try:
                    return await f(self, *args, **kwargs)
                except types as exc:
                    _log.error("Attempt %r/%r failed", attempt, attempts, exc_info=exc)
                    last_exc = exc

                _log.debug("Waiting %.2fs before next attempt", current_wait_timeout)

                await asyncio.sleep(current_wait_timeout)

                current_wait_timeout = min(current_wait_timeout * 2, timeout)

                if retry_lock is not None:
                    if retry_lock.locked():
                        _log.debug("Already retrying, possibly on another method")
                    else:
                        async with retry_lock:
                            _log.debug("Calling retry method")
                            await retry(last_exc)
                else:
                    await retry(last_exc)

            if last_exc is None:
                last_exc = Exception("All %r attempts failed" % attempts)

            raise last_exc

        return cast(_T, async_retryable_decorator)

    return make_retryable

worker special

worker

MessageCancellationAction (str, Enum)

An enumeration.

Source code in mognet/worker/worker.py
class MessageCancellationAction(str, Enum):
    NOTHING = "nothing"
    ACK = "ack"
    NACK = "nack"

Worker

Workers are responsible for running the fetch -> run -> store result loop, for the task queues that are configured.

Source code in mognet/worker/worker.py
class Worker:
    """
    Workers are responsible for running the fetch -> run -> store result
    loop, for the task queues that are configured.
    """

    running_tasks: Dict[UUID, "_RequestProcessorHolder"]

    # Set of tasks that are suspended
    _waiting_tasks: Set[UUID]

    app: "App"

    def __init__(
        self,
        *,
        app: "App",
        middleware: List["Middleware"] = None,
    ) -> None:
        self.app = app
        self.running_tasks = {}
        self._waiting_tasks = set()
        self._middleware = middleware or []

        self._current_prefetch = 1

        self._queue_consumption_tasks: List[AsyncGenerator] = []
        self._consume_task = None

    async def run(self):
        _log.debug("Starting worker")

        try:
            self.app.broker.add_connection_failed_callback(self._handle_connection_lost)

            await self.start_consuming()
        except asyncio.CancelledError:
            _log.debug("Stopping run")
            return
        except Exception as exc:  # pylint: disable=broad-except
            _log.error("Error during consumption", exc_info=exc)

    async def _handle_connection_lost(self, exc: BaseException = None):
        _log.error("Handling connection lost event, stopping all tasks", exc_info=exc)

        # No point in NACKing, because we have been disconnected
        await self._cancel_all_tasks(message_action=MessageCancellationAction.NOTHING)

    async def _cancel_all_tasks(self, *, message_action: MessageCancellationAction):
        all_req_ids = list(self.running_tasks)

        _log.debug("Cancelling all %r running tasks", len(all_req_ids))

        try:
            for req_id in all_req_ids:
                await self.cancel(req_id, message_action=message_action)

            await self._adjust_prefetch()
        finally:
            self._waiting_tasks.clear()

    async def stop_consuming(self):
        _log.debug("Closing queue consumption tasks")

        consumers = self._queue_consumption_tasks
        while consumers:
            consumer = consumers.pop(0)

            try:
                await asyncio.wait_for(consumer.aclose(), 5)
            except (asyncio.CancelledError, GeneratorExit, asyncio.TimeoutError):
                pass
            except Exception as consume_exc:  # pylint: disable=broad-except
                _log.debug("Error closing consumer", exc_info=consume_exc)

        consume_task = self._consume_task
        self._consume_task = None

        if consume_task is not None:
            _log.debug("Closing aggregation task")

            try:
                consume_task.cancel()
                await asyncio.wait_for(consume_task, 15)

                _log.debug("Closed consumption task")
            except (asyncio.CancelledError, asyncio.TimeoutError):
                pass
            except Exception as consume_err:  # pylint: disable=broad-except
                _log.error("Error shutting down consumer task", exc_info=consume_err)

    async def close(self):
        """
        Stops execution, cancelling all running tasks.
        """

        _log.debug("Closing worker")

        await self.stop_consuming()

        # Cancel and NACK all messages currently on this worker.
        await self._cancel_all_tasks(message_action=MessageCancellationAction.NACK)

        _log.debug("Closed worker")

    def _remove_running_task(self, req_id: UUID):
        fut = self.running_tasks.pop(req_id, None)

        asyncio.create_task(self._emit_running_task_count_change())

        return fut

    def _add_running_task(self, req_id: UUID, holder: "_RequestProcessorHolder"):
        self.running_tasks[req_id] = holder
        asyncio.create_task(self._emit_running_task_count_change())

    async def cancel(self, req_id: UUID, *, message_action: MessageCancellationAction):
        """
        Cancels, if any, the execution of a request.
        Whoever calls this method is responsible for updating the result on the backend
        accordingly.
        """
        fut = self._remove_running_task(req_id)

        if fut is None:
            _log.debug("Request id=%r is not running on this worker", req_id)
            return

        _log.info("Cancelling task %r", req_id)

        result = await self.app.result_backend.get_or_create(req_id)

        # Only suspend the result on the backend if it was running in our node.
        if (
            result.state == ResultState.RUNNING
            and result.node_id == self.app.config.node_id
        ):
            await asyncio.shield(result.suspend())
            _log.debug("Result for task %r suspended", req_id)

        _log.debug("Waiting for coroutine of task %r to finish", req_id)

        # Wait for the task to finish, this allows it to clean up.
        try:
            await asyncio.wait_for(fut.cancel(message_action=message_action), 15)
        except asyncio.TimeoutError:
            _log.warning(
                "Handler for task id=%r took longer than 15s to shut down", req_id
            )
        except Exception as exc:  # pylint: disable=broad-except
            _log.error(
                "Handler for task=%r failed while shutting down", req_id, exc_info=exc
            )

        _log.debug("Stopped handler of task id=%r", req_id)

    def _create_context(self, request: "Request") -> "Context":
        if not self.app.state_backend:
            raise RuntimeError("No state backend defined")

        return Context(
            self.app,
            request,
            State(self.app, request.id),
            self,
        )

    async def _run_request(self, req: Request) -> None:
        """
        Processes a request, validating it before running.
        """

        _log.debug("Received request %r", req)

        # Check if we're not trying to (re) start something which is already done,
        # for cases when a request is cancelled before it's started.
        # Even worse, check that we're not trying to start a request whose
        # result might have been evicted.
        result = await self.app.result_backend.get(req.id)

        if result is None:
            _log.error(
                "Attempting to run task %r, but it's result doesn't exist on the backend. Discarding",
                req,
            )
            await self.remove_suspended_task(req.id)
            return

        context = self._create_context(req)

        if result.done:
            _log.error(
                "Attempting to re-run task %r, when it's already done with state %r. Discarding",
                req,
                result.state,
            )
            return await asyncio.shield(self._on_complete(context, result))

        # Check if we should even start, because:
        # 1. We might be in a crash loop (if the process gets killed without cleanup, the numbers won't match),
        # 2. Infinite recursion
        # 3. We might be part of a parent request that was revoked (or doesn't exist)
        # 4. We might be too late.

        # 1. Too many starts
        retry_count = result.unexpected_retry_count

        if retry_count > 0:
            _log.warning(
                "Task %r has been retried %r times (max=%r)",
                req.id,
                retry_count,
                self.app.config.max_retries,
            )

        if retry_count > self.app.config.max_retries:
            _log.error(
                "Discarding task %r because it has exceeded the maximum retry count of %r",
                req,
                self.app.config.max_retries,
            )

            result = await result.set_error(
                TooManyRetries(req.id, retry_count, self.app.config.max_retries)
            )

            return await asyncio.shield(self._on_complete(context, result))

        if req.stack:

            # 2. Recursion
            if len(req.stack) > self.app.config.max_recursion:
                result = await result.set_error(RecursionError())
                return await asyncio.shield(self._on_complete(context, result))

            # 3. Parent task(s) aborted (or doesn't exist)
            for parent_id in reversed(req.stack):
                parent_result = await self.app.result_backend.get(parent_id)

                if parent_result is None:
                    result = await result.set_error(
                        Exception(f"Parent request id={parent_id} does not exist"),
                        state=ResultState.REVOKED,
                    )
                    return await asyncio.shield(self._on_complete(context, result))

                if parent_result.state == ResultState.REVOKED:
                    result = await result.set_error(
                        Exception(f"Parent request id={parent_result.id} was revoked"),
                        state=ResultState.REVOKED,
                    )
                    return await asyncio.shield(self._on_complete(context, result))

        # 4. Request arrived past the deadline.

        if isinstance(req.deadline, datetime):
            # One cannot compare naive and aware datetime,
            # so create equivalent datetime objects.
            now = datetime.now(tz=req.deadline.tzinfo)

            if req.deadline < now:
                _log.error(
                    "Request %r arrived too late. Deadline is %r, current date is %r. Marking it as REVOKED and discarding",
                    req,
                    req.deadline,
                    now,
                )
                result = await asyncio.shield(
                    result.set_error(asyncio.TimeoutError(), state=ResultState.REVOKED)
                )
                return await asyncio.shield(self._on_complete(context, result))

        # Get the function for the task. Fail if the task is not registered in
        # our app's context.
        try:
            task_function = self.app.task_registry.get_task_function(req.name)
        except UnknownTask as unknown_task:
            _log.error(
                "Request %r is for an unknown task: %r",
                req,
                req.name,
                exc_info=unknown_task,
            )
            result = await result.set_error(unknown_task, state=ResultState.INVALID)
            return await asyncio.shield(self._on_complete(context, result))

        # Mark this as running.
        await result.start(node_id=self.app.config.node_id)

        await asyncio.shield(self._on_starting(context))

        try:
            # Create a validated version of the function.
            # This not only does argument validation, but it also parses the values
            # into objects.

            validated_func = validate_call(
                task_function, config=_TaskFuncArgumentValidationConfig
            )

            if inspect.iscoroutinefunction(task_function):
                fut = validated_func(context, *req.args, **req.kwargs)
            else:
                _log.debug(
                    "Handler for task %r is not a coroutine function, running in the loop's default executor",
                    req.name,
                )

                # Run non-coroutine functions inside an executor.
                # This allows them to run without blocking the event loop
                # (providing the GIL does not block it either)
                fut = self.app.loop.run_in_executor(
                    None,
                    functools.partial(validated_func, context, *req.args, **req.kwargs),
                )

        except V2ValidationError as exc:
            _log.error(
                "Could not call task function %r because of a validation error",
                task_function,
                exc_info=exc,
            )

            invalid = InvalidTaskArguments.from_validation_error(exc)

            result = await asyncio.shield(
                result.set_error(invalid, state=ResultState.INVALID)
            )

            return await asyncio.shield(self._on_complete(context, result))

        if req.deadline is not None:
            if isinstance(req.deadline, datetime):
                # One cannot compare naive and aware datetime,
                # so create equivalent datetime objects.
                now = datetime.now(tz=req.deadline.tzinfo)
                timeout = (req.deadline - now).total_seconds()
            else:
                timeout = req.deadline.total_seconds()

            _log.debug("Applying %.2fs timeout to request %r", timeout, req)

            fut = asyncio.wait_for(fut, timeout=timeout)

        # Start executing.
        try:
            value = await fut

            if req.id in self.running_tasks:
                await asyncio.shield(result.set_result(value))

                _log.info(
                    "Request %r finished with status %r in %.2fs",
                    req,
                    result.state,
                    (result.duration or timedelta()).total_seconds(),
                )

                await asyncio.shield(self._on_complete(context, result))
        except Pause:
            _log.info(
                "Handler for %r requested to be paused. Suspending it on the Result Backend and NACKing the message",
                req,
            )

            holder = self.running_tasks.pop(req.id, None)
            if holder is not None:
                await asyncio.shield(result.suspend())
                await asyncio.shield(holder.message.nack())
        except asyncio.CancelledError:
            _log.debug("Handler for task %r cancelled", req)

            # Re-raise the cancellation, this will be caught in the parent function
            # and prevent ack/nack
            raise
        except V2ValidationError as exc: # will enter for sync functions here.
            _log.error(
                "Could not call task function %r because of a validation error",
                task_function,
                exc_info=exc,
            )

            invalid = InvalidTaskArguments.from_validation_error(exc)

            result = await asyncio.shield(
                result.set_error(invalid, state=ResultState.INVALID)
            )

            return await asyncio.shield(self._on_complete(context, result))
        except Exception as exc:  # pylint: disable=broad-except
            state = ResultState.FAILURE

            # The task's coroutine may raise `asyncio.TimeoutError` itself, so there's
            # no guarantee that the timeout we catch is actually related to the request's timeout.
            # So, this heuristic is not the best.
            # TODO: A way to improve it would be to double-check if the deadline itself is expired.
            if req.deadline is not None and isinstance(exc, asyncio.TimeoutError):
                state = ResultState.REVOKED

            if req.id in self.running_tasks:
                result = await asyncio.shield(result.set_error(exc, state=state))
                await asyncio.shield(self._on_complete(context, result))

            duration = result.duration

            if duration is not None:
                _log.error(
                    "Handler for task %r failed in %.2fs with state %r",
                    req,
                    duration.total_seconds(),
                    state,
                    exc_info=exc,
                )
            else:
                _log.error(
                    "Handler for task %r failed with state %r",
                    req,
                    state,
                    exc_info=exc,
                )

    async def _on_complete(self, context: "Context", result: Result):
        if result.done:
            await context.state.clear()

        await self.remove_suspended_task(context.request.id)

        for middleware in self._middleware:
            try:
                _log.debug("Calling 'on_task_completed' middleware: %r", middleware)
                await asyncio.shield(
                    middleware.on_task_completed(result, context=context)
                )
            except Exception as mw_exc:  # pylint: disable=broad-except
                _log.error("Middleware %r failed", middleware, exc_info=mw_exc)

    async def _on_starting(self, context: "Context"):
        _log.info("Starting task %r", context.request)

        for middleware in self._middleware:
            try:
                _log.debug("Calling 'on_task_starting' middleware: %r", middleware)
                await asyncio.shield(middleware.on_task_starting(context))
            except Exception as mw_exc:  # pylint: disable=broad-except
                _log.error("Middleware %r failed", middleware, exc_info=mw_exc)

    def _process_request_message(self, payload: IncomingMessagePayload) -> asyncio.Task:
        """
        Creates an asyncio.Task which will process the enclosed Request
        in the background.

        Returns said task, after adding completion handlers to it.
        """
        _log.debug("Parsing input of message id=%r as Request", payload.id)
        req = Request.model_validate(payload.payload)

        async def request_processor():
            try:
                await self._run_request(req)

                _log.debug("ACK message id=%r for request=%r", payload.id, req)
                await asyncio.shield(payload.ack())
            except asyncio.CancelledError:
                _log.debug("Cancelled execution of request=%r", req)
                return
            except Exception as exc:  # pylint: disable=broad-except
                _log.error(
                    "Fatal error processing request=%r, NAK message id=%r",
                    req,
                    payload.id,
                    exc_info=exc,
                )
                await asyncio.shield(payload.nack())

        def on_processing_done(fut: Future):
            self._remove_running_task(req.id)

            exc = fut.exception()

            if exc is not None and not fut.cancelled():
                _log.error("Fatal error processing %r", req, exc_info=exc)
            else:
                _log.debug("Processed %r successfully", req)

        task = asyncio.create_task(request_processor())
        task.add_done_callback(on_processing_done)

        holder = _RequestProcessorHolder(payload, req, task)

        self._add_running_task(req.id, holder)

        return task

    def start_consuming(self):
        if self._consume_task is not None:
            return self._consume_task

        self._consume_task = asyncio.create_task(self._start_consuming())

        return self._consume_task

    async def _start_consuming(self):

        queues = self.app.get_task_queue_names()

        _log.info("Going to consume %r queues", len(queues))

        try:
            await self._adjust_prefetch()

            for queue in queues:
                _log.info("Start consuming task queue=%r", queue)
                self._queue_consumption_tasks.append(
                    self.app.broker.consume_tasks(queue)
                )

            async for payload in as_generated(self._queue_consumption_tasks):
                try:
                    if payload.kind == "Request":
                        self._process_request_message(payload)
                    else:
                        raise ValueError(f"Unknown kind={payload.kind!r}")
                except asyncio.CancelledError:
                    break
                except Exception as exc:  # pylint: disable=broad-except
                    _log.error(
                        "Error processing message=%r, discarding it",
                        payload,
                        exc_info=exc,
                    )
                    await asyncio.shield(payload.ack())
        finally:
            _log.debug("Stopped consuming task queues")

    async def add_waiting_task(self, task_id: UUID):
        self._waiting_tasks.add(task_id)
        await self._adjust_prefetch()

    async def remove_suspended_task(self, task_id: UUID):
        try:
            self._waiting_tasks.remove(task_id)
        except KeyError:
            pass
        await self._adjust_prefetch()

    @property
    def waiting_task_count(self):
        return len(self._waiting_tasks)

    async def _emit_running_task_count_change(self):
        for middleware in self._middleware:
            try:
                _log.debug("Calling 'on_running_task_count_changed' on %r", middleware)
                await middleware.on_running_task_count_changed(len(self.running_tasks))
            except Exception as mw_exc:  # pylint: disable=broad-except
                _log.error(
                    "'on_running_task_count_changed' failed on %r",
                    middleware,
                    exc_info=mw_exc,
                )

    async def _adjust_prefetch(self):
        if self._consume_task is None:
            _log.debug("Not adjusting prefetch because not consuming the queue")
            return

        minimum_prefetch = self.app.config.minimum_concurrency

        prefetch = self.waiting_task_count + minimum_prefetch

        max_prefetch = self.app.config.maximum_concurrency

        if max_prefetch is not None and prefetch >= max_prefetch:
            _log.error(
                "Maximum prefetch value of %r reached! No more tasks will be fetched on this node",
                max_prefetch,
            )
            prefetch = max_prefetch

        if prefetch == self._current_prefetch:
            _log.debug(
                "Current prefetch is the same as the new prefetch (%r), not adjusting it",
                prefetch,
            )
            return

        _log.debug(
            "Currently have %r tasks suspended waiting for others. Setting prefetch=%r from previous=%r",
            self.waiting_task_count,
            prefetch,
            self._current_prefetch,
        )

        self._current_prefetch = prefetch
        await self.app.broker.set_task_prefetch(self._current_prefetch)
cancel(self, req_id, *, message_action) async

Cancels, if any, the execution of a request. Whoever calls this method is responsible for updating the result on the backend accordingly.

Source code in mognet/worker/worker.py
async def cancel(self, req_id: UUID, *, message_action: MessageCancellationAction):
    """
    Cancels, if any, the execution of a request.
    Whoever calls this method is responsible for updating the result on the backend
    accordingly.
    """
    fut = self._remove_running_task(req_id)

    if fut is None:
        _log.debug("Request id=%r is not running on this worker", req_id)
        return

    _log.info("Cancelling task %r", req_id)

    result = await self.app.result_backend.get_or_create(req_id)

    # Only suspend the result on the backend if it was running in our node.
    if (
        result.state == ResultState.RUNNING
        and result.node_id == self.app.config.node_id
    ):
        await asyncio.shield(result.suspend())
        _log.debug("Result for task %r suspended", req_id)

    _log.debug("Waiting for coroutine of task %r to finish", req_id)

    # Wait for the task to finish, this allows it to clean up.
    try:
        await asyncio.wait_for(fut.cancel(message_action=message_action), 15)
    except asyncio.TimeoutError:
        _log.warning(
            "Handler for task id=%r took longer than 15s to shut down", req_id
        )
    except Exception as exc:  # pylint: disable=broad-except
        _log.error(
            "Handler for task=%r failed while shutting down", req_id, exc_info=exc
        )

    _log.debug("Stopped handler of task id=%r", req_id)
close(self) async

Stops execution, cancelling all running tasks.

Source code in mognet/worker/worker.py
async def close(self):
    """
    Stops execution, cancelling all running tasks.
    """

    _log.debug("Closing worker")

    await self.stop_consuming()

    # Cancel and NACK all messages currently on this worker.
    await self._cancel_all_tasks(message_action=MessageCancellationAction.NACK)

    _log.debug("Closed worker")