大家都在吐槽TensorFlow的文档和API接口。身为一员,不敢妄自菲薄,不如来个示例代码解惑备忘来的实在。
模型训练
下面的代码使用tf.keras定义模型。使用Estimator来工程化,使用absl来处理logging和flag。
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
from absl import app
from absl import flags
from absl import logging
import tensorflow as tf
import numpy as np
def core_model():
model = tf.keras.Sequential([
tf.keras.layers.Conv2D(32, 3, activation=tf.nn.relu),
tf.keras.layers.Conv2D(64, 3, activation=tf.nn.relu),
tf.keras.layers.MaxPooling2D(2),
tf.keras.layers.Conv2D(128, 3, activation=tf.nn.relu),
tf.keras.layers.Conv2D(256, 3, activation=tf.nn.relu),
tf.keras.layers.MaxPooling2D(2),
tf.keras.layers.Flatten(),
tf.keras.layers.Dense(128, activation=tf.nn.relu),
tf.keras.layers.Dense(10)
])
return model
def model_fn(features, labels, mode):
"""The model_fn argument for creating an Estimator."""
model = core_model()
image = features
if isinstance(image, dict):
image = features['image']
# Model just for Predict
if mode == tf.estimator.ModeKeys.PREDICT:
logits = model(image, training=False)
predictions = {
'classes': tf.argmax(logits, axis=1),
'probabilities': tf.nn.softmax(logits),
}
return tf.estimator.EstimatorSpec(
mode=tf.estimator.ModeKeys.PREDICT,
predictions=predictions,
export_outputs={
'classify': tf.estimator.export.PredictOutput(predictions)
})
# Model with Train Operators
if mode == tf.estimator.ModeKeys.TRAIN:
optimizer = tf.train.AdamOptimizer()
logits = model(image, training=True)
loss = tf.losses.sparse_softmax_cross_entropy(labels=labels, logits=logits)
accuracy = tf.metrics.accuracy(
labels=labels, predictions=tf.argmax(logits, axis=1))
# Name tensors to be logged with LoggingTensorHook.
tf.identity(loss, 'cross_entropy')
tf.identity(accuracy[1], name='train_accuracy')
# Save accuracy scalar to Tensorboard output.
tf.summary.scalar('train_accuracy', accuracy[1])
return tf.estimator.EstimatorSpec(
mode=tf.estimator.ModeKeys.TRAIN,
loss=loss,
train_op=optimizer.minimize(loss, tf.train.get_or_create_global_step()))
# Model with Eval Operators
if mode == tf.estimator.ModeKeys.EVAL:
logits = model(image, training=False)
loss = tf.losses.sparse_softmax_cross_entropy(labels=labels, logits=logits)
return tf.estimator.EstimatorSpec(
mode=tf.estimator.ModeKeys.EVAL,
loss=loss,
eval_metric_ops={
'accuracy':
tf.metrics.accuracy(
labels=labels, predictions=tf.argmax(logits, axis=1)),
})
# Set up training and evaluation input functions.
def get_train_input_fn(x_train, y_train):
def train_input_fn():
"""Prepare data for training."""
dataset = tf.data.Dataset.from_tensor_slices((x_train, y_train))
dataset = dataset.shuffle(buffer_size=1000)
dataset = dataset.repeat(2)
dataset = dataset.batch(64)
iterator = dataset.make_one_shot_iterator()
return iterator.get_next()
return train_input_fn
def get_eval_input_fn(x_test, y_test):
def eval_input_fn():
"""Prepare data for test."""
dataset = tf.data.Dataset.from_tensor_slices((x_test, y_test))
dataset = dataset.batch(64)
iterator = dataset.make_one_shot_iterator()
return iterator.get_next()
return eval_input_fn
def load_data():
(x_train, y_train), (x_test, y_test) = tf.keras.datasets.mnist.load_data()
logging.info("Data loaded")
logging.info("Shape:{}".format(x_train.shape))
x_train = np.expand_dims(x_train, axis=-1).astype('float32')/255.0
y_train = y_train.astype('int32')
x_test = np.expand_dims(x_test, axis=-1).astype('float32')/255.0
y_test = y_test.astype('int32')
logging.info("Shape:{}".format(x_train.shape))
return x_train, y_train, x_test, y_test
def get_estimator(model_dir):
return tf.estimator.Estimator(model_fn=model_fn, model_dir=model_dir)
def serving_input_receiver_fn():
system_input = tf.placeholder(dtype=tf.float32,
shape=[None],
name='input')
# reshape (and other operators)
model_input = tf.reshape(system_input, [-1, 28, 28, 1])
# this is the dict that is then passed as "features" parameter to your model_fn
features = {'image': model_input}
receiver = {'image': system_input} # you specify multiple tensor as input
return tf.estimator.export.ServingInputReceiver(features, receiver)
def train_eval_export(model_dir, export_dir):
estimator = get_estimator(model_dir)
x_train, y_train, x_test, y_test = load_data()
estimator.train(
input_fn=get_train_input_fn(x_train, y_train))
eval_results = estimator.evaluate(
input_fn=get_eval_input_fn(x_test, y_test))
logging.info('Evaluation results: {}'.format(eval_results))
export(estimator, export_dir)
return estimator
def export(estimator, export_dir):
exported_dir = estimator.export_saved_model(export_dir, serving_input_receiver_fn)
logging.info('Saved model to: {}'.format(exported_dir))
def main(_):
if flags.FLAGS.task == 'train_eval_export':
train_eval_export(flags.FLAGS.model_dir, flags.FLAGS.export_dir)
elif flags.FLAGS.task == 'export':
estimator = get_estimator(flags.FLAGS.model_dir)
export(estimator, flags.FLAGS.export_dir)
else:
raise Exception('not supported')
if __name__ == '__main__':
flags.DEFINE_string(
name="model_dir", default="/tmp/mnist/estimator",
help="Directory for estimator graph and checkpoints")
flags.DEFINE_string(
name="export_dir", default="/tmp/mnist/savedmodel",
help="Directory for export saved model")
flags.DEFINE_enum(
name="task", default="train_eval_export",
enum_values=['train_eval_export', 'export'],
help="Task type")
app.run(main)
前端演示
这个html文件允许用户画数字并给出推导结果。
<!DOCTYPE html>
<html>
<head>
<meta charset="UTF-8"/>
<title>MNIST</title>
<script>
window.onload = init;
var prevX = 0,
currX = 0,
prevY = 0,
currY = 0,
flag = false;
function init() {
console.log("window onload")
reset();
var canvas = document.getElementById("imageCanvas");
canvas.addEventListener("mousemove", function (e) {
findxy('move', e, canvas)
}, false);
canvas.addEventListener("mousedown", function (e) {
findxy('down', e, canvas)
}, false);
canvas.addEventListener("mouseup", function (e) {
findxy('up', e, canvas)
}, false);
canvas.addEventListener("mouseout", function (e) {
findxy('out', e, canvas)
}, false);
console.log("canvas inited")
}
function findxy(res, e, canvas) {
if (res == 'down') {
console.log('down')
prevX = currX;
prevY = currY;
currX = e.clientX - canvas.offsetLeft;
currY = e.clientY - canvas.offsetTop;
flag = true;
}
if (res == 'up' || res == "out") {
flag = false;
}
if (res == 'move') {
if (flag) {
prevX = currX;
prevY = currY;
currX = e.clientX - canvas.offsetLeft;
currY = e.clientY - canvas.offsetTop;
draw(canvas.getContext("2d"));
}
}
}
function draw(ctx) {
ctx.beginPath();
ctx.strokeStyle = 'white';
ctx.lineWidth = 20;
ctx.lineJoin = 'round';
ctx.lineCap = 'round';
ctx.shadowBlur = 4;
ctx.shadowColor = 'rgb(255, 255, 255)';
ctx.moveTo(prevX, prevY);
ctx.lineTo(currX, currY);
ctx.stroke();
ctx.closePath();
}
function clear(canvas) {
var ctx = canvas.getContext("2d");
var canvasData = ctx.getImageData(0, 0, canvas.width, canvas.height);
for (var row = 0; row < canvas.height; row++) {
for (var col = 0; col < canvas.width; col++) {
position = row * canvas.width + col;
canvasData.data[position * 4 + 0] = 0;
canvasData.data[position * 4 + 1] = 0;
canvasData.data[position * 4 + 2] = 0;
canvasData.data[position * 4 + 3] = 255;
}
}
ctx.putImageData(canvasData, 0, 0);
}
function reset() {
clear(document.getElementById("imageCanvas"));
clear(document.getElementById("debugImage"));
document.getElementById("result").textContent = "";
}
function predict() {
console.log('clicked button');
var canvas = document.getElementById("imageCanvas");
var ctx = canvas.getContext("2d");
var canvasData = ctx.getImageData(0, 0, canvas.width, canvas.height);
var arr = new Float32Array(28 * 28); // or Uint8Array for quantization
var arrRowSpan = canvas.width / 28;
var arrColSpan = canvas.height / 28;
for (var row = 0; row < canvas.height; row++) {
for (var col = 0; col < canvas.width; col++) {
arrRow = Math.floor(row / arrRowSpan);
arrCol = Math.floor(col / arrColSpan);
arrPosition = arrRow * 28 + arrCol;
position = row * canvas.width + col;
// just use Red value as gray input
// nomalize to 0.0-1.0
arr[arrPosition] = arr[arrPosition]
+ canvasData.data[position * 4] / 255.0 / arrRowSpan / arrColSpan;
}
}
var arrJson = Array.from(arr);
postRequest(arrJson);
//showImage(arr);
//showText(arrJson);
}
// not used, only for debug
function showImage(arr) {
var debugCtx = document.getElementById("debugImage").getContext("2d");
var debugImageData = debugCtx.getImageData(0, 0, 28, 28);
for (var item = 0; item < 28 * 28; item++) {
debugImageData.data[item * 4 + 0] = Math.floor(arr[item] * 255);
debugImageData.data[item * 4 + 1] = Math.floor(arr[item] * 255);
debugImageData.data[item * 4 + 2] = Math.floor(arr[item] * 255);
debugImageData.data[item * 4 + 3] = 255;
}
debugCtx.putImageData(debugImageData, 0, 0);
}
// not used, only for debug
function showText(arr) {
document.getElementById("debugText").textContent = arr;
}
function postRequest(arr) {
const url = '/v1/models/mnist:predict';
fetch(url, {
method : "POST",
mode: 'no-cors', // no-cors, cors, *same-origin
//body : new FormData(document.getElementById("inputform")),
body : JSON.stringify({
"instances": [
{"image": arr}
]
})
}).then(
// same as function(response) {return response.text();}
response => response.json()
).then(
result => {
console.log(result);
document.getElementById("result").textContent
= result.predictions[0].classes;
}
);
}
</script>
</head>
<body>
<h1>MNIST Digits Simples</h1>
<button onclick="reset()">清除</button>
<button onclick="predict()">识别</button>
<span>Result: <span id='result'></span></span>
<div>
<canvas id="imageCanvas" width="280" height="280">
Your browser does not support the HTML5 canvas tag.
</canvas>
</div>
<div id="debug" style="display:none">
<div>
<canvas id="debugImage" width="28" height="28">
</canvas>
</div>
<div id="debugText"></div>
</div>
</body>
</html>
部署
模型部署
非常容易,直接tensorflow/serving即可并指定模型路径即可,然后即可接受Restful和gRPC的请求。
docker network create -d bridge mynet
docker run --rm \
--net=mynet \
-v "$(pwd)/models/:/models/" \
-e MODEL_NAME=mnist \
-d --name mnist \
tensorflow/serving:1.15.0
Web部署
docker run --rm --name=web \
--volume=$(pwd)/web/:/usr/share/nginx/html/ \
--net=mynet \
--detach \
nginx:1.15.9
跨域处理(反向代理)
html前端连接后端Restful会存在跨域的问题。可以用Nginx来负责web部署以及反向代理到tf-serving。
server {
listen 80;
listen [::]:80;
# server_name <somedomainname.com>;
client_max_body_size 20m;
# TF Serving
location /v1/models/mnist {
proxy_set_header X-Forwarded-For $proxy_add_x_forwarded_for;
proxy_set_header X-Forwarded-Proto $scheme;
proxy_set_header X-Real-IP $remote_addr;
proxy_set_header Host $http_host;
proxy_pass http://mnist:8501;
}
# Web
location / {
proxy_set_header X-Forwarded-For $proxy_add_x_forwarded_for;
proxy_set_header X-Forwarded-Proto $scheme;
proxy_set_header X-Real-IP $remote_addr;
proxy_set_header Host $http_host;
proxy_pass http://web:80;
}
}
docker run --rm --name=proxy --publish=80:80 \
--volume=$(pwd)/proxy/:/etc/nginx/conf.d/ \
--net=mynet \
--detach \
nginx:1.15.9