From fad34ca7815b4e976196ae25ee84a88c8522eba1 Mon Sep 17 00:00:00 2001 From: Rounak Datta Date: Tue, 9 Apr 2024 17:20:46 +0530 Subject: [PATCH] Restructure; add furthur operations --- lib/neuron.ml | 49 +++++++++++++++++++++++++++++++++++-------------- lib/neuron.mli | 14 ++++++++++++-- 2 files changed, 47 insertions(+), 16 deletions(-) diff --git a/lib/neuron.ml b/lib/neuron.ml index 0f50f7a..8a30e49 100644 --- a/lib/neuron.ml +++ b/lib/neuron.ml @@ -3,27 +3,48 @@ module Neuron = struct mutable data : float; mutable grad : float; mutable backward : unit -> unit; + + (* capturing the operator would be useful later when we add some viz *) op : string; prev : t list; } - let create data operator = { data; grad = 0.; - backward = (fun () -> ()); op = operator; prev = [] + let create data operator = { + data; + grad = 0.; + backward = (fun () -> ()); + op = operator; prev = []; } - let add a b = - let out = create (a.data +. b.data) "+" in - out.backward <- (fun () -> - a.grad <- a.grad +. out.grad; - b.grad <- b.grad +. out.grad; + let add base partner = + let resultant = create (base.data +. partner.data) "+" in + + resultant.backward <- (fun () -> + base.grad <- base.grad +. resultant.grad; + partner.grad <- partner.grad +. resultant.grad; ); - out + resultant - let mul a b = - let out = create (a.data *. b.data) "*" in - out.backward <- (fun () -> - a.grad <- a.grad +. b.data *. out.grad; - b.grad <- b.grad +. a.data *. out.grad; + let mul base partner = + let resultant = create (base.data *. partner.data) "*" in + + resultant.backward <- (fun () -> + base.grad <- base.grad +. partner.data *. resultant.grad; + partner.grad <- partner.grad +. base.data *. resultant.grad; ); - out + resultant + + let exp base exponent = + let resultant = create (base.data ** exponent) "**" in + + resultant.backward <- (fun () -> + base.grad <- base.grad +. exponent *. (base.data ** (exponent -. 1.)) *. resultant.grad; + ) + + let relu base = + let resultant = create (max 0. base.data) "relu" in + + resultant.backward <- (fun () -> + base.grad <- base.grad +. (if base.data > 0. then resultant.grad else 0.); + ) end diff --git a/lib/neuron.mli b/lib/neuron.mli index 2152871..a2dbea5 100644 --- a/lib/neuron.mli +++ b/lib/neuron.mli @@ -4,9 +4,19 @@ module Neuron : sig (* Constructor; constructs a unit neuron of a value and an operator. *) val create : float -> string -> t - (* Adds two values, resulting in a new value. *) + (* Handles the gradient flows in addition operation. *) val add : t -> t -> t - (* Multiplies two values, resulting in a new value. *) + (* Handles the gradient flows in multiplication operation. *) val mul : t -> t -> t + + (* Handles the gradient flows in exponent / power operation. *) + (* second argument is the exponent. *) + val exp : t -> int -> t + + (* Handles the gradient flows in ReLU operation. *) + val relu : t -> t + + (* Handles backpropagation of the gradients. *) + val backpropagate : t -> unit end