Почему результат функции на python отличен от результата на C++ и как исправить?

Рейтинг: -6Ответов: 2Опубликовано: 04.05.2023

Функция на Python:

def inv(a, n): 
    if a == 0:
        return 0
    lm, hm = 1, 0
    low, high = a % n, n
    while low > 1:
        r = high//low
        nm, new = hm-lm*r, high-low*r
        lm, low, hm, high = nm, new, lm, low
    return lm % n

Функция на C++:

int inv(int a, int n) {
    if (a == 0) {
        return 0;
    }
    int lm = 1;
    int hm = 0;
    int low = a % n;
    int high = n;
    while (low > 1) {
        int r = high / low;
        int nm = hm - lm * r;
        int nw = high - low * r;
        lm = nm;
        low = nw;
        hm = lm;
        high = low;
    }
    return lm % n;
}

Python: inv(198411, 524287) -> 448444, C++: inv(198411, 524287) -> 0

Как исправить функцию на C++ что бы она работала как на Python?

Ответы

▲ 2Принят

Запись lm, low, hm, high = nm, new, lm, low аналогична

(lm, low, hm, high) = (nm, new, lm, low)

и она кортежу присваивает кортеж. Сначала создаётся один, а потом второй. Так как вы присваиваете последовательно в C++, не используя никаких дополнительных структур, то происходит ошибка логическая. Сначала вы присваиваете

  lm = nm;
  // потом уже другое значение
  hm = lm;

и происходит эта ошибка. Чтобы её исправить, можно добавить запасные переменные с копиями этих значений.

  int oldlm = lm ;
  lm = nm;
  int oldlow = low ;
  low = nw;
  hm = oldlm;
  high = oldlow;

Далее функция взятия остатка от деления немного отличается от C++ варианта при отрицательных переменных. В C++ она возвращает отрицательное число, а в Python - положительное. Это можно исправить, добавив делитель.

  return ( lm + n ) % n;

В итоге получилось такое :

int inv(int a, int n) {
    if (a == 0) {
        return 0;
    }
    int lm = 1;
    int hm = 0;
    int low = (a + n) % n ;
    int high = n;
    while (low > 1) {
        int r = high / low;
        int nm = hm - lm * r;
        int nw = high - low * r;
        int oldlm = lm ;
        lm = nm;
        int oldlow = low ;
        low = nw;
        hm = oldlm;
        high = oldlow;
    }
    return ( lm + n ) % n ;
}
▲ 2
int inv(int a, int n)
{
    if (a < 1 or n < 2)
        return -1;
    int high = n, hm = 0, low = a, lm = 1;
    while (low)
    {
        int q = high / low;
        int u = high - q*low;
        int v = hm - q*lm;
        high = low;
        hm   = lm;
        low  = u;
        lm   = v;
    }
    return high == 1 ? (hm + n) % n : -1;
}