Помощь в улучшении алгоритма

Рейтинг: -1Ответов: 2Опубликовано: 16.04.2023

Есть двумерный массив размера N*M (пример N = 4, M = 3)

[1, 0, 0]

[1, 1, 0]

[0, 1, 0]

[0, 0, 2]

нужно получить сумму произведений элементов в массиве длины М удовлетворяющих требованию arr[i][j], arr[k][x], arr[q][p] где i<k<q и j<x<p

в примере удовлетворяют условию группы элементов

arr[0][0],arr[1][1],arr[3][2] 1,1,2 прозведени 2
arr[0][0],arr[2][1],arr[3][2] 1,1,2 прозведени 2
arr[1][0],arr[2][1],arr[3][2] 1,1,2 прозведени 2

я написал программу с временной сложностью O(M * N* M)

def dfs(row,col,arr,p):
    if col == len(arr[0]):
        return p
    s = 0
    while row < len(arr):
        s+=dfs(row+1,col+1,arr,p*arr[row][col])
        row+=1
    return s
dfs(0,0,arr,1)

Но есть чувство что его можно улучшить до O(M * N) с помощью динамического программирования. Как это можно сделать?

Ответы

▲ 3Принят

fi,j - искомая сумма произведений для матрицы [0, i] × [0, j]. Тогда

  • fi,-1 = 1
  • fi,j = 0, если i < j
  • fi,j = ai,j·fi-1,j-1 + fi-1,j - первое слагаемое для всех последовательностей, которые заканчиваются в углу ai,j, второе для тех которые заканчиваются раньше.
import functools


def f(a):

    @functools.cache
    def f(i, j):
        if j == 0:
            return 1
        if i < j:
            return 0
        return a[i][j] * f(i - 1, j - 1) + f(i - 1, j)

    return f(len(a) - 1, len(a[0]) - 1)


print(f([
    [1, 0, 0],
    [1, 1, 0],
    [0, 1, 0],
    [0, 0, 2]
]))

Сложность O(nm): благодаря кешу для каждой пары (i, j) f вычисляется один раз. Сложность одного вызова - константа. Могут быть проблемы с глубиной рекурсии. Тогда сделаем ручное кеширование:

def f(a):

    f = [[None] * len(r) for r in a]

    f[0][0] = a[0][0]
    for j in range(1, len(a[0])):
        f[0][j] = 0

    for i in range(1, len(a)):
        f[i][0] = a[i][0] + f[i - 1][0]
        for j in range(1, len(a[i])):
            f[i][j] = a[i][j] * f[i - 1][j - 1] + f[i - 1][j]

    return f[-1][-1]


print(f([
    [1, 0, 0],
    [1, 1, 0],
    [0, 1, 0],
    [0, 0, 2]
]))

Можно заметить что следующая строка кеша зависит только от предыдущей. Можно обойтись памятью O(m):

def f(a):
    f_row = [1] + [0] * len(a[0])
    for a_row in a:
        for j in reversed(range(len(a_row))):
            f_row[j + 1] += a_row[j] * f_row[j]
    return f_row[-1]


print(f([
    [1, 0, 0],
    [1, 1, 0],
    [0, 1, 0],
    [0, 0, 2]
]))
▲ 1

Я бы решал через DP, то есть

  1. создал бы массив той же размерности
  2. в каждой ячейке массива писал бы сумму всех произведений, что в этой ячейке заканчиваются.

То есть изначальный массив

[1, 0, 0]
[1, 1, 0]
[0, 1, 0]
[0, 0, 2]

DP массив поначалу забитый нулями

[0, 0, 0]
[0, 0, 0]
[0, 0, 0]
[0, 0, 0]

Первый столбец от 0 строки до len-2 строки просто проставляем равным изначальнму (первые шаги путей)

[1, 0, 0]
[1, 0, 0]
[0, 0, 0]
[0, 0, 0]

Второй столбец от 1 строки до len-1 строки - суммируем все возможные произведения для второго столбца

[1, 0, 0]
[1, 1, 0]
[0, 1*1+1*1, 0]
[0, 0, 0]

Третий столбец от 2 строки до len строки - делаем то же самое для третьего столбца (тут важно понимать, что для нахождения сумм третьего столбца нам нужен только второй столбец, первый уже не нужен)

[1, 0, 0]
[1, 1, 0]
[0, 2, 0]
[0, 0, 2*1 + 2*2]

Суммируем третий столбец (там уже все суммы произведений всех путей для каждой ячейки)

[1, 0, 0]
[1, 1, 0]
[0, 2, 0]
[0, 0, 6]

сумма третьего столбца = 6

Вот и всё по сути. Поскольку на каждой итерации мы для каждой ячейки перебираем весь следующий стоблец (N*N), и таких переборов у нас будет столько, сколько столбцов (M), то суммарная сложность алгоритма будет M*N*N(то есть M*(N**2))

Код на C# (я питонов ваших не умею) получился бы примерно такой

int SumOfMult(int[,] data)
{
    int N = data.GetLength(0);
    int M = data.GetLength(1);

    var dp = new int[N, M];

    for (int i = 0; i < N; i++)
    {
        dp[i, 0] = data[i, 0];
        for (int j = 0; j < M - 1; j++)
            for (int k = i + 1; k < N; k++)
                dp[k, j + 1] += data[k, j + 1] * dp[i, j];
    }

    int ret = 0;
    for (int i = 0; i < N; i++) ret += dp[i, M - 1];
    return ret;
}

Сложность вышла MNN, но если чуть чуть подшаманить, можно получить NNN при условии, если M>N

Проверка

Console.WriteLine(SumOfMult(new[,] {
    { 1, 0, 0 },
    { 1, 1, 0 },
    { 0, 1, 0 },
    { 0, 0, 2} }));

Вывод

6