きなこの精進日記[python]

コドフォ Educational Contest 90 E Sum of Digits 解説[python]

ジャンル

整数問題 埋め込み問題

概要

f(x)を各位の和とする。
n,kが与えられる時
f(x)+f(x+1)+...+f(x+k) = n となるような最小の非負整数x を求めよ。

制約

0<=k<=9
1<=n<=150

考察

k が小さい(K<=9) ので1のくらいでの繰り上がりはあったとしても一度のみ
N,Kが共に小さく、150*10 通りを計算して埋め込めそうな気もする

K=0 の時、1,2,3,..9,19,29,...99,199,299....999 とn*9 を超えたら桁を増やしていく

K=1の時、奇数だったら作れそう。偶数の時、繰り上がりが必ず起こる
10 => 9,10 12 => 19,20, ..26=> 89,99 28=>189,199 で作れる

K=2の時、3の倍数なら繰り上がりがない範囲で作る。あまり1,2は繰り上がらないと無理そう
繰り上がる時、繰り上がる前の数の桁は9 でそれが繰り上がると0になってその次の桁が1増えるので結果的に3で割ったあまりは1増えるのと変わらない。したがって、繰り上がりがあっても無理。

K=3 を考えると、4の倍数ならいける。同様に考えると、繰り上がると9が減って1増えるので+- 0 になる。がここからはかなり厳しい。

K>=3 の時、取りうる値を考えると、数字が4つ連続するので、9*4*桁数-6 がその桁数の中で最も大きい値になる。(99..96,99..97,99..98,99..99 の時)
 つまり、N<=150の元での上界を考えると、桁数は5桁以内に収まることがわかる。
 したがって、<10**5 でループを回せば良い

実装

step1: K=0 の時の分岐を作る
step2: K=1の分岐を作る
step3: K=2の時の分岐を作る
step4: K>=3なら、10**5以内で全探索する。
計算量:K>=3の時が計算が重くて、O(t*N) N=10**5 で10**7オーダーになる

のようにK の値に応じて処理を分ける

lim を10**5固定にするとtestcase7 (K=5)でTLEしたのでlim をKに応じて変更するとAC
lim は何桁が最大化のみを見れば良くて、9*(K+1)*digit - (K+1)*K//2 >=150 になる最小のdigit が桁数になる

これで埋め込みなしで630msec

import sys
import math


"""
step1: K=0 の時の分岐を作る
step2: K=1の分岐を作る
step3: K=2の時の分岐を作る
step4: K>=3なら、10**5以内で全探索する。
計算量:K>=3の時が計算が重くて、O(t*N) N=10**5 で10**7オーダーになる

"""
read = sys.stdin.buffer.read
input = sys.stdin.buffer.readline
readlines = sys.stdin.buffer.readlines


def solve_sum_digit(n):
    res = 0
    while n:
        res += n % 10
        n //= 10
    return res


def solve_0(n):
    digit = math.ceil(n / 9)
    res = ""
    if n % 9:
        res += str(n % 9) + "9" * (digit - 1)
    else:
        res += "9" * (digit)
    return res


def solve_1(n):
    # 偶奇で分ける
    if n % 2:  # 奇数なら全て作れる
        if n <= 17:
            return n // 2
        else:
            # 2桁以上なら、最後は8になる
            digit = n // 18 + 1
            num_top = (n % 18) // 2 + 1
            return str(num_top) + "9" * (digit - 2) + "8"
    else:
        # 偶数なら必ず繰り上がりが必要で10以上は作れる
        # 10 => 9,10 12 => 19,20, ..26=> 89,99 28=>189,199 で作れる
        if n < 10:
            return -1
        elif n == 10:
            return 9
        else:
            digit = (n + 8) // 18 + 1
            num_top = ((n + 8) % 18) // 2 + 1
            if digit == 2:
                return str(num_top - 1) + str(9)
            else:
                return str(num_top) + "9" * (digit - 3) + "89"


def solve_2(n):
    if n % 3:
        return -1
    else:
        if n <= 24:
            return n // 3 - 1
        else:
            digit = n // 27 + 1
            num_top = (n % 27) // 3 + 1
            return str(num_top) + "9" * (digit - 2) + "7"


t = int(input())
# t = 1501
for case in range(t):
    n, k = map(int, input().split())
    # n = case // 10 + 1
    # k = case % 10
    if k == 0:
        ans = solve_0(n)
    elif k == 1:
        ans = solve_1(n)
    elif k == 2:
        ans = solve_2(n)
    else:
        pass_flag = 0
        if k == 3:
            lim = 10 ** 5
        elif k <= 5:
            lim = 10 ** 4
        else:
            lim = 10 ** 3
        for x in range(0, lim):
            sum_digit = 0
            for i in range(k + 1):
                sum_digit += solve_sum_digit(x + i)
            if sum_digit == n:
                # print(x)
                ans = x
                pass_flag = 1
                break
        if pass_flag == 0:
            # print(-1)
            ans = -1
    print(ans)