From 96d09027436cbbca032f23aa91d93a9cd0e220d1 Mon Sep 17 00:00:00 2001 From: Bernhard Schuster Date: Sun, 31 Oct 2021 08:37:57 +0100 Subject: [PATCH] refactor/nll: improve approximation --- .../src/layers/loss/negative_log_likelihood.rs | 17 +++++++++++------ 1 file changed, 11 insertions(+), 6 deletions(-) diff --git a/juice/src/layers/loss/negative_log_likelihood.rs b/juice/src/layers/loss/negative_log_likelihood.rs index a56f6c428..de7450d1c 100644 --- a/juice/src/layers/loss/negative_log_likelihood.rs +++ b/juice/src/layers/loss/negative_log_likelihood.rs @@ -82,11 +82,12 @@ impl ComputeOutput for NegativeLogLikelihood { let native_probabilities = probabilities.read(native.device()).unwrap().as_slice::(); let mut writable_loss = Vec::::with_capacity(native_labels.len()); - let mut offset = 0; - for &label_value in native_labels { - let probability_value = native_probabilities[offset + label_value as usize]; - writable_loss.push(-probability_value); - offset += batch_size; + for (i, &label_value) in native_labels.iter().enumerate() { + let index = batch_size * i + label_value as usize; + let probability_value = native_probabilities[index]; + let probability_value2 = probability_value * probability_value; + let probability_value3 = probability_value2 * probability_value; + writable_loss.push(-probability_value + probability_value2 / 2_f32 - probability_value3 / 3_f32); } let mut loss = writable_loss.iter().fold(0f32, |sum, &val| sum + val); @@ -107,17 +108,21 @@ impl ComputeInputGradient for NegativeLogLikelihood { input_data: &[&SharedTensor], input_gradients: &mut [&mut SharedTensor], ) { + let probabilities = input_data[0]; let labels = input_data[1]; let batch_size = Self::batch_size(input_data[0].desc()); let num_classes = self.num_classes; let native = native_backend(); let native_labels = labels.read(native.device()).unwrap().as_slice::(); + let native_probabilities = probabilities.read(native.device()).unwrap().as_slice::(); let mut writable_gradient = vec![0f32; input_gradients[0].desc().size()]; for (batch_n, &label_value) in native_labels.iter().enumerate() { let index = (num_classes * batch_n) + label_value as usize; - writable_gradient[index] = -1f32; + let probability_value = native_probabilities[index]; + let probability_value2 = probability_value * probability_value; + writable_gradient[index] = -1_f32 + probability_value - probability_value2; } crate::util::write_to_memory( input_gradients[0].write_only(native.device()).unwrap(),