CNN(卷積神經網路)在iOS上的使用
阿新 • • 發佈:2018-12-31
在iOS11上推出了CoreML和架構在CoreML之上的Vision, 這樣為CNN(卷積神經網路)在iOS裝置上的應用鋪平了道路。
將CoreML模型載入到App
讓你的App整合CoreML模型非常簡單, 將模型檔案(*.mlmodel)拖進工程即可. 在Xcode中可以看到此模型的描述.
Xcode可以為此模型檔案自動生成一個可以被使用的物件, 此預測人年齡的CNN的自動生成程式碼如下(Swift)
//
// AgeNet.swift
//
// This file was automatically generated and should not be edited.
//
import CoreML
/// Model Prediction Input Type
@available(OSX 13.0, iOS 11.0, tvOS 11.0, watchOS 4.0, *)
class AgeNetInput : MLFeatureProvider {
/// An image with a face. as color (kCVPixelFormatType_32BGRA) image buffer, 227 pixels wide by 227 pixels high
var data: CVPixelBuffer
var featureNames: Set<String> {
get {
return ["data"]
}
}
func featureValue(for featureName: String) -> MLFeatureValue? {
if (featureName == "data") {
return MLFeatureValue(pixelBuffer: data)
}
return nil
}
init(data: CVPixelBuffer) {
self.data = data
}
}
/// Model Prediction Output Type
@available(OSX 13.0, iOS 11.0, tvOS 11.0, watchOS 4.0, *)
class AgeNetOutput : MLFeatureProvider {
/// The probabilities for each age, for the given input. as dictionary of strings to doubles
let prob: [String : Double]
/// The most likely age, for the given input. as string value
let classLabel: String
var featureNames: Set<String> {
get {
return ["prob", "classLabel"]
}
}
func featureValue(for featureName: String) -> MLFeatureValue? {
if (featureName == "prob") {
return try! MLFeatureValue(dictionary: prob as [NSObject : NSNumber])
}
if (featureName == "classLabel") {
return MLFeatureValue(string: classLabel)
}
return nil
}
init(prob: [String : Double], classLabel: String) {
self.prob = prob
self.classLabel = classLabel
}
}
/// Class for model loading and prediction
@available(OSX 13.0, iOS 11.0, tvOS 11.0, watchOS 4.0, *)
class AgeNet {
var model: MLModel
/**
Construct a model with explicit path to mlmodel file
- parameters:
- url: the file url of the model
- throws: an NSError object that describes the problem
*/
init(contentsOf url: URL) throws {
self.model = try MLModel(contentsOf: url)
}
/// Construct a model that automatically loads the model from the app's bundle
convenience init() {
let bundle = Bundle(for: AgeNet.self)
let assetPath = bundle.url(forResource: "AgeNet", withExtension:"mlmodelc")
try! self.init(contentsOf: assetPath!)
}
/**
Make a prediction using the structured interface
- parameters:
- input: the input to the prediction as AgeNetInput
- throws: an NSError object that describes the problem
- returns: the result of the prediction as AgeNetOutput
*/
func prediction(input: AgeNetInput) throws -> AgeNetOutput {
let outFeatures = try model.prediction(from: input)
let result = AgeNetOutput(prob: outFeatures.featureValue(for: "prob")!.dictionaryValue as! [String : Double], classLabel: outFeatures.featureValue(for: "classLabel")!.stringValue)
return result
}
/**
Make a prediction using the convenience interface
- parameters:
- data: An image with a face. as color (kCVPixelFormatType_32BGRA) image buffer, 227 pixels wide by 227 pixels high
- throws: an NSError object that describes the problem
- returns: the result of the prediction as AgeNetOutput
*/
func prediction(data: CVPixelBuffer) throws -> AgeNetOutput {
let input_ = AgeNetInput(data: data)
return try self.prediction(input: input_)
}
}
載入CNN, 並且建立分析請求(Image Analysis Request)
let ageModel = AgeNet()
func setupVision() {
guard let vnAgeModel = try? VNCoreMLModel(for: ageModel.model) else {
NSLog("Load age model fail")
return
}
ageRequest = VNCoreMLRequest(model: vnAgeModel, completionHandler: { (request : VNRequest, error : Error? ) in
//NSLog("VNCoreML Request complete")
if let observations = request.results as? [VNClassificationObservation] {
if( observations.count > 1 && observations[0].confidence > 0.5 ){
DispatchQueue.main.async {
self.mInfo.text = "Your age is " + observations[0].identifier + "/" + String(observations[0].confidence)
}
}
}
return
})
ageRequest?.imageCropAndScaleOption = .scaleFit
}
執行分析
func predict(pixelBuffer : CVPixelBuffer) {
let handler = VNImageRequestHandler(cvPixelBuffer: pixelBuffer)
try? handler.perform([ageRequest])
let genderHandler = VNImageRequestHandler(cvPixelBuffer: pixelBuffer)
try? genderHandler.perform([genderRequest])
}