Skip to content

Commit

Permalink
refactor/nll: improve approximation
Browse files Browse the repository at this point in the history
  • Loading branch information
drahnr committed Oct 31, 2021
1 parent 4bf7dda commit 96d0902
Showing 1 changed file with 11 additions and 6 deletions.
17 changes: 11 additions & 6 deletions juice/src/layers/loss/negative_log_likelihood.rs
Original file line number Diff line number Diff line change
Expand Up @@ -82,11 +82,12 @@ impl<B: IBackend> ComputeOutput<f32, B> for NegativeLogLikelihood {
let native_probabilities = probabilities.read(native.device()).unwrap().as_slice::<f32>();

let mut writable_loss = Vec::<f32>::with_capacity(native_labels.len());
let mut offset = 0;
for &label_value in native_labels {
let probability_value = native_probabilities[offset + label_value as usize];
writable_loss.push(-probability_value);
offset += batch_size;
for (i, &label_value) in native_labels.iter().enumerate() {
let index = batch_size * i + label_value as usize;
let probability_value = native_probabilities[index];
let probability_value2 = probability_value * probability_value;
let probability_value3 = probability_value2 * probability_value;
writable_loss.push(-probability_value + probability_value2 / 2_f32 - probability_value3 / 3_f32);
}

let mut loss = writable_loss.iter().fold(0f32, |sum, &val| sum + val);
Expand All @@ -107,17 +108,21 @@ impl<B: IBackend> ComputeInputGradient<f32, B> for NegativeLogLikelihood {
input_data: &[&SharedTensor<f32>],
input_gradients: &mut [&mut SharedTensor<f32>],
) {
let probabilities = input_data[0];
let labels = input_data[1];
let batch_size = Self::batch_size(input_data[0].desc());
let num_classes = self.num_classes;

let native = native_backend();
let native_labels = labels.read(native.device()).unwrap().as_slice::<f32>();
let native_probabilities = probabilities.read(native.device()).unwrap().as_slice::<f32>();
let mut writable_gradient = vec![0f32; input_gradients[0].desc().size()];

for (batch_n, &label_value) in native_labels.iter().enumerate() {
let index = (num_classes * batch_n) + label_value as usize;
writable_gradient[index] = -1f32;
let probability_value = native_probabilities[index];
let probability_value2 = probability_value * probability_value;
writable_gradient[index] = -1_f32 + probability_value - probability_value2;
}
crate::util::write_to_memory(
input_gradients[0].write_only(native.device()).unwrap(),
Expand Down

0 comments on commit 96d0902

Please sign in to comment.