16. Coding Techniques for Computer Vision in TensorFlow.js – AI and Machine Learning for Coders

Chapter 16. Coding Techniques for Computer Vision in TensorFlow.js

In Chapters 2 and 3 you saw how TensorFlow can be used to create models for computer vision, which can be trained to recognize the content in images. In this chapter you’ll do the same, but with JavaScript. You’ll build a handwriting recognizer that runs in the browser and is trained on the MNIST dataset. You can see it in Figure 16-1.

Figure 16-1. A handwriting classifier in the browser

There are a few crucial implementation details to be aware of when you’re working with TensorFlow.js, particularly if you are building applications in the browser. Perhaps the biggest and most important of these is how training data is handled. When using a browser, every time you open a resource at a URL, you’re making an HTTP connection. You use this connection to pass commands to a server, which will then dispatch the results for you to parse. When it comes to machine learning, you generally have a lot of training data—for example, in the case of MNIST and Fashion MNIST, even though they are small learning datasets they still each contain 70,000 images, which would be 70,000 HTTP connections! You’ll see how to deal with this later in this chapter.

Additionally, as you saw in the last chapter, even for a very simple scenario like training for Y = 2X – 1, nothing appeared to happen during the training cycle unless you opened the debug console, where you could see the epoch-by-epoch loss. If you’re training something much more sophisticated, which takes longer, it can be difficult to understand what’s going on during training. Fortunately there are built-in visualization tools that you can use, as seen on the right side of Figure 16-1; you’ll also explore them in this chapter.

There are also syntactical differences to be aware of when defining a convolutional neural network in JavaScript, some of which we touched on in the previous chapter. We’ll start by considering these. If you need a refresher on CNNs, see Chapter 3.

JavaScript Considerations for TensorFlow Developers

When building a full (or close to it) application in JavaScript like you will in this chapter, there are a number of things that you’ll have to take into account. JavaScript is very different from Python, and, as such, while the TensorFlow.js team has worked hard to keep the experience as close to “traditional” TensorFlow as possible, there are some changes.

First is the syntax. While in many respects TensorFlow code in JavaScript (especially Keras code) is quite similar to that in Python, there are a few syntactic differences—most notably, as mentioned in the previous chapter, the use of JSON in parameter lists.

Next is synchronicity. Especially when running in the browser, you can’t lock up the UI thread when training and instead need to perform many operations asynchronously, using JavaScript Promises and await calls. It’s not the intention of this chapter to go into depth teaching these concepts; if you aren’t already familiar with them, you can think of them as asynchronous functions that, instead of waiting to finish executing before returning, will go off and do their own thing and “call you back” when they’re done. The tfjs-vis library was created to help you debug your code when training models asynchronously with TensorFlow.js. The visualization tools give you a separate sidebar in the browser, not interfering with your current page, in which visualizations like training progress can be plotted; we’ll talk more about them in “Using Callbacks for Visualization”.

Resource usage is also an important consideration. As the browser is a shared environment, you may have multiple tabs open in which you’re doing different things, or you might be performing multiple operations within the same web app. Therefore, it’s important to control how much memory you use. ML training can be memory-intensive, as lots of data is required to understand and distinguish the patterns that map features to labels. As a result, you should take care to tidy up after yourself. The tidy API is designed for just that and should be used as much as possible: wrapping a function in tidy ensures that all tensors not returned by the function will be cleaned up and released from memory.

While not a TensorFlow API, the arrayBuffer in JavaScript is another handy construct. It’s analogous to a ByteBuffer for managing data like it was low-level memory. In the case of machine learning applications, it’s often easiest to use very sparse encoding, as you’ve seen already with one-hot encoding. Remembering that processing in JavaScript can be thread-intensive and you don’t want to lock up the browser, it can be easier to have a sparse encoding of data that doesn’t require processor power to decode. In the example from this chapter, the labels are encoded in this way: for each of the 10 classes, 9 of them will have a 0 × 00 byte and the other, representing the matching class for that feature, will have a 0 × 01 byte. This means 10 bytes, or 80 bits, are used for each label, where as a coder you might think that only 4 bits would be necessary to encode a number between 1 and 10. But of course, if you did it that way you would have to decode the results—65,000 times for that many labels. Thus, having a sparsely encoded file that’s easily represented in bytes by an arrayBuffer can be quicker, albeit with a larger file size.

Also worthy of mention are the tf.browser APIs, which are helpful for dealing with images. At the time of writing there are two methods, tf.browser.toPixels and tf.browser.fromPixels, which, as their names suggest, are used for translating pixels between browser-friendly formats and tensor formats. You’ll use these later when you want to draw a picture and have it interpreted by the model.

Building a CNN in JavaScript

When building any neural network with TensorFlow Keras, you define a number of layers. In the case of a convolutional neural network, you’ll typically have a series of convolutional layers followed by pooling layers, whose output is flattened and fed into a dense layer. For example, here’s an example of a CNN that was defined for classifying the MNIST dataset back in Chapter 3:

model = tf.keras.models.Sequential([
    tf.keras.layers.Conv2D(64, (3, 3), activation='relu', 
                           input_shape=(28, 28, 1)),
    tf.keras.layers.MaxPooling2D(2, 2),
    tf.keras.layers.Conv2D(64, (3, 3), activation='relu'),
    tf.keras.layers.Dense(128, activation=tf.nn.relu),
    tf.keras.layers.Dense(10, activation=tf.nn.softmax)])

Let’s break down line by line how this could be implemented in JavaScript. We’ll start by defining the model as a sequential:

model = tf.sequential();

Next, we’ll define the first layer as a 2D convolution that learns 64 filters, with a kernel size of 3 × 3 and an input shape of 28 × 28 × 1. The syntax here is very different from Python, but you can see the similarity:

model.add(tf.layers.conv2d({inputShape: [28, 28, 1], 
          kernelSize: 3, filters: 64, activation: 'relu'}));

The following layer was a MaxPooling2D, with a pool size of 2 × 2. In JavaScript it’s implemented like this:

model.add(tf.layers.maxPooling2d({poolSize: [2, 2]}));

This was followed by another convolutional layer and max pooling layer. The difference here is that there is no input shape, as it isn’t an input layer. In JavaScript this looks like this:

model.add(tf.layers.conv2d({filters: 64, 
          kernelSize: 3, activation: 'relu'}));

model.add(tf.layers.maxPooling2d({poolSize: [2, 2]}));

After this, the output was flattened, and in JavaScript the syntax for that is:


The model was then completed by two dense layers, one with 128 neurons activated by relu, and the output layer of 10 neurons activated by softmax:

model.add(tf.layers.dense({units: 128, activation: 'relu'}));

model.add(tf.layers.dense({units: 10, activation: 'softmax'}));

As you can see, the JavaScript APIs look very similar to the Python ones, but there are syntactical differences that can be gotchas: the names of APIs follow camel case convention but start with a lowercase letter, as expected in JavaScript (i.e., maxPooling2D instead of MaxPooling2D), parameters are defined in JSON instead of comma-separated lists, etc. Keep an eye on these differences as you code your neural networks in JavaScript.

For convenience, here’s the complete JavaScript definition of the model:

model = tf.sequential();

model.add(tf.layers.conv2d({inputShape: [28, 28, 1], 
          kernelSize: 3, filters: 8, activation: 'relu'}));

model.add(tf.layers.maxPooling2d({poolSize: [2, 2]}));

model.add(tf.layers.conv2d({filters: 16, 
          kernelSize: 3, activation: 'relu'}));

model.add(tf.layers.maxPooling2d({poolSize: [2, 2]}));


model.add(tf.layers.dense({units: 128, activation: 'relu'}));

model.add(tf.layers.dense({units: 10, activation: 'softmax'}));

Similarly, when compiling the model, consider the differences between Python and JavaScript. Here’s the Python:


And the equivalent JavaScript:

{  optimizer: tf.train.adam(), 
       loss: 'categoricalCrossentropy', 
       metrics: ['accuracy']

While they’re very similar, keep in mind the JSON syntax for the parameters (parameter: value, not parameter=value) and that the list of parameters is enclosed in curly braces ({}).

Using Callbacks for Visualization

In Chapter 15, when you were training the simple neural network, you logged the loss to the console when each epoch ended. You then used the browser’s developer tools to view the progress in the console, looking at the changes in the loss over time. A more sophisticated approach is to use the TensorFlow.js visualization tools, created specifically for in-browser development. These include tools for reporting on training metrics, model evaluation, and more. The visualization tools appear in a separate area of the browser window that doesn’t interfere with the rest of your web page. The term used for this is a visor. It will default to showing at the very least the model architecture.

To use the tfjs-vis library in your page, you can include it with a script:

<script src="https://cdn.jsdelivr.net/npm/@tensorflow/tfjs-vis"></script>

Then, to see visualizations while training, you need to specify a callback in your model.fit call. Here’s an example:

return model.fit(trainXs, trainYs, {
    batchSize: BATCH_SIZE,
    validationData: [testXs, testYs],
    epochs: 20,
    shuffle: true,
    callbacks: fitCallbacks

The callbacks are defined as a const, using tfvis.show.fitCallbacks. This takes two parameters—a container and the desired metrics. These are also defined using consts, as shown here:

const metrics = ['loss', 'val_loss', 'accuracy', 'val_accuracy'];

const container = { name: 'Model Training', styles: { height: '640px' }, 
                    tab: 'Training Progress' };

const fitCallbacks = tfvis.show.fitCallbacks(container, metrics);

The container const has parameters that define the visualization area. All visualizations are shown in a single tab by default. By using a tab parameter (set to “Training Progress” here), you can split the training progress out into a separate tab. Figure 16-2 illustrates what the preceding code will show in the visualization area at runtime.

Next, let’s explore how to manage the training data. As mentioned earlier, handling thousands of images through URL connections is bad for the browser because it will lock up the UI thread. But there are some tricks that you can use from the world of game development!

Figure 16-2. Using the visualization tools

Training with the MNIST Dataset

Instead of downloading every image one by one, a useful way to handle training of data in TensorFlow.js is to append all the images together into a single image, often called a sprite sheet. This technique is commonly used in game development, where the graphics of a game are stored in a single file instead of multiple smaller ones for file storage efficiency. If we were to store all the images for training in a single file, we’d just need to open one HTTP connection to it in order to download them all in a single shot.

For the purposes of learning, the TensorFlow team has created sprite sheets from the MNIST and Fashion MNIST datasets that we can use here. For example, the MNIST images are available in a file called mnist_images.png (see Figure 16-3).

Figure 16-3. An excerpt from mnist_images.png in an image viewer

If you explore the dimensions of this image, you’ll see that it has 65,000 lines, each with 784 (28 × 28) pixels in it. If those dimensions look familiar, you might recall that MNIST images are 28 × 28 monochrome. So, you can download this image, read it line by line, and then take each of the lines and separate it into a 28 × 28-pixel image.

You can do this in JavaScript by loading the image, and then defining a canvas on which you can draw the individual lines after extracting them from the original image. The bytes from these canvases can then be extracted into a dataset that you’ll use for training. This might seem a bit convoluted, but given that JavaScript is an in-browser technology, it wasn’t really designed for data and image processing like this. That said, it works really well, and runs really quickly! Before we get into the details of that, however, you should also look at the labels and how they’re stored.

First, set up constants for the training and test data, bearing in mind that the MNIST image has 65,000 lines, one for each image. The ratio of training to testing data can be defined as 5:1, and from this you can calculate the number of elements for training and the number for testing:

const IMAGE_SIZE = 784;
const NUM_CLASSES = 10;

const TRAIN_TEST_RATIO = 5 / 6;


Note that all of this code is in the repo for this book, so please feel free to adapt it from there!

Next up, you need to create some constants for the image control that will hold the sprite sheet and the canvas that can be used for slicing it up:

const img = new Image();
const canvas = document.createElement('canvas');
const ctx = canvas.getContext('2d');

To load the image, you simply set the img control to the path of the sprite sheet:


Once the image is loaded, you can set up a buffer to hold the bytes in it. The image is a PNG file, which has 4 bytes per pixel, so you’ll need to reserve 65,000 (number of images) × 768 (number of pixels in a 28 × 28 image) × 4 (number of bytes in a PNG per pixel) bytes for the buffer. You don’t need to split the file image by image, but can split it in chunks. Take five thousand images at a time by specifying the chunkSize as shown here:

img.onload = () => {
    img.width = img.naturalWidth;
    img.height = img.naturalHeight;

    const datasetBytesBuffer =
        new ArrayBuffer(NUM_DATASET_ELEMENTS * IMAGE_SIZE * 4);

    const chunkSize = 5000;
    canvas.width = img.width;
    canvas.height = chunkSize;

Now you can create a loop to go through the image in chunks, creating a set of bytes for each chunk and drawing it to the canvas. This will decode the PNG into the canvas, giving you the ability to get the raw bytes from the image. As the individual images in the dataset are monochrome, the PNG will have the same levels for the R, G, and B bytes, so you can just take any of them:

for (let i = 0; i < NUM_DATASET_ELEMENTS / chunkSize; i++) {
    const datasetBytesView = new Float32Array(
        datasetBytesBuffer, i * IMAGE_SIZE * chunkSize * 4,
        IMAGE_SIZE * chunkSize);
        img, 0, i * chunkSize, img.width, chunkSize, 0, 0, img.width,

    const imageData = ctx.getImageData(0, 0, canvas.width, canvas.height);

    for (let j = 0; j < imageData.data.length / 4; j++) {
        // All channels hold an equal value since the image is grayscale, so
        // just read the red channel.
        datasetBytesView[j] = imageData.data[j * 4] / 255;

The images can now be loaded into a dataset with:

this.datasetImages = new Float32Array(datasetBytesBuffer);

Similar to the images, the labels are stored in a single file. This is a binary file with a sparse encoding of the labels. Each label is represented by 10 bytes, with one of those bytes having the value 01 to represent the class. This is easier to understand with a visualization, so take a look at Figure 16-4.

This shows a hex view of the file with the first 10 bytes highlighted. Here, byte 8 is 01, while the rest are all 00. This indicates that the label for the first image is 8. Given that MNIST has 10 classes, for the digits 0 through 9, we know that the eighth label is for the number 7.

Figure 16-4. Exploring the labels file

So, as well as downloading and decoding the bytes for the images line by line, you’ll also need to decode the labels. You download these alongside the image by fetching the URL, and then decode the labels into integer arrays using arrayBuffer:

const labelsRequest = fetch(MNIST_LABELS_PATH);
const [imgResponse, labelsResponse] =
    await Promise.all([imgRequest, labelsRequest]);

this.datasetLabels = new Uint8Array(await labelsResponse.arrayBuffer());

The sparseness of the encoding of the labels greatly simplifies the code—with this one line you can get all the labels into a buffer. If you were wondering why such an inefficient storage method was used for the labels, that was the trade-off: more complex storage but simpler decoding!

The images and labels can then be split into training and test sets:

this.trainImages =
    this.datasetImages.slice(0, IMAGE_SIZE * NUM_TRAIN_ELEMENTS);
this.testImages = this.datasetImages.slice(IMAGE_SIZE * NUM_TRAIN_ELEMENTS);

this.trainLabels =
    this.datasetLabels.slice(0, NUM_CLASSES * NUM_TRAIN_ELEMENTS);
this.testLabels =
    this.datasetLabels.slice(NUM_CLASSES * NUM_TRAIN_ELEMENTS);

For training, the data can also be batched. The images will be in Float32Arrays and the labels in UInt8Arrays. They’re then converted into tensor2d types called xs and labels:

nextBatch(batchSize, data, index) {
    const batchImagesArray = new Float32Array(batchSize * IMAGE_SIZE);
    const batchLabelsArray = new Uint8Array(batchSize * NUM_CLASSES);

    for (let i = 0; i < batchSize; i++) {
        const idx = index();

        const image =
            data[0].slice(idx * IMAGE_SIZE, idx * IMAGE_SIZE + IMAGE_SIZE);
        batchImagesArray.set(image, i * IMAGE_SIZE);

        const label =
            data[1].slice(idx * NUM_CLASSES, idx * NUM_CLASSES + NUM_CLASSES);
        batchLabelsArray.set(label, i * NUM_CLASSES);

    const xs = tf.tensor2d(batchImagesArray, [batchSize, IMAGE_SIZE]);
    const labels = tf.tensor2d(batchLabelsArray, [batchSize, NUM_CLASSES]);

    return {xs, labels};

The training data can then use this batch function to return shuffled training batches of the desired batch size:

nextTrainBatch(batchSize) {
    return this.nextBatch(
        batchSize, [this.trainImages, this.trainLabels], () => {
            this.shuffledTrainIndex =
                (this.shuffledTrainIndex + 1) % this.trainIndices.length;
            return this.trainIndices[this.shuffledTrainIndex];

Test data can be batched and shuffled in exactly the same way.

Now, to get ready for training, you can set up some parameters for the metrics you want to capture, what the visualization will look like, and details like the batch sizes. To get the batches for training, call nextTrainBatch and reshape the Xs to the correct tensor size. You can then do exactly the same for the test data:

const metrics = ['loss', 'val_loss', 'accuracy', 'val_accuracy'];
const container = { name: 'Model Training', styles: { height: '640px' }, 
                    tab: 'Training Progress' };
const fitCallbacks = tfvis.show.fitCallbacks(container, metrics);

const BATCH_SIZE = 512;
const TRAIN_DATA_SIZE = 5500;
const TEST_DATA_SIZE = 1000;

const [trainXs, trainYs] = tf.tidy(() => {
    const d = data.nextTrainBatch(TRAIN_DATA_SIZE);
    return [
        d.xs.reshape([TRAIN_DATA_SIZE, 28, 28, 1]),

const [testXs, testYs] = tf.tidy(() => {
    const d = data.nextTestBatch(TEST_DATA_SIZE);
    return [
        d.xs.reshape([TEST_DATA_SIZE, 28, 28, 1]),

Note the tf.tidy call. With TensorFlow.js this will, as its name suggests, tidy up, cleaning up all intermediate tensors except those that the function returns. It’s essential when using TensorFlow.js to prevent memory leaks in the browser.

Now that you have everything set up, it’s easy to do the training, giving it your training Xs and Ys (labels) as well as the validation Xs and Ys:

return model.fit(trainXs, trainYs, {
    batchSize: BATCH_SIZE,
    validationData: [testXs, testYs],
    epochs: 20,
    shuffle: true,
    callbacks: fitCallbacks

As you train, the callbacks will give you visualizations in the visor, as you saw back in Figure 16-1.

Running Inference on Images in TensorFlow.js

To run inference, you’ll first need an image. In Figure 16-1, you saw an interface where an image could be drawn by hand and have inference performed on it. This uses a 280 × 280 canvas that is set up like this:

rawImage = document.getElementById('canvasimg');
ctx = canvas.getContext("2d");
ctx.fillStyle = "black";

Note that the canvas is called rawImage. After the user has drawn in the image (code for that is in the GitHub repo for this book), you can then run inference on it by grabbing its pixels using the tf.browser.fromPixels API:

var raw = tf.browser.fromPixels(rawImage,1);

It’s 280 × 280, so it needs to be resized to 28 × 28 for inference. This is done using the tf.image.resize APIs:

var resized = tf.image.resizeBilinear(raw, [28,28]);

The input tensor to the model is 28 × 28 × 1, so you need to expand the dimensions:

var tensor = resized.expandDims(0);

Now you can predict, using model.predict and passing it the tensor. The output of the model is a set of probabilities, so you can pick the biggest one using TensorFlow’s argMax function:

var prediction = model.predict(tensor);
var pIndex = tf.argMax(prediction, 1).dataSync();

The full code, including all the HTML for the page, the JavaScript for the drawing functions, as well as the TensorFlow.js model training and inference, is available in the book’s GitHub repository.


JavaScript is a very powerful browser-based language that can be used for many scenarios. In this chapter you took a tour of what it takes to train an image-based classifier in the browser, and then put that together with a canvas on which the user could draw. The input could then be parsed into a tensor that could be classified, with the results returned to the user. It’s a useful demonstration that puts together many of the pieces of programming in JavaScript, illustrating some of the constraints that you might encounter in training, such as needing to reduce the number of HTTP connections, and how to take advantage of built-in decoders to handle data management, like you saw with the sparsely encoded labels.

You may not always want to train a new model in the browser, but instead want to reuse existing ones that you’ve created in TensorFlow using Python. You’ll explore how to do that in the next chapter.