#! /usr/bin/env python from flask import Flask, render_template, request, Response import numpy as np from binascii import a2b_base64 import imageio from PIL import Image import io import time import ast import torch import torch.nn as nn import torch.nn.functional as F import torch.optim as optim from torchvision import datasets, transforms global model_states, nb_epoch #to have access later model_states = ['Not Trained'] nb_epoch=5 app = Flask(__name__) model =None #page to_train @app.route('/') def to_train(): return render_template('to_train.html', nb_epoch=nb_epoch) #train the model @app.route("/loadmodel/", methods=['GET']) def load(): global model class NN(nn.Module): def __init__(self): super(NN, self).__init__() self.conv1L = nn.Conv2d(1, 20, 3, 1) self.conv2L = nn.Conv2d(20, 50, 3, 1) self.FC1 = nn.Linear(5*5*50, 500) self.FC2 = nn.Linear(500 ,10) def forward(self, x): x = F.relu(self.conv1L(x)) #20 x 26 x = F.max_pool2d(x, (2,2)) #20 x 13 x = F.relu(self.conv2L(x)) #50 x 11 x = F.max_pool2d(x, (2,2)) #50 x 5x 5 x = x.view(-1, 50*5*5) #flatten x = self.FC1(x) #500 x = self.FC2(x) #10 return F.log_softmax(x, dim=1) checkpoint = torch.load("Meetup_MNIST.pt") model = NN() model.load_state_dict(checkpoint) print("model loaded") return "Loading done" #page where you draw the number @app.route('/index/', methods=['GET','POST']) def index(): prediction='?' if request.method == 'POST': dataURL = request.get_data() drawURL_clean = dataURL[22:] binary_data=a2b_base64(drawURL_clean) img = Image.open(io.BytesIO(binary_data)) img.thumbnail((28,28)) img.save("data_img/draw.png") return render_template('index.html', prediction=prediction) #display prediction @app.route('/result/') def result(): time.sleep(0.2) img = Image.open("data_img/draw.png").convert("1") transform=transforms.Compose([transforms.ToTensor()]) img = transform(img) img = torch.unsqueeze(img , 0) prediction = inference(model , img) print(prediction) return render_template("index.html",prediction=prediction) def inference(model , img): output = model(img) output = torch.exp(output) top_prob,top_class=output.topk(1,dim=1) return top_class.item() if __name__ == "__main__": app.run(debug=True, threaded=True)