题目分析

本题要求的是至多选$m$次元素时能够得到的最大总价值,且对于第$i$个元素,每选择它一次,它的价值就会减少$decay[i]$,即从$value[i]$变成$value[i]-decay[i]$。

基本思路

大顶堆模拟

最直接的思路就是模拟选取元素,如果对于每一个位置都将其对应的二元组$(value[i], decay[i])$存入一个大顶堆中,那么在一次选取元素的过程中,直接取出堆顶元素即可最大化当前选取的价值。假设当前弹出的堆顶元素是$(v_i, d_i)$,那么根据题目要求,「选择元素$i$后其价值会减少$decay[i]$」,则需要将更新价值后的该元素,即二元组$(v_i-d_i, d_i)$再次压入堆中。

虽然这是一个暴力模拟方法,但依然可以加入一点优化,即在模拟$m$次选取的$while$循环中,如果某次选择元素时,整个堆中的最大值都小于0了,即可直接退出循环,因为这时再加下去也没有意义了,只会让总和变得更小。

根据上述思路,即可简单写出暴力模拟的代码如下:

class Solution:
    def maxTotalValue(self, value: list[int], decay: list[int], m: int) -> int:
        n = len(value)
        hp = []
        for i in range(n):
            heappush(hp, (-value[i], decay[i]))
        
        res = 0
        modNum = 1_000_000_007
        while m:
            m -= 1
            cur, d = heappop(hp)
            cur = -cur
            if cur < 0:
                break
            res = (res+cur)%modNum
            heappush(hp, (-(cur-d), d))
        return res

确实很简单,但实测只能通过221 / 561个测试用例,因此需要整个优化。

二分查找优化

如果说在整个数组中选至多$m$个元素比较困难,那么可以倒过来想,由于每一个值在选择的过程中它会变得越来越小,因此如果确定了一个最小值的界限,实际上就可以确定一个元素能够被选择的次数。具体操作是这样的:

  1. 已知要选的所有值都要大于阈值$X$,且对于元素$i$,它初始值是$v_i$,每选择它一次值就会减少$d_i$;
  2. 那么能够减少$d_i$的次数就是:$cnt=\lfloor \frac{v_i-X-1}{d_i} \rfloor$;(其中-1是为了保证减少后其值大于$X$)
  3. 因此该元素能够被选择的次数就是:$cnt+1$.

综上所述,如果$X$增大,那么总共选择的数量就一定非增(不变或减小),否则总共选择的数量就一定非减(不变或增大),即总共选择的数量随$X$的增大呈有序分布状态,因此可以考虑用二分查找的方式寻找一个能够满足总共选择数量小于等于$m$的最小阈值$X$。

得到阈值$X$后,即可重新遍历所有元素,根据上面计算出来的每个元素可以被选择的次数,利用等差数列求和公式来得到该元素可以给答案贡献的总和,公式推导如下:

  1. 由上面分析可以知道,一个元素$i$可以被选择的次数是:$\lfloor \frac{v_i-X-1}{d_i} \rfloor+1$,记为$c_i$;
  2. 因此该元素为答案贡献的值即为:$v_i+(v_i-d_i)+\dots+(v_i-(c_i-1)\times d_i)$;
  3. 利用等差数列求和公式即可化简上面的式子变成: $$ \frac{(v_i+v_i-(c_i-1)\times d_i)\times c_i}{2} $$

因此,对于所有元素都将以上公式计算出来的值加到答案中,即可得到所有元素一共可以贡献的值。

但这并不是答案,由于在二分查找$X$时已经保证:当阈值为$X$时,数组中可以选择的数量小于等于$m$,且当阈值为$X-1$时,数组中可以选择的数量大于$m$,也就意味着,可以选择的数字中,大于$X$的数量小于等于$m$,大于等于$X$的数量大于$m$,因此两者之差就是可以选择的数字中等于$X$的数量。又因为上面选择完成后,数组中的所有值最大只能是$X$,因此,如果在选择完成之后仍然有剩余选择次数,即可将剩余所有次数都用来选择$X$,加入答案即可。

复杂度

时间复杂度:$O(n\cdot logM)$,其中$M = max(value)$ 空间复杂度:$O(1)$

代码

class Solution:
    def maxTotalValue(self, value: list[int], decay: list[int], m: int) -> int:
        n = len(value)
        
        def check(mid):
            totCnt = 0
            for i in range(n):
                v, d = value[i], decay[i]
                if v > mid:
                    cnt = (v-mid-1)//d
                    totCnt += cnt+1
            return totCnt > m
        
        l, r = 0, max(value)
        x = 0
        while l<=r:
            mid = (l+r)//2
            if check(mid):
                l = mid+1
            else:
                r = mid-1
                x = mid
        
        res = 0
        picked = 0
        modNum = 1_000_000_007
        for i in range(n):
            v, d = value[i], decay[i]
            if v <= x:
                continue
            cnt = (v-x-1)//d
            res = (res+(v+v-cnt*d)*(cnt+1)//2)%modNum
            picked += cnt+1
        
        if picked < m:
            if x > 0:
                res = (res+(m-picked)*x)%modNum
            
        return res