diff --git a/.gitignore b/.gitignore index 62bb9b0b7..b46833179 100644 --- a/.gitignore +++ b/.gitignore @@ -5,6 +5,7 @@ composer.lock /phpunit.xml .DS_Store Thumbs.db +wandb/ # IDE's .idea/ diff --git a/android/app/src/main/java/org/openbot/autopilot/AutopilotFragment.java b/android/app/src/main/java/org/openbot/autopilot/AutopilotFragment.java index ec08abf33..2271eb2a3 100644 --- a/android/app/src/main/java/org/openbot/autopilot/AutopilotFragment.java +++ b/android/app/src/main/java/org/openbot/autopilot/AutopilotFragment.java @@ -14,19 +14,16 @@ import android.view.View; import android.view.ViewGroup; import android.widget.AdapterView; -import android.widget.ArrayAdapter; import android.widget.Toast; import androidx.annotation.NonNull; import androidx.annotation.Nullable; import androidx.camera.core.ImageProxy; import androidx.navigation.Navigation; import com.google.android.material.bottomsheet.BottomSheetBehavior; -import java.io.File; import java.io.IOException; import java.util.List; import java.util.Locale; import java.util.concurrent.TimeUnit; -import java.util.stream.Collectors; import org.jetbrains.annotations.NotNull; import org.openbot.R; import org.openbot.common.CameraFragment; @@ -34,25 +31,21 @@ import org.openbot.env.BorderedText; import org.openbot.env.Control; import org.openbot.env.ImageUtils; -import org.openbot.server.ServerCommunication; -import org.openbot.server.ServerListener; import org.openbot.tflite.Autopilot; import org.openbot.tflite.Model; import org.openbot.tflite.Network; import org.openbot.tracking.MultiBoxTracker; import org.openbot.utils.Constants; import org.openbot.utils.Enums; -import org.openbot.utils.FileUtils; import org.openbot.utils.PermissionUtils; import timber.log.Timber; -public class AutopilotFragment extends CameraFragment implements ServerListener { +public class AutopilotFragment extends CameraFragment { // options for drop down in object nav? private FragmentAutopilotBinding binding; private Handler handler; private HandlerThread handlerThread; - private ServerCommunication serverCommunication; private long lastProcessingTimeMs; private boolean computingNetwork = false; @@ -71,8 +64,6 @@ public class AutopilotFragment extends CameraFragment implements ServerListener private Network.Device device = Network.Device.CPU; private int numThreads = -1; - private ArrayAdapter modelAdapter; - @Override public void onCreate(@Nullable Bundle savedInstanceState) { super.onCreate(savedInstanceState); @@ -98,41 +89,12 @@ public void onViewCreated(@NonNull View view, @Nullable Bundle savedInstanceStat binding.cameraToggle.setOnClickListener(v -> toggleCamera()); List models = - masterList.stream() - .filter(f -> f.type.equals(Model.TYPE.AUTOPILOT) && f.pathType != Model.PATH_TYPE.URL) - .map(f -> FileUtils.nameWithoutExtension(f.name)) - .collect(Collectors.toList()); - modelAdapter = new ArrayAdapter<>(requireContext(), R.layout.spinner_item, models); - - modelAdapter.setDropDownViewResource(android.R.layout.simple_dropdown_item_1line); - binding.modelSpinner.setAdapter(modelAdapter); - if (!preferencesManager.getAutopilotModel().isEmpty()) - binding.modelSpinner.setSelection( - Math.max( - 0, - modelAdapter.getPosition( - FileUtils.nameWithoutExtension(preferencesManager.getAutopilotModel())))); + getModelNames( + f -> f.type.equals(Model.TYPE.AUTOPILOT) && f.pathType != Model.PATH_TYPE.URL); + initModelSpinner(binding.modelSpinner, models, preferencesManager.getAutopilotModel()); + initServerSpinner(binding.serverSpinner); setAnalyserResolution(Enums.Preview.HD.getValue()); - binding.modelSpinner.setOnItemSelectedListener( - new AdapterView.OnItemSelectedListener() { - @Override - public void onItemSelected(AdapterView parent, View view, int position, long id) { - String selected = parent.getItemAtPosition(position).toString(); - try { - masterList.stream() - .filter(f -> f.name.contains(selected)) - .findFirst() - .ifPresent(value -> setModel(value)); - - } catch (IllegalArgumentException e) { - e.printStackTrace(); - } - } - - @Override - public void onNothingSelected(AdapterView parent) {} - }); binding.deviceSpinner.setOnItemSelectedListener( new AdapterView.OnItemSelectedListener() { @Override @@ -296,8 +258,6 @@ private void recreateNetwork(Model model, Network.Device device, int numThreads) @Override public synchronized void onResume() { - serverCommunication = new ServerCommunication(requireContext(), this); - serverCommunication.start(); handlerThread = new HandlerThread("inference"); handlerThread.start(); handler = new Handler(handlerThread.getLooper()); @@ -314,7 +274,6 @@ public synchronized void onPause() { } catch (final InterruptedException e) { e.printStackTrace(); } - serverCommunication.stop(); super.onPause(); } @@ -452,51 +411,11 @@ public void onConnectionEstablished(String ipAddress) { requireActivity().runOnUiThread(() -> binding.ipAddress.setText(ipAddress)); } - @Override - public void onAddModel(String model) { - Model item = - new Model( - masterList.size() + 1, - Model.CLASS.AUTOPILOT_F, - Model.TYPE.AUTOPILOT, - model, - Model.PATH_TYPE.FILE, - requireActivity().getFilesDir() + File.separator + model, - "256x96"); - - if (modelAdapter != null && modelAdapter.getPosition(model) == -1) { - modelAdapter.add(model); - masterList.add(item); - FileUtils.updateModelConfig(requireActivity(), masterList); - } else { - if (model.equals(binding.modelSpinner.getSelectedItem())) { - setModel(item); - } - } - Toast.makeText( - requireContext().getApplicationContext(), - "AutopilotModel added: " + model, - Toast.LENGTH_SHORT) - .show(); - } - - @Override - public void onRemoveModel(String model) { - if (modelAdapter != null && modelAdapter.getPosition(model) != -1) { - modelAdapter.remove(model); - } - Toast.makeText( - requireContext().getApplicationContext(), - "AutopilotModel removed: " + model, - Toast.LENGTH_SHORT) - .show(); - } - protected Model getModel() { return model; } - private void setModel(Model model) { + protected void setModel(Model model) { if (this.model != model) { Timber.d("Updating model: %s", model); this.model = model; diff --git a/android/app/src/main/java/org/openbot/common/ControlsFragment.java b/android/app/src/main/java/org/openbot/common/ControlsFragment.java index 8aaf50a98..93b4ad165 100644 --- a/android/app/src/main/java/org/openbot/common/ControlsFragment.java +++ b/android/app/src/main/java/org/openbot/common/ControlsFragment.java @@ -7,13 +7,22 @@ import android.view.View; import android.view.animation.Animation; import android.view.animation.AnimationUtils; +import android.widget.AdapterView; +import android.widget.ArrayAdapter; +import android.widget.Spinner; +import android.widget.Toast; import androidx.activity.result.ActivityResultLauncher; import androidx.activity.result.contract.ActivityResultContracts; import androidx.annotation.NonNull; import androidx.annotation.Nullable; import androidx.fragment.app.Fragment; import androidx.lifecycle.ViewModelProvider; +import java.io.File; import java.util.List; +import java.util.Set; +import java.util.function.Predicate; +import java.util.stream.Collectors; +import org.jetbrains.annotations.NotNull; import org.json.JSONObject; import org.openbot.R; import org.openbot.env.AudioPlayer; @@ -24,6 +33,8 @@ import org.openbot.env.SharedPreferencesManager; import org.openbot.env.Vehicle; import org.openbot.main.MainViewModel; +import org.openbot.server.ServerCommunication; +import org.openbot.server.ServerListener; import org.openbot.tflite.Model; import org.openbot.utils.ConnectionUtils; import org.openbot.utils.Constants; @@ -33,7 +44,9 @@ import org.openbot.utils.PermissionUtils; import timber.log.Timber; -public abstract class ControlsFragment extends Fragment { +public abstract class ControlsFragment extends Fragment implements ServerListener { + private static final String NO_SERVER = "No server"; + protected MainViewModel mViewModel; protected Vehicle vehicle; protected Animation startAnimation; @@ -46,6 +59,13 @@ public abstract class ControlsFragment extends Fragment { protected final String voice = "matthew"; protected List masterList; + protected ServerCommunication serverCommunication; + + private ArrayAdapter modelAdapter; + private ArrayAdapter serverAdapter; + private Spinner modelSpinner; + private Spinner serverSpinner; + @Override public void onViewCreated(@NonNull View view, @Nullable Bundle savedInstanceState) { super.onViewCreated(view, savedInstanceState); @@ -58,6 +78,7 @@ public void onViewCreated(@NonNull View view, @Nullable Bundle savedInstanceStat preferencesManager = new SharedPreferencesManager(requireContext()); audioPlayer = new AudioPlayer(requireContext()); masterList = FileUtils.loadConfigJSONFromAsset(requireActivity()); + serverCommunication = new ServerCommunication(requireContext(), this); requireActivity() .getSupportFragmentManager() @@ -239,8 +260,17 @@ private void toggleIndicatorEvent(int value) { } }); + @NotNull + protected List getModelNames(Predicate filter) { + return masterList.stream() + .filter(filter) + .map(f -> FileUtils.nameWithoutExtension(f.name)) + .collect(Collectors.toList()); + } + @Override public void onResume() { + serverCommunication.start(); super.onResume(); } @@ -255,6 +285,7 @@ public void onDestroy() { @Override public synchronized void onPause() { Timber.d("onPause"); + serverCommunication.stop(); vehicle.setControl(0, 0); super.onPause(); } @@ -265,6 +296,127 @@ public void onStop() { super.onStop(); } + protected void initModelSpinner(Spinner spinner, List models, String selected) { + modelAdapter = new ArrayAdapter<>(requireContext(), R.layout.spinner_item, models); + modelAdapter.setDropDownViewResource(android.R.layout.simple_dropdown_item_1line); + modelSpinner = spinner; + modelSpinner.setAdapter(modelAdapter); + if (!selected.isEmpty()) + modelSpinner.setSelection( + Math.max(0, modelAdapter.getPosition(FileUtils.nameWithoutExtension(selected)))); + modelSpinner.setOnItemSelectedListener( + new AdapterView.OnItemSelectedListener() { + @Override + public void onItemSelected(AdapterView parent, View view, int position, long id) { + String selected = parent.getItemAtPosition(position).toString(); + try { + masterList.stream() + .filter(f -> f.name.contains(selected)) + .findFirst() + .ifPresent(value -> setModel(value)); + + } catch (IllegalArgumentException e) { + e.printStackTrace(); + } + } + + @Override + public void onNothingSelected(AdapterView parent) {} + }); + } + + protected void initServerSpinner(Spinner spinner) { + serverAdapter = new ArrayAdapter<>(requireContext(), R.layout.spinner_item); + serverAdapter.setDropDownViewResource(android.R.layout.simple_dropdown_item_1line); + serverSpinner = spinner; + serverSpinner.setAdapter(serverAdapter); + serverSpinner.setOnItemSelectedListener( + new AdapterView.OnItemSelectedListener() { + @Override + public void onItemSelected(AdapterView parent, View view, int position, long id) { + String selected = parent.getItemAtPosition(position).toString(); + if (selected.equals(NO_SERVER)) { + serverCommunication.disconnect(); + if (serverAdapter.getPosition(preferencesManager.getServer()) > -1) { + preferencesManager.setServer(selected); + } + } else { + serverCommunication.connect(selected); + preferencesManager.setServer(selected); + } + } + + @Override + public void onNothingSelected(AdapterView parent) { + serverCommunication.disconnect(); + } + }); + onServerListChange(serverCommunication.getServers()); + } + + @Override + public void onServerListChange(Set servers) { + if (serverAdapter == null) { + return; + } + requireActivity() + .runOnUiThread( + () -> { + serverAdapter.clear(); + serverAdapter.add(NO_SERVER); + serverAdapter.addAll(servers); + if (!preferencesManager.getServer().isEmpty()) { + serverSpinner.setSelection( + Math.max(0, serverAdapter.getPosition(preferencesManager.getServer()))); + } + }); + } + + @Override + public void onAddModel(String model) { + Model item = + new Model( + masterList.size() + 1, + Model.CLASS.AUTOPILOT_F, + Model.TYPE.AUTOPILOT, + model, + Model.PATH_TYPE.FILE, + requireActivity().getFilesDir() + File.separator + model, + "256x96"); + + if (modelAdapter != null && modelAdapter.getPosition(model) == -1) { + modelAdapter.add(model); + masterList.add(item); + FileUtils.updateModelConfig(requireActivity(), masterList); + } else { + if (model.equals(modelSpinner.getSelectedItem())) { + setModel(item); + } + } + Toast.makeText( + requireContext().getApplicationContext(), + "AutopilotModel added: " + model, + Toast.LENGTH_SHORT) + .show(); + } + + @Override + public void onRemoveModel(String model) { + if (modelAdapter != null && modelAdapter.getPosition(model) != -1) { + modelAdapter.remove(model); + } + Toast.makeText( + requireContext().getApplicationContext(), + "AutopilotModel removed: " + model, + Toast.LENGTH_SHORT) + .show(); + } + + @Override + public void onConnectionEstablished(String ipAddress) {} + + protected void setModel(Model model) {} + protected abstract void processControllerKeyData(String command); protected abstract void processUSBData(String data); diff --git a/android/app/src/main/java/org/openbot/env/SharedPreferencesManager.java b/android/app/src/main/java/org/openbot/env/SharedPreferencesManager.java index a5fa724e1..5d9dc5180 100644 --- a/android/app/src/main/java/org/openbot/env/SharedPreferencesManager.java +++ b/android/app/src/main/java/org/openbot/env/SharedPreferencesManager.java @@ -22,6 +22,7 @@ public class SharedPreferencesManager { private static final String DEFAULT_MODEL = "DEFAULT_MODEL_NAME"; private static final String OBJECT_NAV_MODEL = "OBJECT_NAV_MODEL_NAME"; private static final String AUTOPILOT_MODEL = "AUTOPILOT_MODEL_NAME"; + private static final String SERVER_NAME = "SERVER_NAME"; private static final String OBJECT_TYPE = "OBJECT_TYPE"; private static final String DEFAULT_OBJECT_TYPE = "person"; @@ -107,6 +108,14 @@ public String getAutopilotModel() { return preferences.getString(AUTOPILOT_MODEL, ""); } + public void setServer(String server) { + preferences.edit().putString(SERVER_NAME, server).apply(); + } + + public String getServer() { + return preferences.getString(SERVER_NAME, ""); + } + public void setObjectType(String model) { preferences.edit().putString(OBJECT_TYPE, model).apply(); } diff --git a/android/app/src/main/java/org/openbot/logging/LoggerFragment.java b/android/app/src/main/java/org/openbot/logging/LoggerFragment.java index f788a0a36..eb79a3d20 100644 --- a/android/app/src/main/java/org/openbot/logging/LoggerFragment.java +++ b/android/app/src/main/java/org/openbot/logging/LoggerFragment.java @@ -20,7 +20,6 @@ import android.view.View; import android.view.ViewGroup; import android.widget.AdapterView; -import android.widget.ArrayAdapter; import androidx.activity.result.ActivityResultLauncher; import androidx.activity.result.contract.ActivityResultContracts; import androidx.annotation.NonNull; @@ -34,15 +33,12 @@ import java.util.List; import java.util.Locale; import java.util.concurrent.TimeUnit; -import java.util.stream.Collectors; import org.jetbrains.annotations.NotNull; import org.openbot.R; import org.openbot.common.CameraFragment; import org.openbot.databinding.FragmentLoggerBinding; import org.openbot.env.BotToControllerEventBus; import org.openbot.env.ImageUtils; -import org.openbot.server.ServerCommunication; -import org.openbot.server.ServerListener; import org.openbot.tflite.Model; import org.openbot.utils.ConnectionUtils; import org.openbot.utils.Constants; @@ -52,13 +48,12 @@ import org.zeroturnaround.zip.commons.FileUtils; import timber.log.Timber; -public class LoggerFragment extends CameraFragment implements ServerListener { +public class LoggerFragment extends CameraFragment { private FragmentLoggerBinding binding; private Handler handler; private HandlerThread handlerThread; private Intent intentSensorService; - private ServerCommunication serverCommunication; protected String logFolder; protected boolean loggingEnabled; @@ -113,30 +108,10 @@ public void onViewCreated(@NonNull View view, @Nullable Bundle savedInstanceStat binding.cameraToggle.setOnClickListener(v -> toggleCamera()); - List models = - masterList.stream() - .filter(f -> f.pathType != Model.PATH_TYPE.URL) - .map(f -> org.openbot.utils.FileUtils.nameWithoutExtension(f.name)) - .collect(Collectors.toList()); - - ArrayAdapter modelAdapter = - new ArrayAdapter<>(requireContext(), R.layout.spinner_item, models); - modelAdapter.setDropDownViewResource(android.R.layout.simple_dropdown_item_1line); - binding.modelSpinner.setAdapter(modelAdapter); - binding.modelSpinner.setOnItemSelectedListener( - new AdapterView.OnItemSelectedListener() { - @Override - public void onItemSelected(AdapterView parent, View view, int position, long id) { - String selected = parent.getItemAtPosition(position).toString(); - masterList.stream() - .filter(f -> f.name.contains(selected)) - .findFirst() - .ifPresent(f -> updateCropImageInfo(f)); - } + List models = getModelNames(f -> f.pathType != Model.PATH_TYPE.URL); + initModelSpinner(binding.modelSpinner, models, ""); + initServerSpinner(binding.serverSpinner); - @Override - public void onNothingSelected(AdapterView parent) {} - }); binding.resolutionSpinner.setOnItemSelectedListener( new AdapterView.OnItemSelectedListener() { @Override @@ -173,7 +148,8 @@ public void onNothingSelected(AdapterView parent) {} }); } - private void updateCropImageInfo(Model selected) { + @Override + protected void setModel(Model selected) { frameToCropTransform = null; binding.cropInfo.setText( String.format( @@ -200,8 +176,6 @@ private void updateCropImageInfo(Model selected) { @Override public synchronized void onResume() { - serverCommunication = new ServerCommunication(requireContext(), this); - serverCommunication.start(); handlerThread = new HandlerThread("logging"); handlerThread.start(); handler = new Handler(handlerThread.getLooper()); @@ -218,7 +192,6 @@ public synchronized void onPause() { } catch (final InterruptedException e) { e.printStackTrace(); } - serverCommunication.stop(); super.onPause(); } @@ -317,11 +290,11 @@ private void stopLogging() { // Pack and upload the collected data runInBackground( () -> { - String logZipFile = logFolder + ".zip"; - // Zip the log folder and then delete it - File folder = new File(logFolder); - File zip = new File(logZipFile); try { + String logZipFile = logFolder + ".zip"; + // Zip the log folder and then delete it + File folder = new File(logFolder); + File zip = new File(logZipFile); TimeUnit.MILLISECONDS.sleep(500); // These two lines below are messy and may cause bugs. needs to be looked into ZipUtil.pack(folder, zip); @@ -553,10 +526,4 @@ protected void processFrame(Bitmap bitmap, ImageProxy image) { public void onConnectionEstablished(String ipAddress) { requireActivity().runOnUiThread(() -> binding.ipAddress.setText(ipAddress)); } - - @Override - public void onAddModel(String model) {} - - @Override - public void onRemoveModel(String model) {} } diff --git a/android/app/src/main/java/org/openbot/objectNav/ObjectNavFragment.java b/android/app/src/main/java/org/openbot/objectNav/ObjectNavFragment.java index c93bbf984..dd78bb52c 100644 --- a/android/app/src/main/java/org/openbot/objectNav/ObjectNavFragment.java +++ b/android/app/src/main/java/org/openbot/objectNav/ObjectNavFragment.java @@ -28,7 +28,6 @@ import java.util.List; import java.util.Locale; import java.util.concurrent.TimeUnit; -import java.util.stream.Collectors; import org.jetbrains.annotations.NotNull; import org.openbot.R; import org.openbot.common.CameraFragment; @@ -42,7 +41,6 @@ import org.openbot.tracking.MultiBoxTracker; import org.openbot.utils.Constants; import org.openbot.utils.Enums; -import org.openbot.utils.FileUtils; import org.openbot.utils.PermissionUtils; import timber.log.Timber; @@ -133,40 +131,10 @@ public void onNothingSelected(AdapterView parent) {} binding.cameraToggle.setOnClickListener(v -> toggleCamera()); List models = - masterList.stream() - .filter(f -> f.type.equals(Model.TYPE.DETECTOR) && f.pathType != Model.PATH_TYPE.URL) - .map(f -> FileUtils.nameWithoutExtension(f.name)) - .collect(Collectors.toList()); - ArrayAdapter modelAdapter = - new ArrayAdapter<>(requireContext(), R.layout.spinner_item, models); - - modelAdapter.setDropDownViewResource(android.R.layout.simple_dropdown_item_1line); - binding.modelSpinner.setAdapter(modelAdapter); - if (!preferencesManager.getObjectNavModel().isEmpty()) - binding.modelSpinner.setSelection( - Math.max( - 0, - modelAdapter.getPosition( - FileUtils.nameWithoutExtension(preferencesManager.getObjectNavModel())))); + getModelNames(f -> f.type.equals(Model.TYPE.DETECTOR) && f.pathType != Model.PATH_TYPE.URL); + initModelSpinner(binding.modelSpinner, models, preferencesManager.getObjectNavModel()); setAnalyserResolution(Enums.Preview.HD.getValue()); - binding.modelSpinner.setOnItemSelectedListener( - new AdapterView.OnItemSelectedListener() { - @Override - public void onItemSelected(AdapterView parent, View view, int position, long id) { - String selected = parent.getItemAtPosition(position).toString(); - try { - masterList.stream() - .filter(f -> f.name.contains(selected)) - .findFirst() - .ifPresent(value -> setModel(value)); - } catch (IllegalArgumentException e) { - } - } - - @Override - public void onNothingSelected(AdapterView parent) {} - }); binding.deviceSpinner.setOnItemSelectedListener( new AdapterView.OnItemSelectedListener() { @Override @@ -508,7 +476,8 @@ protected Model getModel() { return model; } - private void setModel(Model model) { + @Override + protected void setModel(Model model) { if (this.model != model) { Timber.d("Updating model: %s", model); this.model = model; diff --git a/android/app/src/main/java/org/openbot/original/CameraActivity.java b/android/app/src/main/java/org/openbot/original/CameraActivity.java index 2dcd6b065..70e5df2cf 100755 --- a/android/app/src/main/java/org/openbot/original/CameraActivity.java +++ b/android/app/src/main/java/org/openbot/original/CameraActivity.java @@ -77,6 +77,7 @@ import java.util.List; import java.util.Locale; import java.util.Objects; +import java.util.Set; import java.util.concurrent.TimeUnit; import java.util.stream.Collectors; import org.json.JSONObject; @@ -612,6 +613,9 @@ public void onRemoveModel(String model) { Toast.makeText(context, "AutopilotModel removed: " + model, Toast.LENGTH_SHORT).show(); } + @Override + public void onServerListChange(Set servers) {} + @Override public void onConnectionEstablished(String ipAddress) {} diff --git a/android/app/src/main/java/org/openbot/server/NsdService.java b/android/app/src/main/java/org/openbot/server/NsdService.java index 734f17009..f6c5d0b55 100644 --- a/android/app/src/main/java/org/openbot/server/NsdService.java +++ b/android/app/src/main/java/org/openbot/server/NsdService.java @@ -5,12 +5,11 @@ import android.net.nsd.NsdManager.DiscoveryListener; import android.net.nsd.NsdManager.ResolveListener; import android.net.nsd.NsdServiceInfo; -import android.util.Log; +import timber.log.Timber; class NsdService { - private static final String TAG = "NSD"; - private static final String SERVICE_TYPE = "_http._tcp."; + private static final String SERVICE_TYPE = "_openbot-server._tcp."; private final DiscoveryListener mDiscoveryListener = new DiscoveryListener() { @@ -23,13 +22,12 @@ public void onServiceFound(NsdServiceInfo service) { // A service was found! Do something with it. String name = service.getServiceName(); String type = service.getServiceType(); - Log.d(TAG, "Service Name=" + name); - Log.d(TAG, "Service Type=" + type); - if (type.equals(SERVICE_TYPE) && name.contains("Openbot")) { + Timber.d("Service Name=%s, Type=%s", name, type); + if (type.equals(SERVICE_TYPE)) { try { mNsdManager.resolveService(service, mResolveListener); } catch (IllegalArgumentException e) { - Log.w(TAG, e); + Timber.w(e, "Unable to resolve openbot server service"); } } } @@ -59,13 +57,13 @@ public void onStopDiscoveryFailed(String serviceType, int errorCode) { public void start(Context context, ResolveListener resolveListener) { this.mResolveListener = resolveListener; - Log.d(TAG, "Start discovery"); + Timber.d("Start discovery"); mNsdManager = (NsdManager) context.getSystemService(Context.NSD_SERVICE); mNsdManager.discoverServices(SERVICE_TYPE, NsdManager.PROTOCOL_DNS_SD, mDiscoveryListener); } public void stop() { - Log.d(TAG, "Stop discovery"); + Timber.d("Stop discovery"); mNsdManager.stopServiceDiscovery(mDiscoveryListener); } } diff --git a/android/app/src/main/java/org/openbot/server/ServerCommunication.java b/android/app/src/main/java/org/openbot/server/ServerCommunication.java index 83966df33..cf5e54b3a 100644 --- a/android/app/src/main/java/org/openbot/server/ServerCommunication.java +++ b/android/app/src/main/java/org/openbot/server/ServerCommunication.java @@ -11,7 +11,10 @@ import cz.msebera.android.httpclient.Header; import java.io.File; import java.io.FileNotFoundException; +import java.util.HashMap; import java.util.HashSet; +import java.util.Map; +import java.util.Set; import java.util.Timer; import java.util.TimerTask; import org.json.JSONArray; @@ -25,6 +28,7 @@ public class ServerCommunication { private final AsyncHttpClient client; private final Context context; private final NsdService nsdService; + private final Map servers = new HashMap<>(); private final NsdManager.ResolveListener resolveListener = new NsdManager.ResolveListener() { @Override @@ -35,13 +39,12 @@ public void onResolveFailed(NsdServiceInfo serviceInfo, int errorCode) { @Override public void onServiceResolved(NsdServiceInfo serviceInfo) { - nsdService.stop(); - serverUrl = - "http://" + serviceInfo.getHost().getHostAddress() + ":" + serviceInfo.getPort(); - Timber.d("Resolved address: %s", serverUrl); - - client.get(context, serverUrl + "/test", testResponseHandler); - serverListener.onConnectionEstablished(serverUrl); + servers.put(serviceInfo.getServiceName(), serviceInfo); + try { + serverListener.onServerListChange(servers.keySet()); + } catch (Exception e) { + Timber.w(e); + } } }; private final JsonHttpResponseHandler testResponseHandler = @@ -117,24 +120,31 @@ public void onSuccess(int statusCode, Header[] headers, File file) { } } - String[] list = dir.list((dir1, name) -> name.endsWith(".tflite")); - if (list != null) { - for (String name : list) { - if (!valid.contains(name)) { - File file = new File(dir + File.separator + name); - if (file.delete()) { - serverListener.onRemoveModel(name); - Timber.d("deleted: %s", name); - } else { - Timber.e("delete error: %s", name); - } - } - } - } + // TODO: Fix the commented code. + // Currently all models that were not added by server (e.g. object detection) are + // removed. The file delete should probably be handled in onRemoveModel in + // ControlsFragment. MasterList/ModelManager needs to be updated and it needs be checked + // that only autopilot models that were added from the server are removed. + + // String[] list = dir.list((dir1, name) -> name.endsWith(".tflite")); + // if (list != null) { + // for (String name : list) { + // if (!valid.contains(name)) { + // File file = new File(dir + File.separator + name); + // if (file.delete()) { + // serverListener.onRemoveModel(name); + // Timber.d("deleted: %s", name); + // } else { + // Timber.e("delete error: %s", name); + // } + // } + // } + // } + } }; - private final Timer timer; private final ServerListener serverListener; + private Timer timer; private String serverUrl; @@ -143,12 +153,12 @@ public ServerCommunication(Context context, ServerListener serverListener) { this.context = context; this.nsdService = new NsdService(); this.serverListener = serverListener; - this.timer = new Timer(); } public void start() { Timber.d("service started"); nsdService.start(context, resolveListener); + timer = new Timer(); timer.scheduleAtFixedRate( new TimerTask() { @Override @@ -163,6 +173,26 @@ public void run() { 10000); } + public void connect(String server) { + NsdServiceInfo serviceInfo = servers.get(server); + if (serviceInfo == null) { + Timber.e("Server not found: %s", server); + return; + } + String ipAddress = serviceInfo.getHost().getHostAddress(); + serverUrl = "http://" + ipAddress + ":" + serviceInfo.getPort(); + Timber.d("Resolved address: %s", serverUrl); + + client.get(context, serverUrl + "/test", testResponseHandler); + serverListener.onConnectionEstablished(ipAddress); + } + + public void disconnect() { + client.cancelRequests(context, true); + serverUrl = null; + serverListener.onConnectionEstablished(context.getString(R.string.ip_placeholder)); + } + public void upload(File file) { if (serverUrl == null) { return; @@ -202,9 +232,14 @@ public void uploadAll() { public void stop() { client.cancelRequests(context, true); + nsdService.stop(); timer.cancel(); } + public Set getServers() { + return servers.keySet(); + } + static class UploadResponseHandler extends JsonHttpResponseHandler { private final File file; diff --git a/android/app/src/main/java/org/openbot/server/ServerListener.java b/android/app/src/main/java/org/openbot/server/ServerListener.java index d2aa466d1..05bbe6106 100644 --- a/android/app/src/main/java/org/openbot/server/ServerListener.java +++ b/android/app/src/main/java/org/openbot/server/ServerListener.java @@ -1,9 +1,13 @@ package org.openbot.server; +import java.util.Set; + public interface ServerListener { void onAddModel(String model); void onRemoveModel(String model); void onConnectionEstablished(String ipAddress); + + void onServerListChange(Set servers); } diff --git a/android/app/src/main/res/layout-land/fragment_autopilot.xml b/android/app/src/main/res/layout-land/fragment_autopilot.xml index e91ea46aa..ddbef5928 100644 --- a/android/app/src/main/res/layout-land/fragment_autopilot.xml +++ b/android/app/src/main/res/layout-land/fragment_autopilot.xml @@ -71,25 +71,33 @@ android:layout_height="wrap_content"> + android:layout_weight="1.2" + android:gravity="center_vertical|start" + android:paddingHorizontal="8dp" + android:text="@string/ip_placeholder" + android:textColor="@android:color/black" /> + + + diff --git a/android/app/src/main/res/layout-land/fragment_logger.xml b/android/app/src/main/res/layout-land/fragment_logger.xml index 74a206171..ea84ceb14 100644 --- a/android/app/src/main/res/layout-land/fragment_logger.xml +++ b/android/app/src/main/res/layout-land/fragment_logger.xml @@ -32,8 +32,8 @@ android:id="@+id/usbToggle" android:layout_width="wrap_content" android:layout_height="wrap_content" - android:button="@drawable/usb_toggle" android:layout_marginEnd="16dp" + android:button="@drawable/usb_toggle" app:layout_constraintBottom_toBottomOf="@+id/camera_toggle" app:layout_constraintEnd_toStartOf="@+id/camera_toggle" app:layout_constraintTop_toTopOf="@+id/camera_toggle" /> @@ -90,7 +90,7 @@ android:layout_width="0dp" android:layout_height="0dp" android:layout_marginStart="16dp" - android:layout_marginEnd="16dp" + android:layout_marginEnd="8dp" android:entries="@array/preview_resolutions" android:gravity="center" android:prompt="@string/preview_resolution" @@ -103,14 +103,27 @@ android:id="@+id/model_spinner" android:layout_width="0dp" android:layout_height="0dp" - android:layout_marginStart="16dp" - android:layout_marginEnd="16dp" - tools:entries="@array/models" + android:layout_marginEnd="8dp" android:prompt="@string/model" app:layout_constraintBottom_toBottomOf="@+id/crop_info" app:layout_constraintEnd_toEndOf="parent" - app:layout_constraintStart_toEndOf="@+id/crop_info" - app:layout_constraintTop_toTopOf="@+id/crop_info" /> + app:layout_constraintStart_toStartOf="@+id/resolution_spinner" + app:layout_constraintTop_toTopOf="@+id/crop_info" + tools:entries="@array/models" /> + + @@ -135,6 +148,29 @@ app:layout_constraintStart_toEndOf="@+id/analyseText" app:layout_constraintTop_toTopOf="@+id/analyseText" /> + + + + + @@ -56,7 +62,7 @@ function TrainProgress({state, clear}: { state: ProgressState, clear: () => any - + @@ -71,10 +77,23 @@ function TrainProgress({state, clear}: { state: ProgressState, clear: () => any preview thumbnails )} - {state.model && ( + {!!state.model && ( + + preview thumbnails + + )} + {state.status === 'success' && ( preview thumbnails )} } + +function predictEndDate(state: ProgressState, now: Date) { + const start = state.startTime.getTime(); + const elapsed = now.getTime() - start; + const fullTime = elapsed / state.percent * 100; + + return new Date(start + fullTime); +} diff --git a/policy/frontend/src/utils/useProgress.ts b/policy/frontend/src/utils/useProgress.ts index 16f4539ae..d305e8738 100644 --- a/policy/frontend/src/utils/useProgress.ts +++ b/policy/frontend/src/utils/useProgress.ts @@ -14,6 +14,7 @@ export interface Hyperparametes { } export interface ProgressState { + startTime: Date; status: 'success' | 'fail' | 'active' | undefined; epoch: number; percent: number; @@ -26,6 +27,7 @@ export interface ProgressState { } const defaultState: ProgressState = { + startTime: new Date(), status: undefined, epoch: 0, percent: 0, @@ -51,6 +53,7 @@ function progressReducer(state: ProgressState, msg: any): ProgressState { switch (msg.event) { case 'started': return { + startTime: new Date(), status: 'active', epoch: 0, percent: 0, @@ -62,6 +65,11 @@ function progressReducer(state: ProgressState, msg: any): ProgressState { ...state, rnd: Date.now(), }; + case 'model': + return { + ...state, + model: msg.payload, + }; case 'logs': return { ...state, @@ -98,11 +106,11 @@ function progressReducer(state: ProgressState, msg: any): ProgressState { ...state, rnd: Date.now(), status: 'success', - model: msg.payload.model, message: 'Done', }; case 'clear': return { + startTime: new Date(), status: undefined, epoch: 0, percent: 0, diff --git a/policy/openbot/associate_frames.py b/policy/openbot/associate_frames.py index 1e345bc86..b62fb3bf0 100644 --- a/policy/openbot/associate_frames.py +++ b/policy/openbot/associate_frames.py @@ -38,13 +38,11 @@ Modified and extended by Matthias Mueller - Intel Intelligent Systems Lab - 2020 The controls are event-based and not synchronized to the frames. This script matches the control signals to frames. -Specifically, if there was no control signal event within some threshold (default: 1ms), the last control signal before the frame is used. +Specifically, if there was no control signal event within some threshold (default: 1ms), +the last control signal before the frame is used. """ -import argparse -import sys import os -import numpy from . import utils @@ -65,7 +63,8 @@ def read_file_list(filename): """ f = open(filename) - header = f.readline() # discard header + # discard header + header = f.readline() data = f.read() lines = data.replace(",", " ").replace("\t", " ").split("\n") data = [ @@ -73,7 +72,7 @@ def read_file_list(filename): for line in lines if len(line) > 0 and line[0] != "#" ] - data = [(int(l[0]), l[1:]) for l in data if len(l) > 1] + data = [(int(line[0]), line[1:]) for line in data if len(line) > 1] return dict(data) @@ -146,8 +145,13 @@ def match_frame_session( f.write("timestamp (frame),time_offset (ctrl-frame),frame,left,right\n") for a, b in matches: f.write( - "%d %d %s %s \n" - % (a, b - a, " ".join(frame_list[a]), " ".join(ctrl_list[b])) + "%d,%d,%s,%s\n" + % ( + a, + b - a, + ",".join(frame_list[a]), + ",".join(ctrl_list[b]), + ) ) print(" Frames and controls matched.") @@ -171,8 +175,8 @@ def match_frame_session( ) for a, b in matches: f.write( - "%d %d %s %s \n" - % (a, b - a, " ".join(frame_list[a]), " ".join(cmd_list[b])) + "%d,%d,%s,%s\n" + % (a, b - a, ",".join(frame_list[a]), ",".join(cmd_list[b])) ) print(" Frames and commands matched.") @@ -189,22 +193,21 @@ def match_frame_session( os.path.join(sensor_path, "matched_frame_ctrl_cmd_processed.txt"), "w" ) as f: f.write("timestamp,frame,left,right,cmd\n") + # max_ctrl = get_max_ctrl(frame_list) for timestamp in list(frame_list): - if len(frame_list[timestamp]) < 6: + frame = frame_list[timestamp] + if len(frame) < 6: continue - left = int(frame_list[timestamp][3]) - right = int(frame_list[timestamp][4]) - if remove_zeros and left - right == 0 and left + right == 0: - print( - " Removed timestamp:%s, left:%d, right:%d" - % (timestamp, left, right) - ) - del frame_list[timestamp] + left = int(frame[3]) + right = int(frame[4]) + # left = normalize(max_ctrl, frame[3]) + # right = normalize(max_ctrl, frame[4]) + if remove_zeros and left == 0 and right == 0: + print(f" Removed timestamp: {timestamp}") + del frame else: - frame_name = os.path.join( - img_path, frame_list[timestamp][2] + "_crop.jpeg" - ) - cmd = int(frame_list[timestamp][5]) + frame_name = os.path.join(img_path, frame[2] + "_crop.jpeg") + cmd = int(frame[5]) f.write( "%s,%s,%d,%d,%d\n" % (timestamp, frame_name, left, right, cmd) ) @@ -213,3 +216,21 @@ def match_frame_session( return read_file_list( os.path.join(sensor_path, "matched_frame_ctrl_cmd_processed.txt") ) + + +def normalize(max_ctrl, val): + return int(int(val) / max_ctrl * 255) + + +def get_max_ctrl(frame_list): + max_val = 0 + for timestamp in list(frame_list): + frame = frame_list[timestamp] + if len(frame) < 6: + continue + left = int(frame[3]) + right = int(frame[4]) + max_val = max(max_val, abs(left), abs(right)) + if max_val == 0: + max_val = 255 + return max_val diff --git a/policy/openbot/dataloader.py b/policy/openbot/dataloader.py index 50cec4a64..53844e6b6 100644 --- a/policy/openbot/dataloader.py +++ b/policy/openbot/dataloader.py @@ -1,19 +1,22 @@ # Created by Matthias Mueller - Intel Intelligent Systems Lab - 2020 import os +from typing import List import tensorflow as tf class dataloader: - def __init__(self, data_dir, datasets): + def __init__(self, data_dir: str, datasets: List[str]): self.data_dir = data_dir self.datasets = datasets self.labels = self.load_labels() self.index_table = self.lookup_table() self.label_values = tf.constant( - [(float(l[0]), float(l[1])) for l in self.labels.values()] + [(float(label[0]), float(label[1])) for label in self.labels.values()] + ) + self.cmd_values = tf.constant( + [(float(label[2])) for label in self.labels.values()] ) - self.cmd_values = tf.constant([(float(l[2])) for l in self.labels.values()]) # Load labels def load_labels(self): @@ -33,7 +36,8 @@ def load_labels(self): "matched_frame_ctrl_cmd_processed.txt", ) ) as f_input: - header = f_input.readline() # discard header + # discard header + header = f_input.readline() data = f_input.read() lines = ( data.replace(",", " ") @@ -48,7 +52,7 @@ def load_labels(self): if len(line) > 0 and line[0] != "#" ] # Tuples containing id: framepath and label: left,right,cmd - data = [(l[1], l[2:]) for l in data if len(l) > 1] + data = [(line[1], line[2:]) for line in data if len(line) > 1] corpus.extend(data) return dict(corpus) diff --git a/policy/openbot/models.py b/policy/openbot/models.py index 149129790..75ed0ecbc 100644 --- a/policy/openbot/models.py +++ b/policy/openbot/models.py @@ -93,7 +93,7 @@ def pilot_net(img_width, img_height, bn=False): ) # fuse input MLP and CNN - combinedInput = tf.keras.layers.concatenate([mlp.input, cnn.output]) + combinedInput = tf.keras.layers.concatenate([mlp.output, cnn.output]) # output MLP x = tf.keras.layers.Dense(50, activation="relu")(combinedInput) diff --git a/policy/openbot/server/api.py b/policy/openbot/server/api.py index 35720b52e..36fbaa7e6 100644 --- a/policy/openbot/server/api.py +++ b/policy/openbot/server/api.py @@ -1,15 +1,12 @@ import asyncio -import glob import os import shutil import threading - from aiohttp import web from aiohttp_json_rpc import JsonRpc import numpy as np from numpyencoder import NumpyEncoder - -from .dataset import get_dataset_list, get_dir_info, get_info +from .dataset import get_dataset_list, get_dir_info, get_info, redoMatching from .models import ( get_model_info, get_models, @@ -20,7 +17,7 @@ from .preview import handle_preview from .prediction import getPrediction from .upload import handle_file_upload -from .. import base_dir, dataset_dir, models_dir +from .. import base_dir, dataset_dir from ..train import ( CancelledException, Hyperparameters, @@ -94,6 +91,7 @@ async def init_api(app: web.Application): ("", getSession), ("", moveSession), ("", deleteSession), + ("", redoMatching), ("", start), ("", stop), ) diff --git a/policy/openbot/server/dataset.py b/policy/openbot/server/dataset.py index f41719d8e..6128e3789 100644 --- a/policy/openbot/server/dataset.py +++ b/policy/openbot/server/dataset.py @@ -45,8 +45,8 @@ def get_info(path, basename=None): if not os.path.isdir(real_path): return None - isSession = is_session(real_path) - if isSession: + is_session = os.path.isdir(real_path + "/images") + if is_session: try: max_offset = 1e3 frames = associate_frames.match_frame_session( @@ -72,7 +72,7 @@ def get_info(path, basename=None): return { "path": "/" + path, "name": basename, - "is_session": isSession, + "is_session": is_session, "ctrl": ctrl, "seconds": seconds, "error": error, @@ -86,14 +86,16 @@ def get_info(path, basename=None): return { "path": "/" + path, "name": basename, - "is_session": isSession, + "is_session": is_session, "files": file_count - dir_count, "dirs": dir_count, } -def is_session(path): - return os.path.isdir(path + "/images") +def redoMatching(path): + max_offset = 1e3 + associate_frames.match_frame_session(dataset_dir + path, max_offset, True, True) + return True def count_lines(path): diff --git a/policy/openbot/server/main.py b/policy/openbot/server/main.py index ff667481b..ccaa737d1 100644 --- a/policy/openbot/server/main.py +++ b/policy/openbot/server/main.py @@ -6,6 +6,12 @@ app = web.Application() + +async def up(app: web.Application): + print("Server is running, press Ctrl-C to exit...") + + app.on_startup.append(register) app.on_startup.append(init_api) app.on_startup.append(init_frontend) +app.on_startup.append(up) diff --git a/policy/openbot/server/zeroconf.py b/policy/openbot/server/zeroconf.py index d6d680ea7..49b29ef96 100644 --- a/policy/openbot/server/zeroconf.py +++ b/policy/openbot/server/zeroconf.py @@ -2,6 +2,7 @@ import asyncio import logging +import os import socket import sys @@ -9,6 +10,11 @@ from aiozeroconf import ServiceInfo, Zeroconf from netifaces import interfaces, ifaddresses, AF_INET +SERVICE_TYPE = "_openbot-server._tcp.local." + +loop = asyncio.get_event_loop() +zc = Zeroconf(loop) + async def register(app: web.Application): await run_test(zc) @@ -16,21 +22,24 @@ async def register(app: web.Application): async def run_test(zc): - global info, desc desc = {} local_ip = ip4_address() + name = ( + os.getenv("OPENBOT_NAME", socket.gethostname()) + .replace(".local", "") + .replace(".", "-") + ) info = ServiceInfo( - "_http._tcp.local.", - "Openbot Web Site._http._tcp.local.", + SERVICE_TYPE, + f"{name}.{SERVICE_TYPE}", address=socket.inet_aton(local_ip), port=8000, weight=0, priority=0, properties=desc, - server="openbot.local.", ) - print("Registration of a service, press Ctrl-C to exit...") + print("Registration of the service with name:", name) await zc.register_service(info) @@ -61,9 +70,6 @@ async def on_shutdown(app): await do_close(zc) -loop = asyncio.get_event_loop() -zc = Zeroconf(loop) - if __name__ == "__main__": logging.basicConfig(level=logging.DEBUG) if len(sys.argv) > 1: diff --git a/policy/openbot/train.py b/policy/openbot/train.py index 6b070ef33..c3c97abbf 100644 --- a/policy/openbot/train.py +++ b/policy/openbot/train.py @@ -52,6 +52,8 @@ class Hyperparameters: USE_LAST: bool = False + WANDB: bool = False + @classmethod def parse(cls, name): m = re.match( @@ -88,6 +90,8 @@ def __init__(self, params: Hyperparameters): self.test_data_dir = "" self.train_datasets = [] self.test_datasets = [] + self.redo_matching = False + self.remove_zeros = True self.image_count_train = 0 self.image_count_test = 0 self.train_ds = None @@ -137,15 +141,15 @@ def on_batch_end(self, batch, logs=None): self.broadcast( "progress", dict( - epoch=int(100 * self.step / steps), - train=int( - 100 * (self.epoch * steps + self.step) / (epochs * steps) + epoch=round(100 * self.step / steps, 1), + train=round( + 100 * (self.epoch * steps + self.step) / (epochs * steps), 1 ), ), ) -def process_data(tr: Training, redo_matching=False, remove_zeros=True): +def process_data(tr: Training): tr.train_datasets = utils.list_dirs(tr.train_data_dir) tr.test_datasets = utils.list_dirs(tr.test_data_dir) @@ -158,15 +162,15 @@ def process_data(tr: Training, redo_matching=False, remove_zeros=True): tr.train_data_dir, tr.train_datasets, max_offset, - redo_matching=redo_matching, - remove_zeros=remove_zeros, + redo_matching=tr.redo_matching, + remove_zeros=tr.remove_zeros, ) test_frames = associate_frames.match_frame_ctrl_cmd( tr.test_data_dir, tr.test_datasets, max_offset, - redo_matching=redo_matching, - remove_zeros=remove_zeros, + redo_matching=tr.redo_matching, + remove_zeros=tr.remove_zeros, ) tr.image_count_train = len(train_frames) @@ -312,14 +316,26 @@ def do_training(tr: Training, callback: tf.keras.callbacks.Callback, verbose=0): "direction_metric": metrics.direction_metric, "angle_metric": metrics.angle_metric, } + model_path = os.path.join(models_dir, tr.model_name, "model") + + if tr.hyperparameters.WANDB: + import wandb + from wandb.keras import WandbCallback + + wandb.init(project="openbot") + + config = wandb.config + config.epochs = tr.hyperparameters.NUM_EPOCHS + config.learning_rate = tr.hyperparameters.LEARNING_RATE + config.batch_size = tr.hyperparameters.TRAIN_BATCH_SIZE + config["model_name"] = tr.model_name + append_logs = False model: tf.keras.Model if tr.hyperparameters.USE_LAST: append_logs = True - dirs = utils.list_dirs(tr.checkpoint_path) - last_checkpoint = sorted(dirs)[-1] model = tf.keras.models.load_model( - os.path.join(tr.checkpoint_path, last_checkpoint), + model_path, custom_objects=tr.custom_objects, compile=False, ) @@ -329,6 +345,10 @@ def do_training(tr: Training, callback: tf.keras.callbacks.Callback, verbose=0): tr.NETWORK_IMG_HEIGHT, tr.hyperparameters.BATCH_NORM, ) + dot_img_file = os.path.join(models_dir, tr.model_name, "model.png") + tf.keras.utils.plot_model(model, to_file=dot_img_file, show_shapes=True) + + callback.broadcast("model", tr.model_name) tr.loss_fn = losses.sq_weighted_mse_angle tr.metric_list = [ @@ -350,61 +370,69 @@ def do_training(tr: Training, callback: tf.keras.callbacks.Callback, verbose=0): tr.image_count_train / tr.hyperparameters.TRAIN_BATCH_SIZE ) callback.broadcast("message", "Fit model...") + callback_list = [ + callbacks.checkpoint_cb(tr.checkpoint_path), + callbacks.tensorboard_cb(tr.log_path), + callbacks.logger_cb(tr.log_path, append_logs), + callback, + ] + + if tr.hyperparameters.WANDB: + callback_list += [WandbCallback()] + tr.history = model.fit( tr.train_ds, epochs=tr.hyperparameters.NUM_EPOCHS, steps_per_epoch=STEPS_PER_EPOCH, validation_data=tr.test_ds, verbose=verbose, - callbacks=[ - callbacks.checkpoint_cb(tr.checkpoint_path), - callbacks.tensorboard_cb(tr.log_path), - callbacks.logger_cb(tr.log_path, append_logs), - callback, - ], + callbacks=callback_list, ) + model.save(model_path) + + if tr.hyperparameters.WANDB: + wandb.save(model_path) + wandb.finish() def do_evaluation(tr: Training, callback: tf.keras.callbacks.Callback, verbose=0): callback.broadcast("message", "Generate plots...") - history = tr.history - log_path = tr.log_path - plt.plot(history.history["mean_absolute_error"], label="mean_absolute_error") + plt.plot(tr.history.history["mean_absolute_error"], label="mean_absolute_error") plt.plot( - history.history["val_mean_absolute_error"], label="val_mean_absolute_error" + tr.history.history["val_mean_absolute_error"], label="val_mean_absolute_error" ) plt.xlabel("Epoch") plt.ylabel("Mean Absolute Error") plt.legend(loc="lower right") - savefig(os.path.join(log_path, "error.png")) + savefig(os.path.join(tr.log_path, "error.png")) - plt.plot(history.history["direction_metric"], label="direction_metric") - plt.plot(history.history["val_direction_metric"], label="val_direction_metric") + plt.plot(tr.history.history["direction_metric"], label="direction_metric") + plt.plot(tr.history.history["val_direction_metric"], label="val_direction_metric") plt.xlabel("Epoch") plt.ylabel("Direction Metric") plt.legend(loc="lower right") - savefig(os.path.join(log_path, "direction.png")) + savefig(os.path.join(tr.log_path, "direction.png")) - plt.plot(history.history["angle_metric"], label="angle_metric") - plt.plot(history.history["val_angle_metric"], label="val_angle_metric") + plt.plot(tr.history.history["angle_metric"], label="angle_metric") + plt.plot(tr.history.history["val_angle_metric"], label="val_angle_metric") plt.xlabel("Epoch") plt.ylabel("Angle Metric") plt.legend(loc="lower right") - savefig(os.path.join(log_path, "angle.png")) + savefig(os.path.join(tr.log_path, "angle.png")) - plt.plot(history.history["loss"], label="loss") - plt.plot(history.history["val_loss"], label="val_loss") + plt.plot(tr.history.history["loss"], label="loss") + plt.plot(tr.history.history["val_loss"], label="val_loss") plt.xlabel("Epoch") plt.ylabel("Loss") plt.legend(loc="lower right") - savefig(os.path.join(log_path, "loss.png")) + savefig(os.path.join(tr.log_path, "loss.png")) callback.broadcast("message", "Generate tflite models...") checkpoint_path = tr.checkpoint_path print("checkpoint_path", checkpoint_path) best_index = np.argmax( - np.array(history.history["val_angle_metric"]) - + np.array(history.history["val_direction_metric"]) + np.array(tr.history.history["val_angle_metric"]) + + np.array(tr.history.history["val_direction_metric"]) ) best_checkpoint = str("cp-%04d.ckpt" % (best_index + 1)) best_tflite = utils.generate_tflite(checkpoint_path, best_checkpoint) @@ -412,8 +440,8 @@ def do_evaluation(tr: Training, callback: tf.keras.callbacks.Callback, verbose=0 print( "Best Checkpoint (val_angle: %s, val_direction: %s): %s" % ( - history.history["val_angle_metric"][best_index], - history.history["val_direction_metric"][best_index], + tr.history.history["val_angle_metric"][best_index], + tr.history.history["val_direction_metric"][best_index], best_checkpoint, ) ) @@ -424,8 +452,8 @@ def do_evaluation(tr: Training, callback: tf.keras.callbacks.Callback, verbose=0 print( "Last Checkpoint (val_angle: %s, val_direction: %s): %s" % ( - history.history["val_angle_metric"][-1], - history.history["val_direction_metric"][-1], + tr.history.history["val_angle_metric"][-1], + tr.history.history["val_direction_metric"][-1], last_checkpoint, ) ) @@ -456,7 +484,7 @@ def do_evaluation(tr: Training, callback: tf.keras.callbacks.Callback, verbose=0 utils.show_test_batch( image_batch.numpy(), cmd_batch.numpy(), label_batch.numpy(), pred_batch ) - savefig(os.path.join(log_path, "test_preview.png")) + savefig(os.path.join(tr.log_path, "test_preview.png")) utils.compare_tf_tflite(best_model, best_tflite) @@ -552,6 +580,9 @@ def create_tfrecord(callback: MyCallback): parser.add_argument( "--resume", action="store_true", help="resume previous training" ) + parser.add_argument( + "--wandb", action="store_true", help="training logs with weights & biases" + ) args = parser.parse_args() @@ -565,6 +596,7 @@ def create_tfrecord(callback: MyCallback): params.FLIP_AUG = args.flip_aug params.CMD_AUG = args.cmd_aug params.USE_LAST = args.resume + params.WANDB = args.wandb def broadcast(event, payload=None): print() diff --git a/policy/policy_learning.ipynb b/policy/policy_learning.ipynb index c09ac0661..65f2e2b70 100644 --- a/policy/policy_learning.ipynb +++ b/policy/policy_learning.ipynb @@ -111,7 +111,8 @@ "params.BATCH_NORM = True\n", "params.FLIP_AUG = False\n", "params.CMD_AUG = False\n", - "params.USE_LAST = False" + "params.USE_LAST = False\n", + "params.WANDB = False" ] }, { @@ -139,7 +140,7 @@ "id": "abfc9b9b", "metadata": {}, "source": [ - "Running this for the first time will take some time. This code will match image frames to the controls (labels) and indicator signals (commands). By default, data samples where the vehicle was stationary will be removed. If this is not desired, you need to pass `remove_zeros=False`. If you have made any changes to the sensor files, changed `remove_zeros` or moved your dataset to a new directory, you need to pass `redo_matching=True`. " + "Running this for the first time will take some time. This code will match image frames to the controls (labels) and indicator signals (commands). By default, data samples where the vehicle was stationary will be removed. If this is not desired, you need to set `tr.remove_zeros = False`. If you have made any changes to the sensor files, changed `remove_zeros` or moved your dataset to a new directory, you need to set `tr.redo_matching = True`. " ] }, { @@ -149,7 +150,9 @@ "metadata": {}, "outputs": [], "source": [ - "train.process_data(tr, redo_matching=False, remove_zeros=True)" + "tr.redo_matching = False\n", + "tr.remove_zeros = True\n", + "train.process_data(tr)" ] }, { diff --git a/policy/requirements.txt b/policy/requirements.txt index 6293f481a..96cc73474 100644 --- a/policy/requirements.txt +++ b/policy/requirements.txt @@ -8,6 +8,6 @@ aiohttp_json_rpc aiozeroconf imageio aiohttp_json_rpc -openbot_frontend==0.6.0 +openbot_frontend==0.7.0 numpyencoder black[jupyter]>=21.8b0