Skip to content

Commit

Permalink
fixes
Browse files Browse the repository at this point in the history
  • Loading branch information
wsmoses committed Apr 6, 2024
1 parent be41968 commit 0ed9538
Showing 1 changed file with 136 additions and 71 deletions.
207 changes: 136 additions & 71 deletions src/compiler.jl
Original file line number Diff line number Diff line change
Expand Up @@ -45,17 +45,6 @@ end

import GPUCompiler: @safe_debug, @safe_info, @safe_warn, @safe_error

safe_println(head, tail) = ccall(:jl_safe_printf, Cvoid, (Cstring, Cstring...), "%s%s\n",head, tail)
macro safe_show(exs...)
blk = Expr(:block)
for ex in exs
push!(blk.args, :($safe_println($(sprint(Base.show_unquoted, ex)*" = "),
repr(begin local value = $(esc(ex)) end))))
end
isempty(exs) || push!(blk.args, :value)
return blk
end

if LLVM.has_orc_v1()
include("compiler/orcv1.jl")
else
Expand Down Expand Up @@ -4083,10 +4072,13 @@ function lower_convention(functy::Type, mod::LLVM.Module, entry_f::LLVM.Function
end
for e in toErase
if !isempty(collect(uses(e)))
@safe_show mod
@safe_show entry_f
@safe_show e
throw(AssertionError("Use after deletion"))
msg = sprint() do io
println(io, string(mod))
println(io, string(entry_f))
println(io, string(e))
println(io, "Use after deletion")
end
throw(AssertionError(msg))
end
LLVM.API.LLVMInstructionEraseFromParent(e)
end
Expand Down Expand Up @@ -4145,6 +4137,9 @@ function lower_convention(functy::Type, mod::LLVM.Module, entry_f::LLVM.Function
@assert eltype(ty) == value_type(wrapparm)
store!(builder, wrapparm, ptr)
push!(wrapper_args, ptr)
push!(parameter_attributes(wrapper_f, arg.codegen.i-sret-returnRoots), StringAttribute("enzyme_type", string(typetree(arg.typ, ctx, dl, seen))))
push!(parameter_attributes(wrapper_f, arg.codegen.i-sret-returnRoots), StringAttribute("enzymejl_parmtype", string(convert(UInt, unsafe_to_pointer(arg.typ)))))
push!(parameter_attributes(wrapper_f, arg.codegen.i-sret-returnRoots), StringAttribute("enzymejl_parmtype_ref", string(UInt(GPUCompiler.BITS_REF))))
else
push!(wrapper_args, wrapparm)
for attr in collect(parameter_attributes(entry_f, arg.codegen.i))
Expand Down Expand Up @@ -4207,16 +4202,26 @@ function lower_convention(functy::Type, mod::LLVM.Module, entry_f::LLVM.Function

position!(builder, def)
ret!(builder, extract_value!(builder, res, 0))

push!(return_attributes(wrapper_f), StringAttribute("enzyme_type", string(typetree(actualRetType, ctx, dl, seen))))
push!(return_attributes(wrapper_f), StringAttribute("enzymejl_parmtype", string(convert(UInt, unsafe_to_pointer(actualRetType)))))
push!(return_attributes(wrapper_f), StringAttribute("enzymejl_parmtype_ref", string(UInt(GPUCompiler.BITS_REF))))
end
elseif sret
if sretPtr === nothing
ret!(builder)
else
push!(return_attributes(wrapper_f), StringAttribute("enzyme_type", string(typetree(actualRetType, ctx, dl, seen))))
push!(return_attributes(wrapper_f), StringAttribute("enzymejl_parmtype", string(convert(UInt, unsafe_to_pointer(actualRetType)))))
push!(return_attributes(wrapper_f), StringAttribute("enzymejl_parmtype_ref", string(UInt(GPUCompiler.BITS_REF))))
ret!(builder, load!(builder, RT, sretPtr))
end
elseif LLVM.return_type(entry_ft) == LLVM.VoidType()
ret!(builder)
else
push!(return_attributes(wrapper_f), StringAttribute("enzyme_type", string(typetree(actualRetType, ctx, dl, seen))))
push!(return_attributes(wrapper_f), StringAttribute("enzymejl_parmtype", string(convert(UInt, unsafe_to_pointer(actualRetType)))))
push!(return_attributes(wrapper_f), StringAttribute("enzymejl_parmtype_ref", string(UInt(GPUCompiler.BITS_REF))))
ret!(builder, res)
end
dispose(builder)
Expand All @@ -4232,14 +4237,52 @@ function lower_convention(functy::Type, mod::LLVM.Module, entry_f::LLVM.Function
attributes = function_attributes(wrapper_f)
push!(attributes, StringAttribute("enzymejl_mi", string(convert(UInt, pointer_from_objref(mi)))))
push!(attributes, StringAttribute("enzymejl_rt", string(convert(UInt, unsafe_to_pointer(rt)))))

for prev in collect(function_attributes(entry_f))
if kind(prev) == kind(StringAttribute("enzyme_ta_norecur"))
push!(attributes, prev)
end
if kind(prev) == kind(StringAttribute("enzyme_parmremove"))
push!(attributes, prev)
end
if kind(prev) == kind(StringAttribute("enzyme_math"))
push!(attributes, prev)
end
if kind(prev) == kind(StringAttribute("enzyme_shouldrecompute"))
push!(attributes, prev)
end
if kind(prev) == kind(EnumAttribute("readonly"))
push!(attributes, prev)
end
if kind(prev) == kind(EnumAttribute("readnone"))
push!(attributes, prev)
end
if kind(prev) == kind(EnumAttribute("argmemonly"))
push!(attributes, prev)
end
if kind(prev) == kind(EnumAttribute("inaccessiblememonly"))
push!(attributes, prev)
end
if kind(prev) == kind(EnumAttribute("speculatable"))
push!(attributes, prev)
end
if kind(prev) == kind(EnumAttribute("nofree"))
push!(attributes, prev)
end
if kind(prev) == kind(StringAttribute("enzyme_inactive"))
push!(attributes, prev)
end
end

if LLVM.API.LLVMVerifyFunction(wrapper_f, LLVM.API.LLVMReturnStatusAction) != 0
@safe_show mod
@safe_show LLVM.API.LLVMVerifyFunction(wrapper_f, LLVM.API.LLVMPrintMessageAction)
@safe_show wrapper_f
@safe_show parmsRemoved, retRemoved, prargs
flush(stdout)
throw(LLVM.LLVMException("broken function"))
msg = sprint() do io
println(io, string(mod))
println(io, LVM.API.LLVMVerifyFunction(wrapper_f, LLVM.API.LLVMPrintMessageAction))
println(io, string(wrapper_f))
println(io, "parmsRemoved=", parmsRemoved, " retRemoved=", retRemoved, " prargs=", prargs)
println(io, "Broken function")
end
throw(LLVM.LLVMException(msg))
end

ModulePassManager() do pm
Expand Down Expand Up @@ -4334,19 +4377,17 @@ function lower_convention(functy::Type, mod::LLVM.Module, entry_f::LLVM.Function
end

if LLVM.API.LLVMVerifyFunction(wrapper_f, LLVM.API.LLVMReturnStatusAction) != 0
@safe_show mod
@safe_show LLVM.API.LLVMVerifyFunction(wrapper_f, LLVM.API.LLVMPrintMessageAction)
@safe_show wrapper_f
flush(stdout)
throw(LLVM.LLVMException("broken function"))
msg = sprint() do io
println(io, string(mod))
println(io, LVM.API.LLVMVerifyFunction(wrapper_f, LLVM.API.LLVMPrintMessageAction))
println(io, string(wrapper_f))
println(io, "Broken function")
end
throw(LLVM.LLVMException(msg))
end
return wrapper_f, returnRoots, boxedArgs, loweredArgs
end

function adim(::Array{T, N}) where {T, N}
return N
end

function GPUCompiler.codegen(output::Symbol, job::CompilerJob{<:EnzymeTarget};
libraries::Bool=true, deferred_codegen::Bool=true, optimize::Bool=true, toplevel::Bool=true,

Check warning on line 4392 in src/compiler.jl

View workflow job for this annotation

GitHub Actions / format

[JuliaFormatter] reported by reviewdog 🐶 Raw Output: src/compiler.jl:4392:- libraries::Bool=true, deferred_codegen::Bool=true, optimize::Bool=true, toplevel::Bool=true, src/compiler.jl:4393:- strip::Bool=false, validate::Bool=true, only_entry::Bool=false, parent_job::Union{Nothing, CompilerJob} = nothing) src/compiler.jl:4394:- params = job.config.params src/compiler.jl:4887:+ libraries::Bool=true, deferred_codegen::Bool=true, src/compiler.jl:4888:+ optimize::Bool=true, toplevel::Bool=true, src/compiler.jl:4889:+ strip::Bool=false, validate::Bool=true, only_entry::Bool=false, src/compiler.jl:4890:+ parent_job::Union{Nothing,CompilerJob}=nothing) src/compiler.jl:4891:+ params = job.config.params
strip::Bool=false, validate::Bool=true, only_entry::Bool=false, parent_job::Union{Nothing, CompilerJob} = nothing)
Expand Down Expand Up @@ -4629,7 +4670,7 @@ function GPUCompiler.codegen(output::Symbol, job::CompilerJob{<:EnzymeTarget};
name = meth.name
jlmod = meth.module

function handleCustom(name, attrs=[], setlink=true, noinl=true)
function handleCustom(llvmfn, name, attrs=[], setlink=true, noinl=true)
attributes = function_attributes(llvmfn)
custom[k_name] = linkage(llvmfn)
if setlink
Expand All @@ -4648,23 +4689,23 @@ function GPUCompiler.codegen(output::Symbol, job::CompilerJob{<:EnzymeTarget};

julia_activity_rule(llvmfn)
if has_custom_rule
handleCustom("enzyme_custom", [StringAttribute("enzyme_preserve_primal", "*")])
handleCustom(llvmfn, "enzyme_custom", [StringAttribute("enzyme_preserve_primal", "*")])
continue
end

func = mi.specTypes.parameters[1]

sparam_vals = mi.specTypes.parameters[2:end] # mi.sparam_vals
if func == typeof(Base.eps) || func == typeof(Base.nextfloat) || func == typeof(Base.prevfloat)

Check warning on line 4699 in src/compiler.jl

View workflow job for this annotation

GitHub Actions / format

[JuliaFormatter] reported by reviewdog 🐶 Raw Output: src/compiler.jl:4699:- if func == typeof(Base.eps) || func == typeof(Base.nextfloat) || func == typeof(Base.prevfloat) src/compiler.jl:4700:- handleCustom(llvmfn, "jl_inactive_inout", [StringAttribute("enzyme_inactive"), src/compiler.jl:4701:- EnumAttribute("readnone", 0), src/compiler.jl:4702:- EnumAttribute("speculatable", 0), src/compiler.jl:4703:- StringAttribute("enzyme_shouldrecompute") src/compiler.jl:4704:- ]) src/compiler.jl:5194:+ if func == typeof(Base.eps) || func == typeof(Base.nextfloat) || src/compiler.jl:5195:+ func == typeof(Base.prevfloat) src/compiler.jl:5196:+ handleCustom(llvmfn, "jl_inactive_inout", src/compiler.jl:5197:+ [StringAttribute("enzyme_inactive"), src/compiler.jl:5198:+ EnumAttribute("readnone", 0), src/compiler.jl:5199:+ EnumAttribute("speculatable", 0), src/compiler.jl:5200:+ StringAttribute("enzyme_shouldrecompute")])
handleCustom("jl_inactive_inout", [StringAttribute("enzyme_inactive"),
handleCustom(llvmfn, "jl_inactive_inout", [StringAttribute("enzyme_inactive"),
EnumAttribute("readnone", 0),
EnumAttribute("speculatable", 0),
StringAttribute("enzyme_shouldrecompute")
])
continue
end
if func == typeof(Base.to_tuple_type)
handleCustom("jl_to_tuple_type",
handleCustom(llvmfn, "jl_to_tuple_type",
[EnumAttribute("readonly", 0),

Check warning on line 4709 in src/compiler.jl

View workflow job for this annotation

GitHub Actions / format

[JuliaFormatter] reported by reviewdog 🐶 Raw Output: src/compiler.jl:4709:- [EnumAttribute("readonly", 0), src/compiler.jl:4710:- EnumAttribute("inaccessiblememonly", 0), src/compiler.jl:4711:- EnumAttribute("speculatable", 0), src/compiler.jl:4712:- StringAttribute("enzyme_shouldrecompute"), src/compiler.jl:4713:- StringAttribute("enzyme_inactive"), src/compiler.jl:4714:- ]) src/compiler.jl:5205:+ [EnumAttribute("readonly", 0), src/compiler.jl:5206:+ EnumAttribute("inaccessiblememonly", 0), src/compiler.jl:5207:+ EnumAttribute("speculatable", 0), src/compiler.jl:5208:+ StringAttribute("enzyme_shouldrecompute"), src/compiler.jl:5209:+ StringAttribute("enzyme_inactive")])
EnumAttribute("inaccessiblememonly", 0),
EnumAttribute("speculatable", 0),
Expand All @@ -4675,7 +4716,7 @@ function GPUCompiler.codegen(output::Symbol, job::CompilerJob{<:EnzymeTarget};
end
if func == typeof(Base.Threads.threadid) || func == typeof(Base.Threads.nthreads)
name = (func == typeof(Base.Threads.threadid)) ? "jl_threadid" : "jl_nthreads"
handleCustom(name,
handleCustom(llvmfn, name,
[EnumAttribute("readonly", 0),

Check warning on line 4720 in src/compiler.jl

View workflow job for this annotation

GitHub Actions / format

[JuliaFormatter] reported by reviewdog 🐶 Raw Output: src/compiler.jl:4720:- [EnumAttribute("readonly", 0), src/compiler.jl:4721:- EnumAttribute("inaccessiblememonly", 0), src/compiler.jl:4722:- EnumAttribute("speculatable", 0), src/compiler.jl:4723:- StringAttribute("enzyme_shouldrecompute"), src/compiler.jl:4724:- StringAttribute("enzyme_inactive"), src/compiler.jl:4725:- ]) src/compiler.jl:5215:+ [EnumAttribute("readonly", 0), src/compiler.jl:5216:+ EnumAttribute("inaccessiblememonly", 0), src/compiler.jl:5217:+ EnumAttribute("speculatable", 0), src/compiler.jl:5218:+ StringAttribute("enzyme_shouldrecompute"), src/compiler.jl:5219:+ StringAttribute("enzyme_inactive")])
EnumAttribute("inaccessiblememonly", 0),
EnumAttribute("speculatable", 0),
Expand All @@ -4690,15 +4731,15 @@ function GPUCompiler.codegen(output::Symbol, job::CompilerJob{<:EnzymeTarget};
# fn, but it doesn't presently so for now we will ensure this by hand
if func == typeof(Base.Checked.throw_overflowerr_binaryop)
llvmfn = functions(mod)[k.specfunc]
handleCustom("enz_noop", [StringAttribute("enzyme_inactive"), EnumAttribute("readonly")])
handleCustom(llvmfn, "enz_noop", [StringAttribute("enzyme_inactive"), EnumAttribute("readonly")])
continue
end
if EnzymeRules.is_inactive_from_sig(mi.specTypes; world, method_table, caller)
handleCustom("enz_noop", [StringAttribute("enzyme_inactive"), EnumAttribute("nofree")])
handleCustom(llvmfn, "enz_noop", [StringAttribute("enzyme_inactive"), EnumAttribute("nofree")])
continue
end
if EnzymeRules.is_inactive_noinl_from_sig(mi.specTypes; world, method_table, caller)
handleCustom("enz_noop", [StringAttribute("enzyme_inactive"), EnumAttribute("nofree")], false, false)
handleCustom(llvmfn, "enz_noop", [StringAttribute("enzyme_inactive"), EnumAttribute("nofree")], false, false)
for bb in blocks(llvmfn)
for inst in instructions(bb)
if isa(inst, LLVM.CallInst)
Expand All @@ -4710,54 +4751,78 @@ function GPUCompiler.codegen(output::Symbol, job::CompilerJob{<:EnzymeTarget};
continue
end
if func == typeof(Base.enq_work) && length(sparam_vals) == 1 && first(sparam_vals) <: Task
handleCustom("jl_enq_work")
handleCustom(llvmfn, "jl_enq_work")
continue
end
if func == typeof(Base.wait) || func == typeof(Base._wait)
if length(sparam_vals) == 1 && first(sparam_vals) <: Task
handleCustom("jl_wait")
handleCustom(llvmfn, "jl_wait")
end
continue
end
if func == typeof(Base.Threads.threading_run)
if length(sparam_vals) == 1 || length(sparam_vals) == 2
handleCustom("jl_threadsfor")
handleCustom(llvmfn, "jl_threadsfor")
end
continue
end

name = nothing
arity = nothing
toinject = nothing
Tys = nothing
@inline function find_math_method()
if func keys(known_ops)
name, arity, toinject = known_ops[func]
Tys = (Float32, Float64)

if length(sparam_vals) == arity
T = first(sparam_vals)
legal = T Tys

if legal
if name == :ldexp
if !(sparam_vals[2] <: Integer)
legal = false
end
elseif name == :pow
if sparam_vals[2] <: Integer
name = :powi
elseif sparam_vals[2] != T
legal = false
end
elseif name == :jl_rem2pi
else
if !all(==(T), sparam_vals)
legal = false
end
end
end
if legal
return name, toinject, T
end
end
end

if func keys(known_ops)
name, arity, toinject = known_ops[func]
Tys = (Float32, Float64)
elseif func keys(cmplx_known_ops)
name, arity, toinject = cmplx_known_ops[func]
Tys = (Complex{Float32}, Complex{Float64})
else
continue
end
if func keys(cmplx_known_ops)
name, arity, toinject = cmplx_known_ops[func]
Tys = (Complex{Float32}, Complex{Float64})
if length(sparam_vals) == arity
T = first(sparam_vals)
legal = T Tys

length(sparam_vals) == arity || continue
T = first(sparam_vals)
isfloat = T Tys
if !isfloat
continue
if legal
if !all(==(T), sparam_vals)
legal = false
end
end
if legal
return name, toinject, T
end
end
end
return nothing, nothing, nothing
end
if name == :ldexp
sparam_vals[2] <: Integer || continue
elseif name == :pow
if sparam_vals[2] <: Integer
name = :powi
elseif sparam_vals[2] != T
continue
end
elseif name == :jl_rem2pi
else
all(==(T), sparam_vals) || continue

name, toinject, T = find_math_method()
if name === nothing
continue
end

if toinject !== nothing
Expand All @@ -4779,7 +4844,7 @@ function GPUCompiler.codegen(output::Symbol, job::CompilerJob{<:EnzymeTarget};
name = string(name)
name = T == Float32 ? name*"f" : name

handleCustom(name, [EnumAttribute("readnone", 0),
handleCustom(llvmfn, name, [EnumAttribute("readnone", 0),
StringAttribute("enzyme_shouldrecompute")])
end

Expand Down

0 comments on commit 0ed9538

Please sign in to comment.