LC 2698: Optimized | Pruning | Casting Out Nines Principle

February 16, 2025

Find the problem description here. If you would like to read the solution in leetcode platform, find it here.

Intuition

Okay, putting this here as there is no proper prior explanation that I can find about why this principle works here and also because I spent a lot of time in understanding this and coming up with a solution.

So, basically remaining everything in my solution is usual backtracking (i did that with numbers instead of string, which gave a good boost) except the optimizing trick at the backtracking function initialization loop.

So, why does the numbers that satisfy the given condition always should end up having 0 or 1 as the remainder when multiplied by 9?

There is a principle called "Casting out nines principle" where it proves that to get the modulo 9 of any number, you just need to iteratively make the sum up to get a single digit.

Example: 8976 modulo 9 = 8 + 9 + 7 + 6 = 3 + 0 = 3 which is indeed correct: 8976 = 997 * 9 + '3'.

to know more about this: read here.

Also, you can observe one thing here, by this principle you can say that:

any partition sum of a number, like 8 + 9 + 7 + 6 = 30 or 8 + 97 + 6 = 111 or 89 + 7 + 6 = 102 or any will have the same modulo 9.

now, this implies: =>Sn2mod9=> S ≡ n^2\,mod\,9 where SS is any partition sum of n2n^2 and nn is the current choice. now, to satisfy the given condition in the question, SS should be equal to nn. Hence,
=>nn2mod9=> n ≡ n^2\,mod\,9
=>n2nmod9=> n^2 ≡ n\,mod\,9
=>n2n0mod9=> n^2 - n ≡ 0\,mod\,9
=>n(n1)0mod9=> n(n - 1) ≡ 0\,mod\,9

the first two values that satisfy the above condition are 0 and 1. And there after: 9, 10, 18, 19, etc.

Hence, you can further prune your choices to these numbers and continue the next parition backtracking.


Code

class Solution {
public:
    bool partition(int remaining, int target, int curr) {
        // remaining exhausted, check if target is reached
        if (remaining == 0) return curr == target;
        int tens = 10;
        while (true) {
            // get the trailing remainder
            int rem = remaining % tens;
            // current exceeded the target hence, stopping here
            if (curr + rem > target) return false;
            if (partition(remaining / tens, target, curr + rem)) {
                return true;
            }
            // no further partitions can be made
            if (remaining / tens == 0)    break;
            tens *= 10;
        }

        return false;
    }

    int punishmentNumber(int n) {
        int result = 0;
        for (int i = 1; i <= n; ++i) {
            // pruning based on casting out nines principle.
            // fantastic observation and mind-numbing
            if (i % 9 == 0 || i % 9 == 1) {
                if (partition(i * i, i, 0)) {
                    result += i * i;
                }
            }
        }

        return result;
    }
};

good luck!