Skip to content

Commit

Permalink
Fix/orphan cleanup (#1043)
Browse files Browse the repository at this point in the history
  • Loading branch information
nathanielsimard authored Dec 4, 2023
1 parent 9d92f9a commit 0b16786
Show file tree
Hide file tree
Showing 3 changed files with 157 additions and 178 deletions.
46 changes: 33 additions & 13 deletions burn-fusion/src/graph/base.rs
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,6 @@ use super::RelativeGraphConverter;
use super::TensorOpsDescription;
use crate::Optimization;
use crate::{FusionBackend, HandleContainer};
use std::ops::RangeBounds;

/// The computational graph containing a list of [tensor operation descriptions](TensorOpsDescription).
pub struct Graph<B: FusionBackend> {
Expand Down Expand Up @@ -60,26 +59,47 @@ impl<B: FusionBackend> Graph<B> {
let mut context = self.converter.context(handles);
optimization.execute(&mut context);

self.cleanup_partial(0..num_keep, handles);
self.cleanup_partial(num_keep, handles);
}

pub(crate) fn execute_operations(&mut self, handles: &mut HandleContainer<B>) {
for (description, ops) in self.global.drain(..).zip(self.ops.drain(..)) {
for ops in self.ops.drain(..) {
ops.execute(handles);
description.cleanup_tensor(handles);
}

self.cleanup_total(handles);
}

fn cleanup_total(&mut self, handles: &mut HandleContainer<B>) {
self.global
.iter()
.flat_map(|desc| desc.nodes())
.for_each(|tensor| handles.free(tensor));
handles.free_orphans(&[]);

self.global.clear();
self.ops.clear();
self.cleanup_relative_graph();
}

fn cleanup_partial<R: RangeBounds<usize> + Clone>(
&mut self,
range: R,
handles: &mut HandleContainer<B>,
) {
for ops in self.global.drain(range.clone()) {
ops.cleanup_tensor(handles)
}
self.ops.drain(range);
fn cleanup_partial(&mut self, num_keep: usize, handles: &mut HandleContainer<B>) {
self.global[0..num_keep]
.iter()
.flat_map(|desc| desc.nodes())
.for_each(|tensor| handles.free(tensor));

self.global.drain(0..num_keep);

handles.free_orphans(
&self
.global
.iter()
.flat_map(|desc| desc.nodes())
.map(|tensor| &tensor.id)
.collect::<Vec<_>>(),
);

self.ops.drain(0..num_keep);

// Rebuild the relative graph when partially removing the global graph.
self.cleanup_relative_graph();
Expand Down
Loading

0 comments on commit 0b16786

Please sign in to comment.