Почему поток выполнения прерывается на await?

Рейтинг: -2Ответов: 1Опубликовано: 15.06.2023

Вот мой код:

import asyncio
import logging
from abc import ABC, abstractmethod
from collections.abc import Awaitable, Callable
from typing import Any, AsyncIterator
from queue import Queue, Empty
from asyncio import Task

logging.basicConfig(level=logging.DEBUG)
logger = logging.getLogger(__name__)


def concrete_progress_callback(process):
    # UI, log or something update
    logger.info(f'Iteration complete with: {process}...')

def concrete_done_callback():
    # result and errors handling
    logger.info('All iterations complete!')

class InvalidOperationStateError(RuntimeError):
    """Raised when an `Operation` enters an invalid state."""
    pass

class StopExcecuteOperation(RuntimeError):
    """Signal the end from `Operation`._execute()."""
    pass

class OperationState(ABC):
    def set_context(self, context: 'Operation') -> None:
        self._context = context
    
    def __str__(self) -> str:
        return self.__class__.__name__
    
    @abstractmethod
    def step_execute(self) -> Awaitable[None]:
        """Process `Operation`."""

    @abstractmethod
    def run(self) -> Awaitable[None]:
        """Starts the `Operation`."""
        pass
       
    @abstractmethod 
    def suspend(self) -> Awaitable[None]:
        """Suspends the `Operation`."""
        pass

    @abstractmethod
    def resume(self) -> Awaitable[None]:
        """Resumes the `Operation`."""
        pass
        
    @abstractmethod
    def cancel(self) -> Awaitable[None]:
        """Cancels the `Operation`."""
        pass

class BaseOperationState(OperationState):
    """Implements state methods that raise exceptions."""

    def step_execute(self) -> Awaitable[None]:
        raise InvalidOperationStateError(f'Operation cannot be executed from {self} state.')
        
    def run(self) -> Awaitable[None]:
        raise InvalidOperationStateError(f'Operation cannot be runned from {self} state.')
    
    def suspend(self) -> Awaitable[None]:
        raise InvalidOperationStateError(f'Operation cannot be suspended from {self} state.')
    
    def resume(self) -> Awaitable[None]:
        raise InvalidOperationStateError(f'Operation cannot be resumed from {self} state.')
    
    def cancel(self) -> Awaitable[None]:
        raise InvalidOperationStateError(f'Operation cannot be canceled from {self} state.')


class Idle(BaseOperationState):
    """`Operation` not started."""

    def run(self) -> None:
        self._context._transition_to(Runned())
        logger.debug('Operation is runned.')
        self._context._step_execute()
        self._context._set_target(asyncio.create_task(self._context._execute()))

class Runned(BaseOperationState):
    """`Operation` in progress."""

    async def step_execute(self) -> None:
        try:
            target = await (anext(self._context._async_iterator))
        except StopAsyncIteration:
            self._context._done()
            self._context._transition_to(Done())
            logger.debug('Operation is done.')
        else:
            self._context._step_execute()
            self._context._progress(target)

    async def suspend(self) -> None:
        self._context._transition_to(Suspended())
        logger.debug('Operation is suspended.')
        raise StopExcecuteOperation('Operation was suspended.')
    
    async def cancel(self) -> None:
        self._context._transition_to(Canceled())
        logger.debug('Operation is canceled.')
        self._context._flush()
        raise StopExcecuteOperation('Operation was canceled.')

class Suspended(BaseOperationState):
    """`Operation` is suspends."""
    
    async def resume(self) -> None:
        self._context._transition_to(Runned())
        logger.debug('Operation is resumed.')
        self._context._set_target(asyncio.create_task(self._context._execute()))
    
    async def cancel(self) -> None:
        self._context._transition_to(Canceled())
        logger.debug('Operation is canceled.')
        self._context._flush()
        raise StopExcecuteOperation('Operation was canceled.')

class Canceled(BaseOperationState):
    """`Operation` is canceled.
    
    Terminate state.
    """
    pass

class Done(BaseOperationState):
    """`Operation` is done.
    
    Terminate state.
    """
    pass

class Error(BaseOperationState):
    """`Operation` is invalid state.
    
    Terminate state.
    """
    pass

class Operation():
    def __init__(self,
                async_iterator: AsyncIterator[int],
                progress_callback: Callable[[Any], None],
                done_callback: Callable[[], None]
        ) -> None:
        self._async_iterator = async_iterator
        self._progress = progress_callback
        self._done = done_callback
        self._state = Idle()
        self._state.set_context(self)
        self._execute_queue = Queue(256)
        self._target = None

    def _set_target(self, new_target: Awaitable[None]) -> None:
        self._target = new_target

    def _transition_to(self, new_state: OperationState) -> None:
        """Change current state"""
        old_state = self._state
        try:
            self._state = new_state
            self._state.set_context(self)
        except Exception as err:
            if isinstance(err, InvalidOperationStateError):
                logger.exception(f'Transition of operation {self} from {old_state} state to {new_state} state is invalid.', err)
            self._flush()
            self._state = Error()
            self._state.set_context(self)
            raise
    
    async def _execute(self):
        """Run a consumer to process task queue immediately."""
        while True:
            try:
                task = self._execute_queue.get(block=False)
            except Empty:
                # all tasks complite
                break
            else:
                try:
                    await task()
                except StopExcecuteOperation:
                    break
                finally:
                    self._execute_queue.task_done()

    def _step_execute(self) -> None:
        """Plants the execution step of `Operation`."""

        self._execute_queue.put(self._state.step_execute)
    
    def _flush(self) -> None:
        """Clear task queue immediately."""
        try:
            while True:
                _ = self._execute_queue.get(block=False)
                self._execute_queue.task_done()
        except Empty:
            pass
    
    def run(self) -> None:
        """Starts `Operation` immediately."""
        self._state.run()
        
    def suspend(self) -> None:
        """Plans to suspend `Operation`."""
        self._execute_queue.put(self._state.suspend)

    def resume(self) -> None:
        """Plans to resume `Operation`."""
        self._execute_queue.put(self._state.resume)
        
    def cancel(self) -> None:
        """Plans to cancel `Operation`."""
        self._execute_queue.put(self._state.cancel)
    
    def __await__(self):

        return self._target.__await__() # type: ignore

        # if not isinstance(self._state, Runned):
        #     old_state = self._state
        #     self._state = Error()
        #     self._state.set_context(self)
        #     raise InvalidOperationStateError(f'Operation cannot be awaited from {old_state} state.')

        # if not self._target is None:
        #     return self._target.__await__()
        # else:
        #     raise RuntimeError()

class ConcreteAsyncGenerator:
    def __init__(self, start, stop):
        self.current = start
        self.start = start
        self.stop = stop

    def __aiter__(self) -> AsyncIterator[int]:
        return self

    async def __anext__(self):
        if self.current < self.stop:
            self.current += 1
            #await asyncio.sleep(1)
            return self.current
        else:
            raise StopAsyncIteration


async def main():
    operation = Operation(async_iterator=ConcreteAsyncGenerator(0, 500),
                    progress_callback=concrete_progress_callback,
                    done_callback=concrete_done_callback)
    
    operation.run()
    #await asyncio.sleep(1)
    operation.suspend()
    #await asyncio.sleep(2)
    operation.resume()
    #await asyncio.sleep(1)
    #operation.cancel()

    await operation # type: ignore

    print('done')

asyncio.run(main())
  1. Почему операция не продолжается после resume()? Я же ожидаю её на 271 строке.
  2. Как можно исправить код?
  3. Что вообще можно улучшить в классе Operation?

Воистину, краткость - сестра таланта. Я помещаю здесь код, основанный на комментарии @andreymal

Если он не оформит код как ответ, я отмечу это в качестве ответа.

import asyncio
import logging
from enum import Enum
from typing import Any, Callable, Generator, Generic, Optional, TypeVar
from collections.abc import AsyncGenerator, AsyncIterable

logger = logging.getLogger(__name__)

T1 = TypeVar("T1")
T2 = TypeVar("T2")

class OperationState(Enum):
    CANCELED = -1
    DONE = 0
    IDLE = 1
    RUNNING = 2
    SUSPENDED = 3

class InvalidOperationStateError(RuntimeError):
    """Raised when an `Operation` enters an invalid state."""
    pass

class Operation(Generic[T1]):
    """Represents an operation that can be started, paused, and cancelled.
    
    Accepts an `async_iterator` asynchronous iterator and calls `progress_callback` on each iteration.
    
    When the iterator is exhausted, `done_callback` is called."""

    def __init__(
        self,
        async_iterator: AsyncIterable[T1],
        *,
        progress_callback: Optional[Callable[[T1], Any]] = None,
        done_callback: Optional[Callable[[], Any]] = None,
    ):
        self._async_iterator = async_iterator
        self._progress_callback = progress_callback
        self._done_callback = done_callback

        self._state = OperationState.IDLE
        self._resume_event = asyncio.Event()
        self._runner_task: Optional[asyncio.Task[None]] = None

    @property
    def state(self) -> OperationState:
        return self._state
   
    def run(self, force=True) -> None:
        """Starts operation

            Parameters:
                    `force` (`bool`):
                        Ignore the command if the state is not valid for it
            
            Raises:
                    `InvalidOperationStateError` (`RuntimeError`):
                        If operation has already been started once and `force` parameter is `False`.
        """
        if self._state == OperationState.IDLE:
            self._runner_task = asyncio.create_task(self._runner())
            self._state = OperationState.RUNNING
            logger.debug("Operation is runned.")
        elif not force:
            raise InvalidOperationStateError("Operation is already started")

    def suspend(self, force=True) -> None:
        """Suspends operation

            Parameters:
                    `force` (`bool`):
                        Ignore the command if the state is not valid for it
            
            Raises:
                    `InvalidOperationStateError` (`RuntimeError`):
                        If `Оperation` has not been started and `force` parameter is `False`.
        """
        if self._state == OperationState.RUNNING:
            self._state = OperationState.SUSPENDED
            self._resume_event.clear()
            logger.debug("Operation is suspended.")
        elif not force:
            raise InvalidOperationStateError("Operation is not running")

    def resume(self, force=True) -> None:
        """Resumes operation

            Parameters:
                    `force` (`bool`):
                        Ignore the command if the state is not valid for it
            
            Raises:
                    `InvalidOperationStateError` (`RuntimeError`):
                        If operation has not been suspended and `force` parameter is `False`.
        """
        if self._state == OperationState.SUSPENDED:
            self._state = OperationState.RUNNING
            self._resume_event.set()
            logger.debug("Operation is resumed.")
        elif not force:
            raise InvalidOperationStateError("Operation is not suspended")

    def cancel(self, force=True) -> None:
        """Cancels operation

            Parameters:
                    `force` (`bool`):
                        Ignore the command if the state is not valid for it
            
            Raises:
                    `InvalidOperationStateError` (`RuntimeError`):
                        If operation was not running or suspended and `force` parameter is `False`.
        """
        if self._state in (OperationState.RUNNING, OperationState.SUSPENDED):
            if self._runner_task is not None:
                self._runner_task.cancel()
            self._state = OperationState.CANCELED
            logger.debug("Operation is canceled.")
        elif not force:
            raise InvalidOperationStateError("Operation is not running")

    async def wait(self, force=True) -> None:
        """Waiting for the operation to finish

            Parameters:
                    `force` (`bool`):
                        Ignore the command if the state is not valid for it
            
            Raises:
                    `InvalidOperationStateError` (`RuntimeError`):
                        If operation has already been started once and `force` parameter is `False`.
        """
        if self._runner_task is not None:
            await asyncio.wait({self._runner_task})
        elif not force:
            raise InvalidOperationStateError("Operation is not started")

    async def _runner(self) -> None:
        async for target in self._async_iterator:
            if self._state == OperationState.SUSPENDED:
                await self._resume_event.wait()
                self._resume_event.clear()
            if self._progress_callback is not None:
                self._progress_callback(target)

        if self._done_callback is not None:
            self._done_callback()
        
        self._state = OperationState.DONE
        logger.debug("Operation is done.")

    def __await__(self) -> Generator[Any, None, None]:
        return self.wait().__await__()

С этим решением есть проблема. Если подобрать тайминги, оно может выдать такой результат:

DEBUG:asyncio:Using selector: EpollSelector
DEBUG:operating:Operation is runned.
INFO:__main__:Iteration complete with: 0...
INFO:__main__:Iteration complete with: 1...
INFO:__main__:Iteration complete with: 2...
INFO:__main__:Iteration complete with: 3...
INFO:__main__:Iteration complete with: 4...
INFO:__main__:Iteration complete with: 5...
INFO:__main__:Iteration complete with: 6...
INFO:__main__:Iteration complete with: 7...
INFO:__main__:Iteration complete with: 8...
INFO:__main__:Iteration complete with: 9...
DEBUG:operating:Operation is suspended.
INFO:__main__:All iterations complete!
DEBUG:operating:Operation is done.

Потому что вызов функций не потокобезопасен и вызывает состояние гонки. Поэтому я пока не отмечаю это как ответ. Я использовал очередь в исходном коде (см. первый код в вопросе) чтобы избежать этой проблемы.

Ответы

▲ 1Принят

Я учёл замечание @andreymal по поводу активного ожидания в режиме приостановки и переписал код в соответствии с этим замечанием. Теперь исполнение прерывается на ожидание события.

"""Support for operations."""

__all__ = (
    'AsyncOperation', 'AsAsyncOperation',
    'Operation', 'AsOperation',
    'OperationState',
)

import asyncio
import threading
import logging
from enum import Enum
from typing import Any, Callable, Generator, Generic, Iterator, Optional, TypeVar, Union
from collections.abc import AsyncIterator

logger = logging.getLogger(__name__)

T1 = TypeVar("T1")
T2 = TypeVar("T2")
T3 = TypeVar("T3")

class OperationState(Enum):
    CANCELED = -1
    DONE = 0
    IDLE = 1
    RUNNING = 2
    SUSPENDED = 3

class Operation(Generic[T1]):
    """Represents a thread safe operation that can be started, paused, and cancelled.
    
    Accepts an `iterator` and calls `progress_callback` on each iteration.
    
    When the iterator is exhausted, `done_callback` is called."""

    def __init__(
        self,
        iterator: Iterator[T1],
        *,
        progress_callback: Optional[Callable[[T1], Any]] = None,
        done_callback: Optional[Callable[[], Any]] = None,
    ):
        self._iterator = iterator
        self._progress_callback = progress_callback
        self._done_callback = done_callback

        self._state = OperationState.IDLE
        self._lock_condition = threading.Condition()
        self._stop_event = threading.Event()
        self._runner_task: Optional[threading.Thread] = None

    def __del__(self):
        if not self._runner_task is None and self._runner_task.is_alive():
            self._stop_event.set()

    @property
    def state(self) -> OperationState:
        """Get current operation state"""
        
        with self._lock_condition:
            return self._state
   
    def run(self) -> bool:
        """Starts operation
        
        Returns `True` if successful"""
        
        with self._lock_condition:
            if self._state == OperationState.IDLE:
                self._state = OperationState.RUNNING
                logger.debug("Operation is runned.")
                self._runner_task = threading.Thread(target=self._runner)
                self._runner_task.start()
                return True
            else:
                return False

    
    def suspend(self) -> bool:
        """Suspends operation.
        
        Returns `True` if successful."""

        with self._lock_condition:
            if self._state == OperationState.RUNNING:
                self._state = OperationState.SUSPENDED
                logger.debug("Operation is suspended.")
                return True
            else:
                return False

    def resume(self) -> bool:
        """Resumes operation.
        
        Returns `True` if successful."""
        
        with self._lock_condition:
            if self._state == OperationState.SUSPENDED:
                self._state = OperationState.RUNNING
                self._lock_condition.notify()
                logger.debug("Operation is resumed.")
                return True
            else:
                return False

    def cancel(self) -> bool:
        """Cancels operation.
        
        Returns `True` if successful."""
        
        with self._lock_condition:
            if self._state in (OperationState.RUNNING, OperationState.SUSPENDED):
                if self._runner_task is not None and self._runner_task.is_alive():
                    self._stop_event.set()
                else:
                    return False
                return True
            else:
                return False

    def _runner(self) -> None:
        while True:
            with self._lock_condition:
                try:
                    self._lock_condition.wait_for(lambda: self._state == OperationState.RUNNING)
                    if self._stop_event.is_set():
                        self._state = OperationState.CANCELED
                        logger.debug("Operation is canceled.")
                        break
                    target = self._iterator.__next__()
                except StopIteration as terminate:
                    if self._done_callback is not None:
                        self._done_callback()
                    self._state = OperationState.DONE
                    logger.debug("Operation is done.")
                    break
                else:
                    if self._progress_callback is not None:
                        self._progress_callback(target)

    def wait(self) -> bool:
        """Waiting for the operation to finish.
        
        Returns `True` if successful."""

        if self._runner_task is not None:
            self._runner_task.join()
            return True
        else:
            return False
    
    join = wait

class AsyncOperation(Generic[T1]):
    """Represents a asynchronous operation that can be started, paused, and cancelled.
    
    Accepts an `async_iterator` asynchronous iterator and calls `progress_callback` on each iteration.
    
    When the iterator is exhausted, `done_callback` is called."""

    def __init__(
        self,
        async_iterator: AsyncIterator[T1],
        *,
        progress_callback: Optional[Callable[[T1], Any]] = None,
        done_callback: Optional[Callable[[], Any]] = None,
    ):
        self._async_iterator = async_iterator
        self._progress_callback = progress_callback
        self._done_callback = done_callback

        self._state = OperationState.IDLE
        self._async_lock_condition = asyncio.Condition()
        self._runner_task: Optional[asyncio.Task[None]] = None

    def __del__(self):
        if not self._runner_task is None and not self._runner_task.cancelled():
            self._runner_task.cancel()

    @property
    async def state(self) -> OperationState:
        """Get current operation state"""
        
        async with self._async_lock_condition:
            return self._state
   
    async def run(self) -> bool:
        """Starts operation
        
        Returns `True` if successful"""
        
        async with self._async_lock_condition:
            if self._state == OperationState.IDLE:
                self._state = OperationState.RUNNING
                logger.debug("Operation is runned.")
                self._runner_task = asyncio.create_task(self._runner())
                return True
            else:
                return False

    
    async def suspend(self) -> bool:
        """Suspends operation.
        
        Returns `True` if successful."""

        async with self._async_lock_condition:
            if self._state == OperationState.RUNNING:
                self._state = OperationState.SUSPENDED
                logger.debug("Operation is suspended.")
                return True
            else:
                return False

    async def resume(self) -> bool:
        """Resumes operation.
        
        Returns `True` if successful."""
        
        async with self._async_lock_condition:
            if self._state == OperationState.SUSPENDED:
                self._state = OperationState.RUNNING
                self._async_lock_condition.notify()
                logger.debug("Operation is resumed.")
                return True
            else:
                return False

    async def cancel(self) -> bool:
        """Cancels operation.
        
        Returns `True` if successful."""
        
        async with self._async_lock_condition:
            if self._state in (OperationState.RUNNING, OperationState.SUSPENDED):
                if self._runner_task is not None and not self._runner_task.cancelled():
                    self._runner_task.cancel()
                else:
                    return False
                self._state = OperationState.CANCELED
                return True
            else:
                return False

    async def _runner(self) -> None:
        while True:
            async with self._async_lock_condition:
                try:
                    await self._async_lock_condition.wait_for(lambda: self._state == OperationState.RUNNING)
                    target = await self._async_iterator.__anext__()
                except StopAsyncIteration:
                    if self._done_callback is not None:
                        self._done_callback()
                    self._state = OperationState.DONE
                    logger.debug("Operation is done.")
                    break
                else:
                    if self._progress_callback is not None:
                        self._progress_callback(target)

    async def wait(self) -> bool:
        """Waiting for the operation to finish.
        
        Returns `True` if successful."""

        if self._runner_task is not None:
            try:
                await self._runner_task
            except asyncio.CancelledError:
                logger.debug("Operation is canceled.")
                
            return True
        else:
            return False
    
    def __await__(self) -> Generator[Any, None, bool]:
        return self.wait().__await__()

class AsOperation(Generic[T1, T2]):
    """Decorator wrapping an generator into an `Operation`."""

    def __init__(
        self,
        progress_callback: Optional[Callable[[T1], Any]] = None,
        done_callback: Optional[Callable[[], Any]] = None
        ) -> None:
        self._progress_callback = progress_callback
        self.done_callback = done_callback
    
    def __call__(self, iterator: Callable[..., Iterator[T1]]) -> Callable[..., Operation]:
        def GeneratorWrapper(*T2, **T3):
            return Operation(
                iterator=iterator(*T2, **T3),
                progress_callback=self._progress_callback,
                done_callback=self.done_callback)
        return GeneratorWrapper

class AsAsyncOperation(Generic[T1, T2]):
    """Decorator wrapping an asynchronous generator into an `Operation`."""

    def __init__(
        self,
        progress_callback: Optional[Callable[[T1], Any]] = None,
        done_callback: Optional[Callable[[], Any]] = None
        ) -> None:
        self._progress_callback = progress_callback
        self.done_callback = done_callback
    
    def __call__(self, async_iterator: Callable[..., AsyncIterator[T1]]) -> Callable[..., AsyncOperation]:
        def AsyncGeneratorWrapper(*T2, **T3):
            return AsyncOperation(
                async_iterator=async_iterator(*T2, **T3),
                progress_callback=self._progress_callback,
                done_callback=self.done_callback)
        return AsyncGeneratorWrapper

Использование:

import asyncio
import logging
import time
from collections.abc import AsyncIterator
from typing import Iterator

from operating import AsAsyncOperation, AsOperation

logging.basicConfig(level=logging.DEBUG)
logger = logging.getLogger(__name__)


def concrete_progress_callback(process) -> None:
    # UI, log or something update
    logger.info(f"Iteration complete with: {process}...")


def concrete_done_callback() -> None:
    # result and errors handling
    logger.info("All iterations complete!")


@AsAsyncOperation(progress_callback=concrete_progress_callback, done_callback=concrete_done_callback)
class ConcreteAsyncGenerator:
    def __init__(self, start: int, stop: int):
        self.current = start
        self.start = start
        self.stop = stop

    def __aiter__(self) -> AsyncIterator[int]:
        return self

    async def __anext__(self) -> int:
        if self.current < self.stop:
            self.current += 1
            await asyncio.sleep(0.01)
            return self.current
        raise StopAsyncIteration


@AsOperation(progress_callback=concrete_progress_callback, done_callback=concrete_done_callback)
class ConcreteGenerator:
    def __init__(self, start: int, stop: int):
        self.current = start
        self.start = start
        self.stop = stop

    def __iter__(self) -> Iterator[int]:
        return self

    def __next__(self) -> int:
        if self.current < self.stop:
            self.current += 1
            time.sleep(0.01)
            return self.current
        
        return 42
        raise StopAsyncIteration


@AsAsyncOperation(progress_callback=concrete_progress_callback, done_callback=concrete_done_callback)
async def async_squares(start, stop):
    for i in range(start, stop):
        yield i * i
        await asyncio.sleep(0.01)


@AsOperation(progress_callback=concrete_progress_callback, done_callback=concrete_done_callback)
def squares(start, stop):
    for i in range(start, stop):
        yield i * i
        time.sleep(0.01)
    return 'squares'


async def async_main() -> None:
    sqrt_n = 100
    operations = [async_squares(i*sqrt_n, i*sqrt_n + sqrt_n)
                  for i in range(sqrt_n)]
    for operation in operations:
        await operation.run()
    await asyncio.wait(operations)

    operation = async_squares(0, 100)
    await operation.run()
    await asyncio.sleep(1)
    await operation.suspend()
    await asyncio.sleep(3)
    await operation.resume()
    await operation

    operation = ConcreteAsyncGenerator(0, 1000)
    await operation.run()
    await asyncio.sleep(4)
    await operation.suspend()
    await asyncio.sleep(3)
    await operation.resume()
    await asyncio.sleep(1)
    await operation.cancel()
    await operation.wait()


def main() -> None:
    sqrt_n = 100
    operations = [squares(i*sqrt_n, i*sqrt_n + sqrt_n) for i in range(sqrt_n)]
    for operation in operations:
        operation.run()

    for operation in operations:
        operation.wait()

    operation = squares(0, 100)
    operation.run()
    time.sleep(1)
    operation.suspend()
    time.sleep(3)
    operation.resume()
    operation.join()

    operation = ConcreteGenerator(0, 1000)
    operation.run()
    time.sleep(4)
    operation.suspend()
    time.sleep(3)
    operation.resume()
    time.sleep(1)
    operation.cancel()
    operation.wait()


if __name__ == "__main__":
    asyncio.run(async_main())
    main()