Ошибка в написании алгоритма лексикографической сортировки для списка

Рейтинг: 1Ответов: 2Опубликовано: 11.03.2023

Задача:

Натуральные числа от 1 до N упорядочены лексикографически. Например, для N=25 результат этого упорядочения будет таким: 1, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19, 2, 20, 21, 22, 23, 24, 25, 3, 4, 5, 6, 7, 8, 9.

Требуется написать программу, которая определит, на каком месте оказалось число K.

Входные данные

Входной файл INPUT.TXT содержит два натуральных числа N и K, записанных через пробел (1 ≤ K ≤ N ≤ 10**4).

Выходные данные

Выходной файл OUTPUT.TXT должен содержать одно натуральное число – номер места, на котором оказалось число K.

Например:

Вход: 25 17

Выход: 9

Причина: элемент 17 в списке

[1, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19, 2, 20, 21, 22, 23, 24, 25, 3, 4, 5, 6, 7, 8, 9]

находился под номером 9

Мой код:

def lexicographic_sort(arr: list[int]) -> list[int]:
    if len(arr) < 2:
        return arr
    mid = arr[len(arr) // 2]
    l = [i for num, i in enumerate(arr) if str(i) < str(mid) and num != len(arr) // 2]
    r = [i for num, i in enumerate(arr) if str(i) > str(mid) and num != len(arr) // 2]
    return lexicographic_sort(l) + [mid] + lexicographic_sort(r)


n, k = map(int, input().split())
include_list = list(range(1, n + 1))
arr_lexicographic = lexicographic_sort(include_list)
print(arr_lexicographic.index(k) + 1)

Проблема

Один из тестов на сайте не проходит. Где ошибка и как её исправить, чтобы программа работала корректно? Конечно, проще сделать код так и он пройдёт тесты (проверенно):

n, k = map(int, input().split())
include_list = list(range(1, n + 1))
include_list.sort(key=lambda x: str(x))
print(include_list.index(k) + 1)

Меня интересует, где в моей сортировке ошибка.

Сайт, на котором я тестирую: https://acmp.ru/asp/do/index.asp?main=task&id_course=2&id_section=12&id_topic=8&id_problem=38

Ответы

▲ 2Принят

Ответ

Алгоритм вы написали правильный. Но у быстрой сортировки есть неприятная особенность, с которой вы и познакомились.

n = 1999 переполняет стек:

$ echo 1999 1 | python lex_sort.py
Traceback (most recent call last):
  File "lex_sort.py", line 12, in <module>
    arr_lexicographic = lexicographic_sort(include_list)
  File "lex_sort.py", line 7, in lexicographic_sort
    return lexicographic_sort(l) + [mid] + lexicographic_sort(r)
  File "lex_sort.py", line 7, in lexicographic_sort
    return lexicographic_sort(l) + [mid] + lexicographic_sort(r)
  File "lex_sort.py", line 7, in lexicographic_sort
    return lexicographic_sort(l) + [mid] + lexicographic_sort(r)
  [Previous line repeated 993 more times]
  File "lex_sort.py", line 5, in lexicographic_sort
    l = [i for num, i in enumerate(arr) if str(i) < str(mid) and num != len(arr) // 2]
  File "lex_sort.py", line 5, in <listcomp>
    l = [i for num, i in enumerate(arr) if str(i) < str(mid) and num != len(arr) // 2]
RecursionError: maximum recursion depth exceeded while getting the str of an object

При первом разбиении mid = 1000. Такое разбиение очень неравновесное. Налево уходят только три числа 1, 10, 100, направо все остальные.

Распечатка лога print(len(l), mid, len(r)):

len(l)     mid  len(r)

     3    1000    1995
     1      10       1
     1    1002    1993
     0    1003    1992
     0    1004    1991
  1990     999       0
     0    1005    1989
  1988     998       0
     0    1006    1987
  1986     997       0
     0    1007    1985
  1984     996       0
     0    1008    1983
  1982     995       0
     0    1009    1981
  1980     994       0
     1    1010    1978
     0    1011    1977
  1976     993       0
     0    1012    1975
...

В середине массива скопились минимальные и максимальные элементы - 999 и ниже очень велики, 1000 и выше очень малы. Почти всегда при разбиении одна половина остаётся пустой. Глубина стека быстрой сортировки становится линейной. Стек переполняется – ошибка.

Надо что-то делать

Можно поправить сортировку на вариант в котором стек всегда небольшой. Но что толку? Ошибка пропадёт, но время сортировки станет квадратичным.

В соседнем ответе предложили в качестве пивота выбирать случайный элемент. Это устраняет проблему. В теории переполнение стека всё ещё возможно, на практике должны смениться геологические эпохи, погибнуть Солнце и Вселенная прежде чем это случиться один раз на списке из тысячи элементов.

Встроенные средства решают задачу, вы сами отметили это в вопросе:

n, k = map(int, input().split())
print(sorted(range(n + 1), key=str).index(k))

Но я хочу заглянуть дальше. Можно решить задачу без построения списка и его сортировки. Тогда мы сможем решать задачи поистине астрономического масштаба. Как насчёт n = 10100?

Обозначим f(n, p) - количество чисел начинающихся с префикса p и не превосходящих n.

Например: n = 3017318298, p = 143. Тогда в числа попадут 143, 1430-1439, 14300-14399, 143000-143999, …, 1430000000-1439999999. Их количество f(3017318298, 143) = 1 + 10 + 100 + 1000 + … + 10000000.

Если p является префиксом n, последняя группа не полная. Для n = 3017318298 и p = 301 последняя группа будет 3010000000-3017318298 и общее количество f(3017318298, 301) = 1 + 10 + 100 + 1000 + … + 1000000 + 7318299.

Так или иначе f(n, p) считается за время пропорциональное log2n. И даже это время можно улучшить до log n.

Вычислим позицию k = 483967505 в ряду для n = 3017318298.

До k в ряду идут следующие числа:

цифра группы чисел f(n, p) ∑f(n, p)
4 0…, 1…, 2…, 3…, 4, 0 + 2×1111111111 + 128429410 + 1 2350651633
8 40…, 41…, 42…, 43…, 44…, 45…, 46…, 47…, 48, 8×11111111 + 1 88888889
3 480…, 481…, 482…, 483, 3×1111111 + 1 3333334
9 4830…, 4831…, 4832…, 4833…, 4834…, 4835…, 4836…, 4837…, 4838…, 4839, 9×111111 + 1 1000000
6 48390…, 48391…, 48392…, 48393…, 48394…, 48395…, 48396, 6×11111 + 1 66667
7 483960…, 483961…, 483962…, 483963…, 483964…, 483965…, 483966…, 483967, 7×1111 + 1 7778
5 4839670…, 4839671…, 4839672…, 4839673…, 4839674…, 4839675, 5×111 + 1 556
0 48396750, 1 1
5 483967500…, 483967501…, 483967502…, 483967503…, 483967504…, 483967505. 5×1 + 1 6

Сумма чисел в правом столбце 2443948864. В это количество входят все числа до 483967505 и оно само. То есть, число 483967505 располагается на 2443948864-ом месте.

Код

def f(n, p):
    c = 0
    if p > 0:
        t = 1
        while True:
            if p + t <= n:
                c += t
            else:
                if p <= n:
                    c += n + 1 - p
                break
            p *= 10
            t *= 10
    return c


def main():
    w1, w2 = input().split()
    n = int(w1)

    i = 0
    p = 0
    for d in map(int, w2):
        i += sum(f(n, 10 * p + e) for e in range(d)) + 1
        p = 10 * p + d

    print(i)


main()
$ echo 25 17 | python position.py
9

Одна секунда для n = 10700:

$ time python -c "print(10**700, 10**700 - 1)" | python position.py
1000...000

real  0m0.965s
user  0m1.020s
sys   0m0.008s

Хотя это быстрое решение, можно ещё быстрее. Сложность текущего кода можно грубо оценить как log3n - куб от количества цифр в n. Думаю что можно сделать квадрат, если доработать вычисление f(n, p).

▲ 1

Не знаю почему, но если брать не элемент по середине, а случайный то все тесты проходит:

import random


def lexicographic_sort(arr: list[int]) -> list[int]:
    if len(arr) < 2:
        return arr
    index = random.randrange(len(arr))
    mid = arr[index]
    l = [i for num, i in enumerate(arr) if str(i) < str(mid) and num != index]
    r = [i for num, i in enumerate(arr) if str(i) > str(mid) and num != index]
    return lexicographic_sort(l) + [mid] + lexicographic_sort(r)


n, k = map(int, input().split())
include_list = list(range(1, n + 1))
arr_lexicographic = lexicographic_sort(include_list)
print(arr_lexicographic.index(k) + 1)

Либо же можно брать всегда первый элемент, так тоже прошёл