E - Eating
Abridged Problem Statement
You can merge two cakes \(M\) times. And you can eat the cake by unit \(x\) infinite times, where \(x\) should be less than half the cake size.
But you cannot eat the cake whose size is smaller than \(K\).
Your goal is eat the cake as much as possible.
Observations
Observation 1
We could merge first and eat the cakes after that.
As you can eat the cake infinite times, you can eat 1 unit until the cake’s size becomes \(K\).
Observation 2
When the cake’s size become \(K\), only one bite on this cake in the future. Thus we always eat \(K/2\) unit in this case.
Observation 3
Maximize the cake you eat means minimize your waste.
Greedy
Intuitive thinking
Let’s denote \(w[i]\) as the remaining unit in the end.
- \(w[i]=(k+1)/2\) if \(a[i]>=k\),
- \(w[i]=a[i]\) if \(a[i]<k\).
To minimize the sum of \(w[i]\), a intuitive solution here is sort the \(w[i]\) and remove the first \(M\) \(w[i]\) in decreasing order.
Issue?
But there is another better choice:
We can merge two cakes both of size \(k>=a[i]>(k+1)/2\). In this case, we not only eliminated the wastage of one cake, but also decrease another cake’s waste value from \(a[i]\) to \((k+1)/2\).
Hence, our solution is:
- First merge all cake with size \((k+1)/2 \leq a[i] < k\) in decreasing order.
- After that, we eliminate the cake by merging to a large cake one by one.
-
import heapq class Solution: @staticmethod def solve(n: int, m: int, k: int, a: list[int]) -> int: q1, q2 = [], [] half = (k + 1) // 2 for x in a: if x < k and x >= half: heapq.heappush(q1, -x) else: heapq.heappush(q2, -x) while m > 0 and len(q1) >= 2: x = -heapq.heappop(q1) y = -heapq.heappop(q1) heapq.heappush(q2, -(x + y)) m -= 1 while m > 0 and q1: x = -heapq.heappop(q2) y = -heapq.heappop(q1) heapq.heappush(q2, -(x + y)) m -= 1 while m > 0 and len(q2) >= 2: x = -heapq.heappop(q2) y = -heapq.heappop(q2) heapq.heappush(q2, -(x + y)) m -= 1 ans = 0 while q2: x = -heapq.heappop(q2) if x >= k: ans += x - half return ans if __name__ == "__main__": n, m, k = map(int, input().split()) a = list(map(int, input().split())) print(Solution.solve(n, m, k, a))
-
#include <inttypes.h> #include <stdint.h> #include <stdio.h> #include <stdlib.h> int64_t solve(int n, int m, int k, int64_t* a); int main() { int n, m, k; scanf("%d%d%d", &n, &m, &k); int64_t* a = (int64_t*)malloc(sizeof(int64_t) * n); for (int i = 0; i < n; i++) { scanf("%lld", &a[i]); } printf("%" PRId64 "\n", solve(n, m, k, a)); free(a); return 0; } typedef struct { int64_t* heap; int size; int capacity; } PriorityQueue; PriorityQueue* createPriorityQueue(int capacity) { PriorityQueue* pq = (PriorityQueue*)malloc(sizeof(PriorityQueue)); pq->heap = (int64_t*)malloc(sizeof(int64_t) * capacity); pq->size = 0; pq->capacity = capacity; return pq; } void swap(int64_t* a, int64_t* b) { int64_t temp = *a; *a = *b; *b = temp; } void heapify(PriorityQueue* pq, int idx) { int largest = idx; int left = 2 * idx + 1; int right = 2 * idx + 2; if (left < pq->size && pq->heap[left] > pq->heap[largest]) largest = left; if (right < pq->size && pq->heap[right] > pq->heap[largest]) largest = right; if (largest != idx) { swap(&pq->heap[idx], &pq->heap[largest]); heapify(pq, largest); } } void push(PriorityQueue* pq, int64_t value) { if (pq->size == pq->capacity) { pq->capacity *= 2; pq->heap = (int64_t*)realloc(pq->heap, sizeof(int64_t) * pq->capacity); } int i = pq->size; pq->heap[i] = value; pq->size++; while (i > 0 && pq->heap[(i - 1) / 2] < pq->heap[i]) { swap(&pq->heap[i], &pq->heap[(i - 1) / 2]); i = (i - 1) / 2; } } int64_t top(PriorityQueue* pq) { if (pq->size <= 0) { return -1; // Error: empty queue } return pq->heap[0]; } int64_t pop(PriorityQueue* pq) { if (pq->size <= 0) { return -1; // Error: empty queue } int64_t root = pq->heap[0]; pq->heap[0] = pq->heap[pq->size - 1]; pq->size--; heapify(pq, 0); return root; } int isEmpty(PriorityQueue* pq) { return pq->size == 0; } int size(PriorityQueue* pq) { return pq->size; } void freePriorityQueue(PriorityQueue* pq) { free(pq->heap); free(pq); } int64_t solve(int n, int m, int k, int64_t* a) { PriorityQueue* q1 = createPriorityQueue(n); PriorityQueue* q2 = createPriorityQueue(n); int half = (k + 1) / 2; for (int i = 0; i < n; i++) { if (a[i] < k && a[i] >= half) { push(q1, a[i]); } else { push(q2, a[i]); } } while (m > 0 && size(q1) >= 2) { int64_t x = pop(q1); int64_t y = pop(q1); push(q2, x + y); m--; } while (m > 0 && !isEmpty(q1)) { int64_t x = pop(q2); int64_t y = pop(q1); push(q2, x + y); m--; } while (m > 0 && size(q2) >= 2) { int64_t x = pop(q2); int64_t y = pop(q2); push(q2, x + y); m--; } int64_t ans = 0; while (!isEmpty(q2)) { int64_t x = pop(q2); if (x >= k) { ans += x - half; } } freePriorityQueue(q1); freePriorityQueue(q2); return ans; }
-
#include <cstdint> #include <iostream> #include <queue> class Solve { public: static int64_t solve(int n, int m, int k, std::vector<int64_t> &a) { std::priority_queue<int64_t> q1, q2; int half = (k + 1) / 2; for (auto &x : a) { if (x < k && x >= half) { q1.push(x); } else { q2.push(x); } } while (m && q1.size() >= 2) { int64_t x = q1.top(); q1.pop(); int64_t y = q1.top(); q1.pop(); q2.push(x + y); m--; } while (m && !q1.empty()) { int64_t x = q2.top(); q2.pop(); int64_t y = q1.top(); q1.pop(); q2.push(x + y); m--; } while (m && q2.size() >= 2) { int64_t x = q2.top(); q2.pop(); int64_t y = q2.top(); q2.pop(); q2.push(x + y); m--; } int64_t ans = 0; while (!q2.empty()) { int64_t x = q2.top(); q2.pop(); if (x >= k) { ans += x - half; } } return ans; } }; int main() { int n, m, k; std::cin >> n >> m >> k; std::vector<int64_t> a(n); for (auto &x : a) { std::cin >> x; } std::cout << Solve::solve(n, m, k, a) << std::endl; return 0; }
-
import java.util.Collections; import java.util.PriorityQueue; import java.util.Scanner; public class Solution { public static long solve(int n, int m, int k, long[] a) { PriorityQueue<Long> q1 = new PriorityQueue<>(Collections.reverseOrder()); PriorityQueue<Long> q2 = new PriorityQueue<>(Collections.reverseOrder()); int half = (k + 1) / 2; for (int i = 0; i < n; i++) { if (a[i] < k && a[i] >= half) { q1.add(a[i]); } else { q2.add(a[i]); } } while (m > 0 && q1.size() >= 2) { long x = q1.poll(); long y = q1.poll(); q2.add(x + y); m--; } while (m > 0 && !q1.isEmpty()) { long x = q2.poll(); long y = q1.poll(); q2.add(x + y); m--; } while (m > 0 && q2.size() >= 2) { long x = q2.poll(); long y = q2.poll(); q2.add(x + y); m--; } long ans = 0; while (!q2.isEmpty()) { long x = q2.poll(); if (x >= k) { ans += x - half; } } return ans; } public static void main(String[] args) { Scanner scanner = new Scanner(System.in); int n = scanner.nextInt(); int m = scanner.nextInt(); int k = scanner.nextInt(); long[] a = new long[n]; for (int i = 0; i < n; i++) { a[i] = scanner.nextLong(); } System.out.println(Solution.solve(n, m, k, a)); scanner.close(); } }
-
package main import ( "bufio" "container/heap" "fmt" "os" ) type MaxHeap []int64 func (h MaxHeap) Len() int { return len(h) } func (h MaxHeap) Less(i, j int) bool { return h[i] > h[j] } func (h MaxHeap) Swap(i, j int) { h[i], h[j] = h[j], h[i] } func (h *MaxHeap) Push(x interface{}) { *h = append(*h, x.(int64)) } func (h *MaxHeap) Pop() interface{} { old := *h n := len(old) x := old[n-1] *h = old[0 : n-1] return x } func solve(n int, m int, k int, a []int64) int64 { q1 := &MaxHeap{} q2 := &MaxHeap{} heap.Init(q1) heap.Init(q2) half := (k + 1) / 2 for _, x := range a { if x < int64(k) && x >= int64(half) { heap.Push(q1, x) } else { heap.Push(q2, x) } } for m > 0 && q1.Len() >= 2 { x := heap.Pop(q1).(int64) y := heap.Pop(q1).(int64) heap.Push(q2, x+y) m-- } for m > 0 && q1.Len() > 0 { x := heap.Pop(q2).(int64) y := heap.Pop(q1).(int64) heap.Push(q2, x+y) m-- } for m > 0 && q2.Len() >= 2 { x := heap.Pop(q2).(int64) y := heap.Pop(q2).(int64) heap.Push(q2, x+y) m-- } ans := int64(0) for q2.Len() > 0 { x := heap.Pop(q2).(int64) if x >= int64(k) { ans += x - int64(half) } } return ans } func main() { var n, m, k int reader := bufio.NewReader(os.Stdin) fmt.Fscan(reader, &n, &m, &k) a := make([]int64, n) for i := 0; i < n; i++ { fmt.Fscan(reader, &a[i]) } fmt.Println(solve(n, m, k, a)) }