Skip to content

Commit

Permalink
Enhance CoreML Model Compilation Process
Browse files Browse the repository at this point in the history
- Refined Functionality: Updated private functions to encapsulate specific tasks, improving readability and modularity.
  - getApplicationSupportURL(): Simplified directory access with a more direct approach.
  - getDigest(modelURL:): Introduced a new function to encapsulate SHA256 digest computation.
  - checkShouldCompileModel(...): Revised logic for checking model compilation necessity, including digest comparison and resource reachability.
  - compileAndSaveModel(...): Streamlined model compilation and saving process, enhancing code structure.
  - loadModel(...): Optimized model loading with configuration settings.
- Code Organization: The refactoring focuses on breaking down the compileMLModel function into smaller, more manageable functions, each responsible for a distinct part of the process. This approach enhances the maintainability and scalability of the code.
- Improved Logging: Enhanced logging throughout the process for better traceability and debugging.
  • Loading branch information
ChinChangYang committed Nov 11, 2023
1 parent 0c25374 commit 6abd89b
Showing 1 changed file with 93 additions and 80 deletions.
173 changes: 93 additions & 80 deletions cpp/coremlmodel.swift
Original file line number Diff line number Diff line change
Expand Up @@ -183,10 +183,20 @@ class KataGoModel {
// Get default file manager
let fileManager = FileManager.default

Logger().info("Removing old CoreML model in Application Support directory \(appModelURL)");

// Remove the old model in Application Support directory
try fileManager.removeItem(at: appModelURL)
do {
if try appModelURL.checkResourceIsReachable() {
Logger().info("Removing old CoreML model in Application Support directory \(appModelURL)");

do {
// Remove the old model in Application Support directory
try fileManager.removeItem(at: appModelURL)
} catch {
Logger().warning("Unable to remove the old CoreML model in Application Support directory \(appModelURL): \(error)")
}
}
} catch {
Logger().warning("Unable to check if the old CoreML model is reachable in Application Support directory \(appModelURL)")
}

Logger().info("Copying bundle CoreML model to Application Support directory \(appModelURL)")

Expand All @@ -199,32 +209,17 @@ class KataGoModel {
return mlmodel;
}

class func compileMLModel(modelName: String, modelURL: URL) throws -> MLModel {
// Get compiled model name
let compiledModelName = "\(modelName).mlmodelc"

// Set the directory for KataGo models
let directory = "KataGoModels"

// Get path component
let pathComponent = "\(directory)/\(compiledModelName)"

private class func getApplicationSupportURL() throws -> URL {
// Get default file manager
let fileManager = FileManager.default

// Get application support directory
// Create the directory if it does not already exist
let appSupportURL = try fileManager.url(for: .applicationSupportDirectory,
in: .userDomainMask,
appropriateFor: nil,
create: true)

// Create the URL for the permanent compiled model file
let permanentURL = appSupportURL.appending(component: pathComponent)

// Initialize model
var model: MLModel
return try fileManager.url(for: .applicationSupportDirectory,
in: .userDomainMask,
appropriateFor: nil,
create: true)
}

private class func getDigest(modelURL: URL) throws -> String {
// Create the URL for the model data file
let dataURL = modelURL.appending(component: "Data/com.apple.CoreML/model.mlmodel")

Expand All @@ -237,23 +232,25 @@ class KataGoModel {
// Get hash digest
let digest = hashData.map { String(format: "%02x", $0) }.joined()

// Set digest path
let savedDigestPath = "\(directory)/\(modelName).digest"
return digest
}

// Get digest URL
let savedDigestURL = appSupportURL.appending(component: savedDigestPath)
private class func checkShouldCompileModel(permanentURL: URL,
savedDigestURL: URL,
modelURL: URL,
digest: String) -> Bool {
// Model should be compiled if the compiled model is not reachable or the digest changes
var shouldCompile = true

// Get saved digest
var isChangedDigest = true

do {
if (try savedDigestURL.checkResourceIsReachable()) {
let savedDigest = try String(contentsOf: savedDigestURL, encoding: .utf8)

// Check the saved digest is changed or not
isChangedDigest = digest != savedDigest
shouldCompile = digest != savedDigest

if (isChangedDigest) {
if (shouldCompile) {
Logger().info("Compiling CoreML model because the digest has changed");
}
} else {
Expand All @@ -263,59 +260,82 @@ class KataGoModel {
Logger().warning("Compiling CoreML model because it is unable to get the saved digest from: \(savedDigestURL)")
}

// Check permanent compiled model is reachable
let reachableModel = try permanentURL.checkResourceIsReachable()

if (!reachableModel) {
Logger().info("Compiling CoreML model because it is not reachable");
}

// Model should be compiled if the compiled model is not reachable or the digest changes
let shouldCompile = !reachableModel || isChangedDigest;
if !shouldCompile {
// Check permanent compiled model is reachable
do {
shouldCompile = try !permanentURL.checkResourceIsReachable()

if (shouldCompile) {
Logger().info("Compiling CoreML model at \(modelURL)");
if (shouldCompile) {
Logger().info("Compiling CoreML model because the permanent URL is not reachable: \(permanentURL)");
}
} catch {
shouldCompile = true

// Compile the model
let compiledURL = try MLModel.compileModel(at: modelURL)
Logger().warning("Compiling CoreML model because it is unable to check the resource at: \(permanentURL)")
}
}

Logger().info("Copying the compiled CoreML model to the permanent location \(permanentURL)");
return shouldCompile
}

// Create the directory for KataGo models
try fileManager.createDirectory(at: appSupportURL.appending(component: directory),
withIntermediateDirectories: true)
private class func compileAndSaveModel(permanentURL: URL,
savedDigestURL: URL,
modelURL: URL,
digest: String) throws {
// Get default file manager
let fileManager = FileManager.default

// Copy the file to the to the permanent location, replacing it if necessary
try fileManager.replaceItem(at: permanentURL,
withItemAt: compiledURL,
backupItemName: nil,
options: .usingNewMetadataOnly,
resultingItemURL: nil)
Logger().info("Compiling CoreML model at \(modelURL)");

// Update the digest
try digest.write(to: savedDigestURL, atomically: true, encoding: .utf8)
}
// Compile the model
let compiledURL = try MLModel.compileModel(at: modelURL)

// Initialize the model configuration
let configuration = MLModelConfiguration()
Logger().info("Creating the directory for the permanent location: \(permanentURL)");

// Set the compute units to CPU and Neural Engine
configuration.computeUnits = MLComputeUnits.cpuAndNeuralEngine
// Create the directory for KataGo models
try fileManager.createDirectory(at: permanentURL.deletingLastPathComponent(),
withIntermediateDirectories: true)

// Set the model display name
configuration.modelDisplayName = modelName;
Logger().info("Copying the compiled CoreML model to the permanent location \(permanentURL)");

Logger().info("Creating CoreML model with contents \(permanentURL)");
// Copy the file to the to the permanent location, replacing it if necessary
try fileManager.replaceItem(at: permanentURL,
withItemAt: compiledURL,
backupItemName: nil,
options: .usingNewMetadataOnly,
resultingItemURL: nil)

// Create the model
model = try MLModel(contentsOf: permanentURL, configuration: configuration)
// Update the digest
try digest.write(to: savedDigestURL, atomically: true, encoding: .utf8)
}

let description: String = model.modelDescription.metadata[MLModelMetadataKey.description] as! String? ?? "Unknown"
private class func loadModel(permanentURL: URL, modelName: String) throws -> MLModel {
let configuration = MLModelConfiguration()
configuration.computeUnits = .cpuAndNeuralEngine
configuration.modelDisplayName = modelName
Logger().info("Creating CoreML model with contents \(permanentURL)")
return try MLModel(contentsOf: permanentURL, configuration: configuration)
}

Logger().info("Created CoreML model: \(description)");
class func compileMLModel(modelName: String, modelURL: URL) throws -> MLModel {
let appSupportURL = try getApplicationSupportURL()
let permanentURL = appSupportURL.appending(component: "KataGoModels/\(modelName).mlmodelc")
let savedDigestURL = appSupportURL.appending(component: "KataGoModels/\(modelName).digest")
let digest = try getDigest(modelURL: modelURL)

let shouldCompileModel = checkShouldCompileModel(permanentURL: permanentURL,
savedDigestURL: savedDigestURL,
modelURL: modelURL,
digest: digest)

if shouldCompileModel {
try compileAndSaveModel(permanentURL: permanentURL,
savedDigestURL: savedDigestURL,
modelURL: modelURL,
digest: digest)
}

// Return the model
return model;
return try loadModel(permanentURL: permanentURL, modelName: modelName);
}

init(model: MLModel) {
Expand All @@ -336,13 +356,6 @@ class KataGoModel {
out_moremiscvalue: out_moremiscvalue,
out_ownership: out_ownership)
}

func prediction(from input: KataGoModelInput,
options: MLPredictionOptions) throws -> KataGoModelOutput {

let outFeatures = try model.prediction(from: input, options: options)
return createOutput(from: outFeatures)
}

func prediction(from inputBatch: KataGoModelInputBatch,
options: MLPredictionOptions) throws -> KataGoModelOutputBatch {
Expand Down

0 comments on commit 6abd89b

Please sign in to comment.