try lru cache

This commit is contained in:
Justin Lin
2026-05-02 16:19:50 +10:00
parent 8582780ac9
commit bd9292ccd4
+28 -26
View File
@@ -1,3 +1,4 @@
import functools
import math import math
import sys import sys
@@ -10,17 +11,18 @@ def inputs():
n, wet_c, dry_c = inputs() n, wet_c, dry_c = inputs()
costs = [[(math.inf, -1)] * (dry_c + 1) for _ in range(wet_c + 1)] cost_map = {}
costs[0][0] = (0, -1) cost_map[(0, 0)] = (0, -1)
bundles = [] bundles = []
for i in range(n): for i in range(n):
wet, dry, cost = inputs() wet, dry, cost = inputs()
bundles.append((wet, dry, cost)) bundles.append((wet, dry, cost))
if wet <= wet_c and dry <= dry_c: if wet <= wet_c and dry <= dry_c:
costs[wet][dry] = (cost, i) cost_map[(wet, dry)] = (cost, i)
@functools.lru_cache
def get_cost(w, d) -> float: def get_cost(w, d) -> float:
if w == 0 and d == 0: if w == 0 and d == 0:
return 0 return 0
@@ -28,48 +30,48 @@ def get_cost(w, d) -> float:
if w < 0 or d < 0: if w < 0 or d < 0:
return math.inf return math.inf
if costs[w][d][1] != -1: if (w, d) in cost_map:
print("reusing: ", w, d) # print("reusing: ", w, d)
return costs[w][d][0] return cost_map[(w, d)][0]
print(w, d)
# print(w, d)
cost_map[(w, d)] = (math.inf, -1)
for i in range(len(bundles)): for i in range(len(bundles)):
wet, dry, cost = bundles[i] wet, dry, cost = bundles[i]
new_cost = get_cost(w - wet, d - dry) + cost new_cost = get_cost(w - wet, d - dry) + cost
if costs[w][d][0] > new_cost: if cost_map[(w, d)][0] > new_cost:
costs[w][d] = (new_cost, i) cost_map[(w, d)] = (new_cost, i)
return costs[w][d][0] return cost_map[(w, d)][0]
def bottom_up() -> float: # def bottom_up() -> float:
for w in range(wet_c + 1): # for w in range(wet_c + 1):
for d in range(dry_c + 1): # for d in range(dry_c + 1):
for i in range(len(bundles)): # for i in range(len(bundles)):
wet, dry, cost = bundles[i] # wet, dry, cost = bundles[i]
if wet > w or dry > d: # if wet > w or dry > d:
continue # continue
new_cost = costs[w - wet][d - dry][0] + cost # new_cost = costs[w - wet][d - dry][0] + cost
if costs[w][d][0] > new_cost: # if costs[w][d][0] > new_cost:
costs[w][d] = (new_cost, i) # costs[w][d] = (new_cost, i)
return costs[wet_c][dry_c][0] # return costs[wet_c][dry_c][0]
def backtrack(w, d) -> list[int]: def backtrack(w, d) -> list[int]:
counter = [0] * n counter = [0] * n
b = costs[w][d][1] b = cost_map[(w, d)][1]
while b != -1: while b != -1:
counter[b] += 1 counter[b] += 1
wet, dry, _ = bundles[b] wet, dry, _ = bundles[b]
w -= wet w -= wet
d -= dry d -= dry
b = costs[w][d][1] b = cost_map[(w, d)][1]
return counter return counter
# c = get_cost(wet_c, dry_c) c = get_cost(wet_c, dry_c)
c = bottom_up() # c = bottom_up()
# print(costs) # print(costs)
if math.isinf(c): if math.isinf(c):
print(-1) print(-1)