Как модифицировать член класса в multiprocessing.Pool.map?

Рейтинг: 1Ответов: 1Опубликовано: 08.02.2023

Я пытаюсь распараллелить множественные вычисления численными методами. Есть функция, которая принимает скалярные значения x1, x2 и вычисляет значение функции в этой точке. Значение функции необходимо найти в множестве точек. Поэтому я попытался разбить эти вычисления по ядрам.

Попытка №1. Создана функция calcSingle, которая рассчитывает значение функции в заданной точке и сохраняет его в массиве - члене класса. Создана функция calc, которая формирует сетку исходных данных и запускает многоядерный расчёт с помощью multiprocessing.Pool.map.

class Solver(object):
    def __init__(self):
        self.mesh = np.zeros(0)

    def calcSingle(self, *l):
        (index, num), = l
        self.mesh[index] = num

    def calc(self):
        self.x1, self.x2 = np.meshgrid(np.arange(1, 5), np.arange(5, 15))
        self.mesh = np.zeros(self.x1.shape)
        p = Pool(4)
        p.map(self.calcSingle, zip(np.ndindex(self.mesh.shape), itertools.count(1)))

if __name__ == '__main__':
    solver = Solver()
    solver.calc()
    print(solver.mesh)

Этот код выводит массив нулей. Вероятно, когда функция calcSingle вызывается в отдельном процессе, self ссылается уже на какой-то другой объект. Возможно, это происходит потому, что multiprocessing.Pool повторно импортирует текущий модуль для каждого процесса. Тогда я стал искать информацию о передаче аргументов в функцию по ссылке. Такого, как оказалось, в Питоне нет, но есть рекомендации, что можно заворачивать требуемый аргумент в класс с одним атрибутом. По сути, класс Solver и есть таковой класс, так почему бы не передать его в функцию calcSingle?

Попытка №2.

class Solver(object):
    def __init__(self):
        self.mesh = np.zeros(0)

    def calcSingle(self, *l):
        (index, num, obj), = l
        obj.mesh[index] = num

    def calc(self):
        self.x1, self.x2 = np.meshgrid(np.arange(1, 5), np.arange(5, 15))
        self.mesh = np.zeros(self.x1.shape)
        p = Pool(4)
        p.map(self.calcSingle, zip(np.ndindex(self.mesh.shape), itertools.count(1), itertools.repeat(self)))

if __name__ == '__main__':
    solver = Solver()
    solver.calc()
    print(solver.mesh)

Этот код порождает неконтролируемое количество процессов. Почему - не понимаю.

Конечно, можно заметить, что функция map возвращает list значений, возвращённых целевой функцией, и из него можно потом создать массив требуемой структуры. Но интересует именно модификация уже имеющегося массива, так как он может быть и memmap.

Ответы

▲ 0

Массивы numpy нельзя сделать разделяемыми между процессами с возможностью записи. Вместо этого можно создать разделяемый C-массив требуемого размера и каждый раз в новом процессе создавать numpy-массив, указав ему в качестве буфера данных разделяемый массив. Тогда данные, которые программа будет записывать в numpy-массив, окажутся в разделяемом C-массиве, и исходный процесс сможет их увидеть.

import multiprocessing
from multiprocessing import Pool, RawArray
from ctypes import c_double
import numpy as np
import itertools

class Solver(object):
    def __init__(self):
        self.mesh = np.zeros(0)

    @staticmethod
    def processInit(arr, shape):   # Не может быть приватным
        global ndarray_shared
        ndarray_shared = np.ndarray(shape, dtype=float, buffer=arr)

    @staticmethod
    def calcSingle(*l):     # Не может быть приватным
        (index, num), = l
        ndarray_shared[index] = num

    def calc(self):
        x1, x2 = np.meshgrid(np.arange(1, 5), np.arange(5, 15))
        # Создаём C-массив в разделяемой памяти
        self.shared = RawArray(c_double, x1.size)
        # И на его основе создаём numpy-массив
        self.mesh_shared = np.ndarray(x2.shape, dtype=float, buffer=self.shared)

        with Pool(processes=multiprocessing.cpu_count(), initializer=Solver.processInit, initargs=(self.shared, self.mesh_shared.shape)) as p:
            p.map(Solver.calcSingle, zip(np.ndindex(x1.shape), itertools.count(1)))

if __name__ == '__main__':
    solver = Solver()
    solver.calc()
    print(solver.mesh_shared)