From 4a973ab692c81ba2fb723a9646e87ebd3c8f9324 Mon Sep 17 00:00:00 2001 From: kylebd99 Date: Thu, 5 Dec 2024 14:20:31 -0800 Subject: [PATCH] Move get_julia_scheduler into the class definitions --- src/finch/compiled.py | 14 +++++++------- 1 file changed, 7 insertions(+), 7 deletions(-) diff --git a/src/finch/compiled.py b/src/finch/compiled.py index f56411d..2b7eb6b 100644 --- a/src/finch/compiled.py +++ b/src/finch/compiled.py @@ -34,18 +34,18 @@ class GalleyScheduler(AbstractScheduler): def __init__(self, verbose=False): self.verbose=verbose + def get_julia_scheduler(self): + return jl.Finch.galley_scheduler(verbose=self.verbose) + class DefaultScheduler(AbstractScheduler): def __init__(self, verbose=False): self.verbose=verbose -def get_julia_scheduler(opt): - if isinstance(opt, DefaultScheduler): - return jl.Finch.default_scheduler(verbose=opt.verbose) - elif isinstance(opt, GalleyScheduler): - return jl.Finch.galley_scheduler(verbose=opt.verbose) + def get_julia_scheduler(self): + return jl.Finch.default_scheduler(verbose=self.verbose) def set_optimizer(opt): - jl.Finch.set_scheduler_b(get_julia_scheduler(opt)) + jl.Finch.set_scheduler_b(opt.get_julia_scheduler()) return def compute(tensor: Tensor, *, verbose: bool = False, opt=None, tag=-1): @@ -53,5 +53,5 @@ def compute(tensor: Tensor, *, verbose: bool = False, opt=None, tag=-1): if opt == None: return Tensor(jl.Finch.compute(tensor._obj, verbose=verbose, tag=tag)) else: - return Tensor(jl.Finch.compute(tensor._obj, verbose=verbose, tag=tag, ctx=get_julia_scheduler(opt))) + return Tensor(jl.Finch.compute(tensor._obj, verbose=verbose, tag=tag, ctx=opt.get_julia_scheduler())) return tensor