schedule.py 21 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386
  1. import sys, pickle, atexit
  2. from collections import defaultdict, deque
  3. from dataclasses import dataclass
  4. from typing import Tuple, List, Dict, Optional, Set, DefaultDict, Union, cast, get_args
  5. from tinygrad.ops import MetaOps, BufferOps, LazyOp, Op, ReduceOps, ConstBuffer, MemBuffer, UNSAFE_PAD_OPS, UnaryOps, reduce_st
  6. from tinygrad.engine.graph import log_lazybuffer, realized_lazybuffer
  7. from tinygrad.helpers import GRAPH, DEBUG, MULTIOUTPUT, SAVE_SCHEDULE, GlobalCounters, colored, prod, dedup, all_int, merge_dicts, getenv, Metadata
  8. from tinygrad.shape.symbolic import Variable
  9. from tinygrad.dtype import ConstType, ImageDType, dtypes
  10. from tinygrad.lazy import LazyBuffer
  11. from tinygrad.shape.shapetracker import ShapeTracker
  12. from tinygrad.device import Buffer, Device
  13. # creation can recurse a lot
  14. sys.setrecursionlimit(10000)
  15. # optionally log the ops to disk
  16. logops = open(getenv("LOGOPS", ""), "a") if getenv("LOGOPS", "") else None
  17. # *** ScheduleItem return type ***
  18. @dataclass(frozen=True)
  19. class ScheduleItem:
  20. ast: LazyOp
  21. bufs: Tuple[Buffer, ...]
  22. metadata: Optional[List[Metadata]] = None
  23. @property
  24. def outputs(self) -> Tuple[Buffer, ...]:
  25. """Read/write or write only buffers in the schedule."""
  26. return self.bufs[:len(self.ast.src)] if self.ast.op is MetaOps.KERNEL else self.bufs[0:1]
  27. @property
  28. def inputs(self) -> Tuple[Buffer, ...]:
  29. """Read only buffers in the schedule."""
  30. return self.bufs[len(self.ast.src):] if self.ast.op is MetaOps.KERNEL else self.bufs[1:]
  31. # *** DAG transformation: List[LazyBuffer] -> ScheduleItem ***
  32. def _recursive_lazyop(buf:LazyBuffer, inputs:List[LazyBuffer], outputs:Tuple[LazyBuffer, ...], var_vals:Dict[Variable, int], st:ShapeTracker,
  33. realizes:Dict[LazyBuffer, None], assign_targets:Dict[LazyBuffer, LazyBuffer],
  34. reduce_info:Dict[LazyBuffer, Tuple[ShapeTracker, Tuple[int, ...]]], cache) -> LazyOp:
  35. """recursively create a lazyop"""
  36. if buf is not buf.base: st, buf = buf.st+st, buf.base
  37. if (buf, st) in cache: return cache[(buf, st)]
  38. arg = buf.arg
  39. # consts are always fused and generated
  40. if buf.op is MetaOps.CONST:
  41. unbound_st, st_var_vals = st.simplify().unbind()
  42. var_vals.update(st_var_vals)
  43. if isinstance(buf.arg, Variable):
  44. val, var_val = buf.arg.unbind()
  45. var_vals.__setitem__(val, var_val)
  46. else:
  47. assert isinstance(buf.arg, get_args(ConstType)), f"cannot create ConstBuffer with value {buf.arg}"
  48. val = buf.arg
  49. return LazyOp(BufferOps.CONST, (), ConstBuffer(val, buf.dtype, unbound_st))
  50. # if we aren't fusing it, it's a load and we add it to the inputs
  51. if buf.realized is not None or (buf in realizes and buf not in outputs):
  52. unbound_st, st_var_vals = st.simplify().unbind()
  53. var_vals.update(st_var_vals)
  54. if buf in assign_targets:
  55. # can only assign to contiguous read+write buffer
  56. if not unbound_st.contiguous:
  57. # we also allow masked views. if it has a single view and it's equal when you shrink a contig, it's fine
  58. if not (len(unbound_st.views) == 1 and unbound_st.views[0].mask is not None and
  59. ShapeTracker.from_shape(unbound_st.shape).shrink(unbound_st.views[0].mask) == unbound_st.shrink(unbound_st.views[0].mask)):
  60. raise RuntimeError("self operand of augmented assign must be contiguous.\nhelp: consider using .contiguous():\n"
  61. +colored(" - a += a.T\n", "red")+colored(" + a += a.T.contiguous()", "green"))
  62. return LazyOp(BufferOps.LOAD, (), MemBuffer(outputs.index(assign_targets[buf]), buf.dtype, unbound_st))
  63. if buf not in inputs: inputs.append(buf)
  64. return LazyOp(BufferOps.LOAD, (), MemBuffer(len(outputs)+inputs.index(buf), buf.dtype, unbound_st))
  65. # if a CONTIGUOUS or ASSIGN made it all the way here, just skip it
  66. if buf.op is MetaOps.CONTIGUOUS:
  67. assert buf in outputs
  68. return _recursive_lazyop(buf.srcs[0], inputs, outputs, var_vals, st, realizes, assign_targets, reduce_info, cache)
  69. if buf.op is MetaOps.ASSIGN:
  70. assert buf in outputs
  71. assert buf.srcs[1].base is buf.srcs[1], "assign must be to base"
  72. assert buf.srcs[1].realized is not None, f"assign must be already realized to schedule {buf.srcs[1]}"
  73. return _recursive_lazyop(buf.srcs[0], inputs, outputs, var_vals, st, realizes, assign_targets, reduce_info, cache)
  74. # if it's a reduce, we have to change the shapetracker
  75. if buf.op in ReduceOps:
  76. assert st.contiguous, "ReduceOps late fusion must be contiguous"
  77. st, arg = reduce_info[buf]
  78. # otherwise we fuse it like normal
  79. return cache.setdefault((buf, st), LazyOp(cast(Op,buf.op), tuple(_recursive_lazyop(x, inputs, outputs, var_vals, st, realizes, assign_targets, \
  80. reduce_info, cache) for x in buf.srcs), arg))
  81. def _recurse_reduceops(buf:LazyBuffer, st:ShapeTracker, realizes:Dict[LazyBuffer, None], outs:List[LazyBuffer], reduce_info:Dict, cache):
  82. if buf.base.realized is not None or (buf.base in realizes and buf.base not in outs) or (buf, st) in cache: return
  83. cache.setdefault((buf, st))
  84. if buf is not buf.base: st, buf = buf.st+st, buf.base
  85. for x in buf.srcs: _recurse_reduceops(x, buf.srcs[0].st if buf.op in ReduceOps else st, realizes, outs, reduce_info, cache)
  86. if buf.op in ReduceOps:
  87. reduce_input, axis = buf.srcs[0], buf.arg
  88. assert st.contiguous
  89. st = ShapeTracker.from_shape(reduce_input.shape)
  90. reduce_info[buf] = (st, axis)
  91. def _lower_lazybuffer(outs:List[LazyBuffer], realizes:Dict[LazyBuffer, None]):
  92. """describe the computation for a LazyBuffer with LazyOp + inputs + var_vals"""
  93. if (out:=outs[0]).op is MetaOps.COPY and getenv("USE_COPY_KERNEL") and out.device.split(":")[0] == out.srcs[0].device.split(":")[0]:
  94. rd = LazyOp(BufferOps.LOAD, (), MemBuffer(1, dtypes.uint8, st:=ShapeTracker.from_shape((out.arg,))))
  95. return LazyOp(MetaOps.KERNEL, (LazyOp(BufferOps.STORE, (rd,), MemBuffer(0, dtypes.uint8, st)), )), [x.base for x in out.srcs], {}, []
  96. if out.op in {MetaOps.CUSTOM, MetaOps.COPY, MetaOps.EMPTY, MetaOps.VIEW}: return LazyOp(out.op, (), out.arg), [x.base for x in out.srcs], {}, []
  97. var_vals: Dict[Variable, int] = merge_dicts([out.st.var_vals.copy() for out in outs])
  98. assign_targets = {x.srcs[1]:x for x in outs if x.op is MetaOps.ASSIGN}
  99. cache: Dict[Tuple[LazyBuffer, ShapeTracker], LazyOp] = {}
  100. ast: List[LazyOp] = []
  101. inputs: List[LazyBuffer] = []
  102. reduce_info: Dict[LazyBuffer, Tuple[ShapeTracker, Tuple[int, ...]]] = {}
  103. seen_ops: Dict[Tuple[LazyBuffer, ShapeTracker], None] = {}
  104. for i, out in enumerate(outs):
  105. _recurse_reduceops(out, out.st, realizes, outs, reduce_info, seen_ops)
  106. output_st = ShapeTracker.from_shape(reduce_st(*deque(reduce_info.values(), 1).pop()) if reduce_info else out.shape)
  107. output_view = out.arg[0] if out.op is MetaOps.ASSIGN and out.arg else output_st
  108. lop = _recursive_lazyop(out, inputs, tuple(outs), var_vals, output_st, realizes, assign_targets, reduce_info, cache=cache)
  109. output_view, vv = output_view.simplify().unbind()
  110. if vv: var_vals.update(vv)
  111. ast.append(LazyOp(BufferOps.STORE, (lop, ), MemBuffer(i, out.dtype, output_view)))
  112. return LazyOp(MetaOps.KERNEL, tuple(ast)), inputs, var_vals, dedup([x[0].metadata for x in cache if x[0].metadata and x[0] not in inputs])
  113. # *** DAG creation: decide which LazyBuffers should realize ***
  114. def _recurse_lb(buf:LazyBuffer, realizes:Dict[LazyBuffer, None], allbufs:Dict[LazyBuffer, None],
  115. simple_pads:Set[LazyBuffer], children:DefaultDict[LazyBuffer, Dict[LazyBuffer, None]], scheduled=False):
  116. """recursively search the entire graph for all LazyBuffers, insert realizes after expands"""
  117. if buf in allbufs or buf.base.realized is not None: return
  118. if GRAPH: log_lazybuffer(buf, scheduled)
  119. # view
  120. if buf.base != buf:
  121. # fuse some pads
  122. if len(buf.st.views) == 1 and buf.st.views[-1].mask is not None and all_int(buf.base.st.shape) and \
  123. prod(buf.base.st.shape) >= prod([y-x for x,y in buf.st.views[-1].mask]):
  124. simple_pads.add(buf.base)
  125. # realize all expands
  126. elif prod(buf.base.st.shape) < prod(buf.st.shape):
  127. if buf.base.op in ReduceOps and buf.base.srcs[0].base.op is MetaOps.CONST:
  128. pass # don't realize reduceops on const (unless base is forced_realize)
  129. # this was causing "test_lil_model" to fail
  130. if buf.base.op is UnaryOps.CAST and isinstance(buf.base.srcs[0].dtype, ImageDType) and isinstance(buf.base.arg, ImageDType):
  131. simple_pads.add(buf.base) # don't realize image to image casts. this is part of a larger problem
  132. else:
  133. realizes[buf.base] = None
  134. # check all other pads for safe fusion
  135. elif any(v.mask is not None for v in buf.st.views): simple_pads.add(buf.base)
  136. return _recurse_lb(buf.base, realizes, allbufs, simple_pads, children)
  137. # base
  138. allbufs[buf] = None
  139. if buf.forced_realize: realizes[buf] = None
  140. if buf.op in MetaOps: realizes[buf.base] = None
  141. if buf.op is MetaOps.COPY:
  142. assert buf.srcs[0].st.contiguous and buf.srcs[0].size == buf.srcs[0].base.size, "can only copy contig"
  143. realizes[buf.srcs[0].base] = None
  144. if buf.op is MetaOps.VIEW: realizes[buf.srcs[0].base] = None
  145. for x in buf.srcs:
  146. if x.base.realized is None: children[x.base][buf] = None
  147. _recurse_lb(x, realizes, allbufs, simple_pads, children)
  148. def _is_padding_okay(buf:LazyBuffer, realizes:Dict[LazyBuffer, None]) -> bool:
  149. if buf in realizes or buf.realized is not None: return True
  150. # NOTE: this broke to_image_idx and coder with JIT
  151. if buf.op in UNSAFE_PAD_OPS: return False
  152. return all(_is_padding_okay(x.base, realizes) for x in buf.srcs)
  153. def _recursive_group(tr:LazyBuffer, st:ShapeTracker, r:LazyBuffer, children:DefaultDict[LazyBuffer, Dict[LazyBuffer, None]],
  154. realizes:Dict[LazyBuffer, None], reduce_for_op:Dict[LazyBuffer, LazyBuffer], group:Set[LazyBuffer], cache:Set):
  155. """recursively search the LazyBuffer for groupable children, realize the LazyBuffer if a child can't group"""
  156. if (tr, st) in cache: return
  157. cache.add((tr, st))
  158. if tr in realizes:
  159. # can only fuse contiguous
  160. # max one reduceop per kernel
  161. if not st.contiguous or st.size != r.st.size or tr in reduce_for_op: group.add(r)
  162. return group.add(tr)
  163. for tr_next in children[tr]:
  164. # max one reduceop per kernel
  165. if tr_next.op in ReduceOps: return group.add(r)
  166. # can only fuse contiguous
  167. if len(st_childs:=dedup(s for s in tr_next.srcs if s.base == tr)) > 1: return group.add(r)
  168. _recursive_group(tr_next, st+st_childs[0].st, r, children, realizes, reduce_for_op, group, cache)
  169. def _graph_schedule(outs:List[LazyBuffer], seen:Set[LazyBuffer]):
  170. """create a graph for realizing the outputs"""
  171. # start by just realizing the buffers passed in
  172. realizes: Dict[LazyBuffer, None] = {x.base:None for x in outs if x.base.realized is None}
  173. allbufs: Dict[LazyBuffer, None] = {}
  174. simple_pads: Set[LazyBuffer] = set()
  175. children: DefaultDict[LazyBuffer, Dict[LazyBuffer, None]] = defaultdict(dict)
  176. for out in outs: _recurse_lb(out.base, realizes, allbufs, simple_pads, children, scheduled=True)
  177. assign_targets = {x.srcs[1]:x for x in realizes if x.op is MetaOps.ASSIGN and x not in seen and x.realized is None}
  178. # check if we have to realize pads
  179. for p in simple_pads:
  180. if not _is_padding_okay(p, realizes):
  181. realizes[p] = None
  182. # find all reduces, and pair them to a elementwise op. if they can't be cleanly paired, force realize the reduce (or a contig child)
  183. reduce_for_op: Dict[LazyBuffer, LazyBuffer] = {}
  184. for r in allbufs:
  185. if r.op not in ReduceOps or r in realizes: continue
  186. group: Set[LazyBuffer] = set()
  187. _recursive_group(r, r.st, r, children, realizes, reduce_for_op, group, cache=set())
  188. # max one reduceop per kernel
  189. can_chase = all(tr not in reduce_for_op for tr in group)
  190. # TODO: forced_realize exists because the scheduler is incapable of checking for self-contained DAGs
  191. forced_realize = r in group
  192. if not forced_realize and len(group) > 1:
  193. # create a multi output kernel if the LazyBuffers can cleanly group
  194. cache: Set[LazyBuffer] = set()
  195. rc_parents, rc_children = deque(group), deque(group)
  196. while rc_parents:
  197. if (p:=rc_parents.pop()) in cache: continue
  198. cache.add(p)
  199. # max one reduceop per kernel
  200. if p.op in ReduceOps:
  201. forced_realize = True
  202. break
  203. rc_parents.extend(x.base for x in p.srcs if x.base.realized is None and x.base is not r)
  204. # search descendants of the reduceop that can cleanly group
  205. cache.clear()
  206. realized_descendants: Set[LazyBuffer] = set()
  207. while rc_children and not forced_realize:
  208. if (c:=rc_children.pop()) in cache: continue
  209. cache.add(c)
  210. if c.op in ReduceOps or not c.st.contiguous or c.st.size != r.st.size or c in reduce_for_op:
  211. realized_descendants.clear()
  212. break
  213. if c in realizes and c not in group: realized_descendants.add(c)
  214. rc_children.extend(x for x in children[c] if x.realized is None and x.device == r.device)
  215. group.update(realized_descendants)
  216. # can only fuse assign if no other assign_target is used in the kernel
  217. if not forced_realize and any(x.op is MetaOps.ASSIGN for x in group):
  218. parents = deque((r, *group))
  219. while parents and not forced_realize:
  220. if (p:=parents.pop().base).realized or p in realizes:
  221. if p in assign_targets and assign_targets[p] not in group: forced_realize, can_chase = True, False
  222. continue
  223. parents.extend(p.srcs)
  224. if forced_realize and (r.srcs[0].base.op is not MetaOps.CONST or any(x.shape != r.shape for x in children[r])):
  225. tr = r
  226. if can_chase:
  227. # can chase this down to contiguous children
  228. st = tr.st
  229. while len(children[tr]) == 1:
  230. tr_next = next(iter(children[tr]))
  231. st_childs = dedup(s for s in tr_next.srcs if s.base is tr)
  232. if len(st_childs) > 1: break
  233. if st.size != st_childs[0].st.size: break
  234. st = st + st_childs[0].st
  235. if not st.contiguous or tr_next.op in ReduceOps: break
  236. tr = tr_next
  237. # don't cast to higher size before store (tr cannot be realized if forced_realize)
  238. if tr.op is UnaryOps.CAST and tr.arg.itemsize > tr.srcs[0].dtype.itemsize:
  239. tr = tr.srcs[0].base
  240. reduce_for_op[tr] = r
  241. realizes[tr] = None
  242. else: reduce_for_op.update((tr, r) for tr in group)
  243. output_groups: DefaultDict[LazyBuffer, List[LazyBuffer]] = defaultdict(list)
  244. for buf in realizes:
  245. if buf.realized is not None or buf.op is MetaOps.CONST or buf in seen: continue
  246. output_groups[reduce_for_op[buf] if buf in reduce_for_op and MULTIOUTPUT else buf].append(buf)
  247. # make things that can't be images not images
  248. if isinstance(buf.dtype, ImageDType) and (prod(buf.shape) != prod(buf.dtype.shape) or
  249. not any(buf.shape[x]%4 == 0 for x in buf.st.unit_stride_axes())):
  250. if DEBUG >= 2: print(f"forcing image {buf.dtype} with shape {buf.shape} to float32")
  251. buf.dtype = dtypes.float32
  252. # hack the underlying buffer too
  253. if buf.base is buf:
  254. assert not hasattr(buf.buffer, '_buf'), "can't fixup allocated buffer"
  255. buf.buffer.dtype = dtypes.float32
  256. buf.buffer.options = None
  257. # preschedule all buffers in realizes
  258. prescheduled = {group[0]:(group, *_lower_lazybuffer(group, realizes)) for group in output_groups.values()}
  259. schedule_targets = {out:ps for ps in prescheduled.values() for out in ps[0]}
  260. graph: DefaultDict[LazyBuffer, List[LazyBuffer]] = defaultdict(list)
  261. in_degree: DefaultDict[LazyBuffer, int] = defaultdict(int)
  262. for key, lsi in prescheduled.items():
  263. if key not in in_degree: in_degree[key] = 0
  264. # realize outputs after all parents are realized
  265. scheduled_parents = set(schedule_targets[x][0][0] for x in lsi[2] if x in schedule_targets)
  266. for x in scheduled_parents:
  267. graph[x].append(key)
  268. in_degree[key] += 1
  269. # realize outputs before a parent is assigned to
  270. parents_assigns = set(schedule_targets[assign_targets[x]][0][0] for x in lsi[2] if x in assign_targets)
  271. for assign in parents_assigns:
  272. graph[key].append(assign)
  273. in_degree[assign] += 1
  274. return graph, in_degree, prescheduled
  275. # *** DAG ordering: breadth first search ***
  276. SCHEDULES: List = []
  277. def create_schedule_with_vars(outs:List[LazyBuffer], seen:Optional[Set[LazyBuffer]]=None) -> Tuple[List[ScheduleItem], Dict[Variable, int]]:
  278. if seen is None: seen = set()
  279. graph, in_degree, prescheduled = _graph_schedule(outs, seen)
  280. queue = deque(si for key, si in prescheduled.items() if in_degree[key] == 0)
  281. schedule: List[ScheduleItem] = []
  282. var_vals: Dict[Variable, int] = {}
  283. kernel_number = GlobalCounters.kernel_count
  284. while queue:
  285. ps = queue.popleft()
  286. for buf in ps[0]: seen.add(buf)
  287. if GRAPH:
  288. kernel_number += 1
  289. for out in ps[0]: realized_lazybuffer(out, kernel_number)
  290. var_vals = merge_dicts([var_vals, ps[3]])
  291. for out in ps[0]: del out.srcs # can only schedule once
  292. schedule.append(si:=ScheduleItem(ps[1], tuple(x.buffer for x in ps[0]+ps[2] if x.size != 0), ps[4]))
  293. if logops and si.ast.op is MetaOps.KERNEL and not any(i.device.startswith("DISK:") for i in si.inputs): logops.write(str(si.ast)+"\n")
  294. for x in graph[ps[0][0]]:
  295. in_degree[x] -= 1
  296. if in_degree[x] == 0: queue.append(prescheduled[x])
  297. if SAVE_SCHEDULE:
  298. def _save():
  299. print(f"saving {len(SCHEDULES)} schedule graphs to", fp:=getenv("SAVE_SCHEDULE_PATH", "schedule.pkl"))
  300. with open(fp, "wb") as f: pickle.dump(SCHEDULES, f)
  301. if len(SCHEDULES) == 0: atexit.register(_save)
  302. SCHEDULES.extend((ps[1] for ps in prescheduled.values()) if getenv("CAPTURE_AST") else [(graph, prescheduled)])
  303. # confirm everything was scheduled correctly
  304. if not all(degree == 0 for degree in in_degree.values()) or len(prescheduled) != len(schedule):
  305. raise RuntimeError(f"cycle detected in graph, prescheduled {len(prescheduled)} but only scheduled {len(schedule)}")
  306. if DEBUG >= 1 and len(schedule) >= 10: print(f"scheduled {len(schedule)} kernels")
  307. return schedule, var_vals
  308. def create_schedule(outs:List[LazyBuffer], seen:Optional[Set[LazyBuffer]]=None) -> List[ScheduleItem]:
  309. schedule, var_vals = create_schedule_with_vars(outs, seen)
  310. assert len(var_vals) == 0
  311. return schedule
  312. # *** memory planning ***
  313. def _internal_memory_planner(buffers:List[Union[List[Buffer], Tuple[Buffer, ...]]], noopt_buffers=None, debug_prefix="") -> Dict[Buffer, Buffer]:
  314. if getenv("NO_MEMORY_PLANNER"): return {}
  315. first_appearance, last_appearance = {}, {}
  316. for i,u in enumerate(buffers):
  317. for buf in u:
  318. if buf.is_allocated() or buf.lb_refcount > 0 or (noopt_buffers is not None and buf.base in noopt_buffers): continue
  319. if buf.base not in first_appearance: first_appearance[buf.base] = i
  320. last_appearance[buf.base] = i
  321. # Sort buffers by size in descending order, prioritizing largest buffers for allocation first.
  322. # Track free segments, each containing (start, stop, and buffer that could be reused on this segment).
  323. free_segs: Dict[Tuple, List[Tuple[int, int, Buffer]]] = defaultdict(list) # Dict[buffer key, Tuple[start, end, buffer to reuse on the seg]]
  324. def find_replace_buffer(buf, st, en):
  325. key = (buf.device, buf.dtype, buf.options) + ((buf.nbytes,) if not hasattr(Device[buf.device].allocator, "offset") else tuple())
  326. default_buf = (0, len(buffers) - 1, buf) # will return the buffer itself if the replace one is not found.
  327. seg_st, seg_en, seg_buf = next((free_segs[key].pop(i) for i,(sst,sen,_) in enumerate(free_segs[key]) if sst <= st and en <= sen), default_buf)
  328. free_segs[key] += [(seg_st, st - 1, seg_buf)] if st - 1 >= seg_st else []
  329. free_segs[key] += [(en + 1, seg_en, seg_buf)] if seg_en >= en + 1 else []
  330. return seg_buf if seg_buf.nbytes == buf.nbytes else Buffer(buf.device, buf.size, buf.dtype, base=seg_buf)
  331. buffer_requests = sorted([(first_appearance[buf], last_appearance[buf], buf) for buf in first_appearance.keys()], key=lambda x: -x[2].nbytes)
  332. assigned = {buf:find_replace_buffer(buf, st, en) for st, en, buf in buffer_requests}
  333. for i,u in enumerate(buffers):
  334. for buf in u:
  335. if buf.is_allocated() or buf.lb_refcount > 0 or (noopt_buffers is not None and buf.base in noopt_buffers): continue
  336. if buf._base is not None: assigned[buf] = Buffer(buf.device, buf.size, buf.dtype, base=assigned.get(buf.base, buf.base).base, offset=buf.offset)
  337. else: assigned[buf] = assigned.get(buf, buf)
  338. if DEBUG >= 1 and len(ak:=dedup(x for x in assigned.keys() if x._base is None)) != len(av:=dedup(x for x in assigned.values() if x._base is None)):
  339. print(debug_prefix+f"memory reduced from {sum([x.nbytes for x in ak])/1e6:.2f} MB -> {sum([x.nbytes for x in av])/1e6:.2f} MB,",
  340. f"{len(ak)} -> {len(av)} bufs")
  341. return assigned
  342. def memory_planner(schedule:List[ScheduleItem]) -> List[ScheduleItem]:
  343. # Exclude buffers involved in load ops (e.g transfers) to preserve parallelism in graphs.
  344. assigned = _internal_memory_planner([si.bufs for si in schedule],
  345. noopt_buffers={b for si in schedule if si.ast.op is not MetaOps.KERNEL for b in si.bufs})
  346. return [ScheduleItem(si.ast, tuple(assigned.get(x, x) for x in si.bufs), si.metadata) for si in schedule]