diff --git a/src/main/java/org/apache/sysds/runtime/controlprogram/ParForProgramBlock.java b/src/main/java/org/apache/sysds/runtime/controlprogram/ParForProgramBlock.java index c4c75c35e79..848d4bc7956 100644 --- a/src/main/java/org/apache/sysds/runtime/controlprogram/ParForProgramBlock.java +++ b/src/main/java/org/apache/sysds/runtime/controlprogram/ParForProgramBlock.java @@ -738,6 +738,10 @@ public void execute(ExecutionContext ec) private void executeLocalParFor( ExecutionContext ec, IntObject from, IntObject to, IntObject incr ) throws InterruptedException { + if (DMLScript.USE_ACCELERATOR) { + _numThreads = Math.min(_numThreads, ec.getNumGPUContexts()); + } + LOG.trace("Local Par For (multi-threaded) with degree of parallelism : " + _numThreads); /* Step 1) init parallel workers, task queue and threads * start threads (from now on waiting for tasks) @@ -808,6 +812,12 @@ private void executeLocalParFor( ExecutionContext ec, IntObject from, IntObject LineageCacheConfig.setReuseLineageTraces(false); //disable lineage trace reuse for( Thread thread : threads ) thread.join(); + + if (DMLScript.USE_ACCELERATOR) { + for(LocalParWorker worker : workers) { + LOG.trace("The worker of GPU " + worker.getExecutionContext().getGPUContext(0).toString() + " has executed " + worker.getExecutedTasks() + " tasks."); + } + } if( _monitor ) StatisticMonitor.putPFStat(_ID, Stat.PARFOR_WAIT_EXEC_T, time.stop()); diff --git a/src/main/java/org/apache/sysds/runtime/controlprogram/caching/CacheableData.java b/src/main/java/org/apache/sysds/runtime/controlprogram/caching/CacheableData.java index b70e5837650..0c1ba31af0f 100644 --- a/src/main/java/org/apache/sysds/runtime/controlprogram/caching/CacheableData.java +++ b/src/main/java/org/apache/sysds/runtime/controlprogram/caching/CacheableData.java @@ -22,8 +22,8 @@ import java.io.File; import java.io.IOException; import java.lang.ref.SoftReference; -import java.util.HashMap; import java.util.Map; +import java.util.concurrent.ConcurrentHashMap; import java.util.concurrent.atomic.AtomicLong; import org.apache.commons.lang3.mutable.MutableBoolean; @@ -214,7 +214,7 @@ public enum CacheStatus { //for lazily evaluated RDDs, and (2) as abstraction for environments that do not necessarily have spark libraries available private RDDObject _rddHandle = null; //RDD handle private BroadcastObject _bcHandle = null; //Broadcast handle - protected HashMap _gpuObjects = null; //Per GPUContext object allocated on GPU + protected ConcurrentHashMap _gpuObjects = null; //Per GPUContext object allocated on GPU private LineageItem _lineage = null; @@ -229,7 +229,7 @@ protected CacheableData(DataType dt, ValueType vt) { _uniqueID = _seq.getNextID(); _cacheStatus = CacheStatus.EMPTY; _numReadThreads = 0; - _gpuObjects = DMLScript.USE_ACCELERATOR ? new HashMap<>() : null; + _gpuObjects = DMLScript.USE_ACCELERATOR ? new ConcurrentHashMap<>() : null; } /** @@ -472,7 +472,7 @@ public synchronized GPUObject getGPUObject(GPUContext gCtx) { public synchronized void setGPUObject(GPUContext gCtx, GPUObject gObj) { if( _gpuObjects == null ) - _gpuObjects = new HashMap<>(); + _gpuObjects = new ConcurrentHashMap<>(); GPUObject old = _gpuObjects.put(gCtx, gObj); if (old != null) throw new DMLRuntimeException("GPU : Inconsistent internal state - this CacheableData already has a GPUObject assigned to the current GPUContext (" + gCtx + ")"); diff --git a/src/test/config/SystemDS-SingleGPU-config.xml b/src/test/config/SystemDS-SingleGPU-config.xml new file mode 100644 index 00000000000..e8a950c9164 --- /dev/null +++ b/src/test/config/SystemDS-SingleGPU-config.xml @@ -0,0 +1,29 @@ + + + + + 2 + + 2 + + 128 + + 0 + \ No newline at end of file diff --git a/src/test/java/org/apache/sysds/test/gpu/multigpu/GPUTest.java b/src/test/java/org/apache/sysds/test/gpu/multigpu/GPUTest.java new file mode 100644 index 00000000000..32f4234ed87 --- /dev/null +++ b/src/test/java/org/apache/sysds/test/gpu/multigpu/GPUTest.java @@ -0,0 +1,165 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +package org.apache.sysds.test.gpu.multigpu; + +import java.util.ArrayList; +import java.util.List; +import java.util.regex.Matcher; +import java.util.regex.Pattern; + +import org.apache.log4j.AppenderSkeleton; +import org.apache.log4j.Level; +import org.apache.log4j.Logger; +import org.apache.log4j.spi.LoggingEvent; +import org.apache.sysds.runtime.controlprogram.ParForProgramBlock; +import org.apache.sysds.test.AutomatedTestBase; +import org.apache.sysds.test.TestConfiguration; + +import static org.junit.Assert.assertEquals; +import static org.junit.Assert.assertTrue; + +public abstract class GPUTest extends AutomatedTestBase { + protected static final String TEST_DIR = "gpu/"; + protected static final String TEST_CLASS_DIR = TEST_DIR + MultiGPUTest.class.getSimpleName() + "/"; + protected static final String SINGLE_GPU_TEST = "SingleGPUTest"; + protected static final String MULTI_GPUS_TEST = "MultiGPUsTest"; + protected static final String TEST_NAME = "InferenceScript"; + protected static final String TRAIN_SCRIPT = "TrainScript"; + protected static final String DATA_SET = DATASET_DIR + "MNIST/mnist_test.csv"; + protected static final String SINGLE_TEST_CONFIG = CONFIG_DIR + "SystemDS-SingleGPU-config.xml"; + protected static final String MULTI_TEST_CONFIG = CONFIG_DIR + "SystemDS-config.xml"; + + @Override + public void setUp() { + TEST_GPU = true; + VERBOSE_STATS = true; + addTestConfiguration(SINGLE_GPU_TEST, + new TestConfiguration(TEST_CLASS_DIR, SINGLE_GPU_TEST, new String[] { "R" })); + addTestConfiguration(MULTI_GPUS_TEST, + new TestConfiguration(TEST_CLASS_DIR, MULTI_GPUS_TEST, new String[] { "R" })); + } + + /** + * Run the test with multiple GPUs + * + * @param multiGPUs whether to run the test with multiple GPUs + */ + protected void runMultiGPUsTest(boolean multiGPUs, int numTestImages) { + getAndLoadTestConfiguration(multiGPUs ? MULTI_GPUS_TEST : SINGLE_GPU_TEST); + + String HOME = SCRIPT_DIR + TEST_DIR; + fullDMLScriptName = HOME + TEST_NAME + ".dml"; + programArgs = new String[] { "-args", DATA_SET, output("R"), Integer.toString(numTestImages), "-config", + multiGPUs ? MULTI_TEST_CONFIG : SINGLE_TEST_CONFIG }; + fullRScriptName = HOME + TEST_NAME + ".R"; + + rCmd = null; + InMemoryAppender appender = configureLog4j(); + + runTest(true, false, null, -1); + + List logs = appender.getLogMessages(); + int numRealThread = 0; + for (String log : logs) { + if (log.contains("has executed") && extractNumTasks(log) > 0) { + numRealThread ++; + } + } + if (multiGPUs) { + assertTrue(numRealThread > 1); + } else { + assertEquals(1, numRealThread); + } + + appender.clearLogMessages(); + } + + /** + * Run the training script + */ + protected void runTrainingScript(boolean multiGPUs, int numTestImages) { + getAndLoadTestConfiguration(multiGPUs ? MULTI_GPUS_TEST : SINGLE_GPU_TEST); + + String HOME = SCRIPT_DIR + TEST_DIR; + fullDMLScriptName = HOME + TRAIN_SCRIPT + ".dml"; + programArgs = new String[] { "-args", DATA_SET, output("R"), Integer.toString(numTestImages), "-config", + multiGPUs ? MULTI_TEST_CONFIG : SINGLE_TEST_CONFIG }; + fullRScriptName = HOME + TEST_NAME + ".R"; + + rCmd = null; + InMemoryAppender appender = configureLog4j(); + + runTest(true, false, null, -1); + } + + protected static InMemoryAppender configureLog4j() { + Logger rootLogger = Logger.getRootLogger(); + rootLogger.setLevel(Level.ERROR); + + Logger logger = Logger.getLogger(ParForProgramBlock.class.getName()); + logger.setLevel(Level.TRACE); + + InMemoryAppender inMemoryAppender = new InMemoryAppender(); + inMemoryAppender.setThreshold(Level.TRACE); + logger.addAppender(inMemoryAppender); + + return inMemoryAppender; + } + + protected static int extractNumTasks(String logMessage) { + String regex = "has executed (\\d+) tasks"; + Pattern pattern = Pattern.compile(regex); + Matcher matcher = pattern.matcher(logMessage); + if (matcher.find()) { + return Integer.parseInt(matcher.group(1)); + } + throw new IllegalArgumentException("No _numTasks value found in log message"); + } + + protected static class InMemoryAppender extends AppenderSkeleton { + + protected final List logMessages = new ArrayList<>(); + + @Override + protected void append(LoggingEvent event) { + if (event.getLevel().isGreaterOrEqual(Level.TRACE)) { + logMessages.add(event.getRenderedMessage()); + } + } + + @Override + public void close() { + // No resources to release + } + + @Override + public boolean requiresLayout() { + return false; + } + + public List getLogMessages() { + return new ArrayList<>(logMessages); + } + + public void clearLogMessages() { + logMessages.clear(); + } + } +} diff --git a/src/test/java/org/apache/sysds/test/gpu/multigpu/MultiGPUTest.java b/src/test/java/org/apache/sysds/test/gpu/multigpu/MultiGPUTest.java new file mode 100644 index 00000000000..46a7dca38ae --- /dev/null +++ b/src/test/java/org/apache/sysds/test/gpu/multigpu/MultiGPUTest.java @@ -0,0 +1,82 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +package org.apache.sysds.test.gpu.multigpu; + +import org.junit.AfterClass; +import org.junit.FixMethodOrder; +import org.junit.Test; +import org.junit.runners.MethodSorters; + +import java.util.ArrayList; +import java.util.List; + +@FixMethodOrder(MethodSorters.NAME_ASCENDING) +public class MultiGPUTest extends GPUTest { + + private static List executionTimes = new ArrayList<>(); + + @Test + public void test01_gpuTest_10k() { + runMultiGPUsTest(true, 10000); + } + + @Test + public void test01_gpuTest_20k() { + runMultiGPUsTest(true, 20000); + } + + @Test + public void test01_gpuTest_50k() { + runMultiGPUsTest(true, 50000); + } + + @Test + public void test01_gpuTest_100k() { + runMultiGPUsTest(true, 100000); + } + + @Test + public void test01_gpuTest_200k() { + runMultiGPUsTest(true, 200000); + } + + @Test + public void test01_gpuTest_500k() { + runMultiGPUsTest(true, 500000); + } + + @Override + protected void runMultiGPUsTest(boolean multiGPUs, int numTestImages) { + long startTime = System.nanoTime(); + super.runMultiGPUsTest(multiGPUs, numTestImages); + long endTime = System.nanoTime(); + double executionTime = (endTime - startTime) / 1e9; + executionTimes.add(executionTime); + } + + @AfterClass + public static void printExecutionTimes() { + System.out.println("Execution times for each test:"); + for (int i = 0; i < executionTimes.size(); i++) { + System.out.printf("Test %d: %.3f sec\n", i + 1, executionTimes.get(i)); + } + } +} + diff --git a/src/test/java/org/apache/sysds/test/gpu/multigpu/SingleGPUTest.java b/src/test/java/org/apache/sysds/test/gpu/multigpu/SingleGPUTest.java new file mode 100644 index 00000000000..51a45737aba --- /dev/null +++ b/src/test/java/org/apache/sysds/test/gpu/multigpu/SingleGPUTest.java @@ -0,0 +1,84 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +package org.apache.sysds.test.gpu.multigpu; + +import org.junit.AfterClass; +import org.junit.FixMethodOrder; +import org.junit.Test; +import org.junit.runners.MethodSorters; + +import java.util.ArrayList; +import java.util.List; + +@FixMethodOrder(MethodSorters.NAME_ASCENDING) +public class SingleGPUTest extends GPUTest { + + private static List executionTimes = new ArrayList<>(); + + @Test + public void test01_gpuTest_10k() { + runMultiGPUsTest(false, 10000); + } + + @Test + public void test01_gpuTest_20k() { + runMultiGPUsTest(false, 20000); + } + + @Test + public void test01_gpuTest_50k() { + runMultiGPUsTest(false, 50000); + } + + @Test + public void test01_gpuTest_100k() { + runMultiGPUsTest(false, 100000); + } + + @Test + public void test01_gpuTest_200k() { + runMultiGPUsTest(false, 200000); + } + + @Test + public void test01_gpuTest_500k() { + runMultiGPUsTest(false, 500000); + } + + @Override + protected void runMultiGPUsTest(boolean multiGPUs, int numTestImages) { + // Train the model first + super.runTrainingScript(multiGPUs, numTestImages); + + long startTime = System.nanoTime(); + super.runMultiGPUsTest(multiGPUs, numTestImages); + long endTime = System.nanoTime(); + double executionTime = (endTime - startTime) / 1e9; + executionTimes.add(executionTime); + } + + @AfterClass + public static void printExecutionTimes() { + System.out.println("Execution times for each test:"); + for (int i = 0; i < executionTimes.size(); i++) { + System.out.printf("Test %d: %.3f sec\n", i + 1, executionTimes.get(i)); + } + } +} diff --git a/src/test/scripts/applications/nn/component/efficientNet.dml b/src/test/scripts/applications/nn/component/efficientNet.dml new file mode 100644 index 00000000000..1bb177c735a --- /dev/null +++ b/src/test/scripts/applications/nn/component/efficientNet.dml @@ -0,0 +1,337 @@ +#------------------------------------------------------------- +# +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +# +#------------------------------------------------------------- + +# TODO move to builtin functions (needs fix for imports in builtin functions) +# TODO scale up to real EfficientNet-B0 + +# Trains a partial Efficient-Net B0 model +# This script trains the top and bottom part of the Efficient-Net B0 +# The original Efficient-Net B0 has the following Layers +#---------------------------------------------------------------- +# Layers Dimension Filters Nr Repeats +#---------------------------------------------------------------- +# 1. Conv3x3 224x224 32 1 +# 2. MBConv1, k3x3 112x112 16 1 +# 3. MBConv6, k3x3 56x 56 24 2 +# 4. MBConv6, k5x5 28x 28 40 2 +# 5. MBConv6, k3x3 14x 14 80 3 +# 6. MBConv6, k5x5 14x 14 112 3 +# 7. MBConv6, k5x5 7x 7 192 4 +# 8. MBConv6, k3x3 7x 7 320 1 +# 9. Conv1x1 & Pooling & FC 7x 7 1280 1 +#---------------------------------------------------------------- +# In this partial implementation we implement the layers number 1, 2 and the prediction layer 9 +# This init-Method is purely for convenience reasons there is not problem with a manual initialization of weight and +# biases. To extend the current implementation to a full EfficientNet-B0 only the intermediate MBConv need to be extended +# Both stem and top part are already complete as is the first MBConv layer. +# The number after MBConv is the corresponding ExpansionFactor and is followed +# by the kernel size stride and padding can be calculated from the dimension. If the layer is repeated +# The skip connection is activated otherwise not. +#---------------------------------------------------------------- + +source("scripts/nn/layers/batch_norm2d.dml") as batchnorm +source("scripts/nn/layers/conv2d_builtin.dml") as conv2d +source("scripts/nn/layers/conv2d_depthwise.dml") as depthwise +source("scripts/nn/layers/global_avg_pool2d.dml") as global_avg_pool +source("scripts/nn/layers/silu.dml") as silu +source("scripts/nn/layers/upsample2d.dml") as upsample +source("scripts/nn/layers/mbconv.dml") as mbconv +source("scripts/nn/layers/affine.dml") as affine +source("scripts/nn/layers/softmax.dml") as softmax +source("scripts/nn/optim/sgd.dml") as sgd +source("scripts/nn/layers/cross_entropy_loss.dml") as cross_entropy_loss + +initNetwork = function(int InputChannels, int NumberOutputClasses, int seed) + return(list[unknown] model) +{ + /* + * Convenience function for initialization of all required weights and biases. + * + * Inputs: + * - InputChannels: Number of Input Channels for the model (Cin) + * - NumberOutputClasses: Number of classes for the network + * - seed: seed for the random generation of the weights + * + * Outputs: + * - model: A list containing the total of 36 matrices needed for the computation of the + * Mini EfficientNet + */ + + # Layer 1 + [CW_stem, Cb_stem] = conv2d::init(32, InputChannels, 3, 3, seed) + seed = ifelse(seed==-1, -1, seed + 1); + [Gamma_stem, Beta_stem, EmaMean_stem, EmaVar_stem] = batchnorm::init(32) + + # Layer 2 + [mb_parameters] = mbconv::init(32, 16, 3, 3, 1, 0.25, seed) + seed = ifelse(seed==-1, -1, seed + 1); + + # Layer 9 + [CW_top, Cb_top] = conv2d::init(1280, 16, 1, 1, seed) + seed = ifelse(seed==-1, -1, seed + 1); + [Gamma_top, Beta_top, EmaMean_top, EmaVar_top] = batchnorm::init(1280) + [DW_top, Db_top] = affine::init(1280, NumberOutputClasses, seed) + + model = list(CW_stem, Cb_stem, Gamma_stem, Beta_stem, EmaMean_stem, EmaVar_stem, + as.matrix(mb_parameters[1]), + as.matrix(mb_parameters[2]), + as.matrix(mb_parameters[3]), + as.matrix(mb_parameters[4]), + as.matrix(mb_parameters[5]), + as.matrix(mb_parameters[6]), + as.matrix(mb_parameters[7]), + as.matrix(mb_parameters[8]), + as.matrix(mb_parameters[9]), + as.matrix(mb_parameters[10]), + as.matrix(mb_parameters[11]), + as.matrix(mb_parameters[12]), + as.matrix(mb_parameters[13]), + as.matrix(mb_parameters[14]), + as.matrix(mb_parameters[15]), + as.matrix(mb_parameters[16]), + as.matrix(mb_parameters[17]), + as.matrix(mb_parameters[18]), + as.matrix(mb_parameters[19]), + as.matrix(mb_parameters[20]), + as.matrix(mb_parameters[21]), + as.matrix(mb_parameters[22]), + CW_top, Cb_top, Gamma_top, Beta_top, EmaMean_top, EmaVar_top, DW_top, Db_top) +} + + +netPredict = function(matrix[double] X, list[unknown] model, int Cin, int Hin, int Win) + return(matrix[double] pred) +{ + /* + * This function generates the prediction of the model for a input X + * + * Inputs: + * - X: Input features of format (N, Cin * Hin * Win) + * - model: the list of length 36 containing the matrices generated from the initNetwork function + * - Cin: Number of input channels (dimensionality of depth). + * - Hin: Input height. + * - Win: Input width. + * + * Outputs: + * - pred: The output of the final softmax layer of the Mini Efficient-Net + */ + CW_stem = as.matrix(model[1]) + Cb_stem = as.matrix(model[2]) + Gamma_stem = as.matrix(model[3]) + Beta_stem = as.matrix(model[4]) + EmaMean_stem = as.matrix(model[5]) + EmaVar_stem = as.matrix(model[6]) + MBConv_params = model[7:28] + CW_top = as.matrix(model[29]) + Cb_top = as.matrix(model[30]) + Gamma_top = as.matrix(model[31]) + Beta_top = as.matrix(model[32]) + EmaMean_top = as.matrix(model[33]) + EmaVar_top = as.matrix(model[34]) + DW_top = as.matrix(model[35]) + Db_top = as.matrix(model[36]) + + padh = (Hin + 1) %% 2 + padw = (Win + 1) %% 2 + + [stem_out, stem_h, stem_w] = conv2d::forward(X, CW_stem, Cb_stem, Cin, Hin, Win, 3, 3, 2, 2, padh, padw) + [bn_stem_out, update_EmaMean_stem, update_EmaVar_stem, cache_EmaMean_stem, cache_EmaVar_stem] = batchnorm::forward( + stem_out, Gamma_stem, Beta_stem, 32, stem_h, stem_w, "train", EmaMean_stem, EmaVar_stem, 0.9, 1e-5) + silu_out = silu::forward(bn_stem_out) + + [mbconv_out, intermediate_mbconv, mbconvbatchnorm_updates, mbconv_h, mbconv_w] = mbconv::forward( + silu_out, MBConv_params, 32, 16, stem_h, stem_w, 3, 3, 2, 2, padh, padw, FALSE, 1, "train", 0.25) + + [top_out, outh, outw] = conv2d::forward(mbconv_out, CW_top, Cb_top, 16, mbconv_h, mbconv_w, 1, 1, 1, 1, 0, 0) + [bntop_out, update_EmaMean_top, update_EmaVar_top, cache_EmaMean_top, cache_EmaVar_top] = batchnorm::forward( + top_out, Gamma_top, Beta_top, 1280, outh, outw, "train", EmaMean_top, EmaVar_top, 0.9, 1e-5) + silu_out2 = silu::forward(bntop_out) + [pool_out, None, None] = global_avg_pool::forward(silu_out2, 1280, outh, outw) + dense_out = affine::forward(pool_out, DW_top, Db_top) + pred = softmax::forward(dense_out) +} + +netTrain = function(list[unknown] model, matrix[double] X, int Cin, int Hin, int Win, + matrix[double] Y, int epochs, int batch_size, double learning_rate, double lr_decay, boolean verbose) + return(list[unknown] trained_model) +{ + /* + * This function trains the given model with an sgd optimizer with the given batch_size for a number of + * epochs. + * + * Inputs: + * - model: the list of length 36 containing the matrices generated from the initNetwork function + * - X: Input features of format (N, Cin * Hin * Win) + * - Cin: Number of input channels (dimensionality of depth). + * - Hin: Input height. + * - Win: Input width. + * - Y: The true labels used for learning in a OneHotEncoding(N, NumberOutClasses) + * - epochs: Number of epochs to train for + * - batch_size: Size of batch used for a single update step + * - learning_rate: Size of batch used for a single update step + * - lr_decay: The learning rate is multiplied with lr_decay after each epoch. + * - verbose: Whether the accuracy and the cross-entropy loss should be printed after each update step + * + * Outputs: + * - trained_model: The new list of the updated 36 matrices + */ + CW_stem = as.matrix(model[1]) + Cb_stem = as.matrix(model[2]) + Gamma_stem = as.matrix(model[3]) + Beta_stem = as.matrix(model[4]) + EmaMean_stem = as.matrix(model[5]) + EmaVar_stem = as.matrix(model[6]) + MBConv_params = model[7:28] + CW_top = as.matrix(model[29]) + Cb_top = as.matrix(model[30]) + Gamma_top = as.matrix(model[31]) + Beta_top = as.matrix(model[32]) + EmaMean_top = as.matrix(model[33]) + EmaVar_top = as.matrix(model[34]) + DW_top = as.matrix(model[35]) + Db_top = as.matrix(model[36]) + + padh = (Hin + 1) %% 2 + padw = (Win + 1) %% 2 + + N = nrow(X) + lr = learning_rate + + # Optimize + iters = ceil(N / batch_size) + for (e in 1:epochs) { + for(i in 1:iters) { + # Get next batch + beg = ((i-1) * batch_size) %% N + 1 + end = min(N, beg + batch_size - 1) + X_batch = X[beg:end,] + y_batch = Y[beg:end,] + + # Compute forward pass + [stem_out, stem_h, stem_w] = conv2d::forward(X_batch, CW_stem, Cb_stem, Cin, Hin, Win, 3, 3, 2, 2, padh, padw) + [bn_stem_out, update_EmaMean_stem, update_EmaVar_stem, cache_EmaMean_stem, cache_EmaVar_stem] = batchnorm::forward(stem_out, Gamma_stem, Beta_stem, 32, stem_h, stem_w, "train", EmaMean_stem, EmaVar_stem, 0.9, 1e-5) + silu_out = silu::forward(bn_stem_out) + + [mbconv_out, intermediate_mbconv, mbconvbatchnorm_updates, mbconv_h, mbconv_w] = mbconv::forward(silu_out, MBConv_params, 32, 16, stem_h, stem_w, 3, 3, 2, 2, padh, padw, FALSE, 1, "train", 0.25) + + [top_out, outh, outw] = conv2d::forward(mbconv_out, CW_top, Cb_top, 16, mbconv_h, mbconv_w, 1, 1, 1, 1, 0, 0) + [bntop_out, update_EmaMean_top, update_EmaVar_top, cache_EmaMean_top, cache_EmaVar_top] = batchnorm::forward(top_out, Gamma_top, Beta_top, 1280, outh, outw, "train", EmaMean_top, EmaVar_top, 0.9, 1e-5) + silu_out2 = silu::forward(bntop_out) + [pool_out, None, None] = global_avg_pool::forward(silu_out2, 1280, outh, outw) + dense_out = affine::forward(pool_out, DW_top, Db_top) + pred = softmax::forward(dense_out) + + # Compute loss & accuracy for training + loss = cross_entropy_loss::forward(pred, y_batch) + if(verbose) { + accuracy = mean(rowIndexMax(pred) == rowIndexMax(y_batch)) + print("Epoch: " + e + ", Iter: " + i + ", Train Loss: " + loss + ", Train Accuracy: " + accuracy) + } + + # Compute backward pass + ## loss: + dprobs = cross_entropy_loss::backward(pred, y_batch) + + ## TOP + d_softmax = softmax::backward(dprobs, dense_out) + [d_dense_back, dDenseW_top, dDenseb_top] = affine::backward(d_softmax, pool_out, DW_top, Db_top) + d_pool_back = global_avg_pool::backward(d_dense_back, silu_out2, 1280, outh, outw) + d_silu2_back = silu::backward(d_pool_back, bntop_out) + [d_bntop_back, dGamma_top, dBeta_top] = batchnorm::backward(d_silu2_back, cache_EmaMean_top, cache_EmaVar_top, top_out, Gamma_top, 1280, outh, outw, 1e-5) + [dtop_back, d_ConvW_top, d_Convb_top] = conv2d::backward(d_bntop_back, outh, outw, mbconv_out, CW_top, Cb_top, 16, mbconv_h, mbconv_w, 1, 1, 1, 1, 0, 0) + + # MBCONV + [d_mbconv_back, mbconv_gradients] = mbconv::backward(dtop_back, silu_out, MBConv_params, intermediate_mbconv, mbconvbatchnorm_updates, 32, 16, stem_h, stem_w, 3, 3, 2, 2, padh, padw, FALSE, 1, "train", 0.25) + + ## STEM + d_silu_back = silu::backward(d_mbconv_back, bn_stem_out) + [d_bn_stem_back, dGamma_stem, dBeta_stem] = batchnorm::backward(d_silu_back, cache_EmaMean_stem, cache_EmaVar_stem, stem_out, Gamma_stem, 32, stem_h, stem_w, 1e-5) + [dconv_back, dW_stem, db_stem] = conv2d::backward(d_bn_stem_back, stem_h, stem_w, X_batch, CW_stem, Cb_stem, Cin, Hin, Win, 3, 3, 2, 2, padh, padw) + + #Optimize with SGD + # Update Stem + CW_stem = sgd::update(CW_stem, dW_stem, lr) + Cb_stem = sgd::update(Cb_stem, db_stem, lr) + Gamma_stem = sgd::update(Gamma_stem, dGamma_stem, lr) + Beta_stem = sgd::update(Beta_stem, dBeta_stem, lr) + EmaMean_stem = update_EmaMean_stem + EmaVar_stem = update_EmaVar_stem + + # Update MBConv + update_depth_W = sgd::update(as.matrix(MBConv_params[7]), as.matrix(mbconv_gradients[11]), lr) + update_depth_b = sgd::update(as.matrix(MBConv_params[8]), as.matrix(mbconv_gradients[12]), lr) + update_gamma_depth = sgd::update(as.matrix(MBConv_params[9]), as.matrix(mbconv_gradients[9]), lr) + update_beta_depth = sgd::update(as.matrix(MBConv_params[10]), as.matrix(mbconv_gradients[10]), lr) + update_ema_mean_depth = as.matrix(mbconvbatchnorm_updates[5]) + update_ema_var_depth = as.matrix(mbconvbatchnorm_updates[6]) + update_squeeze_W = sgd::update(as.matrix(MBConv_params[13]), as.matrix(mbconv_gradients[7]), lr) + update_squeeze_b = sgd::update(as.matrix(MBConv_params[14]), as.matrix(mbconv_gradients[8]), lr) + update_excite_W = sgd::update(as.matrix(MBConv_params[15]), as.matrix(mbconv_gradients[5]), lr) + update_excite_b = sgd::update(as.matrix(MBConv_params[16]), as.matrix(mbconv_gradients[6]), lr) + update_out_W = sgd::update(as.matrix(MBConv_params[17]), as.matrix(mbconv_gradients[3]), lr) + update_out_b = sgd::update(as.matrix(MBConv_params[18]), as.matrix(mbconv_gradients[4]), lr) + update_out_gamma = sgd::update(as.matrix(MBConv_params[19]), as.matrix(mbconv_gradients[1]), lr) + update_out_beta = sgd::update(as.matrix(MBConv_params[20]), as.matrix(mbconv_gradients[2]), lr) + update_ema_mean_out = as.matrix(mbconvbatchnorm_updates[9]) + update_ema_var_out = as.matrix(mbconvbatchnorm_updates[10]) + + MBConv_params = list( + as.matrix(model[7]), as.matrix(model[8]), + as.matrix(model[9]), as.matrix(model[10]), + as.matrix(model[11]), as.matrix(model[12]), + update_depth_W, update_depth_b, + update_gamma_depth, update_beta_depth, + update_ema_mean_depth, update_ema_var_depth, + update_squeeze_W, update_squeeze_b, + update_excite_W, update_excite_b, + update_out_W, update_out_b, + update_out_gamma, update_out_beta, + update_ema_mean_out, update_ema_var_out) + + # Update Top + CW_top = sgd::update(CW_top, d_ConvW_top, lr) + Cb_top = sgd::update(Cb_top, d_Convb_top, lr) + Gamma_top = sgd::update(Gamma_top, dGamma_top, lr) + Beta_top = sgd::update(Beta_top, dBeta_top, lr) + EmaMean_top = update_EmaMean_top + EmaVar_top = update_EmaVar_top + DW_top = sgd::update(DW_top, dDenseW_top, lr) + Db_top = sgd::update(Db_top, dDenseb_top, lr) + } + # Decay learning rate + lr = lr * lr_decay + } + + # Pack everything into model format + trained_model = list(CW_stem, Cb_stem, Gamma_stem, Beta_stem, EmaMean_stem, EmaVar_stem, + as.matrix(MBConv_params[1]), as.matrix(MBConv_params[2]), + as.matrix(MBConv_params[3]), as.matrix(MBConv_params[4]), + as.matrix(MBConv_params[5]), as.matrix(MBConv_params[6]), + as.matrix(MBConv_params[7]), as.matrix(MBConv_params[8]), + as.matrix(MBConv_params[9]), as.matrix(MBConv_params[10]), + as.matrix(MBConv_params[11]), as.matrix(MBConv_params[12]), + as.matrix(MBConv_params[13]), as.matrix(MBConv_params[14]), + as.matrix(MBConv_params[15]), as.matrix(MBConv_params[16]), + as.matrix(MBConv_params[17]), as.matrix(MBConv_params[18]), + as.matrix(MBConv_params[19]), as.matrix(MBConv_params[20]), + as.matrix(MBConv_params[21]), as.matrix(MBConv_params[22]), + CW_top, Cb_top, Gamma_top, Beta_top, EmaMean_top, EmaVar_top, DW_top, Db_top) +} diff --git a/src/test/scripts/gpu/GPUTest.dml b/src/test/scripts/gpu/GPUTest.dml new file mode 100644 index 00000000000..a9d9f0eb447 --- /dev/null +++ b/src/test/scripts/gpu/GPUTest.dml @@ -0,0 +1,61 @@ +#------------------------------------------------------------- +# +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +# +#------------------------------------------------------------- + +source("src/test/scripts/applications/nn/component/efficientNet.dml") as eff + +# Read training data +data = read($1, format="csv") + +N = nrow(data) + +# Extract images and labels +images = data[,2:ncol(data)] +labels = data[,1] + +# Scale images to [0,1], and one-hot encode the labels +images = images / 255.0 +labels = table(seq(1, N), labels+1, N, 10) + +model = eff::initNetwork(1, 10, -1) + +# Train +epochs = 1 +batch_size = 256 +model = eff::netTrain(model, images, 1, 28, 28, labels, epochs, batch_size, 0.025, 0.9, TRUE) + +# Read num_test_images from arguments +num_test_images = as.integer($3) + +# Also Predict in Batches since otherwise we can run into Memory Issues +# Could be unnecessary on more powerful machines :) +iters = ceil(num_test_images / batch_size) +# Start timing the parfor loop +parfor_start_time = time() +parfor(i in 1:iters) { + # Generate random data for predicting + X_batch = rand(rows=batch_size, cols=28*28, min=0, max=1, sparsity=1.0, pdf="uniform", seed=42) + + pred = eff::netPredict(X_batch, model, 1, 28, 28) +} +# End timing the parfor loop +parfor_end_time = time() +parfor_execution_time = floor((parfor_end_time-parfor_start_time)/1000000000) +print("Parfor Execution Time: " + parfor_execution_time) diff --git a/src/test/scripts/gpu/InferenceScript.dml b/src/test/scripts/gpu/InferenceScript.dml new file mode 100644 index 00000000000..5321653b754 --- /dev/null +++ b/src/test/scripts/gpu/InferenceScript.dml @@ -0,0 +1,73 @@ +#------------------------------------------------------------- +# +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +# +#------------------------------------------------------------- + +source("src/test/scripts/applications/nn/component/efficientNet.dml") as eff + +# Read training data +data = read($1, format="csv") + +N = nrow(data) + +# Extract images and labels +images = data[,2:ncol(data)] +labels = data[,1] + +# Scale images to [0,1], and one-hot encode the labels +images = images / 255.0 +labels = table(seq(1, N), labels+1, N, 10) + +# Load the trained model +model = read("output/model.txt", format="text") + +# Generate random data for predicting +num_test_images = as.integer($3) +test_images = rand(rows=num_test_images, cols=28*28, min=0, max=1, sparsity=1.0, pdf="uniform", seed=42) +test_labels = rand(rows=num_test_images, cols=1, min=0, max=9, sparsity=1.0, pdf="uniform", seed=42) +test_labels = round(test_labels) + +# One-hot encode the test labels +test_labels = table(seq(1, num_test_images), test_labels+1, num_test_images, 10) + +# Materialize intermediates by computing and printing their sums +print("Sum of test_images: " + sum(test_images)) +print("Sum of test_labels: " + sum(test_labels)) + +# Also Predict in Batches since otherwise we can run into Memory Issues +# Could be unnecessary on more powerful machines :) +batch_size = 1024 # Adjust the batch size to a larger value +iters = ceil(num_test_images / batch_size) +partial_accuracies = matrix(0, rows=iters, cols=1) + +# Start timing the parfor loop +parfor_start_time = time() +parfor(i in 1:iters) { + beg = ((i-1) * batch_size) %% num_test_images + 1 + end = min(num_test_images, beg + batch_size - 1) + X_batch = test_images[beg:end,] + y_batch = test_labels[beg:end,] + + pred = eff::netPredict(X_batch, model, 1, 28, 28) + partial_accuracies[i,1] = mean(rowIndexMax(pred) == rowIndexMax(y_batch)) +} +# End timing the parfor loop +parfor_end_time = time() +parfor_execution_time = floor((parfor_end_time-parfor_start_time)/1000000000) +print("Parfor Execution Time: " + parfor_execution_time) diff --git a/src/test/scripts/gpu/TrainScript.dml b/src/test/scripts/gpu/TrainScript.dml new file mode 100644 index 00000000000..c0711398ab7 --- /dev/null +++ b/src/test/scripts/gpu/TrainScript.dml @@ -0,0 +1,45 @@ +#------------------------------------------------------------- +# +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +# +#------------------------------------------------------------- + +source("src/test/scripts/applications/nn/component/efficientNet.dml") as eff + +# Read training data +data = read($1, format="csv") + +N = nrow(data) + +# Extract images and labels +images = data[,2:ncol(data)] +labels = data[,1] + +# Scale images to [0,1], and one-hot encode the labels +images = images / 255.0 +labels = table(seq(1, N), labels+1, N, 10) + +# Initialize and train the model +model = eff::initNetwork(1, 10, -1) +epochs = 1 +batch_size = 256 +model = eff::netTrain(model, images, 1, 28, 28, labels, epochs, batch_size, 0.025, 0.9, TRUE) + +# Save the trained model to disk +write(model, "output/model.txt", format="text") +print("Trained model saved to output/model.bin") \ No newline at end of file