Skip to content

Commit

Permalink
Set-up JuliaFormatter (#67)
Browse files Browse the repository at this point in the history
* Set-up JuliaFormatter

* Implemented formatting
  • Loading branch information
stemann authored Jan 13, 2025
1 parent 4066bf2 commit 7c1d969
Show file tree
Hide file tree
Showing 20 changed files with 869 additions and 673 deletions.
6 changes: 6 additions & 0 deletions .JuliaFormatter.toml
Original file line number Diff line number Diff line change
@@ -0,0 +1,6 @@
style = "blue"

ignore = ["src/Wrapper.jl"]
pipe_to_function_call = false
whitespace_in_kwargs = true
whitespace_typedefs = true
8 changes: 8 additions & 0 deletions .github/workflows/Format.yml
Original file line number Diff line number Diff line change
@@ -0,0 +1,8 @@
name: Format suggestions
on:
pull_request:
jobs:
code-style:
runs-on: ubuntu-latest
steps:
- uses: julia-actions/julia-format@v3
3 changes: 3 additions & 0 deletions README.md
Original file line number Diff line number Diff line change
@@ -1,5 +1,8 @@
# Torch.jl

[![Build Status](https://github.com/FluxML/Torch.jl/actions/workflows/CI.yaml/badge.svg?branch=master)](https://github.com/FluxML/Torch.jl/actions/workflows/CI.yaml?query=branch%3Amaster)
[![Code Style: Blue](https://img.shields.io/badge/code%20style-blue-4495d1.svg)](https://github.com/JuliaDiff/BlueStyle)

Sensible extensions for exposing torch in Julia.

This package is aimed at providing the `Tensor` type, which offloads all computations over to [ATen](https://pytorch.org/cppdocs/), the foundational tensor library for PyTorch, written in C++.
Expand Down
4 changes: 2 additions & 2 deletions deps/julia_wrapper_generator/generator.jl
Original file line number Diff line number Diff line change
Expand Up @@ -22,11 +22,11 @@ function rewrite!(e::Expr)
end

function rewrite!(e::Expr, ::Val{:function})
rewrite!(e.args[2], Val(e.args[2].head))
return rewrite!(e.args[2], Val(e.args[2].head))
end

function rewrite!(e::Expr, ::Val{:block})
e.args[1] = Expr(:macrocall, Symbol("@runtime_error_check"), nothing, e.args[1])
return e.args[1] = Expr(:macrocall, Symbol("@runtime_error_check"), nothing, e.args[1])
end

function rewrite!(dag::ExprDAG)
Expand Down
2 changes: 1 addition & 1 deletion deps/julia_wrapper_generator/generator.toml
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
[general]
library_name = "libtorch_c_api"
output_file_path = "../../src/wrapper.jl"
output_file_path = "../../src/Wrapper.jl"
prologue_file_path = "./prologue.jl"
module_name = "Wrapper"
jll_pkg_name = "TorchCAPI_jll"
Expand Down
20 changes: 10 additions & 10 deletions deps/julia_wrapper_generator/prologue.jl
Original file line number Diff line number Diff line change
@@ -1,15 +1,15 @@
function get_error()
err = cglobal((:myerr, libtorch_c_api), Cstring) |> unsafe_load
unsafe_string(err)
err = cglobal((:myerr, libtorch_c_api), Cstring) |> unsafe_load
return unsafe_string(err)
end

macro runtime_error_check(ex)
quote
x = $ex
if x == 1
cs = get_error()
flush_error()
throw(cs)
end
end |> esc
return quote
x = $ex
if x == 1
cs = get_error()
flush_error()
throw(cs)
end
end |> esc
end
30 changes: 16 additions & 14 deletions src/Torch.jl
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@ using FillArrays

TURN_ON_LOGGING = false

include("wrapper.jl")
include("Wrapper.jl")

using .Wrapper

Expand All @@ -32,23 +32,25 @@ include("statistics.jl")
include("grads.jl")
include("utils.jl")

@init @require Flux="587475ba-b771-5e3f-ad9e-33799f191a9c" begin
using .Flux
@init @require Flux = "587475ba-b771-5e3f-ad9e-33799f191a9c" begin
using .Flux

function (tbn::Flux.BatchNorm)(x::Tensor)
tbn.λ.(Torch.batchnorm(x, tbn.γ, tbn.β, tbn.μ, tbn.σ², 0, tbn.momentum, tbn.ϵ, 1))
end
function (tbn::Flux.BatchNorm)(x::Tensor)
return tbn.λ.(
Torch.batchnorm(x, tbn.γ, tbn.β, tbn.μ, tbn.σ², 0, tbn.momentum, tbn.ϵ, 1)
)
end

function Flux.Zygote.accum(t1::Tensor, t2::Tensor{T,N}) where {T,N}
ptr = Ref(Ptr{Cvoid}())
function Flux.Zygote.accum(t1::Tensor, t2::Tensor{T, N}) where {T, N}
ptr = Ref(Ptr{Cvoid}())

Torch.Wrapper.atg_add_(ptr, t1.ptr, t2.ptr)
Tensor{T,N}(ptr[], Torch.on(t1))
end
Torch.Wrapper.atg_add_(ptr, t1.ptr, t2.ptr)
return Tensor{T, N}(ptr[], Torch.on(t1))
end

eval(:(Flux.Zygote.@nograd Torch.Wrapper.at_copy_data))
eval(:(Flux.Zygote.@nograd Torch.Wrapper.at_dim))
torch(x) = Flux.fmap(to_tensor, x)
eval(:(Flux.Zygote.@nograd Torch.Wrapper.at_copy_data))
eval(:(Flux.Zygote.@nograd Torch.Wrapper.at_dim))
torch(x) = Flux.fmap(to_tensor, x)
end

end # module
File renamed without changes.
39 changes: 19 additions & 20 deletions src/broadcast.jl
Original file line number Diff line number Diff line change
Expand Up @@ -9,23 +9,23 @@ using Base.Broadcast: broadcast_shape
# Base.BroadcastStyle(::Type{Tensor}) = TensorStyle()

for op in (:+, :-, :/)
@eval function broadcasted(::typeof($op), t1::Tensor, t2::Tensor)
$op(t1, t2)
end
@eval function broadcasted(::typeof($op), t1::Tensor, t2::Tensor)
return $op(t1, t2)
end
end

for op in (:+, :-)
@eval function broadcasted(::typeof($op), t1::Tensor, t2::TensorVector)
t_ = reshape(t2, -1, 1)
$op(t1, t_)
end
@eval function broadcasted(::typeof($op), t1::Tensor, t2::TensorVector)
t_ = reshape(t2, -1, 1)
return $op(t1, t_)
end
end

function broadcasted(::typeof(*), t1::Tensor{T,N}, t2::Tensor{T,M}) where {T,N,M}
ptr = Ref(Ptr{Cvoid}())
function broadcasted(::typeof(*), t1::Tensor{T, N}, t2::Tensor{T, M}) where {T, N, M}
ptr = Ref(Ptr{Cvoid}())

atg_mul(ptr, t1.ptr, t2.ptr)
Tensor{T,max(N,M)}(ptr[], on(t1))
atg_mul(ptr, t1.ptr, t2.ptr)
return Tensor{T, max(N, M)}(ptr[], on(t1))
end

broadcasted(::typeof(NNlib.relu), t::Tensor) = NNlib.relu(t)
Expand All @@ -34,22 +34,21 @@ broadcasted(::typeof(identity), t::Tensor) = identity(t)
broadcasted(::typeof(NNlib.sigmoid), t::Tensor) = NNlib.sigmoid(t)

for op in (:+, :-, :*, :/)
@eval function broadcasted(::typeof($op), t::Tensor, args...)
$op(t, args...)
end
@eval function broadcasted(::typeof($op), t::Tensor, args...)
return $op(t, args...)
end
end

broadcasted(::typeof(sqrt), t::Tensor) = sqrt(t)

function broadcasted(::typeof(copy), t::Tensor{T,N}) where {T,N}
t
function broadcasted(::typeof(copy), t::Tensor{T, N}) where {T, N}
return t
end

@adjoint function broadcast(::typeof(NNlib.sigmoid), t::Tensor)

NNlib.sigmoid(t), Δ -> (∇sigmoid(Δ, t),)
return NNlib.sigmoid(t), Δ -> (∇sigmoid(Δ, t),)
end

@adjoint function broadcasted(::typeof(NNlib.relu), t::Tensor{T}) where T
relu(t), Δ -> (nothing, ∇leaky_relu(Δ, t, zero(T)),)
@adjoint function broadcasted(::typeof(NNlib.relu), t::Tensor{T}) where {T}
return relu(t), Δ -> (nothing, ∇leaky_relu(Δ, t, zero(T)))
end
Loading

0 comments on commit 7c1d969

Please sign in to comment.