API Reference
app
special
app
App
Represents the Mognet application.
You can use these objects to:
- Create and abort tasks
- Check the status of tasks
- Configure the middleware that runs on key lifecycle events of the app and its tasks
Source code in mognet/app/app.py
class App:
"""
Represents the Mognet application.
You can use these objects to:
- Create and abort tasks
- Check the status of tasks
- Configure the middleware that runs on key lifecycle events of the app and its tasks
"""
# Where results are stored.
result_backend: BaseResultBackend
# Where task state is stored.
# Task state is information that a task can save
# during its execution, in case, for example, it gets
# interrupted.
state_backend: BaseStateBackend
# Task broker.
broker: BaseBroker
# Mapping of [service name] -> dependency object,
# should be accessed via Context#get_service.
services: Dict[Any, Callable]
# Holds references to all the tasks.
task_registry: TaskRegistry
_connected: bool
# Configuration used to start this app.
config: "AppConfig"
# Worker running in this app instance.
worker: Optional[Worker]
# Background tasks spawned by this app.
_consume_control_task: Optional[Future] = None
_heartbeat_task: Optional[Future] = None
_worker_task: Optional[Future]
_middleware: List[Middleware]
_loop: asyncio.AbstractEventLoop
def __init__(
self,
*,
name: str,
config: "AppConfig",
) -> None:
self.name = name
self._connected = False
self.config = config
# Create the task registry and register it globally
reg = task_registry.get(None)
if reg is None:
reg = TaskRegistry()
reg.register_globally()
self.task_registry = reg
self._worker_task = None
self.services = {}
self._middleware = []
self._load_modules()
self.worker = None
# Event that gets set when the app is closed
self._run_result = None
def add_middleware(self, mw_inst: Middleware):
"""
Adds middleware to this app.
Middleware is called in the order of in which it was added
to the app.
"""
if mw_inst in self._middleware:
return
self._middleware.append(mw_inst)
async def start(self):
"""
Starts the app.
"""
_log.info("Starting app %r", self.config.node_id)
self._loop = asyncio.get_event_loop()
self._run_result = asyncio.Future()
self._log_tasks_and_queues()
await self._call_on_starting_middleware()
await self.connect()
self._heartbeat_task = asyncio.create_task(self._background_heartbeat())
self._consume_control_task = asyncio.create_task(self._consume_control_queue())
self.worker = Worker(app=self, middleware=self._middleware)
self._worker_task = asyncio.create_task(self.worker.run())
_log.info("Started")
await self._call_on_started_middleware()
return await self._run_result
async def get_current_status_of_nodes(
self,
) -> AsyncGenerator[StatusResponseMessage, None]:
"""
Query all nodes of this App and get their status.
"""
request = QueryRequestMessage(name="Status")
responses = self.broker.send_query_message(
payload=MessagePayload(
id=str(request.id),
kind="Query",
payload=request.model_dump(),
)
)
try:
async for response in responses:
try:
yield StatusResponseMessage.model_validate(response)
except asyncio.CancelledError:
break
except Exception as exc: # pylint: disable=broad-except
_log.error(
"Could not parse status response %r", response, exc_info=exc
)
finally:
await responses.aclose()
async def submit(self, req: "Request", context: Optional[Context] = None) -> Result:
"""
Submits a request for execution.
If a context is defined, it will be used to create a parent-child
relationship between the new-to-submit request and the one existing
in the context instance. This is later used to cancel the whole task tree.
"""
if not self.result_backend:
raise ImproperlyConfigured("Result backend not defined")
if not self.broker:
raise ImproperlyConfigured("Broker not connected")
try:
if req.kwargs_repr is None:
req.kwargs_repr = format_kwargs_repr(req.args, req.kwargs)
_log.debug("Set default kwargs_repr on Request %r", req)
res = Result(
self.result_backend,
id=req.id,
name=req.name,
state=ResultState.PENDING,
created=datetime.now(tz=timezone.utc),
request_kwargs_repr=req.kwargs_repr,
)
if context is not None:
# Set the parent-child relationship and update the request stack.
parent_request = context.request
res.parent_id = parent_request.id
req.stack = [*parent_request.stack, parent_request.id]
if res.parent_id is not None:
await self.result_backend.add_children(res.parent_id, req.id)
await self.result_backend.set(req.id, res)
# Store the metadata on the Result.
if req.metadata:
await res.set_metadata(**req.metadata)
await self._on_submitting(req, context=context)
payload = MessagePayload(
id=str(req.id),
kind="Request",
payload=req.model_dump(),
priority=req.priority,
)
_log.debug("Sending message %r", payload.id)
await self.broker.send_task_message(self._get_task_route(req), payload)
return res
except Exception as exc:
raise CouldNotSubmit(f"Could not submit {req!r}") from exc
def get_task_queue_names(self) -> Set[str]:
"""
Return the names of the queues that are going to be consumed,
after applying defaults, inclusions, and exclusions.
"""
all_queues = {*self.config.task_routes.values(), self.config.default_task_route}
_log.debug("All queues: %r", all_queues)
configured_queues = self.config.task_queues
configured_queues.ensure_valid()
if configured_queues.exclude:
_log.debug("Applying queue exclusions: %r", configured_queues.exclude)
return all_queues - configured_queues.exclude
if configured_queues.include:
_log.debug("Applying queue inclusions: %r", configured_queues.include)
return all_queues & configured_queues.include
_log.debug("No inclusions or exclusions applied")
return all_queues
@overload
def create_request(
self,
func: Callable[Concatenate["Context", _P], Awaitable[_Return]],
*args: _P.args,
**kwargs: _P.kwargs,
) -> Request[_Return]:
"""
Creates a Request object from the function that was decorated with @task,
and the provided arguments.
This overload is just to document async def function return values.
"""
...
@overload
def create_request(
self,
func: Callable[Concatenate["Context", _P], _Return],
*args: _P.args,
**kwargs: _P.kwargs,
) -> Request[_Return]:
"""
Creates a Request object from the function that was decorated with @task,
and the provided arguments.
This overload is just to document non-async def function return values.
"""
...
def create_request(
self,
func: Callable[Concatenate["Context", _P], Any],
*args: _P.args,
**kwargs: _P.kwargs,
) -> Request:
"""
Creates a Request object from the function that was decorated with @task,
and the provided arguments.
"""
return Request(
name=self.task_registry.get_task_name(cast(Any, func)),
args=args,
kwargs=kwargs,
)
@overload
async def run(
self,
request: Callable[Concatenate["Context", _P], Awaitable[_Return]],
*args: _P.args,
**kwargs: _P.kwargs,
) -> _Return:
"""
Short-hand method for creating a Request from a function decorated with `@task`,
(see `create_request`), submitting it (see `submit`) and waiting for the result (see `run(Request)`).
"""
...
@overload
async def run(
self, request: "Request[_Return]", context: Optional[Context] = None
) -> _Return:
"""
Runs the request and waits for the result.
Call `submit` if you just want to send a request
without waiting for the result.
"""
...
async def run(self, request, *args, **kwargs) -> Any:
if not isinstance(request, Request):
request = self.create_request(*args, **kwargs)
res = await self.submit(request, *args, **kwargs)
return await res
async def revoke(self, request_id: uuid.UUID, *, force: bool = False) -> Result:
"""
Revoke the execution of a request.
If the request is already completed, this method returns
the associated result as-is. Optionally, `force=True` may be set
in order to ignore the state check.
This will also revoke any request that's launched as a child of this one,
recursively.
Returns the cancelled result.
"""
res = await self.result_backend.get_or_create(request_id)
if not force and res.done:
_log.warning(
"Attempting to cancel result %r that's already done, this is a no-op",
res.id,
)
return res
_log.info("Revoking request id=%r", res)
await res.revoke()
payload = MessagePayload(
id=str(uuid.uuid4()),
kind=Revoke.MESSAGE_KIND,
payload=Revoke(id=request_id).model_dump(),
)
await self.broker.send_control_message(payload)
child_count = await res.children.count()
if child_count:
_log.info("Revoking %r children of id=%r", child_count, res.id)
# Abort children.
async for child_id in res.children.iter_ids():
await self.revoke(child_id, force=force)
return res
async def connect(self):
"""Connect this app and its components to their respective backends."""
if self._connected:
return
self.broker = self._create_broker()
self.result_backend = self._create_result_backend()
self.state_backend = self._create_state_backend()
self._connected = True
await self._setup_broker()
_log.debug("Connecting to result backend %s", self.result_backend)
await self.result_backend.connect()
_log.debug("Connected to result backend %s", self.result_backend)
_log.debug("Connecting to state backend %s", self.state_backend)
await self.state_backend.connect()
_log.debug("Connected to state backend %s", self.state_backend)
async def __aenter__(self):
await self.connect()
return self
async def __aexit__(self, *args, **kwargs):
await self.close()
@shield
async def close(self):
"""Close this app and its components's backends."""
_log.info("Closing app")
await asyncio.shield(self._stop())
if self._run_result and not self._run_result.done():
self._run_result.set_result(None)
_log.info("Closed app")
async def _stop(self):
await self._call_on_stopping_middleware()
if self._heartbeat_task is not None:
self._heartbeat_task.cancel()
try:
await self._heartbeat_task
except BaseException: # pylint: disable=broad-except
pass
self._heartbeat_task = None
_log.debug("Closing queue listeners")
if self._consume_control_task:
self._consume_control_task.cancel()
try:
await self._consume_control_task
except asyncio.CancelledError:
pass
except Exception as exc: # pylint: disable=broad-except
_log.debug("Error shutting down control consumption task", exc_info=exc)
self._consume_control_task = None
# Disconnect from the broker, this should NACK
# all pending messages too.
_log.debug("Closing broker connection")
if self.broker:
await self.broker.close()
# Stop the worker
await self._stop_worker()
# Remove service instances
for svc in self.services:
if isinstance(svc, ClassService):
try:
svc.close()
await svc.wait_closed()
except Exception as exc: # pylint: disable=broad-except
_log.error("Error closing service %r", svc, exc_info=exc)
self.services.clear()
# Finally, shut down the state and result backends.
_log.debug("Closing backends")
if self.result_backend:
try:
await self.result_backend.close()
except Exception as exc: # pylint: disable=broad-except
_log.error("Error closing result backend", exc_info=exc)
if self.state_backend:
try:
await self.state_backend.close()
except Exception as exc: # pylint: disable=broad-except
_log.error("Error closing state backend", exc_info=exc)
self._connected = False
await self._call_on_stopped_middleware()
async def purge_task_queues(self) -> Dict[str, int]:
"""
Purge all known task queues.
Returns a dict where the keys are the names of the queues,
and the values are the number of messages purged.
"""
deleted_per_queue = {}
for queue in self.get_task_queue_names():
_log.info("Purging task queue=%r", queue)
deleted_per_queue[queue] = await self.broker.purge_task_queue(queue)
return deleted_per_queue
async def purge_control_queue(self) -> int:
"""
Purges the control queue related to this app.
Returns the number of messages purged.
"""
return await self.broker.purge_control_queue()
@property
def loop(self) -> asyncio.AbstractEventLoop:
return self._loop
def _create_broker(self) -> BaseBroker:
return AmqpBroker(config=self.config.broker, app=self)
def _create_result_backend(self) -> BaseResultBackend:
return RedisResultBackend(self.config.result_backend, app=self)
def _create_state_backend(self) -> BaseStateBackend:
return RedisStateBackend(self.config.state_backend, app=self)
def _load_modules(self):
for module in self.config.imports:
importlib.import_module(module)
def _log_tasks_and_queues(self):
all_tasks = self.task_registry.registered_task_names
tasks_msg = "\n".join(
f"\t - {t!r} (queue={self._get_task_route(t)!r})" for t in all_tasks
)
_log.info("Registered %r tasks:\n%s", len(all_tasks), tasks_msg)
all_queues = self.get_task_queue_names()
queues_msg = "\n".join(f"\t - {q!r}" for q in all_queues)
_log.info("Registered %r queues:\n%s", len(all_queues), queues_msg)
async def _setup_broker(self):
_log.debug("Connecting to broker %s", self.broker)
await self.broker.connect()
_log.debug("Connected to broker %r", self.broker)
_log.debug("Setting up task queues")
for queue_name in self.get_task_queue_names():
await self.broker.setup_task_queue(TaskQueue(name=queue_name))
_log.debug("Setup queues")
async def _stop_worker(self):
if self.worker is None or not self._worker_task:
_log.debug("No worker running")
return
try:
_log.debug("Closing worker")
await self.worker.close()
if self._worker_task is not None:
self._worker_task.cancel()
await self._worker_task
_log.debug("Worker closed")
except asyncio.CancelledError:
pass
except Exception as worker_close_exc: # pylint: disable=broad-except
_log.error(
"Worker raised an exception while closing", exc_info=worker_close_exc
)
finally:
self.worker = None
self._worker_task = None
async def _background_heartbeat(self):
"""
Background task that checks if the event loop was blocked
for too long.
A crude check, it asyncio.sleep()s and checks if the time difference
before and after sleeping is significantly higher. This could bring problems,
for example, with task brokers, that may need to send periodic keep-alive messages
to the broker in order to prevent connection drops.
Error messages are logged in case the event loop got blocked for too long.
"""
while True:
current_ts = self.loop.time()
await asyncio.sleep(5)
next_ts = self.loop.time()
diff = next_ts - current_ts
if diff > 10:
_log.error(
"Event loop seemed blocked for %.2fs (>10s), this could bring issues. Consider using asyncio.run_in_executor to run CPU-bound work",
diff,
)
else:
_log.debug("Event loop heartbeat: %.2fs", diff)
async def _consume_control_queue(self):
"""
Reads messages from the control queue and dispatches them.
"""
await self.broker.setup_control_queue()
_log.debug("Listening on the control queue")
async for msg in self.broker.consume_control_queue():
try:
await self._process_control_message(msg)
except asyncio.CancelledError:
break
except Exception as exc: # pylint: disable=broad-except
_log.error(
"Could not process control queue message %r", msg, exc_info=exc
)
async def _process_control_message(self, msg: IncomingMessagePayload):
_log.debug("Received control message id=%r", msg.id)
try:
if msg.kind == Revoke.MESSAGE_KIND:
abort = Revoke.model_validate(msg.payload)
_log.debug("Received request to revoke request id=%r", abort.id)
if self.worker is None:
_log.debug("No worker running. Discarding revoke message.")
return
try:
# Cancel the task's execution and ACK it on the broker
# to prevent it from re-running.
await self.worker.cancel(
abort.id, message_action=MessageCancellationAction.ACK
)
except Exception as exc: # pylint: disable=broad-except
_log.error(
"Error while cancelling request id=%r", abort.id, exc_info=exc
)
return
if msg.kind == "Query":
query = QueryRequestMessage.model_validate(msg.payload)
if query.name == "Status":
# Get the status of this worker and reply to the incoming message
if self.worker is None:
_log.debug("No worker running for Status query")
running_request_ids = []
else:
running_request_ids = list(self.worker.running_tasks.keys())
reply = StatusResponseMessage(
node_id=self.config.node_id,
payload=StatusResponseMessage.Status(
running_request_ids=running_request_ids,
),
)
payload = MessagePayload(
id=str(reply.id), kind=reply.kind, payload=reply.model_dump()
)
return await self.broker.send_reply(msg, payload)
_log.warning("Unknown query name=%r, discarding", query.name)
return
_log.warning("Unknown message kind=%r, discarding", msg.kind)
finally:
await msg.ack()
async def _on_submitting(self, req: "Request", context: Optional["Context"]):
for mw_inst in self._middleware:
try:
await mw_inst.on_request_submitting(req, context=context)
except Exception as mw_exc: # pylint: disable=broad-except
_log.error("Middleware failed", exc_info=mw_exc)
def _get_task_route(self, req: Union[str, Request]):
if isinstance(req, Request):
if req.queue_name is not None:
_log.debug(
"Request %r has a queue override to route to queue=%r",
req,
req.queue_name,
)
return req.queue_name
req = req.name
route = self.config.task_routes.get(req)
if route is not None:
_log.debug(
"Request %r has a config-set route to queue=%r",
req,
route,
)
return route
default_queue = self.config.default_task_route
_log.debug(
"Request %r has no route set, falling back to default queue=%r",
req,
default_queue,
)
return default_queue
async def _call_on_starting_middleware(self):
for mw in self._middleware:
try:
await mw.on_app_starting(self)
except Exception as exc: # pylint: disable=broad-except
_log.debug(
"Middleware %r failed on 'on_app_starting'", mw, exc_info=exc
)
async def _call_on_started_middleware(self):
for mw in self._middleware:
try:
await mw.on_app_started(self)
except Exception as exc: # pylint: disable=broad-except
_log.debug("Middleware %r failed on 'on_app_started'", mw, exc_info=exc)
async def _call_on_stopping_middleware(self):
for mw in self._middleware:
try:
await mw.on_app_stopping(self)
except Exception as exc: # pylint: disable=broad-except
_log.debug(
"Middleware %r failed on 'on_app_stopping'", mw, exc_info=exc
)
async def _call_on_stopped_middleware(self):
for mw in self._middleware:
try:
await mw.on_app_stopped(self)
except Exception as exc: # pylint: disable=broad-except
_log.debug("Middleware %r failed on 'on_app_stopped'", mw, exc_info=exc)
add_middleware(self, mw_inst)
Adds middleware to this app.
Middleware is called in the order of in which it was added to the app.
close(self)
async
Close this app and its components's backends.
connect(self)
async
Connect this app and its components to their respective backends.
Source code in mognet/app/app.py
async def connect(self):
"""Connect this app and its components to their respective backends."""
if self._connected:
return
self.broker = self._create_broker()
self.result_backend = self._create_result_backend()
self.state_backend = self._create_state_backend()
self._connected = True
await self._setup_broker()
_log.debug("Connecting to result backend %s", self.result_backend)
await self.result_backend.connect()
_log.debug("Connected to result backend %s", self.result_backend)
_log.debug("Connecting to state backend %s", self.state_backend)
await self.state_backend.connect()
_log.debug("Connected to state backend %s", self.state_backend)
create_request(self, func, *args, **kwargs)
Creates a Request object from the function that was decorated with @task, and the provided arguments.
Source code in mognet/app/app.py
def create_request(
self,
func: Callable[Concatenate["Context", _P], Any],
*args: _P.args,
**kwargs: _P.kwargs,
) -> Request:
"""
Creates a Request object from the function that was decorated with @task,
and the provided arguments.
"""
return Request(
name=self.task_registry.get_task_name(cast(Any, func)),
args=args,
kwargs=kwargs,
)
get_current_status_of_nodes(self)
Query all nodes of this App and get their status.
Source code in mognet/app/app.py
async def get_current_status_of_nodes(
self,
) -> AsyncGenerator[StatusResponseMessage, None]:
"""
Query all nodes of this App and get their status.
"""
request = QueryRequestMessage(name="Status")
responses = self.broker.send_query_message(
payload=MessagePayload(
id=str(request.id),
kind="Query",
payload=request.model_dump(),
)
)
try:
async for response in responses:
try:
yield StatusResponseMessage.model_validate(response)
except asyncio.CancelledError:
break
except Exception as exc: # pylint: disable=broad-except
_log.error(
"Could not parse status response %r", response, exc_info=exc
)
finally:
await responses.aclose()
get_task_queue_names(self)
Return the names of the queues that are going to be consumed, after applying defaults, inclusions, and exclusions.
Source code in mognet/app/app.py
def get_task_queue_names(self) -> Set[str]:
"""
Return the names of the queues that are going to be consumed,
after applying defaults, inclusions, and exclusions.
"""
all_queues = {*self.config.task_routes.values(), self.config.default_task_route}
_log.debug("All queues: %r", all_queues)
configured_queues = self.config.task_queues
configured_queues.ensure_valid()
if configured_queues.exclude:
_log.debug("Applying queue exclusions: %r", configured_queues.exclude)
return all_queues - configured_queues.exclude
if configured_queues.include:
_log.debug("Applying queue inclusions: %r", configured_queues.include)
return all_queues & configured_queues.include
_log.debug("No inclusions or exclusions applied")
return all_queues
purge_control_queue(self)
async
Purges the control queue related to this app.
Returns the number of messages purged.
purge_task_queues(self)
async
Purge all known task queues.
Returns a dict where the keys are the names of the queues, and the values are the number of messages purged.
Source code in mognet/app/app.py
async def purge_task_queues(self) -> Dict[str, int]:
"""
Purge all known task queues.
Returns a dict where the keys are the names of the queues,
and the values are the number of messages purged.
"""
deleted_per_queue = {}
for queue in self.get_task_queue_names():
_log.info("Purging task queue=%r", queue)
deleted_per_queue[queue] = await self.broker.purge_task_queue(queue)
return deleted_per_queue
revoke(self, request_id, *, force=False)
async
Revoke the execution of a request.
If the request is already completed, this method returns
the associated result as-is. Optionally, force=True
may be set
in order to ignore the state check.
This will also revoke any request that's launched as a child of this one, recursively.
Returns the cancelled result.
Source code in mognet/app/app.py
async def revoke(self, request_id: uuid.UUID, *, force: bool = False) -> Result:
"""
Revoke the execution of a request.
If the request is already completed, this method returns
the associated result as-is. Optionally, `force=True` may be set
in order to ignore the state check.
This will also revoke any request that's launched as a child of this one,
recursively.
Returns the cancelled result.
"""
res = await self.result_backend.get_or_create(request_id)
if not force and res.done:
_log.warning(
"Attempting to cancel result %r that's already done, this is a no-op",
res.id,
)
return res
_log.info("Revoking request id=%r", res)
await res.revoke()
payload = MessagePayload(
id=str(uuid.uuid4()),
kind=Revoke.MESSAGE_KIND,
payload=Revoke(id=request_id).model_dump(),
)
await self.broker.send_control_message(payload)
child_count = await res.children.count()
if child_count:
_log.info("Revoking %r children of id=%r", child_count, res.id)
# Abort children.
async for child_id in res.children.iter_ids():
await self.revoke(child_id, force=force)
return res
start(self)
async
Starts the app.
Source code in mognet/app/app.py
async def start(self):
"""
Starts the app.
"""
_log.info("Starting app %r", self.config.node_id)
self._loop = asyncio.get_event_loop()
self._run_result = asyncio.Future()
self._log_tasks_and_queues()
await self._call_on_starting_middleware()
await self.connect()
self._heartbeat_task = asyncio.create_task(self._background_heartbeat())
self._consume_control_task = asyncio.create_task(self._consume_control_queue())
self.worker = Worker(app=self, middleware=self._middleware)
self._worker_task = asyncio.create_task(self.worker.run())
_log.info("Started")
await self._call_on_started_middleware()
return await self._run_result
submit(self, req, context=None)
async
Submits a request for execution.
If a context is defined, it will be used to create a parent-child relationship between the new-to-submit request and the one existing in the context instance. This is later used to cancel the whole task tree.
Source code in mognet/app/app.py
async def submit(self, req: "Request", context: Optional[Context] = None) -> Result:
"""
Submits a request for execution.
If a context is defined, it will be used to create a parent-child
relationship between the new-to-submit request and the one existing
in the context instance. This is later used to cancel the whole task tree.
"""
if not self.result_backend:
raise ImproperlyConfigured("Result backend not defined")
if not self.broker:
raise ImproperlyConfigured("Broker not connected")
try:
if req.kwargs_repr is None:
req.kwargs_repr = format_kwargs_repr(req.args, req.kwargs)
_log.debug("Set default kwargs_repr on Request %r", req)
res = Result(
self.result_backend,
id=req.id,
name=req.name,
state=ResultState.PENDING,
created=datetime.now(tz=timezone.utc),
request_kwargs_repr=req.kwargs_repr,
)
if context is not None:
# Set the parent-child relationship and update the request stack.
parent_request = context.request
res.parent_id = parent_request.id
req.stack = [*parent_request.stack, parent_request.id]
if res.parent_id is not None:
await self.result_backend.add_children(res.parent_id, req.id)
await self.result_backend.set(req.id, res)
# Store the metadata on the Result.
if req.metadata:
await res.set_metadata(**req.metadata)
await self._on_submitting(req, context=context)
payload = MessagePayload(
id=str(req.id),
kind="Request",
payload=req.model_dump(),
priority=req.priority,
)
_log.debug("Sending message %r", payload.id)
await self.broker.send_task_message(self._get_task_route(req), payload)
return res
except Exception as exc:
raise CouldNotSubmit(f"Could not submit {req!r}") from exc
app_config
AppConfig (BaseModel)
Configuration for a Mognet application.
Source code in mognet/app/app_config.py
class AppConfig(BaseModel):
"""
Configuration for a Mognet application.
"""
# An ID for the node. Defaults to a string containing
# the current PID and the hostname.
node_id: str = Field(default_factory=_default_node_id)
# Configuration for the result backend.
result_backend: ResultBackendConfig
# Configuration for the state backend.
state_backend: StateBackendConfig
# Configuration for the task broker.
broker: BrokerConfig
# List of modules to import
imports: List[str] = Field(default_factory=list)
# Maximum number of tasks that this app can handle.
max_tasks: Optional[int] = None
# Maximum recursion depth for tasks that call other tasks.
max_recursion: int = 64
# Defines the number of times a task that unexpectedly
# failed (i.e., SIGKILL) can be retried.
max_retries: int = 3
# Default task route to send messages to.
default_task_route: str = "tasks"
# A mapping of [task name] -> [queue] that overrides the queue on which a task is listening.
# If a task is not here, it will default to the queue set in [default_task_route].
task_routes: Dict[str, str] = Field(default_factory=dict)
# Specify which queues to listen, or not listen, on.
task_queues: Queues = Field(default_factory=Queues)
# The minimum prefetch count. Task consumption will start with
# this value, and is then incremented based on the number of waiting
# tasks that are running.
# A higher value allows more tasks to run concurrently on this node.
minimum_concurrency: int = 1
# The minimum prefetch count. This helps ensure that not too many
# recursive tasks run on this node.
# Bear in mind that, if set, you can run into deadlocks if you have
# overly recursive tasks.
maximum_concurrency: Optional[int] = None
# Settings that can be passed to instances retrieved via
# Context#get_service()
services_settings: Dict[str, Any] = Field(default_factory=dict)
@classmethod
def from_file(cls, file_path: str) -> "AppConfig":
with open(file_path, "r", encoding="utf-8") as config_file:
return cls.model_validate_json(config_file.read())
# Maximum number of attempts to connect
max_reconnect_retries: int = 5
# Time to wait between reconnects
reconnect_interval: float = 5
backend
special
Result Backends are used to retrieve Task Results from a persistent storage backend.
backend_config
Encoding (str, Enum)
RedisResultBackendSettings (BaseModel)
Configuration for the Redis Result Backend
Source code in mognet/backend/backend_config.py
class RedisResultBackendSettings(BaseModel):
"""Configuration for the Redis Result Backend"""
url: str = "redis://localhost:6379/"
# TTL for the results.
result_ttl: Optional[int] = int(timedelta(days=21).total_seconds())
# TTL for the result values. This is set lower than `result_ttl` to keep
# the results themselves available for longer.
result_value_ttl: Optional[int] = int(timedelta(days=7).total_seconds())
# Encoding for the result values.
result_value_encoding: Optional[Encoding] = Encoding.GZIP
retry_connect_attempts: int = 10
retry_connect_timeout: float = 30
# Set the limit of connections on the Redis connection pool.
# DANGER! Setting this to too low a value WILL cause issues opening connections!
max_connections: Optional[int] = None
base_result_backend
BaseResultBackend
Base interface to implemenent a Result Backend.
Source code in mognet/backend/base_result_backend.py
class BaseResultBackend(metaclass=ABCMeta):
"""Base interface to implemenent a Result Backend."""
config: ResultBackendConfig
app: AppParameters
def __init__(self, config: ResultBackendConfig, app: AppParameters) -> None:
super().__init__()
self.config = config
self.app = app
@abstractmethod
async def get(self, result_id: UUID) -> Optional[Result]:
"""
Get a Result by it's ID.
If it doesn't exist, this method returns None.
"""
raise NotImplementedError
async def get_many(self, *result_ids: UUID) -> List[Result]:
"""
Get a list of Results by specifying their IDs.
Results that don't exist will be removed from this list.
"""
all_results = await asyncio.gather(*[self.get(r_id) for r_id in result_ids])
return [r for r in all_results if r if r is not None]
async def get_or_create(self, result_id: UUID) -> Result:
"""
Get a Result by it's ID.
If it doesn't exist, this method creates one.
The returned Result will either be the existing one,
or the newly-created one.
"""
res = await self.get(result_id)
if res is None:
res = Result(self, id=result_id)
await self.set(result_id, res)
return res
@abstractmethod
async def set(self, result_id: UUID, result: Result) -> None:
"""
Save a Result.
"""
raise NotImplementedError
async def wait(
self, result_id: UUID, timeout: Optional[float] = None, poll: float = 0.1
) -> Result:
"""
Wait until a result is ready.
Raises `asyncio.TimeoutError` if a timeout is set and exceeded.
"""
async def waiter():
while True:
result = await self.get(result_id)
if result is not None and result.done:
return result
await asyncio.sleep(poll)
if timeout:
return await asyncio.wait_for(waiter(), timeout)
return await waiter()
@abstractmethod
async def get_children_count(self, parent_result_id: UUID) -> int:
"""
Return the number of children of a Result.
Returns 0 if the Result doesn't exist.
"""
raise NotImplementedError
@abstractmethod
def iterate_children_ids(
self, parent_result_id: UUID, *, count: Optional[int] = None
) -> AsyncGenerator[UUID, None]:
"""
Get an AsyncGenerator for the IDs for the children of a Result.
The AsyncGenerator will be empty if the Result doesn't exist.
"""
raise NotImplementedError
def iterate_children(
self, parent_result_id: UUID, *, count: Optional[int] = None
) -> AsyncGenerator[Result, None]:
"""
Get an AsyncGenerator for the children of a Result.
The AsyncGenerator will be empty if the Result doesn't exist.
"""
raise NotImplementedError
async def __aenter__(self):
return self
async def __aexit__(self, *args, **kwargs):
return None
async def connect(self):
"""
Explicit method to connect to the backend provided by
this Result backend.
"""
async def close(self):
"""
Explicit method to close the backend provided by
this Result backend.
"""
@abstractmethod
async def add_children(self, result_id: UUID, *children: UUID) -> None:
"""
Add children to a parent Result.
"""
raise NotImplementedError
@abstractmethod
async def get_value(self, result_id: UUID) -> ResultValueHolder:
"""
Get the value of a Result.
If the value is lost, ResultValueLost is raised.
"""
raise NotImplementedError
@abstractmethod
async def set_value(self, result_id: UUID, value: ResultValueHolder) -> None:
"""
Set the value of a Result.
"""
raise NotImplementedError
@abstractmethod
async def get_metadata(self, result_id: UUID) -> Dict[str, Any]:
"""
Get the metadata of a Result.
Returns an empty Dict if the Result doesn't exist.
"""
raise NotImplementedError
@abstractmethod
async def set_metadata(self, result_id: UUID, **kwargs: Any) -> None:
"""
Set metadata on a Result.
"""
raise NotImplementedError
@abstractmethod
async def delete(self, result_id: UUID, include_children: bool = True) -> None:
"""
Delete a Result.
"""
raise NotImplementedError
@abstractmethod
async def set_ttl(
self, result_id: UUID, ttl: timedelta, include_children: bool = True
) -> None:
"""
Set expiration on a Result.
If include_children is True, children will have the same TTL set.
"""
raise NotImplementedError
add_children(self, result_id, *children)
async
close(self)
async
connect(self)
async
delete(self, result_id, include_children=True)
async
get(self, result_id)
async
Get a Result by it's ID. If it doesn't exist, this method returns None.
get_children_count(self, parent_result_id)
async
Return the number of children of a Result.
Returns 0 if the Result doesn't exist.
get_many(self, *result_ids)
async
Get a list of Results by specifying their IDs. Results that don't exist will be removed from this list.
Source code in mognet/backend/base_result_backend.py
async def get_many(self, *result_ids: UUID) -> List[Result]:
"""
Get a list of Results by specifying their IDs.
Results that don't exist will be removed from this list.
"""
all_results = await asyncio.gather(*[self.get(r_id) for r_id in result_ids])
return [r for r in all_results if r if r is not None]
get_metadata(self, result_id)
async
Get the metadata of a Result.
Returns an empty Dict if the Result doesn't exist.
get_or_create(self, result_id)
async
Get a Result by it's ID. If it doesn't exist, this method creates one.
The returned Result will either be the existing one, or the newly-created one.
Source code in mognet/backend/base_result_backend.py
async def get_or_create(self, result_id: UUID) -> Result:
"""
Get a Result by it's ID.
If it doesn't exist, this method creates one.
The returned Result will either be the existing one,
or the newly-created one.
"""
res = await self.get(result_id)
if res is None:
res = Result(self, id=result_id)
await self.set(result_id, res)
return res
get_value(self, result_id)
async
Get the value of a Result.
If the value is lost, ResultValueLost is raised.
iterate_children(self, parent_result_id, *, count=None)
Get an AsyncGenerator for the children of a Result.
The AsyncGenerator will be empty if the Result doesn't exist.
Source code in mognet/backend/base_result_backend.py
iterate_children_ids(self, parent_result_id, *, count=None)
Get an AsyncGenerator for the IDs for the children of a Result.
The AsyncGenerator will be empty if the Result doesn't exist.
Source code in mognet/backend/base_result_backend.py
set(self, result_id, result)
async
set_metadata(self, result_id, **kwargs)
async
set_ttl(self, result_id, ttl, include_children=True)
async
Set expiration on a Result.
If include_children is True, children will have the same TTL set.
set_value(self, result_id, value)
async
wait(self, result_id, timeout=None, poll=0.1)
async
Wait until a result is ready.
Raises asyncio.TimeoutError
if a timeout is set and exceeded.
Source code in mognet/backend/base_result_backend.py
async def wait(
self, result_id: UUID, timeout: Optional[float] = None, poll: float = 0.1
) -> Result:
"""
Wait until a result is ready.
Raises `asyncio.TimeoutError` if a timeout is set and exceeded.
"""
async def waiter():
while True:
result = await self.get(result_id)
if result is not None and result.done:
return result
await asyncio.sleep(poll)
if timeout:
return await asyncio.wait_for(waiter(), timeout)
return await waiter()
memory_result_backend
MemoryResultBackend (BaseResultBackend)
Result backend that "persists" results in memory. Useful for testing, but this is not recommended for production setups.
Source code in mognet/backend/memory_result_backend.py
class MemoryResultBackend(BaseResultBackend):
"""
Result backend that "persists" results in memory. Useful for testing,
but this is not recommended for production setups.
"""
def __init__(self, config: ResultBackendConfig, app: AppParameters) -> None:
super().__init__(config, app)
self._results: Dict[UUID, Result] = {}
self._result_tree: Dict[UUID, Set[UUID]] = {}
self._values: Dict[UUID, ResultValueHolder] = {}
self._metadata: Dict[UUID, Dict[str, Any]] = {}
async def get(self, result_id: UUID) -> Optional[Result]:
return self._results.get(result_id, None)
async def set(self, result_id: UUID, result: Result):
self._results[result_id] = result
async def get_children_count(self, parent_result_id: UUID) -> int:
return len(self._result_tree.get(parent_result_id, set()))
async def iterate_children_ids(
self, parent_result_id: UUID, *, count: int = None
) -> AsyncGenerator[UUID, None]:
children = self._result_tree[parent_result_id]
for idx, child in enumerate(children):
yield child
if count is not None and idx > count:
break
async def iterate_children(
self, parent_result_id: UUID, *, count: int = None
) -> AsyncGenerator[Result, None]:
async for child_id in self.iterate_children_ids(parent_result_id, count=count):
child = self._results.get(child_id, None)
if child is not None:
yield child
async def add_children(self, result_id: UUID, *children: UUID) -> None:
self._result_tree.setdefault(result_id, set()).update(children)
async def get_value(self, result_id: UUID) -> ResultValueHolder:
value = self._values.get(result_id, None)
if value is None:
raise ResultValueLost(result_id)
return value
async def set_value(self, result_id: UUID, value: ResultValueHolder):
self._values[result_id] = value
async def get_metadata(self, result_id: UUID) -> Dict[str, Any]:
meta = self._metadata.get(result_id, {})
return meta
async def set_metadata(self, result_id: UUID, **kwargs: Any) -> None:
self._metadata.setdefault(result_id, {}).update(kwargs)
async def delete(self, result_id: UUID, include_children: bool = True):
if include_children:
for child_id in self._result_tree.get(result_id, set()):
await self.delete(child_id, include_children=include_children)
self._results.pop(result_id, None)
self._metadata.pop(result_id, None)
self._values.pop(result_id, None)
async def set_ttl(
self, result_id: UUID, ttl: timedelta, include_children: bool = True
):
pass
async def close(self):
self._metadata = {}
self._result_tree = {}
self._results = {}
self._values = {}
return await super().close()
add_children(self, result_id, *children)
async
close(self)
async
delete(self, result_id, include_children=True)
async
Delete a Result.
Source code in mognet/backend/memory_result_backend.py
async def delete(self, result_id: UUID, include_children: bool = True):
if include_children:
for child_id in self._result_tree.get(result_id, set()):
await self.delete(child_id, include_children=include_children)
self._results.pop(result_id, None)
self._metadata.pop(result_id, None)
self._values.pop(result_id, None)
get(self, result_id)
async
get_children_count(self, parent_result_id)
async
get_metadata(self, result_id)
async
get_value(self, result_id)
async
Get the value of a Result.
If the value is lost, ResultValueLost is raised.
iterate_children(self, parent_result_id, *, count=None)
Get an AsyncGenerator for the children of a Result.
The AsyncGenerator will be empty if the Result doesn't exist.
Source code in mognet/backend/memory_result_backend.py
iterate_children_ids(self, parent_result_id, *, count=None)
Get an AsyncGenerator for the IDs for the children of a Result.
The AsyncGenerator will be empty if the Result doesn't exist.
Source code in mognet/backend/memory_result_backend.py
set(self, result_id, result)
async
set_metadata(self, result_id, **kwargs)
async
set_ttl(self, result_id, ttl, include_children=True)
async
redis_result_backend
RedisResultBackend (BaseResultBackend)
Result backend that uses Redis for persistence.
Source code in mognet/backend/redis_result_backend.py
class RedisResultBackend(BaseResultBackend):
"""
Result backend that uses Redis for persistence.
"""
def __init__(self, config: ResultBackendConfig, app: AppParameters) -> None:
super().__init__(config, app)
self._url = config.redis.url
self.__redis = None
self._connected = False
# Holds references to tasks which are spawned by .wait()
self._waiters: List[asyncio.Future] = []
# Attributes for @_retry
self._retry_connect_attempts = self.config.redis.retry_connect_attempts
self._retry_connect_timeout = self.config.redis.retry_connect_timeout
self._retry_lock = asyncio.Lock()
@property
def _redis(self) -> Redis:
if self.__redis is None:
raise NotConnected
return self.__redis
@_retry
async def get(self, result_id: UUID) -> Optional[Result]:
obj_key = self._format_key(result_id)
async with self._redis.pipeline(transaction=True) as pip:
# Since HGETALL returns an empty HASH for keys that don't exist,
# test if it exists at all and use that to check if we should return null.
pip.exists(obj_key)
pip.hgetall(obj_key)
exists, value, *_ = await shield(pip.execute())
if not exists:
return None
return self._decode_result(value)
@_retry
async def get_or_create(self, result_id: UUID) -> Result:
"""
Gets a result, or creates one if it doesn't exist.
"""
async with self._redis.pipeline(transaction=True) as pip:
result_key = self._format_key(result_id)
pip.hsetnx(result_key, "id", json.dumps(str(result_id)).encode())
pip.hgetall(result_key)
if self.config.redis.result_ttl is not None:
pip.expire(result_key, self.config.redis.result_ttl)
# Also set the value, to a default holding an absence of result.
value_key = self._format_key(result_id, "value")
default_not_ready = ResultValueHolder.not_ready()
encoded = self._encode_result_value(default_not_ready)
if self.config.redis.result_value_ttl is not None:
pip.expire(value_key, self.config.redis.result_value_ttl)
for encoded_k, encoded_v in encoded.items():
pip.hsetnx(value_key, encoded_k, encoded_v)
existed, value, *_ = await shield(pip.execute())
if not existed:
_log.debug("Created result %r on key %r", result_id, result_key)
return self._decode_result(value)
def _encode_result_value(self, value: ResultValueHolder) -> Dict[str, bytes]:
contents = value.model_dump_json().encode()
encoding = b"null"
if self.config.redis.result_value_encoding == Encoding.GZIP:
encoding = _json_bytes("gzip")
contents = gzip.compress(contents)
return {
"contents": contents,
"encoding": encoding,
"content_type": _json_bytes("application/json"),
}
def _decode_result_value(self, encoded: Dict[bytes, bytes]) -> ResultValueHolder:
if encoded.get(b"encoding") == _json_bytes("gzip"):
contents = gzip.decompress(encoded[b"contents"])
else:
contents = encoded[b"contents"]
if encoded.get(b"content_type") != _json_bytes("application/json"):
raise ValueError(f"Unknown content_type={encoded.get(b'content_type')!r}")
return ResultValueHolder.model_validate_json(contents)
@_retry
async def set(self, result_id: UUID, result: Result):
key = self._format_key(result_id)
async with self._redis.pipeline(transaction=True) as pip:
encoded = _encode_result(result)
pip.hset(key, None, None, encoded)
if self.config.redis.result_ttl is not None:
pip.expire(key, self.config.redis.result_ttl)
await shield(pip.execute())
def _format_key(self, result_id: UUID, subkey: str = None) -> str:
key = f"{self.app.name}.mognet.result.{str(result_id)}"
if subkey:
key = f"{key}/{subkey}"
_log.debug(
"Formatted result key=%r for id=%r and subkey=%r", key, subkey, result_id
)
return key
@_retry
async def add_children(self, result_id: UUID, *children: UUID):
if not children:
return
# If there are children to add, add them to the set
# on Redis using SADD
children_key = self._format_key(result_id, "children")
async with self._redis.pipeline(transaction=True) as pip:
pip.sadd(children_key, *_encode_children(children))
if self.config.redis.result_ttl is not None:
pip.expire(children_key, self.config.redis.result_ttl)
await shield(pip.execute())
async def get_value(self, result_id: UUID) -> ResultValueHolder:
value_key = self._format_key(result_id, "value")
async with self._redis.pipeline(transaction=True) as pip:
pip.exists(value_key)
pip.hgetall(value_key)
exists, contents = await shield(pip.execute())
if not exists:
raise ResultValueLost(result_id)
return self._decode_result_value(contents)
async def set_value(self, result_id: UUID, value: ResultValueHolder):
value_key = self._format_key(result_id, "value")
encoded = self._encode_result_value(value)
async with self._redis.pipeline(transaction=True) as pip:
pip.hset(value_key, None, None, encoded)
if self.config.redis.result_value_ttl is not None:
pip.expire(value_key, self.config.redis.result_value_ttl)
await shield(pip.execute())
async def delete(self, result_id: UUID, include_children: bool = True):
if include_children:
async for child_id in self.iterate_children_ids(result_id):
await self.delete(child_id, include_children=True)
key = self._format_key(result_id)
children_key = self._format_key(result_id, "children")
value_key = self._format_key(result_id, "value")
metadata_key = self._format_key(result_id, "metadata")
await shield(self._redis.delete(key, children_key, value_key, metadata_key))
async def set_ttl(
self, result_id: UUID, ttl: timedelta, include_children: bool = True
):
if include_children:
async for child_id in self.iterate_children_ids(result_id):
await self.set_ttl(child_id, ttl, include_children=True)
key = self._format_key(result_id)
children_key = self._format_key(result_id, "children")
value_key = self._format_key(result_id, "value")
metadata_key = self._format_key(result_id, "metadata")
await shield(self._redis.expire(key, ttl))
await shield(self._redis.expire(children_key, ttl))
await shield(self._redis.expire(value_key, ttl))
await shield(self._redis.expire(metadata_key, ttl))
async def connect(self):
if self._connected:
return
self._connected = True
await self._connect()
async def close(self):
self._connected = False
await self._close_waiters()
await self._disconnect()
async def get_children_count(self, parent_result_id: UUID) -> int:
children_key = self._format_key(parent_result_id, "children")
return await shield(self._redis.scard(children_key))
async def iterate_children_ids(
self, parent_result_id: UUID, *, count: Optional[float] = None
):
children_key = self._format_key(parent_result_id, "children")
raw_child_id: bytes
async for raw_child_id in self._redis.sscan_iter(children_key, count=count):
child_id = UUID(bytes=raw_child_id)
yield child_id
async def iterate_children(
self, parent_result_id: UUID, *, count: Optional[float] = None
):
async for child_id in self.iterate_children_ids(parent_result_id, count=count):
child = await self.get(child_id)
if child is not None:
yield child
@_retry
async def wait(
self, result_id: UUID, timeout: Optional[float] = None, poll: float = 1
) -> Result:
async def waiter():
key = self._format_key(result_id=result_id)
# Type def for the state key. It can (but shouldn't)
# be null.
t = Optional[ResultState]
while True:
raw_state = await shield(self._redis.hget(key, "state")) or b"null"
state = TypeAdapter(t).validate_json(raw_state)
if state is None:
raise ResultValueLost(result_id)
if state in READY_STATES:
final_result = await self.get(result_id)
if final_result is None:
raise RuntimeError(
f"Result id={result_id!r} that previously existed no longer does"
)
return final_result
await asyncio.sleep(poll)
waiter_task = asyncio.create_task(
waiter(),
name=f"RedisResultBackend:wait_for:{result_id}",
)
if timeout:
waiter_task = asyncio.create_task(asyncio.wait_for(waiter_task, timeout))
self._waiters.append(waiter_task)
return await waiter_task
async def get_metadata(self, result_id: UUID) -> Dict[str, Any]:
key = self._format_key(result_id, "metadata")
value = await shield(self._redis.hgetall(key))
return _decode_json_dict(value)
async def set_metadata(self, result_id: UUID, **kwargs: Any) -> None:
key = self._format_key(result_id, "metadata")
if not kwargs:
return
async with self._redis.pipeline(transaction=True) as pip:
pip.hset(key, None, None, _dict_to_json_dict(kwargs))
if self.config.redis.result_ttl is not None:
pip.expire(key, self.config.redis.result_ttl)
await shield(pip.execute())
def __repr__(self):
return f"RedisResultBackend(url={censor_credentials(self._url)!r})"
async def __aenter__(self):
await self.connect()
return self
async def __aexit__(self, *args, **kwargs):
await self.close()
async def _close_waiters(self):
"""
Cancel any wait loop we have running.
"""
while self._waiters:
waiter_task = self._waiters.pop()
try:
_log.debug("Cancelling waiter %r", waiter_task)
waiter_task.cancel()
await waiter_task
except asyncio.CancelledError:
pass
except Exception as exc: # pylint: disable=broad-except
_log.debug("Error on waiter task %r", waiter_task, exc_info=exc)
async def _create_redis(self):
_log.debug("Creating Redis connection")
redis: Redis = await from_url(
self._url,
max_connections=self.config.redis.max_connections,
)
return redis
@_retry
async def _connect(self):
if self.__redis is None:
self.__redis = await self._create_redis()
await shield(self._redis.ping())
async def _disconnect(self):
redis = self.__redis
if redis is not None:
self.__redis = None
_log.debug("Closing Redis connection")
await redis.close()
def _decode_result(self, json_dict: Dict[bytes, bytes]) -> Result:
# Load the dict of JSON values first; then update it with overrides.
value = _decode_json_dict(json_dict)
return Result(self, **value)
close(self)
async
connect(self)
async
delete(self, result_id, include_children=True)
async
Delete a Result.
Source code in mognet/backend/redis_result_backend.py
async def delete(self, result_id: UUID, include_children: bool = True):
if include_children:
async for child_id in self.iterate_children_ids(result_id):
await self.delete(child_id, include_children=True)
key = self._format_key(result_id)
children_key = self._format_key(result_id, "children")
value_key = self._format_key(result_id, "value")
metadata_key = self._format_key(result_id, "metadata")
await shield(self._redis.delete(key, children_key, value_key, metadata_key))
get_children_count(self, parent_result_id)
async
Return the number of children of a Result.
Returns 0 if the Result doesn't exist.
get_metadata(self, result_id)
async
Get the metadata of a Result.
Returns an empty Dict if the Result doesn't exist.
get_or_create(self, result_id)
async
Gets a result, or creates one if it doesn't exist.
Source code in mognet/backend/redis_result_backend.py
@_retry
async def get_or_create(self, result_id: UUID) -> Result:
"""
Gets a result, or creates one if it doesn't exist.
"""
async with self._redis.pipeline(transaction=True) as pip:
result_key = self._format_key(result_id)
pip.hsetnx(result_key, "id", json.dumps(str(result_id)).encode())
pip.hgetall(result_key)
if self.config.redis.result_ttl is not None:
pip.expire(result_key, self.config.redis.result_ttl)
# Also set the value, to a default holding an absence of result.
value_key = self._format_key(result_id, "value")
default_not_ready = ResultValueHolder.not_ready()
encoded = self._encode_result_value(default_not_ready)
if self.config.redis.result_value_ttl is not None:
pip.expire(value_key, self.config.redis.result_value_ttl)
for encoded_k, encoded_v in encoded.items():
pip.hsetnx(value_key, encoded_k, encoded_v)
existed, value, *_ = await shield(pip.execute())
if not existed:
_log.debug("Created result %r on key %r", result_id, result_key)
return self._decode_result(value)
get_value(self, result_id)
async
Get the value of a Result.
If the value is lost, ResultValueLost is raised.
Source code in mognet/backend/redis_result_backend.py
async def get_value(self, result_id: UUID) -> ResultValueHolder:
value_key = self._format_key(result_id, "value")
async with self._redis.pipeline(transaction=True) as pip:
pip.exists(value_key)
pip.hgetall(value_key)
exists, contents = await shield(pip.execute())
if not exists:
raise ResultValueLost(result_id)
return self._decode_result_value(contents)
iterate_children(self, parent_result_id, *, count=None)
Get an AsyncGenerator for the children of a Result.
The AsyncGenerator will be empty if the Result doesn't exist.
iterate_children_ids(self, parent_result_id, *, count=None)
Get an AsyncGenerator for the IDs for the children of a Result.
The AsyncGenerator will be empty if the Result doesn't exist.
Source code in mognet/backend/redis_result_backend.py
async def iterate_children_ids(
self, parent_result_id: UUID, *, count: Optional[float] = None
):
children_key = self._format_key(parent_result_id, "children")
raw_child_id: bytes
async for raw_child_id in self._redis.sscan_iter(children_key, count=count):
child_id = UUID(bytes=raw_child_id)
yield child_id
set_metadata(self, result_id, **kwargs)
async
Set metadata on a Result.
Source code in mognet/backend/redis_result_backend.py
async def set_metadata(self, result_id: UUID, **kwargs: Any) -> None:
key = self._format_key(result_id, "metadata")
if not kwargs:
return
async with self._redis.pipeline(transaction=True) as pip:
pip.hset(key, None, None, _dict_to_json_dict(kwargs))
if self.config.redis.result_ttl is not None:
pip.expire(key, self.config.redis.result_ttl)
await shield(pip.execute())
set_ttl(self, result_id, ttl, include_children=True)
async
Set expiration on a Result.
If include_children is True, children will have the same TTL set.
Source code in mognet/backend/redis_result_backend.py
async def set_ttl(
self, result_id: UUID, ttl: timedelta, include_children: bool = True
):
if include_children:
async for child_id in self.iterate_children_ids(result_id):
await self.set_ttl(child_id, ttl, include_children=True)
key = self._format_key(result_id)
children_key = self._format_key(result_id, "children")
value_key = self._format_key(result_id, "value")
metadata_key = self._format_key(result_id, "metadata")
await shield(self._redis.expire(key, ttl))
await shield(self._redis.expire(children_key, ttl))
await shield(self._redis.expire(value_key, ttl))
await shield(self._redis.expire(metadata_key, ttl))
set_value(self, result_id, value)
async
Set the value of a Result.
Source code in mognet/backend/redis_result_backend.py
async def set_value(self, result_id: UUID, value: ResultValueHolder):
value_key = self._format_key(result_id, "value")
encoded = self._encode_result_value(value)
async with self._redis.pipeline(transaction=True) as pip:
pip.hset(value_key, None, None, encoded)
if self.config.redis.result_value_ttl is not None:
pip.expire(value_key, self.config.redis.result_value_ttl)
await shield(pip.execute())
broker
special
amqp_broker
AmqpBroker (BaseBroker)
Source code in mognet/broker/amqp_broker.py
class AmqpBroker(BaseBroker):
_task_channel: Channel
_control_channel: Channel
_task_queues: Dict[str, Queue]
_direct_exchange: Exchange
_control_exchange: Exchange
_retry = retryableasyncmethod(
_RETRYABLE_ERRORS,
max_attempts="_retry_connect_attempts",
wait_timeout="_retry_connect_timeout",
)
def __init__(self, app: "App", config: BrokerConfig) -> None:
super().__init__()
self._connected = False
self.__connection = None
self.config = config
self._task_queues = {}
self._control_queue = None
# Lock to prevent duplicate queue declaration
self._lock = Lock()
self.app = app
# Attributes for @retryableasyncmethod
self._retry_connect_attempts = self.config.amqp.retry_connect_attempts
self._retry_connect_timeout = self.config.amqp.retry_connect_timeout
# List of callbacks for when connection drops
self._on_connection_failed_callbacks: List[
Callable[[Optional[BaseException]], Awaitable]
] = []
@property
def _connection(self) -> Connection:
if self.__connection is None:
raise NotConnected
return self.__connection
async def ack(self, delivery_tag: str):
await self._task_channel.channel.basic_ack(delivery_tag)
async def nack(self, delivery_tag: str):
await self._task_channel.channel.basic_nack(delivery_tag)
@_retry
async def set_task_prefetch(self, prefetch: int):
await self._task_channel.set_qos(prefetch_count=prefetch, global_=True)
@_retry
async def send_task_message(self, queue: str, payload: MessagePayload):
amqp_queue = self._task_queue_name(queue)
msg = Message(
body=payload.model_dump_json().encode(),
content_type="application/json",
content_encoding="utf-8",
priority=payload.priority,
message_id=payload.id,
)
await self._direct_exchange.publish(msg, amqp_queue)
_log.debug(
"Message %r sent to queue=%r (amqp queue=%r)", payload.id, queue, amqp_queue
)
async def consume_tasks(
self, queue: str
) -> AsyncGenerator[IncomingMessagePayload, None]:
amqp_queue = await self._get_or_create_task_queue(TaskQueue(name=queue))
async for message in self._consume(amqp_queue):
yield message
async def consume_control_queue(
self,
) -> AsyncGenerator[IncomingMessagePayload, None]:
amqp_queue = await self._get_or_create_control_queue()
async for message in self._consume(amqp_queue):
yield message
@_retry
async def send_control_message(self, payload: MessagePayload):
msg = Message(
body=payload.model_dump_json().encode(),
content_type="application/json",
content_encoding="utf-8",
message_id=payload.id,
expiration=timedelta(seconds=300),
)
# No queue name set because this is a fanout exchange.
await self._control_exchange.publish(msg, "")
@_retry
async def _send_query_message(self, payload: MessagePayload):
callback_queue = await self._task_channel.declare_queue(
name=self._callback_queue_name,
durable=False,
exclusive=False,
auto_delete=True,
arguments={
"x-expires": 30000,
"x-message-ttl": 30000,
},
)
await callback_queue.bind(self._direct_exchange)
msg = Message(
body=payload.model_dump_json().encode(),
content_type="application/json",
content_encoding="utf-8",
message_id=payload.id,
expiration=timedelta(seconds=300),
reply_to=callback_queue.name,
)
await self._control_exchange.publish(msg, "")
return callback_queue
async def send_query_message(
self, payload: MessagePayload
) -> AsyncGenerator[QueryResponseMessage, None]:
# Create a callback queue for getting the replies,
# then send the message to the control exchange (fanout).
# When done, delete the callback queue.
callback_queue = None
try:
callback_queue = await self._send_query_message(payload)
async with callback_queue.iterator() as iterator:
async for message in iterator:
async with message.process():
contents: dict = json.loads(message.body)
msg = _AmqpIncomingMessagePayload(
broker=self, incoming_message=message, **contents
)
yield QueryResponseMessage.model_validate(msg.payload)
finally:
if callback_queue is not None:
await callback_queue.delete()
async def setup_control_queue(self):
await self._get_or_create_control_queue()
async def setup_task_queue(self, queue: TaskQueue):
await self._get_or_create_task_queue(queue)
@_retry
async def _create_connection(self):
connection = await aio_pika.connect_robust(
self.config.amqp.url,
reconnect_interval=self.app.config.reconnect_interval,
client_properties={
"connection_name": self.app.config.node_id,
},
)
# All callback for broadcasting unexpected connection drops
connection.add_close_callback(self._send_connection_failed_events)
return connection
def add_connection_failed_callback(
self, cb: Callable[[Optional[BaseException]], Awaitable]
):
self._on_connection_failed_callbacks.append(cb)
def _send_connection_failed_events(self, connection, exc=None):
if not self._connected:
_log.debug(
"Not sending connection closed events because we are disconnected"
)
return
_log.error("AMQP connection %r failed", connection, exc_info=exc)
tasks = [cb(exc) for cb in self._on_connection_failed_callbacks]
_log.info(
"Notifying %r listeners of a disconnect",
len(tasks),
)
def notify_task_completion_callback(fut: asyncio.Future):
exc = fut.exception()
if exc and not fut.cancelled():
_log.error("Error notifying connection dropped", exc_info=exc)
for task in tasks:
notify_task = asyncio.create_task(task)
notify_task.add_done_callback(notify_task_completion_callback)
async def connect(self):
if self._connected:
return
self._connected = True
self.__connection = await self._create_connection()
# Use two separate channels with separate prefetch counts.
# This allows the task channel to increase the prefetch count
# without affecting the control channel,
# and allows the control channel to still receive messages, even if
# the task channel has reached the full prefetch count.
self._task_channel = await self._connection.channel()
await self.set_task_prefetch(1)
self._control_channel = await self._connection.channel()
await self.set_control_prefetch(4)
await self._create_exchanges()
_log.debug("Connected")
async def set_control_prefetch(self, prefetch: int):
await self._control_channel.set_qos(prefetch_count=prefetch, global_=False)
async def close(self):
self._connected = False
connection = self.__connection
if connection is not None:
self.__connection = None
_log.debug("Closing connections")
await connection.close()
_log.debug("Connection closed")
@_retry
async def send_reply(self, message: IncomingMessagePayload, reply: MessagePayload):
if not message.reply_to:
raise ValueError("Message has no reply_to set")
msg = Message(
body=reply.model_dump_json().encode(),
content_type="application/json",
content_encoding="utf-8",
message_id=reply.id,
)
await self._direct_exchange.publish(msg, message.reply_to)
async def purge_task_queue(self, queue: str) -> int:
amqp_queue = self._task_queue_name(queue)
if amqp_queue not in self._task_queues:
_log.warning(
"Queue %r (amqp=%r) does not exist in this broker", queue, amqp_queue
)
return 0
result = await self._task_queues[amqp_queue].purge()
deleted_count: int = result.message_count
_log.info(
"Deleted %r messages from queue=%r (amqp=%r)",
deleted_count,
queue,
amqp_queue,
)
return deleted_count
async def purge_control_queue(self) -> int:
if not self._control_queue:
_log.debug("Not listening on any control queue, not purging it")
return 0
result = await self._control_queue.purge()
return result.message_count
def __repr__(self):
return f"AmqpBroker(url={censor_credentials(self.config.amqp.url)!r})"
async def __aenter__(self):
await self.connect()
return self
async def __aexit__(self, *args, **kwargs):
await self.close()
return None
async def _create_exchanges(self):
self._direct_exchange = await self._task_channel.declare_exchange(
self._direct_exchange_name,
type=ExchangeType.DIRECT,
durable=True,
)
self._control_exchange = await self._control_channel.declare_exchange(
self._control_exchange_name,
type=ExchangeType.FANOUT,
)
async def _consume(
self, amqp_queue: Queue
) -> AsyncGenerator[IncomingMessagePayload, None]:
async with amqp_queue.iterator() as queue_iterator:
msg: IncomingMessage
async for msg in queue_iterator:
try:
contents: dict = json.loads(msg.body)
payload = _AmqpIncomingMessagePayload(
broker=self, incoming_message=msg, **contents
)
_log.debug("Successfully parsed message %r", payload.id)
yield payload
except asyncio.CancelledError:
raise
except Exception as exc: # pylint: disable=broad-except
_log.error(
"Error parsing contents of message %r, discarding",
msg.correlation_id,
exc_info=exc,
)
try:
await asyncio.shield(msg.ack())
except Exception as ack_err:
_log.error(
"Could not ACK message %r for discarding",
msg.correlation_id,
exc_info=ack_err,
)
def _task_queue_name(self, name: str) -> str:
return f"{self.app.name}.{name}"
@property
def _control_queue_name(self) -> str:
return f"{self.app.name}.mognet.control.{self.app.config.node_id}"
@property
def _callback_queue_name(self) -> str:
return f"{self.app.name}.mognet.callback.{self.app.config.node_id}"
@property
def _control_exchange_name(self) -> str:
return f"{self.app.name}.mognet.control"
@property
def _direct_exchange_name(self) -> str:
return f"{self.app.name}.mognet.direct"
@_retry
async def _get_or_create_control_queue(self) -> Queue:
if self._control_queue is None:
async with self._lock:
if self._control_queue is None:
self._control_queue = await self._control_channel.declare_queue(
name=self._control_queue_name,
durable=False,
auto_delete=True,
arguments={
"x-expires": 30000,
"x-message-ttl": 30000,
},
)
await self._control_queue.bind(self._control_exchange)
_log.debug("Prepared control queue=%r", self._control_queue.name)
return self._control_queue
@_retry
async def _get_or_create_task_queue(self, queue: TaskQueue) -> Queue:
name = self._task_queue_name(queue.name)
if name not in self._task_queues:
async with self._lock:
if name not in self._task_queues:
_log.debug("Preparing queue %r as AMQP queue=%r", queue, name)
self._task_queues[name] = await self._task_channel.declare_queue(
name,
durable=True,
arguments={"x-max-priority": queue.max_priority},
)
await self._task_queues[name].bind(self._direct_exchange)
_log.debug(
"Prepared task queue=%r as AMQP queue=%r", queue.name, name
)
return self._task_queues[name]
async def task_queue_stats(self, task_queue_name: str) -> QueueStats:
"""
Get the stats of a task queue.
"""
name = self._task_queue_name(task_queue_name)
# AMQP can close the Channel on us if we try accessing an object
# that doesn't exist. So, to avoid trouble when the same Channel
# is being used to consume a queue, use an ephemeral channel
# for this operation alone.
async with self._connection.channel() as channel:
try:
queue = await channel.get_queue(name, ensure=False)
declare_result = await queue.declare()
except (
aiormq.exceptions.ChannelNotFoundEntity,
aio_pika.exceptions.ChannelClosed,
) as query_err:
raise QueueNotFound(task_queue_name) from query_err
return QueueStats(
queue_name=task_queue_name,
message_count=declare_result.message_count,
consumer_count=declare_result.consumer_count,
)
task_queue_stats(self, task_queue_name)
async
Get the stats of a task queue.
Source code in mognet/broker/amqp_broker.py
async def task_queue_stats(self, task_queue_name: str) -> QueueStats:
"""
Get the stats of a task queue.
"""
name = self._task_queue_name(task_queue_name)
# AMQP can close the Channel on us if we try accessing an object
# that doesn't exist. So, to avoid trouble when the same Channel
# is being used to consume a queue, use an ephemeral channel
# for this operation alone.
async with self._connection.channel() as channel:
try:
queue = await channel.get_queue(name, ensure=False)
declare_result = await queue.declare()
except (
aiormq.exceptions.ChannelNotFoundEntity,
aio_pika.exceptions.ChannelClosed,
) as query_err:
raise QueueNotFound(task_queue_name) from query_err
return QueueStats(
queue_name=task_queue_name,
message_count=declare_result.message_count,
consumer_count=declare_result.consumer_count,
)
cli
special
exceptions
GracefulShutdown (BaseException)
If this exception is raised from a coroutine, the Mognet app running from the CLI will be gracefully closed
main
callback(app=<typer.models.ArgumentInfo object at 0x7fd935785e70>, log_level=<typer.models.OptionInfo object at 0x7fd9357863b0>, log_format=<typer.models.OptionInfo object at 0x7fd9357863e0>)
Mognet CLI
Source code in mognet/cli/main.py
@main.callback()
def callback(
app: str = typer.Argument(..., help="App module to import"),
log_level: LogLevel = typer.Option("INFO", metavar="log-level"),
log_format: str = typer.Option(
"%(asctime)s:%(name)s:%(levelname)s:%(message)s", metavar="log-format"
),
):
"""Mognet CLI"""
logging.basicConfig(
level=getattr(logging, log_level.value),
format=log_format,
)
app_instance = _get_app(app)
state["app_instance"] = app_instance
models
LogLevel (Enum)
nodes
status(format=<typer.models.OptionInfo object at 0x7fd93599f130>, text_label_format=<typer.models.OptionInfo object at 0x7fd93599c9a0>, json_indent=<typer.models.OptionInfo object at 0x7fd93599c970>, poll=<typer.models.OptionInfo object at 0x7fd93599f4f0>, timeout=<typer.models.OptionInfo object at 0x7fd93599f8e0>)
async
Query each node for their status
Source code in mognet/cli/nodes.py
@group.command("status")
@run_in_loop
async def status(
format: OutputFormat = typer.Option(OutputFormat.TEXT, metavar="format"),
text_label_format: str = typer.Option(
"{name}(id={id!r}, state={state!r})",
metavar="text-label-format",
help="Label format for text format",
),
json_indent: int = typer.Option(2, metavar="json-indent"),
poll: Optional[int] = typer.Option(
None,
metavar="poll",
help="Polling interval, in seconds (default=None)",
),
timeout: int = typer.Option(
30,
help="Timeout for querying nodes",
),
):
"""Query each node for their status"""
async with state["app_instance"] as app:
while True:
each_node_status: List[StatusResponseMessage] = []
async def read_status():
async for node_status in app.get_current_status_of_nodes():
each_node_status.append(node_status)
try:
await asyncio.wait_for(read_status(), timeout=timeout)
except asyncio.TimeoutError:
pass
all_result_ids = set()
for node_status in each_node_status:
all_result_ids.update(node_status.payload.running_request_ids)
all_results_by_id = {
r.id: r
for r in await app.result_backend.get_many(
*all_result_ids,
)
if r is not None
}
report = _CliStatusReport()
for node_status in each_node_status:
running_requests = [
all_results_by_id[r]
for r in node_status.payload.running_request_ids
if r in all_results_by_id
]
running_requests.sort(key=lambda r: r.created or now_utc())
report.node_status.append(
_CliStatusReport.NodeStatus(
node_id=node_status.node_id, running_requests=running_requests
)
)
if poll:
typer.clear()
if format == "text":
table_headers = ("Node name", "Running requests")
table_data = [
(
n.node_id,
"\n".join(
text_label_format.format(**r.dict())
for r in n.running_requests
)
or "(Empty)",
)
for n in report.node_status
]
typer.echo(
f"{len(report.node_status)} nodes replied as of {datetime.now()}:"
)
typer.echo(tabulate.tabulate(table_data, headers=table_headers))
elif format == "json":
typer.echo(report.model_dump_json(indent=json_indent))
if not poll:
break
await asyncio.sleep(poll)
queues
purge(force=<typer.models.OptionInfo object at 0x7fd9364a1570>)
async
Purge task and control queues
Source code in mognet/cli/queues.py
@group.command("purge")
@run_in_loop
async def purge(force: bool = typer.Option(False)):
"""Purge task and control queues"""
if not force:
typer.echo("Must pass --force")
raise typer.Exit(1)
async with state["app_instance"] as app:
await app.connect()
purged_task_counts = await app.purge_task_queues()
purged_control_count = await app.purge_control_queue()
typer.echo("Purged the following queues:")
for queue_name, count in purged_task_counts.items():
typer.echo(f"\t- {queue_name!r}: {count!r}")
typer.echo(f"Purged {purged_control_count!r} control messages")
run
run(include_queues=<typer.models.OptionInfo object at 0x7fd935831d80>, exclude_queues=<typer.models.OptionInfo object at 0x7fd935831d50>)
Run the app
Source code in mognet/cli/run.py
@group.callback()
def run(
include_queues: Optional[str] = typer.Option(
None,
metavar="include-queues",
help="Comma-separated list of the ONLY queues to listen on.",
),
exclude_queues: Optional[str] = typer.Option(
None,
metavar="exclude-queues",
help="Comma-separated list of the ONLY queues to NOT listen on.",
),
):
"""Run the app"""
app = state["app_instance"]
# Allow overriding the queues this app listens on.
queues = app.config.task_queues
if include_queues is not None:
queues.include = set(q.strip() for q in include_queues.split(","))
if exclude_queues is not None:
queues.exclude = set(q.strip() for q in exclude_queues.split(","))
queues.ensure_valid()
async def start():
async with app:
await app.start()
async def stop(_: AbstractEventLoop):
_log.info("Going to close app as part of a shut down")
await app.close()
pending_exception_to_raise = SystemExit(0)
def custom_exception_handler(loop: AbstractEventLoop, context: dict):
"""See: https://docs.python.org/3/library/asyncio-eventloop.html#error-handling-api"""
nonlocal pending_exception_to_raise
exc = context.get("exception")
if isinstance(exc, GracefulShutdown):
_log.debug("Got GracefulShutdown")
elif isinstance(exc, BaseException):
pending_exception_to_raise = exc
_log.error(
"Unhandled exception; stopping loop: %r %r",
context.get("message"),
context,
exc_info=pending_exception_to_raise,
)
loop.stop()
loop = asyncio.get_event_loop()
loop.set_exception_handler(custom_exception_handler)
aiorun.run(
start(), loop=loop, stop_on_unhandled_errors=False, shutdown_callback=stop
)
if pending_exception_to_raise is not None:
raise pending_exception_to_raise
return 0
run_in_loop
run_in_loop(f)
Utility to run a click/typer command function in an event loop (because they don't support it out of the box)
Source code in mognet/cli/run_in_loop.py
tasks
get(task_id=<typer.models.ArgumentInfo object at 0x7fd9357841c0>, include_value=<typer.models.OptionInfo object at 0x7fd935787dc0>)
async
Get a task's details
Source code in mognet/cli/tasks.py
@group.command("get")
@run_in_loop
async def get(
task_id: UUID = typer.Argument(
...,
metavar="id",
help="Task ID to get",
),
include_value: bool = typer.Option(
False,
metavar="include-value",
help="If passed, the task's result (or exception) will be printed",
),
):
"""Get a task's details"""
async with state["app_instance"] as app:
res = await app.result_backend.get(task_id)
if res is None:
_log.warning("Request %r does not exist", task_id)
raise typer.Exit(1)
table_data = [
("ID", res.id),
("Name", res.name),
("Arguments", res.request_kwargs_repr),
("State", res.state),
("Number of starts", res.number_of_starts),
("Number of stops", res.number_of_stops),
("Unexpected retry count", res.unexpected_retry_count),
("Parent", res.parent_id),
("Created at", res.created),
("Started at", res.started),
("Time in queue", res.queue_time),
("Finished at", res.finished),
("Runtime duration", res.duration),
("Node ID", res.node_id),
("Metadata", await res.get_metadata()),
]
if include_value:
try:
value = await res.value.get_raw_value()
if isinstance(value, _ExceptionInfo):
table_data.append(("Error raised", value.traceback))
else:
table_data.append(("Result value", repr(value)))
except ResultValueLost:
table_data.append(("Result value", "<Lost>"))
print(tabulate.tabulate(table_data))
revoke(task_id=<typer.models.ArgumentInfo object at 0x7fd935787160>, force=<typer.models.OptionInfo object at 0x7fd9357875e0>)
async
Revoke a task
Source code in mognet/cli/tasks.py
@group.command("revoke")
@run_in_loop
async def revoke(
task_id: UUID = typer.Argument(
...,
metavar="id",
help="Task ID to revoke",
),
force: bool = typer.Option(
False,
metavar="force",
help="Attempt revoking anyway if the result is complete. Helps cleaning up cases where subtasks may have been spawned.",
),
):
"""Revoke a task"""
async with state["app_instance"] as app:
res = await app.result_backend.get(task_id)
ret_code = 0
if res is None:
_log.warning("Request %r does not exist", task_id)
ret_code = 1
await app.revoke(task_id, force=force)
raise typer.Exit(ret_code)
tree(task_id=<typer.models.ArgumentInfo object at 0x7fd935787490>, format=<typer.models.OptionInfo object at 0x7fd935787d30>, json_indent=<typer.models.OptionInfo object at 0x7fd935787550>, text_label_format=<typer.models.OptionInfo object at 0x7fd935787520>, max_depth=<typer.models.OptionInfo object at 0x7fd935787640>, max_width=<typer.models.OptionInfo object at 0x7fd935787ca0>, poll=<typer.models.OptionInfo object at 0x7fd935787c40>)
async
Get the tree (descendants) of a task
Source code in mognet/cli/tasks.py
@group.command("tree")
@run_in_loop
async def tree(
task_id: UUID = typer.Argument(
...,
metavar="id",
help="Task ID to get tree from",
),
format: OutputFormat = typer.Option(OutputFormat.TEXT, metavar="format"),
json_indent: int = typer.Option(2, metavar="json-indent"),
text_label_format: str = typer.Option(
"{name}(id={id!r}, state={state!r})",
metavar="text-label-format",
help="Label format for text format",
),
max_depth: int = typer.Option(3, metavar="max-depth"),
max_width: int = typer.Option(16, metavar="max-width"),
poll: Optional[int] = typer.Option(None, metavar="poll"),
):
"""Get the tree (descendants) of a task"""
async with state["app_instance"] as app:
while True:
result = await app.result_backend.get(task_id)
if result is None:
raise RuntimeError(f"Result for request id={task_id!r} does not exist")
_log.info("Building tree for result id=%r", result.id)
tree = await result.tree(max_depth=max_depth, max_width=max_width)
if poll:
typer.clear()
if format == "text":
t = treelib.Tree()
def build_tree(n: ResultTree, parent: Optional[ResultTree] = None):
t.create_node(
tag=text_label_format.format(**n.dict()),
identifier=n.result.id,
parent=None if parent is None else parent.result.id,
)
for c in n.children:
build_tree(c, parent=n)
build_tree(tree)
t.show()
if format == "json":
print(tree.model_dump_json(indent=json_indent))
if not poll:
break
await asyncio.sleep(poll)
context
special
context
Context
Context for a request.
Allows access to the App instance, task state, and the request that is part of this task execution.
Source code in mognet/context/context.py
class Context:
"""
Context for a request.
Allows access to the App instance, task state,
and the request that is part of this task execution.
"""
app: "App"
state: "State"
request: "Request"
_dependencies: Set[UUID]
def __init__(
self,
app: "App",
request: "Request",
state: "State",
worker: "Worker",
):
self.app = app
self.state = state
self.request = request
self._worker = worker
self._dependencies = set()
self.create_request = self.app.create_request
async def submit(self, request: "Request"):
"""
Submits a new request as part of this one.
The difference from this method to the one defined in the `App` class
is that this one will submit the new request as a child request of
the one that's a part of this `Context` instance. This allows
the subrequests to be cancelled if the parent is also cancelled.
"""
return await self.app.submit(request, self)
@overload
async def run(self, request: Request[_Return]) -> _Return:
"""
Submits a Request to be run as part of this one (see `submit`), and waits for the result
"""
...
@overload
async def run(
self,
request: Callable[Concatenate["Context", _P], _Return],
*args: _P.args,
**kwargs: _P.kwargs
) -> _Return:
"""
Short-hand method for creating a Request from a function decorated with `@task`,
(see `create_request`), submitting it (see `submit`) and waiting for the result (see `run(Request)`).
This overload is for documenting non-async def functions.
"""
...
# This overload unwraps the Awaitable object
@overload
async def run(
self,
request: Callable[Concatenate["Context", _P], Awaitable[_Return]],
*args: _P.args,
**kwargs: _P.kwargs
) -> _Return:
"""
Short-hand method for creating a Request from a function decorated with `@task`,
(see `create_request`), submitting it (see `submit`) and waiting for the result (see `run(Request)`).
This overload is for documenting async def functions.
"""
...
async def run(self, request, *args, **kwargs):
"""
Submits and runs a new request as part of this one.
See `submit` for the difference between this and the equivalent
`run` method on the `App` class.
"""
if not isinstance(request, Request):
request = self.create_request(request, *args, **kwargs)
cancelled = False
try:
had_dependencies = bool(self._dependencies)
self._dependencies.add(request.id)
# If we transition from having no dependencies
# to having some, then we should suspend.
if not had_dependencies and self._dependencies:
await asyncio.shield(self._suspend())
self._log_dependencies()
return await self.app.run(request, self)
except asyncio.CancelledError:
cancelled = True
finally:
self._dependencies.remove(request.id)
self._log_dependencies()
if not self._dependencies and not cancelled:
await asyncio.shield(self._resume())
def _log_dependencies(self):
_log.debug(
"Task %r is waiting on %r dependencies",
self.request,
len(self._dependencies),
)
async def gather(
self, *results_or_ids: Union["Result", UUID], return_exceptions: bool = False
) -> List[Any]:
results = []
cancelled = False
try:
for result in results_or_ids:
if isinstance(result, UUID):
result = await self.app.result_backend.get(result)
results.append(result)
# If we transition from having no dependencies
# to having some, then we should suspend.
had_dependencies = bool(self._dependencies)
self._dependencies.update(r.id for r in results)
self._log_dependencies()
if not had_dependencies and self._dependencies:
await asyncio.shield(self._suspend())
return await asyncio.gather(*results, return_exceptions=return_exceptions)
except asyncio.CancelledError:
cancelled = True
raise
finally:
self._dependencies.difference_update(r.id for r in results)
self._log_dependencies()
if not self._dependencies and not cancelled:
await asyncio.shield(self._resume())
@overload
def get_service(
self, func: Type[ClassService[_Return]], *args, **kwargs
) -> _Return:
...
@overload
def get_service(
self,
func: Callable[Concatenate["Context", _P], _Return],
*args: _P.args,
**kwargs: _P.kwargs
) -> _Return:
...
def get_service(self, func, *args, **kwargs):
"""
Get a service to use in the task function.
This can be used for dependency injection purposes.
"""
if inspect.isclass(func) and issubclass(func, ClassService):
if func not in self.app.services:
# This cast() is only here to silence Pylance (because it thinks the class is abstract)
instance: ClassService = cast(Any, func)(self.app.config)
self.app.services[func] = instance.__enter__()
svc = self.app.services[func]
else:
svc = self.app.services.setdefault(func, func)
return svc(self, *args, **kwargs)
async def _suspend(self):
_log.debug("Suspending %r", self.request)
result = await self.get_result()
if result.state == ResultState.RUNNING:
await result.suspend()
await self._worker.add_waiting_task(self.request.id)
async def get_result(self):
"""
Gets the Result associated with this task.
WARNING: Do not `await` the returned Result instance! You will run
into a deadlock (you will be awaiting yourself)
"""
result = await self.app.result_backend.get(self.request.id)
if result is None:
raise ResultLost(self.request.id)
return result
def call_threadsafe(self, coro: Awaitable[_Return]) -> _Return:
"""
NOTE: ONLY TO BE USED WITH SYNC TASKS!
Utility function that will run the coroutine in the app's event loop
in a thread-safe way.
In reality this is a wrapper for `asyncio.run_coroutine_threadsafe(...)`
Use as follows:
```
context.call_sync(context.submit(...))
```
"""
return asyncio.run_coroutine_threadsafe(coro, loop=self.app.loop).result()
async def set_metadata(self, **kwargs: Any):
"""
Update metadata on the Result associated with the current task.
"""
result = await self.get_result()
return await result.set_metadata(**kwargs)
async def _resume(self):
_log.debug("Resuming %r", self.request)
result = await self.get_result()
if result.state == ResultState.SUSPENDED:
await result.resume()
await self._worker.remove_suspended_task(self.request.id)
call_threadsafe(self, coro)
NOTE: ONLY TO BE USED WITH SYNC TASKS!
Utility function that will run the coroutine in the app's event loop in a thread-safe way.
In reality this is a wrapper for asyncio.run_coroutine_threadsafe(...)
Use as follows:
Source code in mognet/context/context.py
def call_threadsafe(self, coro: Awaitable[_Return]) -> _Return:
"""
NOTE: ONLY TO BE USED WITH SYNC TASKS!
Utility function that will run the coroutine in the app's event loop
in a thread-safe way.
In reality this is a wrapper for `asyncio.run_coroutine_threadsafe(...)`
Use as follows:
```
context.call_sync(context.submit(...))
```
"""
return asyncio.run_coroutine_threadsafe(coro, loop=self.app.loop).result()
get_result(self)
async
Gets the Result associated with this task.
WARNING: Do not await
the returned Result instance! You will run
into a deadlock (you will be awaiting yourself)
Source code in mognet/context/context.py
async def get_result(self):
"""
Gets the Result associated with this task.
WARNING: Do not `await` the returned Result instance! You will run
into a deadlock (you will be awaiting yourself)
"""
result = await self.app.result_backend.get(self.request.id)
if result is None:
raise ResultLost(self.request.id)
return result
get_service(self, func, *args, **kwargs)
Get a service to use in the task function. This can be used for dependency injection purposes.
Source code in mognet/context/context.py
def get_service(self, func, *args, **kwargs):
"""
Get a service to use in the task function.
This can be used for dependency injection purposes.
"""
if inspect.isclass(func) and issubclass(func, ClassService):
if func not in self.app.services:
# This cast() is only here to silence Pylance (because it thinks the class is abstract)
instance: ClassService = cast(Any, func)(self.app.config)
self.app.services[func] = instance.__enter__()
svc = self.app.services[func]
else:
svc = self.app.services.setdefault(func, func)
return svc(self, *args, **kwargs)
run(self, request, *args, **kwargs)
async
Submits and runs a new request as part of this one.
See submit
for the difference between this and the equivalent
run
method on the App
class.
Source code in mognet/context/context.py
async def run(self, request, *args, **kwargs):
"""
Submits and runs a new request as part of this one.
See `submit` for the difference between this and the equivalent
`run` method on the `App` class.
"""
if not isinstance(request, Request):
request = self.create_request(request, *args, **kwargs)
cancelled = False
try:
had_dependencies = bool(self._dependencies)
self._dependencies.add(request.id)
# If we transition from having no dependencies
# to having some, then we should suspend.
if not had_dependencies and self._dependencies:
await asyncio.shield(self._suspend())
self._log_dependencies()
return await self.app.run(request, self)
except asyncio.CancelledError:
cancelled = True
finally:
self._dependencies.remove(request.id)
self._log_dependencies()
if not self._dependencies and not cancelled:
await asyncio.shield(self._resume())
set_metadata(self, **kwargs)
async
submit(self, request)
async
Submits a new request as part of this one.
The difference from this method to the one defined in the App
class
is that this one will submit the new request as a child request of
the one that's a part of this Context
instance. This allows
the subrequests to be cancelled if the parent is also cancelled.
Source code in mognet/context/context.py
async def submit(self, request: "Request"):
"""
Submits a new request as part of this one.
The difference from this method to the one defined in the `App` class
is that this one will submit the new request as a child request of
the one that's a part of this `Context` instance. This allows
the subrequests to be cancelled if the parent is also cancelled.
"""
return await self.app.submit(request, self)
decorators
special
task_decorator
task(*, name=None)
Register a function as a task that can be run.
The name argument is recommended, but not required. It is used as an identifier for which task to run when creating Request objects.
If the name is not provided, the function's full name (module + name) is used instead. Bear in mind that this means that if you rename the module or the function, things may break during rolling upgrades.
Source code in mognet/decorators/task_decorator.py
def task(*, name: Optional[str] = None):
"""
Register a function as a task that can be run.
The name argument is recommended, but not required. It is used as an identifier
for which task to run when creating Request objects.
If the name is not provided, the function's full name (module + name) is used instead.
Bear in mind that this means that if you rename the module or the function, things may break
during rolling upgrades.
"""
def task_decorator(t: _T) -> _T:
reg = task_registry.get(None)
if reg is None:
_log.debug("No global task registry set. Creating one")
reg = TaskRegistry()
reg.register_globally()
reg.add_task_function(cast(Callable, t), name=name)
return t
return task_decorator
exceptions
special
base_exceptions
ConnectionError (MognetError)
CouldNotSubmit (MognetError)
ImproperlyConfigured (MognetError)
MognetError (Exception)
NotConnected (ConnectionError)
result_exceptions
ResultLost (ResultError)
Raised when the result itself was lost (potentially due to key eviction)
Source code in mognet/exceptions/result_exceptions.py
ResultValueLost (ResultError)
Raised when the value for a result was lost (potentially due to key eviction)
Source code in mognet/exceptions/result_exceptions.py
class ResultValueLost(ResultError):
"""
Raised when the value for a result was lost
(potentially due to key eviction)
"""
def __init__(self, result_id: UUID) -> None:
super().__init__(result_id)
self.result_id = result_id
def __str__(self) -> str:
return f"Value for result id={self.result_id!r} lost"
Revoked (ResultFailed)
Raised when a task is revoked, either by timing out, or manual revoking.
Source code in mognet/exceptions/result_exceptions.py
task_exceptions
InvalidErrorInfo (BaseModel)
InvalidTaskArguments (Exception)
Raised when the arguments to a task could not be validated.
Source code in mognet/exceptions/task_exceptions.py
class InvalidTaskArguments(Exception):
"""
Raised when the arguments to a task could not be validated.
"""
def __init__(self, errors: List[InvalidErrorInfo]) -> None:
super().__init__(errors)
self.errors = errors
@classmethod
def from_validation_error(cls, validation_error: ValidationError):
return cls([InvalidErrorInfo.model_validate(e) for e in validation_error.errors()])
Pause (Exception)
Tasks may raise this when they want to stop execution and have their message return to the Task Broker.
Once the message is retrieved again, task execution will resume.
too_many_retries
TooManyRetries (MognetError)
Raised when a task is retried too many times due to unforeseen errors (e.g., SIGKILL).
The number of retries for any particular task can be configured through the App's
configuration, in max_retries
.
Source code in mognet/exceptions/too_many_retries.py
class TooManyRetries(MognetError):
"""
Raised when a task is retried too many times due to unforeseen errors (e.g., SIGKILL).
The number of retries for any particular task can be configured through the App's
configuration, in `max_retries`.
"""
def __init__(
self,
request_id: UUID,
actual_retries: int,
max_retries: int,
) -> None:
super().__init__(request_id, actual_retries, max_retries)
self.request_id = request_id
self.max_retries = max_retries
self.actual_retries = actual_retries
def __str__(self) -> str:
return f"Task id={self.request_id!r} has been retried {self.actual_retries!r} times, which is more than the limit of {self.max_retries!r}"
middleware
special
middleware
Middleware (Protocol)
Defines middleware that can hook into different parts of a Mognet App's lifecycle.
Source code in mognet/middleware/middleware.py
class Middleware(Protocol):
"""
Defines middleware that can hook into different parts of a Mognet App's lifecycle.
"""
async def on_app_starting(self, app: "App") -> None:
"""
Called when the app is starting, but before it starts connecting to the backends.
For example, you can use this for some early initialization of singleton-type objects in your app.
"""
async def on_app_started(self, app: "App") -> None:
"""
Called when the app has started.
For example, you can use this for some early initialization of singleton-type objects in your app.
"""
async def on_app_stopping(self, app: "App") -> None:
"""
Called when the app is preparing to stop, but before it starts disconnecting.
For example, you can use this for cleaning up objects that were previously set up.
"""
async def on_app_stopped(self, app: "App") -> None:
"""
Called when the app has stopped.
For example, you can use this for cleaning up objects that were previously set up.
"""
async def on_task_starting(self, context: "Context"):
"""
Called when a task is starting.
You can use this, for example, to track a task on a database.
"""
async def on_task_completed(
self, result: "Result", context: Optional["Context"] = None
):
"""
Called when a task has completed it's execution.
You can use this, for example, to track a task on a database.
"""
async def on_request_submitting(
self, request: "Request", context: Optional["Context"] = None
):
"""
Called when a Request object is going to be submitted to the Broker.
You can use this, for example, both to track the task on a database, or to modify
the Request object (e.g., to modify arguments, or set metadata).
"""
async def on_running_task_count_changed(self, running_task_count: int):
"""
Called when the Worker's task count changes.
This can be used to determine when the Worker has nothing to do.
"""
on_app_started(self, app)
async
Called when the app has started.
For example, you can use this for some early initialization of singleton-type objects in your app.
on_app_starting(self, app)
async
Called when the app is starting, but before it starts connecting to the backends.
For example, you can use this for some early initialization of singleton-type objects in your app.
on_app_stopped(self, app)
async
Called when the app has stopped.
For example, you can use this for cleaning up objects that were previously set up.
on_app_stopping(self, app)
async
Called when the app is preparing to stop, but before it starts disconnecting.
For example, you can use this for cleaning up objects that were previously set up.
on_request_submitting(self, request, context=None)
async
Called when a Request object is going to be submitted to the Broker.
You can use this, for example, both to track the task on a database, or to modify the Request object (e.g., to modify arguments, or set metadata).
Source code in mognet/middleware/middleware.py
async def on_request_submitting(
self, request: "Request", context: Optional["Context"] = None
):
"""
Called when a Request object is going to be submitted to the Broker.
You can use this, for example, both to track the task on a database, or to modify
the Request object (e.g., to modify arguments, or set metadata).
"""
on_running_task_count_changed(self, running_task_count)
async
Called when the Worker's task count changes.
This can be used to determine when the Worker has nothing to do.
on_task_completed(self, result, context=None)
async
Called when a task has completed it's execution.
You can use this, for example, to track a task on a database.
on_task_starting(self, context)
async
model
special
result
Result (BaseModel)
Represents the result of executing a Request
.
It contains, along the return value (or raised exception), information on the resulting state, how many times it started, and timing information.
Source code in mognet/model/result.py
class Result(BaseModel):
"""
Represents the result of executing a [`Request`][mognet.Request].
It contains, along the return value (or raised exception),
information on the resulting state, how many times it started,
and timing information.
"""
id: UUID
name: Optional[str] = None
state: ResultState = ResultState.PENDING
number_of_starts: int = 0
number_of_stops: int = 0
parent_id: Optional[UUID] = None
created: Optional[datetime] = None
started: Optional[datetime] = None
finished: Optional[datetime] = None
node_id: Optional[str] = None
request_kwargs_repr: Optional[str] = None
_backend: "BaseResultBackend" = PrivateAttr()
_children: Optional[ResultChildren] = PrivateAttr()
_value: Optional[ResultValue] = PrivateAttr()
def __init__(self, backend: "BaseResultBackend", **data) -> None:
super().__init__(**data)
self._backend = backend
self._children = None
self._value = None
@property
def children(self) -> ResultChildren:
"""Get an iterator on the children of this Result. Non-recursive."""
if self._children is None:
self._children = ResultChildren(self, self._backend)
return self._children
@property
def value(self) -> ResultValue:
"""Get information about the value of this Result"""
if self._value is None:
self._value = ResultValue(self, self._backend)
return self._value
@property
def duration(self) -> Optional[timedelta]:
"""
Returns the time it took to complete this result.
Returns None if the result did not start or finish.
"""
if not self.started or not self.finished:
return None
return self.finished - self.started
@property
def queue_time(self) -> Optional[timedelta]:
"""
Returns the time it took to start the task associated to this result.
Returns None if the task did not start.
"""
if not self.created or not self.started:
return None
return self.started - self.created
@property
def done(self):
"""
True if the result is in a terminal state (e.g., SUCCESS, FAILURE).
See `READY_STATES`.
"""
return self.state in READY_STATES
@property
def successful(self):
"""True if the result was successful."""
return self.state in SUCCESS_STATES
@property
def failed(self):
"""True if the result failed or was revoked."""
return self.state in ERROR_STATES
@property
def revoked(self):
"""True if the result was revoked."""
return self.state == ResultState.REVOKED
@property
def unexpected_retry_count(self) -> int:
"""
Return the number of times the task associated with this result was retried
as a result of an unexpected error, such as a SIGKILL.
"""
return max(0, self.number_of_starts - self.number_of_stops)
async def wait(self, *, timeout: Optional[float] = None, poll: float = 0.1) -> None:
"""Wait for the task associated with this result to finish."""
updated_result = await self._backend.wait(self.id, timeout=timeout, poll=poll)
await self._refresh(updated_result)
async def revoke(self) -> "Result":
"""
Revoke this Result.
This shouldn't be called directly, use the method on the App class instead,
as that will also revoke the children, recursively.
"""
self.state = ResultState.REVOKED
self.number_of_stops += 1
self.finished = now_utc()
await self._backend.set(self.id, self)
return self
async def get(self) -> Any:
"""
Gets the value of this `Result` instance.
Raises `ResultNotReady` if it's not ready yet.
Raises any stored exception if the result failed
Returns the stored value otherwise.
Use `value.get_raw_value()` if you want access to the raw value.
Call `wait` to wait for the value to be available.
Optionally, `await` the result instance.
"""
if not self.done:
raise ResultNotReady()
value = await self.value.get_raw_value()
if self.state == ResultState.REVOKED:
raise Revoked(self)
if self.failed:
if value is None:
value = ResultFailed(self)
# Re-hydrate exceptions.
if isinstance(value, _ExceptionInfo):
raise value.exception
if not isinstance(value, BaseException):
value = Exception(value)
raise value
return value
async def set_result(
self,
value: Any,
state: ResultState = ResultState.SUCCESS,
) -> "Result":
"""
Set this Result to a success state, and store the value
which will be return when one `get()`s this Result's value.
"""
await self.value.set_raw_value(value)
self.finished = now_utc()
self.state = state
self.number_of_stops += 1
await self._update()
return self
async def set_error(
self,
exc: BaseException,
state: ResultState = ResultState.FAILURE,
) -> "Result":
"""
Set this Result to an error state, and store the exception
which will be raised if one attempts to `get()` this Result's
value.
"""
_log.debug("Setting result id=%r to %r", self.id, state)
await self.value.set_raw_value(exc)
self.finished = now_utc()
self.state = state
self.number_of_stops += 1
await self._update()
return self
async def start(self, *, node_id: Optional[str] = None) -> "Result":
"""
Sets this `Result` as RUNNING, and logs the event.
"""
self.started = now_utc()
self.node_id = node_id
self.state = ResultState.RUNNING
self.number_of_starts += 1
await self._update()
return self
async def resume(self, *, node_id: Optional[str] = None) -> "Result":
if node_id is not None:
self.node_id = node_id
self.state = ResultState.RUNNING
self.number_of_starts += 1
await self._update()
return self
async def suspend(self) -> "Result":
"""
Sets this `Result` as SUSPENDED, and logs the event.
"""
self.state = ResultState.SUSPENDED
self.number_of_stops += 1
await self._update()
return self
async def tree(self, max_depth: int = 3, max_width: int = 500) -> "ResultTree":
"""
Gets the tree of this result.
:param max_depth: The maximum depth of the tree that's to be generated.
This filters out results whose recursion levels are greater than it.
"""
from .result_tree import ResultTree
async def get_tree(result: Result, depth=1):
_log.debug(
"Getting tree of result id=%r, depth=%r max_depth=%r",
result.id,
depth,
max_depth,
)
node = ResultTree(result=result, children=[])
if depth >= max_depth and (await result.children.count()):
_log.warning(
"Result id=%r has %r or more levels of children, which is more than the limit of %r. Results will be truncated",
result.id,
depth,
max_depth,
)
return node
children_count = await result.children.count()
if children_count > max_width:
_log.warning(
"Result id=%r has %r children, which is more than the limit of %r. Results will be truncated",
result.id,
children_count,
max_width,
)
async for child in result.children.iter_instances(count=max_width):
node.children.append(await get_tree(child, depth=depth + 1))
node.children.sort(key=lambda r: r.result.created or now_utc())
return node
return await get_tree(self, depth=1)
async def get_metadata(self) -> Dict[str, Any]:
"""Get the metadata associated with this Result."""
return await self._backend.get_metadata(self.id)
async def set_metadata(self, **kwargs: Any) -> None:
"""Set metadata on this Result."""
await self._backend.set_metadata(self.id, **kwargs)
async def _refresh(self, updated_result: Optional["Result"] = None):
updated_result = updated_result or await self._backend.get(self.id)
if updated_result is None:
raise RuntimeError("Result no longer present")
for k, v in updated_result.__dict__.items():
if k == "id":
continue
setattr(self, k, v)
async def _update(self):
await self._backend.set(self.id, self)
def __repr__(self):
v = f"Result[{self.name or 'unknown'}, id={self.id!r}, state={self.state!r}]"
if self.request_kwargs_repr is not None:
v += f"({self.request_kwargs_repr})"
return v
# Implemented for asyncio's `await` functionality.
def __hash__(self) -> int:
return hash(f"Result_{self.id}")
def __await__(self):
yield from self.wait().__await__()
value = yield from self.get().__await__()
return value
async def delete(self, include_children: bool = True):
"""
Delete this Result from the backend.
By default, this will delete children too.
"""
await self._backend.delete(self.id, include_children=include_children)
async def set_ttl(self, ttl: timedelta, include_children: bool = True):
"""
Set TTL on this Result.
By default, this will set it on the children too.
"""
await self._backend.set_ttl(self.id, ttl, include_children=include_children)
children: ResultChildren
property
readonly
Get an iterator on the children of this Result. Non-recursive.
done
property
readonly
True if the result is in a terminal state (e.g., SUCCESS, FAILURE).
See READY_STATES
.
duration: Optional[datetime.timedelta]
property
readonly
Returns the time it took to complete this result.
Returns None if the result did not start or finish.
failed
property
readonly
True if the result failed or was revoked.
queue_time: Optional[datetime.timedelta]
property
readonly
Returns the time it took to start the task associated to this result.
Returns None if the task did not start.
revoked
property
readonly
True if the result was revoked.
successful
property
readonly
True if the result was successful.
unexpected_retry_count: int
property
readonly
Return the number of times the task associated with this result was retried as a result of an unexpected error, such as a SIGKILL.
value: ResultValue
property
readonly
Get information about the value of this Result
__hash__(self)
special
delete(self, include_children=True)
async
Delete this Result from the backend.
By default, this will delete children too.
get(self)
async
Gets the value of this Result
instance.
Raises ResultNotReady
if it's not ready yet.
Raises any stored exception if the result failed
Returns the stored value otherwise.
Use value.get_raw_value()
if you want access to the raw value.
Call wait
to wait for the value to be available.
Optionally, await
the result instance.
Source code in mognet/model/result.py
async def get(self) -> Any:
"""
Gets the value of this `Result` instance.
Raises `ResultNotReady` if it's not ready yet.
Raises any stored exception if the result failed
Returns the stored value otherwise.
Use `value.get_raw_value()` if you want access to the raw value.
Call `wait` to wait for the value to be available.
Optionally, `await` the result instance.
"""
if not self.done:
raise ResultNotReady()
value = await self.value.get_raw_value()
if self.state == ResultState.REVOKED:
raise Revoked(self)
if self.failed:
if value is None:
value = ResultFailed(self)
# Re-hydrate exceptions.
if isinstance(value, _ExceptionInfo):
raise value.exception
if not isinstance(value, BaseException):
value = Exception(value)
raise value
return value
get_metadata(self)
async
model_post_init(/, self, context)
This function is meant to behave like a BaseModel method to initialise private attributes.
It takes context as an argument since that's what pydantic-core passes when calling it.
Parameters:
Name | Type | Description | Default |
---|---|---|---|
self |
BaseModel |
The BaseModel instance. |
required |
context |
Any |
The context. |
required |
Source code in mognet/model/result.py
def init_private_attributes(self: BaseModel, context: Any, /) -> None:
"""This function is meant to behave like a BaseModel method to initialise private attributes.
It takes context as an argument since that's what pydantic-core passes when calling it.
Args:
self: The BaseModel instance.
context: The context.
"""
if getattr(self, '__pydantic_private__', None) is None:
pydantic_private = {}
for name, private_attr in self.__private_attributes__.items():
default = private_attr.get_default()
if default is not PydanticUndefined:
pydantic_private[name] = default
object_setattr(self, '__pydantic_private__', pydantic_private)
revoke(self)
async
Revoke this Result.
This shouldn't be called directly, use the method on the App class instead, as that will also revoke the children, recursively.
Source code in mognet/model/result.py
async def revoke(self) -> "Result":
"""
Revoke this Result.
This shouldn't be called directly, use the method on the App class instead,
as that will also revoke the children, recursively.
"""
self.state = ResultState.REVOKED
self.number_of_stops += 1
self.finished = now_utc()
await self._backend.set(self.id, self)
return self
set_error(self, exc, state='FAILURE')
async
Set this Result to an error state, and store the exception
which will be raised if one attempts to get()
this Result's
value.
Source code in mognet/model/result.py
async def set_error(
self,
exc: BaseException,
state: ResultState = ResultState.FAILURE,
) -> "Result":
"""
Set this Result to an error state, and store the exception
which will be raised if one attempts to `get()` this Result's
value.
"""
_log.debug("Setting result id=%r to %r", self.id, state)
await self.value.set_raw_value(exc)
self.finished = now_utc()
self.state = state
self.number_of_stops += 1
await self._update()
return self
set_metadata(self, **kwargs)
async
set_result(self, value, state='SUCCESS')
async
Set this Result to a success state, and store the value
which will be return when one get()
s this Result's value.
Source code in mognet/model/result.py
async def set_result(
self,
value: Any,
state: ResultState = ResultState.SUCCESS,
) -> "Result":
"""
Set this Result to a success state, and store the value
which will be return when one `get()`s this Result's value.
"""
await self.value.set_raw_value(value)
self.finished = now_utc()
self.state = state
self.number_of_stops += 1
await self._update()
return self
set_ttl(self, ttl, include_children=True)
async
Set TTL on this Result.
By default, this will set it on the children too.
start(self, *, node_id=None)
async
Sets this Result
as RUNNING, and logs the event.
Source code in mognet/model/result.py
suspend(self)
async
tree(self, max_depth=3, max_width=500)
async
Gets the tree of this result.
:param max_depth: The maximum depth of the tree that's to be generated. This filters out results whose recursion levels are greater than it.
Source code in mognet/model/result.py
async def tree(self, max_depth: int = 3, max_width: int = 500) -> "ResultTree":
"""
Gets the tree of this result.
:param max_depth: The maximum depth of the tree that's to be generated.
This filters out results whose recursion levels are greater than it.
"""
from .result_tree import ResultTree
async def get_tree(result: Result, depth=1):
_log.debug(
"Getting tree of result id=%r, depth=%r max_depth=%r",
result.id,
depth,
max_depth,
)
node = ResultTree(result=result, children=[])
if depth >= max_depth and (await result.children.count()):
_log.warning(
"Result id=%r has %r or more levels of children, which is more than the limit of %r. Results will be truncated",
result.id,
depth,
max_depth,
)
return node
children_count = await result.children.count()
if children_count > max_width:
_log.warning(
"Result id=%r has %r children, which is more than the limit of %r. Results will be truncated",
result.id,
children_count,
max_width,
)
async for child in result.children.iter_instances(count=max_width):
node.children.append(await get_tree(child, depth=depth + 1))
node.children.sort(key=lambda r: r.result.created or now_utc())
return node
return await get_tree(self, depth=1)
wait(self, *, timeout=None, poll=0.1)
async
Wait for the task associated with this result to finish.
Source code in mognet/model/result.py
ResultChildren
The children of a Result.
Source code in mognet/model/result.py
class ResultChildren:
"""The children of a Result."""
def __init__(self, result: "Result", backend: "BaseResultBackend") -> None:
self._result = result
self._backend = backend
async def count(self) -> int:
"""The number of children."""
return await self._backend.get_children_count(self._result.id)
def iter_ids(self, *, count: Optional[int] = None) -> AsyncGenerator[UUID, None]:
"""Iterate the IDs of the children, optionally limited to a set count."""
return self._backend.iterate_children_ids(self._result.id, count=count)
def iter_instances(
self, *, count: Optional[int] = None
) -> AsyncGenerator["Result", None]:
"""Iterate the instances of the children, optionally limited to a set count."""
return self._backend.iterate_children(self._result.id, count=count)
async def add(self, *children_ids: UUID):
"""For internal use."""
await self._backend.add_children(self._result.id, *children_ids)
add(self, *children_ids)
async
count(self)
async
iter_ids(self, *, count=None)
Iterate the IDs of the children, optionally limited to a set count.
iter_instances(self, *, count=None)
Iterate the instances of the children, optionally limited to a set count.
ResultValue
Represents information about the value of a Result.
Source code in mognet/model/result.py
class ResultValue:
"""
Represents information about the value of a Result.
"""
def __init__(self, result: "Result", backend: "BaseResultBackend") -> None:
self._result = result
self._backend = backend
self._value_holder: Optional[ResultValueHolder] = None
async def get_value_holder(self) -> ResultValueHolder:
if self._value_holder is None:
self._value_holder = await self._backend.get_value(self._result.id)
return self._value_holder
async def get_raw_value(self) -> Any:
"""Get the value. In case this is an exception, it won't be raised."""
holder = await self.get_value_holder()
return holder.deserialize()
async def set_raw_value(self, value: Any):
if isinstance(value, BaseException):
value = _ExceptionInfo.from_exception(value)
holder = ResultValueHolder(raw_value=value, value_type=_serialize_name(value))
await self._backend.set_value(self._result.id, holder)
get_raw_value(self)
async
ResultValueHolder (BaseModel)
Holds information about the type of the Result's value, and the raw value itself.
Use deserialize()
to parse the value according to the type.
Source code in mognet/model/result.py
class ResultValueHolder(BaseModel):
"""
Holds information about the type of the Result's value, and the raw value itself.
Use `deserialize()` to parse the value according to the type.
"""
value_type: str
raw_value: Any
def deserialize(self) -> Any:
if self.raw_value is None:
return None
if self.value_type is not None:
cls = _get_attr(self.value_type)
value = TypeAdapter(cls).validate_python(self.raw_value)
else:
value = self.raw_value
return value
@classmethod
def not_ready(cls):
"""
Creates a value holder which is not ready yet.
"""
value = _ExceptionInfo.from_exception(ResultNotReady())
return cls(value_type=_serialize_name(value), raw_value=value)
not_ready()
classmethod
result_state
ResultState (str, Enum)
States that a task execution, and its result, can be in.
Source code in mognet/model/result_state.py
class ResultState(str, Enum):
"""
States that a task execution, and its result, can be in.
"""
# The task associated with this result has not yet started.
PENDING = "PENDING"
# The task associated with this result is currently running.
RUNNING = "RUNNING"
# The task associated with this result was suspended, either because
# it yielded subtasks, or because the worker it was on was shut down
# gracefully.
SUSPENDED = "SUSPENDED"
# The task associated with this result finished successfully.
SUCCESS = "SUCCESS"
# The task associated with this result failed.
FAILURE = "FAILURE"
# The task associated with this result was aborted.
REVOKED = "REVOKED"
# Invalid task
INVALID = "INVALID"
def __repr__(self):
return f"{self.name!r}"
primitives
special
request
Request (BaseModel, Generic)
Represents the Mognet request.
Source code in mognet/primitives/request.py
class Request(BaseModel, Generic[TReturn]):
"""
Represents the Mognet request.
"""
id: UUID = Field(default_factory=uuid4)
name: str
args: tuple = ()
kwargs: Dict[str, Any] = Field(default_factory=dict)
stack: List[UUID] = Field(default_factory=list)
# Metadata that's going to be put in the Result associated
# with this Request.
metadata: Dict[str, Any] = Field(default_factory=dict)
# Deadline to run this request.
# If it's a datetime, the deadline will be computed based on the difference
# to `datetime.now(tz=timezone.utc)`. If the deadline is already passed, the task
# is discarded and marked as `REVOKED`.
# If it's a timedelta, the task's coroutine will be given that long to run, after which
# it will be cancelled and marked as `REVOKED`.
# Note that, like with manual revoking, there is no guarantee that the timed out task will
# actually stop running.
deadline: Optional[Union[timedelta, datetime]] = None
# Overrides the queue the message will be sent to.
queue_name: Optional[str] = None
# Allow setting a kwargs representation for debugging purposes.
# This is stored on the corresponding Result.
# If not set, it's set when the request is submitted.
# Note that if there are arguments which contain sensitive data, this will leak their values,
# so you are responsible for ensuring such values are censored.
kwargs_repr: Optional[str] = None
# Task priority. The higher the value, the higher the priority.
priority: Priority = 5
def __repr__(self):
msg = f"{self.name}[id={self.id!r}]"
if self.kwargs_repr is not None:
msg += f"({self.kwargs_repr})"
return msg
service
special
class_service
ClassService (Generic)
Base class for object-based services retrieved through Context#get_service()
To get instances of a class-based service using Context#get_service(), pass the class itself, and an instance of the class will be returned.
Note that the instances are singletons.
The instances, when created, get access to the app's configuration.
Source code in mognet/service/class_service.py
class ClassService(Generic[_TReturn], metaclass=ABCMeta):
"""
Base class for object-based services retrieved
through Context#get_service()
To get instances of a class-based service using
Context#get_service(), pass the class itself,
and an instance of the class will be returned.
Note that the instances are *singletons*.
The instances, when created, get access to the app's configuration.
"""
def __init__(self, config: "AppConfig") -> None:
self.config = config
@abstractmethod
def __call__(self, context: "Context", *args, **kwds) -> _TReturn:
raise NotImplementedError
def __enter__(self):
return self
def __exit__(self, exc_type, exc_val, exc_tb):
self.close()
def close(self):
pass
async def wait_closed(self):
pass
state
special
state
State
Represents state that can persist across task restarts.
Has facilities for getting, setting, and removing values.
The task's state is deleted when the task finishes.
Source code in mognet/state/state.py
class State:
"""
Represents state that can persist across task restarts.
Has facilities for getting, setting, and removing values.
The task's state is deleted when the task finishes.
"""
request_id: UUID
def __init__(self, app: "App", request_id: UUID) -> None:
self._app = app
self.request_id = request_id
@property
def _backend(self):
return self._app.state_backend
async def get(self, key: str, default: Any = None) -> Any:
"""Get a value."""
return await self._backend.get(self.request_id, key, default)
async def set(self, key: str, value: Any):
"""Set a value."""
return await self._backend.set(self.request_id, key, value)
async def pop(self, key: str, default: Any = None) -> Any:
"""Delete a value from the state and return it's value."""
return await self._backend.pop(self.request_id, key, default)
async def clear(self):
"""Clear all values."""
return await self._backend.clear(self.request_id)
clear(self)
async
get(self, key, default=None)
async
pop(self, key, default=None)
async
state_backend_config
RedisStateBackendSettings (BaseModel)
Configuration for the Redis State Backend
Source code in mognet/state/state_backend_config.py
class RedisStateBackendSettings(BaseModel):
"""Configuration for the Redis State Backend"""
url: str = "redis://localhost:6379/"
# How long each task's state should live for.
state_ttl: int = 7200
# Set the limit of connections on the Redis connection pool.
# DANGER! Setting this to too low a value WILL cause issues opening connections!
max_connections: Optional[int] = None
testing
special
pytest_integration
create_app_fixture(app)
Create a Pytest fixture for a Mognet application.
Source code in mognet/testing/pytest_integration.py
def create_app_fixture(app: App):
"""Create a Pytest fixture for a Mognet application."""
@pytest_asyncio.fixture
async def app_fixture():
async with app:
start_task = asyncio.create_task(app.start())
yield app
await app.close()
try:
start_task.cancel()
await start_task
except BaseException: # pylint: disable=broad-except
pass
return app_fixture
tools
special
backports
special
aioitertools
Backport of https://github.com/RedRoserade/aioitertools/blob/f86552753e626cb71a3a305b9ec890f97d771e6b/aioitertools/asyncio.py#L93
Should be upstreamed here: https://github.com/omnilib/aioitertools/pull/103
as_generated(iterables, *, return_exceptions=False)
Yield results from one or more async iterables, in the order they are produced.
Like :func:as_completed
, but for async iterators or generators instead of futures.
Creates a separate task to drain each iterable, and a single queue for results.
If return_exceptions
is False
, then any exception will be raised, and
pending iterables and tasks will be cancelled, and async generators will be closed.
If return_exceptions
is True
, any exceptions will be yielded as results,
and execution will continue until all iterables have been fully consumed.
Example::
async def generator(x):
for i in range(x):
yield i
gen1 = generator(10)
gen2 = generator(12)
async for value in as_generated([gen1, gen2]):
... # intermixed values yielded from gen1 and gen2
Source code in mognet/tools/backports/aioitertools.py
async def as_generated(
iterables: Iterable[AsyncIterable[T]],
*,
return_exceptions: bool = False,
) -> AsyncIterable[T]:
"""
Yield results from one or more async iterables, in the order they are produced.
Like :func:`as_completed`, but for async iterators or generators instead of futures.
Creates a separate task to drain each iterable, and a single queue for results.
If ``return_exceptions`` is ``False``, then any exception will be raised, and
pending iterables and tasks will be cancelled, and async generators will be closed.
If ``return_exceptions`` is ``True``, any exceptions will be yielded as results,
and execution will continue until all iterables have been fully consumed.
Example::
async def generator(x):
for i in range(x):
yield i
gen1 = generator(10)
gen2 = generator(12)
async for value in as_generated([gen1, gen2]):
... # intermixed values yielded from gen1 and gen2
"""
queue: asyncio.Queue[dict] = asyncio.Queue()
tailer_count: int = 0
async def tailer(iterable: AsyncIterable[T]) -> None:
nonlocal tailer_count
try:
async for item in iterable:
await queue.put({"value": item})
except asyncio.CancelledError:
if isinstance(iterable, AsyncGenerator): # pragma:nocover
with suppress(Exception):
await iterable.aclose()
raise
except Exception as exc: # pylint: disable=broad-except
await queue.put({"exception": exc})
finally:
tailer_count -= 1
if tailer_count == 0:
await queue.put({"done": True})
tasks = [asyncio.ensure_future(tailer(iter)) for iter in iterables]
if not tasks:
# Nothing to do
return
tailer_count = len(tasks)
try:
while True:
i = await queue.get()
if "value" in i:
yield i["value"]
elif "exception" in i:
if return_exceptions:
yield i["exception"]
else:
raise i["exception"]
elif "done" in i:
break
except (asyncio.CancelledError, GeneratorExit):
pass
finally:
for task in tasks:
if not task.done():
task.cancel()
for task in tasks:
with suppress(asyncio.CancelledError):
await task
kwargs_repr
format_kwargs_repr(args, kwargs, *, value_max_length=64)
Utility function to create an args + kwargs representation.
Source code in mognet/tools/kwargs_repr.py
def format_kwargs_repr(
args: tuple,
kwargs: dict,
*,
value_max_length: Optional[int] = 64,
) -> str:
"""Utility function to create an args + kwargs representation."""
parts = []
for arg in args:
parts.append(_format_value(arg, max_length=value_max_length))
for arg_name, arg_value in kwargs.items():
parts.append(
f"{arg_name}={_format_value(arg_value, max_length=value_max_length)}"
)
return ", ".join(parts)
retries
retryableasyncmethod(types, *, max_attempts, wait_timeout, lock=None, on_retry=None)
Decorator to wrap an async method and make it retryable.
Source code in mognet/tools/retries.py
def retryableasyncmethod(
types: Tuple[Type[BaseException], ...],
*,
max_attempts: Union[int, str],
wait_timeout: Union[float, str],
lock: Union[asyncio.Lock, str] = None,
on_retry: Union[Callable[[BaseException], Awaitable], str] = None,
):
"""
Decorator to wrap an async method and make it retryable.
"""
def make_retryable(func: _T) -> _T:
if inspect.isasyncgenfunction(func):
raise TypeError("Async generator functions are not supported")
f: Any = cast(Any, func)
@wraps(f)
async def async_retryable_decorator(self, *args, **kwargs):
last_exc = None
retry = _noop
if isinstance(on_retry, str):
retry = getattr(self, on_retry)
elif callable(on_retry):
retry = on_retry
attempts: int
if isinstance(max_attempts, str):
attempts = getattr(self, max_attempts)
else:
attempts = max_attempts
timeout: float
if isinstance(wait_timeout, str):
timeout = getattr(self, wait_timeout)
else:
timeout = wait_timeout
retry_lock = None
if isinstance(lock, str):
retry_lock = getattr(self, lock)
elif lock is not None:
retry_lock = lock
# Use an exponential backoff, starting with 1s
# and with a maximum of whatever was configured
current_wait_timeout = min(1, timeout)
for attempt in range(1, attempts + 1):
try:
return await f(self, *args, **kwargs)
except types as exc:
_log.error("Attempt %r/%r failed", attempt, attempts, exc_info=exc)
last_exc = exc
_log.debug("Waiting %.2fs before next attempt", current_wait_timeout)
await asyncio.sleep(current_wait_timeout)
current_wait_timeout = min(current_wait_timeout * 2, timeout)
if retry_lock is not None:
if retry_lock.locked():
_log.debug("Already retrying, possibly on another method")
else:
async with retry_lock:
_log.debug("Calling retry method")
await retry(last_exc)
else:
await retry(last_exc)
if last_exc is None:
last_exc = Exception("All %r attempts failed" % attempts)
raise last_exc
return cast(_T, async_retryable_decorator)
return make_retryable
worker
special
worker
MessageCancellationAction (str, Enum)
Worker
Workers are responsible for running the fetch -> run -> store result loop, for the task queues that are configured.
Source code in mognet/worker/worker.py
class Worker:
"""
Workers are responsible for running the fetch -> run -> store result
loop, for the task queues that are configured.
"""
running_tasks: Dict[UUID, "_RequestProcessorHolder"]
# Set of tasks that are suspended
_waiting_tasks: Set[UUID]
app: "App"
def __init__(
self,
*,
app: "App",
middleware: List["Middleware"] = None,
) -> None:
self.app = app
self.running_tasks = {}
self._waiting_tasks = set()
self._middleware = middleware or []
self._current_prefetch = 1
self._queue_consumption_tasks: List[AsyncGenerator] = []
self._consume_task = None
async def run(self):
_log.debug("Starting worker")
try:
self.app.broker.add_connection_failed_callback(self._handle_connection_lost)
await self.start_consuming()
except asyncio.CancelledError:
_log.debug("Stopping run")
return
except Exception as exc: # pylint: disable=broad-except
_log.error("Error during consumption", exc_info=exc)
async def _handle_connection_lost(self, exc: BaseException = None):
_log.error("Handling connection lost event, stopping all tasks", exc_info=exc)
# No point in NACKing, because we have been disconnected
await self._cancel_all_tasks(message_action=MessageCancellationAction.NOTHING)
async def _cancel_all_tasks(self, *, message_action: MessageCancellationAction):
all_req_ids = list(self.running_tasks)
_log.debug("Cancelling all %r running tasks", len(all_req_ids))
try:
for req_id in all_req_ids:
await self.cancel(req_id, message_action=message_action)
await self._adjust_prefetch()
finally:
self._waiting_tasks.clear()
async def stop_consuming(self):
_log.debug("Closing queue consumption tasks")
consumers = self._queue_consumption_tasks
while consumers:
consumer = consumers.pop(0)
try:
await asyncio.wait_for(consumer.aclose(), 5)
except (asyncio.CancelledError, GeneratorExit, asyncio.TimeoutError):
pass
except Exception as consume_exc: # pylint: disable=broad-except
_log.debug("Error closing consumer", exc_info=consume_exc)
consume_task = self._consume_task
self._consume_task = None
if consume_task is not None:
_log.debug("Closing aggregation task")
try:
consume_task.cancel()
await asyncio.wait_for(consume_task, 15)
_log.debug("Closed consumption task")
except (asyncio.CancelledError, asyncio.TimeoutError):
pass
except Exception as consume_err: # pylint: disable=broad-except
_log.error("Error shutting down consumer task", exc_info=consume_err)
async def close(self):
"""
Stops execution, cancelling all running tasks.
"""
_log.debug("Closing worker")
await self.stop_consuming()
# Cancel and NACK all messages currently on this worker.
await self._cancel_all_tasks(message_action=MessageCancellationAction.NACK)
_log.debug("Closed worker")
def _remove_running_task(self, req_id: UUID):
fut = self.running_tasks.pop(req_id, None)
asyncio.create_task(self._emit_running_task_count_change())
return fut
def _add_running_task(self, req_id: UUID, holder: "_RequestProcessorHolder"):
self.running_tasks[req_id] = holder
asyncio.create_task(self._emit_running_task_count_change())
async def cancel(self, req_id: UUID, *, message_action: MessageCancellationAction):
"""
Cancels, if any, the execution of a request.
Whoever calls this method is responsible for updating the result on the backend
accordingly.
"""
fut = self._remove_running_task(req_id)
if fut is None:
_log.debug("Request id=%r is not running on this worker", req_id)
return
_log.info("Cancelling task %r", req_id)
result = await self.app.result_backend.get_or_create(req_id)
# Only suspend the result on the backend if it was running in our node.
if (
result.state == ResultState.RUNNING
and result.node_id == self.app.config.node_id
):
await asyncio.shield(result.suspend())
_log.debug("Result for task %r suspended", req_id)
_log.debug("Waiting for coroutine of task %r to finish", req_id)
# Wait for the task to finish, this allows it to clean up.
try:
await asyncio.wait_for(fut.cancel(message_action=message_action), 15)
except asyncio.TimeoutError:
_log.warning(
"Handler for task id=%r took longer than 15s to shut down", req_id
)
except Exception as exc: # pylint: disable=broad-except
_log.error(
"Handler for task=%r failed while shutting down", req_id, exc_info=exc
)
_log.debug("Stopped handler of task id=%r", req_id)
def _create_context(self, request: "Request") -> "Context":
if not self.app.state_backend:
raise RuntimeError("No state backend defined")
return Context(
self.app,
request,
State(self.app, request.id),
self,
)
async def _run_request(self, req: Request) -> None:
"""
Processes a request, validating it before running.
"""
_log.debug("Received request %r", req)
# Check if we're not trying to (re) start something which is already done,
# for cases when a request is cancelled before it's started.
# Even worse, check that we're not trying to start a request whose
# result might have been evicted.
result = await self.app.result_backend.get(req.id)
if result is None:
_log.error(
"Attempting to run task %r, but it's result doesn't exist on the backend. Discarding",
req,
)
await self.remove_suspended_task(req.id)
return
context = self._create_context(req)
if result.done:
_log.error(
"Attempting to re-run task %r, when it's already done with state %r. Discarding",
req,
result.state,
)
return await asyncio.shield(self._on_complete(context, result))
# Check if we should even start, because:
# 1. We might be in a crash loop (if the process gets killed without cleanup, the numbers won't match),
# 2. Infinite recursion
# 3. We might be part of a parent request that was revoked (or doesn't exist)
# 4. We might be too late.
# 1. Too many starts
retry_count = result.unexpected_retry_count
if retry_count > 0:
_log.warning(
"Task %r has been retried %r times (max=%r)",
req.id,
retry_count,
self.app.config.max_retries,
)
if retry_count > self.app.config.max_retries:
_log.error(
"Discarding task %r because it has exceeded the maximum retry count of %r",
req,
self.app.config.max_retries,
)
result = await result.set_error(
TooManyRetries(req.id, retry_count, self.app.config.max_retries)
)
return await asyncio.shield(self._on_complete(context, result))
if req.stack:
# 2. Recursion
if len(req.stack) > self.app.config.max_recursion:
result = await result.set_error(RecursionError())
return await asyncio.shield(self._on_complete(context, result))
# 3. Parent task(s) aborted (or doesn't exist)
for parent_id in reversed(req.stack):
parent_result = await self.app.result_backend.get(parent_id)
if parent_result is None:
result = await result.set_error(
Exception(f"Parent request id={parent_id} does not exist"),
state=ResultState.REVOKED,
)
return await asyncio.shield(self._on_complete(context, result))
if parent_result.state == ResultState.REVOKED:
result = await result.set_error(
Exception(f"Parent request id={parent_result.id} was revoked"),
state=ResultState.REVOKED,
)
return await asyncio.shield(self._on_complete(context, result))
# 4. Request arrived past the deadline.
if isinstance(req.deadline, datetime):
# One cannot compare naive and aware datetime,
# so create equivalent datetime objects.
now = datetime.now(tz=req.deadline.tzinfo)
if req.deadline < now:
_log.error(
"Request %r arrived too late. Deadline is %r, current date is %r. Marking it as REVOKED and discarding",
req,
req.deadline,
now,
)
result = await asyncio.shield(
result.set_error(asyncio.TimeoutError(), state=ResultState.REVOKED)
)
return await asyncio.shield(self._on_complete(context, result))
# Get the function for the task. Fail if the task is not registered in
# our app's context.
try:
task_function = self.app.task_registry.get_task_function(req.name)
except UnknownTask as unknown_task:
_log.error(
"Request %r is for an unknown task: %r",
req,
req.name,
exc_info=unknown_task,
)
result = await result.set_error(unknown_task, state=ResultState.INVALID)
return await asyncio.shield(self._on_complete(context, result))
# Mark this as running.
await result.start(node_id=self.app.config.node_id)
await asyncio.shield(self._on_starting(context))
try:
# Create a validated version of the function.
# This not only does argument validation, but it also parses the values
# into objects.
validated_func = validate_call(
task_function, config=_TaskFuncArgumentValidationConfig
)
if inspect.iscoroutinefunction(task_function):
fut = validated_func(context, *req.args, **req.kwargs)
else:
_log.debug(
"Handler for task %r is not a coroutine function, running in the loop's default executor",
req.name,
)
# Run non-coroutine functions inside an executor.
# This allows them to run without blocking the event loop
# (providing the GIL does not block it either)
fut = self.app.loop.run_in_executor(
None,
functools.partial(validated_func, context, *req.args, **req.kwargs),
)
except V2ValidationError as exc:
_log.error(
"Could not call task function %r because of a validation error",
task_function,
exc_info=exc,
)
invalid = InvalidTaskArguments.from_validation_error(exc)
result = await asyncio.shield(
result.set_error(invalid, state=ResultState.INVALID)
)
return await asyncio.shield(self._on_complete(context, result))
if req.deadline is not None:
if isinstance(req.deadline, datetime):
# One cannot compare naive and aware datetime,
# so create equivalent datetime objects.
now = datetime.now(tz=req.deadline.tzinfo)
timeout = (req.deadline - now).total_seconds()
else:
timeout = req.deadline.total_seconds()
_log.debug("Applying %.2fs timeout to request %r", timeout, req)
fut = asyncio.wait_for(fut, timeout=timeout)
# Start executing.
try:
value = await fut
if req.id in self.running_tasks:
await asyncio.shield(result.set_result(value))
_log.info(
"Request %r finished with status %r in %.2fs",
req,
result.state,
(result.duration or timedelta()).total_seconds(),
)
await asyncio.shield(self._on_complete(context, result))
except Pause:
_log.info(
"Handler for %r requested to be paused. Suspending it on the Result Backend and NACKing the message",
req,
)
holder = self.running_tasks.pop(req.id, None)
if holder is not None:
await asyncio.shield(result.suspend())
await asyncio.shield(holder.message.nack())
except asyncio.CancelledError:
_log.debug("Handler for task %r cancelled", req)
# Re-raise the cancellation, this will be caught in the parent function
# and prevent ack/nack
raise
except V2ValidationError as exc: # will enter for sync functions here.
_log.error(
"Could not call task function %r because of a validation error",
task_function,
exc_info=exc,
)
invalid = InvalidTaskArguments.from_validation_error(exc)
result = await asyncio.shield(
result.set_error(invalid, state=ResultState.INVALID)
)
return await asyncio.shield(self._on_complete(context, result))
except Exception as exc: # pylint: disable=broad-except
state = ResultState.FAILURE
# The task's coroutine may raise `asyncio.TimeoutError` itself, so there's
# no guarantee that the timeout we catch is actually related to the request's timeout.
# So, this heuristic is not the best.
# TODO: A way to improve it would be to double-check if the deadline itself is expired.
if req.deadline is not None and isinstance(exc, asyncio.TimeoutError):
state = ResultState.REVOKED
if req.id in self.running_tasks:
result = await asyncio.shield(result.set_error(exc, state=state))
await asyncio.shield(self._on_complete(context, result))
duration = result.duration
if duration is not None:
_log.error(
"Handler for task %r failed in %.2fs with state %r",
req,
duration.total_seconds(),
state,
exc_info=exc,
)
else:
_log.error(
"Handler for task %r failed with state %r",
req,
state,
exc_info=exc,
)
async def _on_complete(self, context: "Context", result: Result):
if result.done:
await context.state.clear()
await self.remove_suspended_task(context.request.id)
for middleware in self._middleware:
try:
_log.debug("Calling 'on_task_completed' middleware: %r", middleware)
await asyncio.shield(
middleware.on_task_completed(result, context=context)
)
except Exception as mw_exc: # pylint: disable=broad-except
_log.error("Middleware %r failed", middleware, exc_info=mw_exc)
async def _on_starting(self, context: "Context"):
_log.info("Starting task %r", context.request)
for middleware in self._middleware:
try:
_log.debug("Calling 'on_task_starting' middleware: %r", middleware)
await asyncio.shield(middleware.on_task_starting(context))
except Exception as mw_exc: # pylint: disable=broad-except
_log.error("Middleware %r failed", middleware, exc_info=mw_exc)
def _process_request_message(self, payload: IncomingMessagePayload) -> asyncio.Task:
"""
Creates an asyncio.Task which will process the enclosed Request
in the background.
Returns said task, after adding completion handlers to it.
"""
_log.debug("Parsing input of message id=%r as Request", payload.id)
req = Request.model_validate(payload.payload)
async def request_processor():
try:
await self._run_request(req)
_log.debug("ACK message id=%r for request=%r", payload.id, req)
await asyncio.shield(payload.ack())
except asyncio.CancelledError:
_log.debug("Cancelled execution of request=%r", req)
return
except Exception as exc: # pylint: disable=broad-except
_log.error(
"Fatal error processing request=%r, NAK message id=%r",
req,
payload.id,
exc_info=exc,
)
await asyncio.shield(payload.nack())
def on_processing_done(fut: Future):
self._remove_running_task(req.id)
exc = fut.exception()
if exc is not None and not fut.cancelled():
_log.error("Fatal error processing %r", req, exc_info=exc)
else:
_log.debug("Processed %r successfully", req)
task = asyncio.create_task(request_processor())
task.add_done_callback(on_processing_done)
holder = _RequestProcessorHolder(payload, req, task)
self._add_running_task(req.id, holder)
return task
def start_consuming(self):
if self._consume_task is not None:
return self._consume_task
self._consume_task = asyncio.create_task(self._start_consuming())
return self._consume_task
async def _start_consuming(self):
queues = self.app.get_task_queue_names()
_log.info("Going to consume %r queues", len(queues))
try:
await self._adjust_prefetch()
for queue in queues:
_log.info("Start consuming task queue=%r", queue)
self._queue_consumption_tasks.append(
self.app.broker.consume_tasks(queue)
)
async for payload in as_generated(self._queue_consumption_tasks):
try:
if payload.kind == "Request":
self._process_request_message(payload)
else:
raise ValueError(f"Unknown kind={payload.kind!r}")
except asyncio.CancelledError:
break
except Exception as exc: # pylint: disable=broad-except
_log.error(
"Error processing message=%r, discarding it",
payload,
exc_info=exc,
)
await asyncio.shield(payload.ack())
finally:
_log.debug("Stopped consuming task queues")
async def add_waiting_task(self, task_id: UUID):
self._waiting_tasks.add(task_id)
await self._adjust_prefetch()
async def remove_suspended_task(self, task_id: UUID):
try:
self._waiting_tasks.remove(task_id)
except KeyError:
pass
await self._adjust_prefetch()
@property
def waiting_task_count(self):
return len(self._waiting_tasks)
async def _emit_running_task_count_change(self):
for middleware in self._middleware:
try:
_log.debug("Calling 'on_running_task_count_changed' on %r", middleware)
await middleware.on_running_task_count_changed(len(self.running_tasks))
except Exception as mw_exc: # pylint: disable=broad-except
_log.error(
"'on_running_task_count_changed' failed on %r",
middleware,
exc_info=mw_exc,
)
async def _adjust_prefetch(self):
if self._consume_task is None:
_log.debug("Not adjusting prefetch because not consuming the queue")
return
minimum_prefetch = self.app.config.minimum_concurrency
prefetch = self.waiting_task_count + minimum_prefetch
max_prefetch = self.app.config.maximum_concurrency
if max_prefetch is not None and prefetch >= max_prefetch:
_log.error(
"Maximum prefetch value of %r reached! No more tasks will be fetched on this node",
max_prefetch,
)
prefetch = max_prefetch
if prefetch == self._current_prefetch:
_log.debug(
"Current prefetch is the same as the new prefetch (%r), not adjusting it",
prefetch,
)
return
_log.debug(
"Currently have %r tasks suspended waiting for others. Setting prefetch=%r from previous=%r",
self.waiting_task_count,
prefetch,
self._current_prefetch,
)
self._current_prefetch = prefetch
await self.app.broker.set_task_prefetch(self._current_prefetch)
cancel(self, req_id, *, message_action)
async
Cancels, if any, the execution of a request. Whoever calls this method is responsible for updating the result on the backend accordingly.
Source code in mognet/worker/worker.py
async def cancel(self, req_id: UUID, *, message_action: MessageCancellationAction):
"""
Cancels, if any, the execution of a request.
Whoever calls this method is responsible for updating the result on the backend
accordingly.
"""
fut = self._remove_running_task(req_id)
if fut is None:
_log.debug("Request id=%r is not running on this worker", req_id)
return
_log.info("Cancelling task %r", req_id)
result = await self.app.result_backend.get_or_create(req_id)
# Only suspend the result on the backend if it was running in our node.
if (
result.state == ResultState.RUNNING
and result.node_id == self.app.config.node_id
):
await asyncio.shield(result.suspend())
_log.debug("Result for task %r suspended", req_id)
_log.debug("Waiting for coroutine of task %r to finish", req_id)
# Wait for the task to finish, this allows it to clean up.
try:
await asyncio.wait_for(fut.cancel(message_action=message_action), 15)
except asyncio.TimeoutError:
_log.warning(
"Handler for task id=%r took longer than 15s to shut down", req_id
)
except Exception as exc: # pylint: disable=broad-except
_log.error(
"Handler for task=%r failed while shutting down", req_id, exc_info=exc
)
_log.debug("Stopped handler of task id=%r", req_id)
close(self)
async
Stops execution, cancelling all running tasks.