TensorFlow Keras + Estimator示例

大家都在吐槽TensorFlow的文档和API接口。身为一员,不敢妄自菲薄,不如来个示例代码解惑备忘来的实在。

模型训练

下面的代码使用tf.keras定义模型。使用Estimator来工程化,使用absl来处理logging和flag。

from __future__ import absolute_import
from __future__ import division
from __future__ import print_function

from absl import app
from absl import flags
from absl import logging

import tensorflow as tf
import numpy as np

def core_model():
  model = tf.keras.Sequential([
      tf.keras.layers.Conv2D(32, 3, activation=tf.nn.relu),
      tf.keras.layers.Conv2D(64, 3, activation=tf.nn.relu),
      tf.keras.layers.MaxPooling2D(2),
      tf.keras.layers.Conv2D(128, 3, activation=tf.nn.relu),
      tf.keras.layers.Conv2D(256, 3, activation=tf.nn.relu),
      tf.keras.layers.MaxPooling2D(2),
      tf.keras.layers.Flatten(),
      tf.keras.layers.Dense(128, activation=tf.nn.relu),
      tf.keras.layers.Dense(10)
  ])

  return model

def model_fn(features, labels, mode):
  """The model_fn argument for creating an Estimator."""
  model = core_model()
  image = features
  if isinstance(image, dict):
    image = features['image']

  # Model just for Predict
  if mode == tf.estimator.ModeKeys.PREDICT:
    logits = model(image, training=False)
    predictions = {
        'classes': tf.argmax(logits, axis=1),
        'probabilities': tf.nn.softmax(logits),
    }
    return tf.estimator.EstimatorSpec(
        mode=tf.estimator.ModeKeys.PREDICT,
        predictions=predictions,
        export_outputs={
            'classify': tf.estimator.export.PredictOutput(predictions)
        })

  # Model with Train Operators
  if mode == tf.estimator.ModeKeys.TRAIN:
    optimizer = tf.train.AdamOptimizer()

    logits = model(image, training=True)
    loss = tf.losses.sparse_softmax_cross_entropy(labels=labels, logits=logits)
    accuracy = tf.metrics.accuracy(
        labels=labels, predictions=tf.argmax(logits, axis=1))

    # Name tensors to be logged with LoggingTensorHook.
    tf.identity(loss, 'cross_entropy')
    tf.identity(accuracy[1], name='train_accuracy')

    # Save accuracy scalar to Tensorboard output.
    tf.summary.scalar('train_accuracy', accuracy[1])

    return tf.estimator.EstimatorSpec(
        mode=tf.estimator.ModeKeys.TRAIN,
        loss=loss,
        train_op=optimizer.minimize(loss, tf.train.get_or_create_global_step()))

  # Model with Eval Operators
  if mode == tf.estimator.ModeKeys.EVAL:
    logits = model(image, training=False)
    loss = tf.losses.sparse_softmax_cross_entropy(labels=labels, logits=logits)
    return tf.estimator.EstimatorSpec(
        mode=tf.estimator.ModeKeys.EVAL,
        loss=loss,
        eval_metric_ops={
            'accuracy':
                tf.metrics.accuracy(
                    labels=labels, predictions=tf.argmax(logits, axis=1)),
        })

# Set up training and evaluation input functions.
def get_train_input_fn(x_train, y_train):
  def train_input_fn():
    """Prepare data for training."""
    dataset = tf.data.Dataset.from_tensor_slices((x_train, y_train))
    dataset = dataset.shuffle(buffer_size=1000)
    dataset = dataset.repeat(2)
    dataset = dataset.batch(64)
    iterator = dataset.make_one_shot_iterator()
    return iterator.get_next()
  return train_input_fn

def get_eval_input_fn(x_test, y_test):
  def eval_input_fn():
    """Prepare data for test."""
    dataset = tf.data.Dataset.from_tensor_slices((x_test, y_test))
    dataset = dataset.batch(64)
    iterator = dataset.make_one_shot_iterator()
    return iterator.get_next()
  return eval_input_fn

def load_data():
  (x_train, y_train), (x_test, y_test) = tf.keras.datasets.mnist.load_data()
  logging.info("Data loaded")
  logging.info("Shape:{}".format(x_train.shape))

  x_train = np.expand_dims(x_train, axis=-1).astype('float32')/255.0
  y_train = y_train.astype('int32')
  x_test = np.expand_dims(x_test, axis=-1).astype('float32')/255.0
  y_test = y_test.astype('int32')
  logging.info("Shape:{}".format(x_train.shape))
  return x_train, y_train, x_test, y_test

def get_estimator(model_dir):
  return tf.estimator.Estimator(model_fn=model_fn, model_dir=model_dir)

def serving_input_receiver_fn():
  system_input = tf.placeholder(dtype=tf.float32,
                                shape=[None],
                                name='input')
  # reshape (and other operators)
  model_input = tf.reshape(system_input, [-1, 28, 28, 1])

  # this is the dict that is then passed as "features" parameter to your model_fn
  features = {'image': model_input}
  receiver = {'image': system_input} # you specify multiple tensor as input
  return tf.estimator.export.ServingInputReceiver(features, receiver)

def train_eval_export(model_dir, export_dir):
  estimator = get_estimator(model_dir)

  x_train, y_train, x_test, y_test = load_data()
  estimator.train(
      input_fn=get_train_input_fn(x_train, y_train))
  eval_results = estimator.evaluate(
      input_fn=get_eval_input_fn(x_test, y_test))
  logging.info('Evaluation results: {}'.format(eval_results))

  export(estimator, export_dir)
  return estimator

def export(estimator, export_dir):
  exported_dir = estimator.export_saved_model(export_dir, serving_input_receiver_fn)
  logging.info('Saved model to: {}'.format(exported_dir))

def main(_):
  if flags.FLAGS.task == 'train_eval_export':
    train_eval_export(flags.FLAGS.model_dir, flags.FLAGS.export_dir)
  elif flags.FLAGS.task == 'export':
    estimator = get_estimator(flags.FLAGS.model_dir)
    export(estimator, flags.FLAGS.export_dir)
  else:
    raise Exception('not supported')

if __name__ == '__main__':
  flags.DEFINE_string(
      name="model_dir", default="/tmp/mnist/estimator",
      help="Directory for estimator graph and checkpoints")
  flags.DEFINE_string(
      name="export_dir", default="/tmp/mnist/savedmodel",
      help="Directory for export saved model")
  flags.DEFINE_enum(
      name="task", default="train_eval_export",
      enum_values=['train_eval_export', 'export'],
      help="Task type")

  app.run(main)

前端演示

这个html文件允许用户画数字并给出推导结果。

<!DOCTYPE html>
<html>
  <head>
    <meta charset="UTF-8"/>
    <title>MNIST</title>
    <script>
      window.onload = init;
      var prevX = 0,
          currX = 0,
          prevY = 0,
          currY = 0,
          flag = false;

      function init() {
        console.log("window onload")
        reset();

        var canvas = document.getElementById("imageCanvas");
        canvas.addEventListener("mousemove", function (e) {
          findxy('move', e, canvas)
        }, false);
        canvas.addEventListener("mousedown", function (e) {
          findxy('down', e, canvas)
        }, false);
        canvas.addEventListener("mouseup", function (e) {
          findxy('up', e, canvas)
        }, false);
        canvas.addEventListener("mouseout", function (e) {
          findxy('out', e, canvas)
        }, false);
        console.log("canvas inited")
      }

      function findxy(res, e, canvas) {
        if (res == 'down') {
          console.log('down')
          prevX = currX;
          prevY = currY;
          currX = e.clientX - canvas.offsetLeft;
          currY = e.clientY - canvas.offsetTop;

          flag = true;
        }

        if (res == 'up' || res == "out") {
          flag = false;
        }

        if (res == 'move') {
          if (flag) {
            prevX = currX;
            prevY = currY;
            currX = e.clientX - canvas.offsetLeft;
            currY = e.clientY - canvas.offsetTop;
            draw(canvas.getContext("2d"));
          }
        }
      }

      function draw(ctx) {
        ctx.beginPath();
        ctx.strokeStyle = 'white';
        ctx.lineWidth = 20;
        ctx.lineJoin = 'round';
        ctx.lineCap = 'round';
        ctx.shadowBlur = 4;
        ctx.shadowColor = 'rgb(255, 255, 255)';

        ctx.moveTo(prevX, prevY);
        ctx.lineTo(currX, currY);
        ctx.stroke();
        ctx.closePath();
      }

      function clear(canvas) {
        var ctx = canvas.getContext("2d");
        var canvasData = ctx.getImageData(0, 0, canvas.width, canvas.height);
        for (var row = 0; row < canvas.height; row++) {
          for (var col = 0; col < canvas.width; col++) {
            position = row * canvas.width + col;
            canvasData.data[position * 4 + 0] = 0;
            canvasData.data[position * 4 + 1] = 0;
            canvasData.data[position * 4 + 2] = 0;
            canvasData.data[position * 4 + 3] = 255;
          }
        }
        ctx.putImageData(canvasData, 0, 0);
      }

      function reset() {
        clear(document.getElementById("imageCanvas"));
        clear(document.getElementById("debugImage"));
        document.getElementById("result").textContent = "";
      }

      function predict() {
        console.log('clicked button');

        var canvas = document.getElementById("imageCanvas");
        var ctx = canvas.getContext("2d");
        var canvasData = ctx.getImageData(0, 0, canvas.width, canvas.height);
        var arr = new Float32Array(28 * 28); // or Uint8Array for quantization
        var arrRowSpan = canvas.width / 28;
        var arrColSpan = canvas.height / 28;
        for (var row = 0; row < canvas.height; row++) {
          for (var col = 0; col < canvas.width; col++) {
            arrRow = Math.floor(row / arrRowSpan);
            arrCol = Math.floor(col / arrColSpan);
            arrPosition = arrRow * 28 + arrCol;

            position = row * canvas.width + col;
            // just use Red value as gray input
            // nomalize to 0.0-1.0
            arr[arrPosition] = arr[arrPosition]
                + canvasData.data[position * 4] / 255.0 / arrRowSpan / arrColSpan;
          }
        }
        var arrJson = Array.from(arr);
        postRequest(arrJson);
        //showImage(arr);
        //showText(arrJson);
      }

      // not used, only for debug
      function showImage(arr) {
        var debugCtx = document.getElementById("debugImage").getContext("2d");
        var debugImageData = debugCtx.getImageData(0, 0, 28, 28);
        for (var item = 0; item < 28 * 28; item++) {
          debugImageData.data[item * 4 + 0] = Math.floor(arr[item] * 255);
          debugImageData.data[item * 4 + 1] = Math.floor(arr[item] * 255);
          debugImageData.data[item * 4 + 2] = Math.floor(arr[item] * 255);
          debugImageData.data[item * 4 + 3] = 255;
        }
        debugCtx.putImageData(debugImageData, 0, 0);
      }

      // not used, only for debug
      function showText(arr) {
        document.getElementById("debugText").textContent = arr;
      }

      function postRequest(arr) {
        const url = '/v1/models/mnist:predict';
        fetch(url, {
          method : "POST",
          mode: 'no-cors', // no-cors, cors, *same-origin
          //body : new FormData(document.getElementById("inputform")),
          body : JSON.stringify({
            "instances": [
              {"image": arr}
            ]
          })
        }).then(
          // same as function(response) {return response.text();}
          response => response.json()
        ).then(
          result => {
            console.log(result);
            document.getElementById("result").textContent
                = result.predictions[0].classes;
          }
        );
      }
    </script>
  </head>
  <body>
    <h1>MNIST Digits Simples</h1>
    <button onclick="reset()">清除</button>
    <button onclick="predict()">识别</button>
    <span>Result: <span id='result'></span></span>
    <div>
      <canvas id="imageCanvas" width="280" height="280">
        Your browser does not support the HTML5 canvas tag.
      </canvas>
    </div>
    <div id="debug" style="display:none">
      <div>
        <canvas id="debugImage" width="28" height="28">
        </canvas>
      </div>
      <div id="debugText"></div>
    </div>
  </body>
</html>

部署

模型部署

非常容易,直接tensorflow/serving即可并指定模型路径即可,然后即可接受Restful和gRPC的请求。

docker network create -d bridge mynet

docker run --rm \
    --net=mynet \
    -v "$(pwd)/models/:/models/" \
    -e MODEL_NAME=mnist \
    -d --name mnist \
    tensorflow/serving:1.15.0

Web部署

docker run --rm --name=web \
    --volume=$(pwd)/web/:/usr/share/nginx/html/ \
    --net=mynet \
    --detach \
    nginx:1.15.9

跨域处理(反向代理)

html前端连接后端Restful会存在跨域的问题。可以用Nginx来负责web部署以及反向代理到tf-serving。

server {
    listen 80;
    listen [::]:80;

    # server_name <somedomainname.com>;

    client_max_body_size 20m;

    # TF Serving
    location /v1/models/mnist {
        proxy_set_header X-Forwarded-For $proxy_add_x_forwarded_for;
        proxy_set_header X-Forwarded-Proto $scheme;
        proxy_set_header X-Real-IP $remote_addr;
        proxy_set_header Host $http_host;
        proxy_pass http://mnist:8501;
    }

    # Web
    location / {
        proxy_set_header X-Forwarded-For $proxy_add_x_forwarded_for;
        proxy_set_header X-Forwarded-Proto $scheme;
        proxy_set_header X-Real-IP $remote_addr;
        proxy_set_header Host $http_host;
        proxy_pass http://web:80;
    }
}
docker run --rm --name=proxy --publish=80:80 \
    --volume=$(pwd)/proxy/:/etc/nginx/conf.d/ \
    --net=mynet \
    --detach \
    nginx:1.15.9

Contents