圖片 flask json_想要部署深度學習模型?試試 FLASK 構建 REST API 部署
技術標籤:圖片 flask json
想必大家都訓練出過比較好玩的模型,但是是不是想要向別人提供下介面或者自己試著玩下,這時候就需要涉及到部署模型了,這裡,我們將使用 Flask 部署 PyTorch 模型,並構建用於模型推理的REST API。
要注意的是:使用 Flask 是為 PyTorch 模型提供服務的最簡單方法,但不適用於具有高效能要求的場景。
對高效能有要求的場景,可以使用 TorchScript,下次再說。
環境安裝:
pip install Flask==1.0.3 torch==1.2.0 torchvision-0.3.0
```
假設我們的場景是上傳圖片進行返回圖片的分類結果,那麼我們定義下 API 形式,請求和響應型別。
將 API endpoint 將位於 /predict,接受帶有包含影象的檔案引數的 HTTP POST 請求。響應將是包含預測結果的 JSON 響應:
{"class_id": "xx", "class_name": "yy"}
首先先複習下,構建一個簡單的 Web 伺服器
```
from flask import Flaskapp=Flask(__name__)@app.route('/')defhello(): return 'welcome to http://towardsdeeplearning.com !'
```
執行
```
FLASK_ENV=development FLASK_APP=app.py flask run
訪問 http://localhost:5000/ 可以看到 welcome to http://towardsdeeplearning.com !
可以檢視 flask 文件,熟悉下 post。為了符合上邊 api 的定義,我們需要修改下程式碼:
fromflaskimportFlask,jsonifyapp = Flask(__name__)@app.route('/predict',methods=['POST'])def predict(): return jsonify({'class_id':'IMAGE_NET_XXX','class_name':'Cat'})
`
到此,骨幹網路已經搭建完畢。
還缺少什麼呢?上邊這個是返回的json是寫死的,但是實際上要根據 post 的圖片進行預測。
圖片通過 HTTP POST 請求傳遞過來, 可以通過下面這個方式獲取
@app.route('/predict',methods=['POST'])defpredict():ifrequest.method=='POST':#wewillgetthefilefromtherequestfile=request.files['file']
搭建下預測的程式碼,這裡使用了 mnasnet ,可以在 torchvision 匯入預訓模型。mnasnet 的輸入圖片是 3 通道的 RGB 模型,大小為 224 x 224。
其實熟悉 pytorch 的同學應該很容易寫出前向預測的程式碼的。
importioimporttorchvision.transformsastransformsfromPILimport Imagedeftransform_image(image_bytes):my_transforms=transforms.Compose([transforms.Resize(255),transforms.CenterCrop(224),transforms.ToTensor(),transforms.Normalize([0.485,0.456,0.406],[0.229,0.224,0.225])]) #接收的圖片是bytes轉成圖片格式,再進行轉換 image=Image.open(io.BytesIO(image_bytes)) return my_transforms(image).unsqueeze(0)fromtorchvisionimportmodelsmodel=models.mnasnet1_0(pretrained=True)model.eval()defpredict(image_bytes):tensor=transform_image(image_bytes=image_bytes)outputs=model.forward(tensor)_,pred=outputs.max(1) return pred
predict 的結果是類別的id,為了方便顯示,我們需要進行轉成文字, 就是具體的類別,狗狗啊這樣人類可讀性好的。
import jsonimagenet_class_index=json.load(open('imagenet_class_index.json'))defpredict(image_bytes):tensor=transform_image(image_bytes=image_bytes)outputs=model.forward(tensor)_,y_hat=outputs.max(1)predicted_idx=str(y_hat.item()) return imagenet_class_index[predicted_idx]
最後,整理的程式碼如下
importioimportjsonimport torchvision.transforms as transformsfromPILimportImagefromflaskimportFlask,jsonify,requestfrom torchvision import modelsapp=Flask(__name__)imagenet_class_index=json.load(open('./imagenet_class_index.json',"r"))model=models.mnasnet1_0(pretrained=True)model.eval()deftransform_image(image_bytes):my_transforms=transforms.Compose([transforms.Resize(255),transforms.CenterCrop(224),transforms.ToTensor(),transforms.Normalize([0.485,0.456,0.406],[0.229,0.224,0.225])])image=Image.open(io.BytesIO(image_bytes))returnmy_transforms(image).unsqueeze(0)def get_prediction(image_bytes):tensor=transform_image(image_bytes=image_bytes)outputs=model.forward(tensor)_,y_hat=outputs.max(1)predicted_idx=str(y_hat.item()) return imagenet_class_index[predicted_idx]@app.route('/predict',methods=['POST'])def predict():ifrequest.method=='POST':file=request.files['file']img_bytes=file.read()class_id,class_name=get_prediction(image_bytes=img_bytes)returnjsonify({'class_id':class_id,'class_name':class_name})if__name__=='__main__':app.run()
使用下面的命令執行
FLASK_ENV=development FLASK_APP=app.py flask run
使用下面的測試程式碼,進行測試。
importrequestsresp=requests.post("http://localhost:5000/predict",files={"file":open('dog.jpg','rb')})print(resp.json()#{"class_id":"xx","class_name":"xx"}
完。