Codementor Events

How to build a quick image labelling tool to train a neural network?

Published Aug 23, 2019
How to build a quick image labelling tool to train a neural network?

Intro

Let's say you want to train a neural network to recognise certain types of objects in your image (it can be cars, pedestrians and road signages, for instance).

For that purpose, you will need to label the surrounding box of each object of interest in each image of your training set.

This is where a simple web application in Dash comes in handy.

What is Dash?

Dash by Plotly is a very simple and efficient way to build an interactive analytic apps.

Dash Canvas is a module for image annotation and image processing using Dash. This is what we will use to label the objects inside an image.

Installation

First, you need to install Dash Canvas:

$ pip install dash-canvas

Building the app

Once you are set up with the installation process, it is time to try building the app. I will go into more details for each section, but in case you want to explore further, I highly recommend taking the tutorial tour which actually starts at the second chapter, once you have installed Dash.

Imports

Here are the different imports we will need for our app to work.

import dash
import dash_canvas
import dash_table

import pandas as pd
import dash_html_components as html

from dash.dependencies import Input, Output, State
from dash_canvas.utils import parse_jsonstring_rectangle

It is pretty much self-explanatory, but two points are worth mentioning:

  1. Input, Output and State are actually the basic components that you will use in your app's callbacks. It respectively specifies:
    • the part(s) of your app that will trigger the callback (Input),
    • the part(s) of your app that should be consulted for the callback (State),
    • the part of your app that will be impacted by the callback (Output).
  2. parse_jsonstring_rectangle is a specific method from dash_canvas.utils that transform a json (selected area) to the coordinates.

Some parameters

First, let's introduce some parameters:

filename = 'https://scontent.flux1-1.fna.fbcdn.net/v/t31.0-8/12967364_10156756123340156_4744155697204378544_o.jpg?_nc_cat=105&_nc_oc=AQlhs2-REQnUOid86bHftpz5qanIlzFkiB2jeF3EEeH8EYSXp7wTHsTFI-PzHCEyBT9gY_ZlLALaPK97_rhUcuB1&_nc_ht=scontent.flux1-1.fna&oh=39ab45df1db97aa64bdb289c7d94666d&oe=5DCBE290'

app = dash.Dash(__name__)

app.config.suppress_callback_exceptions = True

list_columns = ['width', 'height', 'left', 'top', 'animal']
list_animals = ['penguin', 'iguana', 'turtle', 'pelican']

columns = [{'name': i, "id": i} for i in list_columns]
columns[-1]['presentation'] = 'dropdown'

animals = [{'label': i, 'value': i} for i in list_animals]

The filename can be an URL for the image to be labelled (here, it's a photo I took in Galapagos). You define the app as a Dash component, you name your database columns (for the boxes), the possibilities of labels (list_animals), and you specify a list of dictionaries for both the colums and the types of objects you want to label. Note that we add a key,value 'presentation': 'dropdown' for the last column. This will indeed allow us to specify which object (here animal) we are labelling for the corresponding box coordinates.

App Layout

Now, it is time to specify the html tags to build the app:

app.layout = html.Div([
    html.Div([
              html.H3('Label images with bounding boxes'),
              dash_canvas.DashCanvas(
                                     id='canvas',
                                     width=500,
                                     tool='rectangle',
                                     lineWidth=2,
                                     lineColor='rgba(0, 255, 0, 0.5)',
                                     filename=filename,
                                     hide_buttons=['pencil', 'line'],
                                     goButtonTitle='Label'
                                     ),
             ]),
    html.Div([
              dash_table.DataTable(
                                   id='table',
                                   columns=columns,
                                   editable=True,
                                   dropdown={
                                             'animal': {
                                                      'options': animals
                                                     }
                                            }

                                  ),
             ])
])
  • The DashCanvas contains the image:
    • the id of the container (used later in the callback functions)
    • the width of the window
    • the selection tool
    • the width of the line
    • the color (and transparency) of the line
    • the file path
    • the buttons to hide
    • the button to label
  • The DataTable contains the table:
    • the id of the container (used later in the callback functions)
    • the columns
    • the editable option (for the animal)
    • the dropdown options (for the label)

The Callbacks

Of course, if you want your app to be responsive, you need to write some callbacks:

@app.callback(Output('table', 'data'), [Input('canvas', 'json_data')])
def show_string(json_data):
    box_coordinates = parse_jsonstring_rectangle(json_data)
    df = pd.DataFrame(box_coordinates, columns=list_columns[:-1])
    df['animal'] = 'penguin'
    return df.to_dict('records')

The callback is a decorator @ (wrapper, if you want) that is called when something changes in the Input. It then executes a function with the values collected from the different input(s) and potentially the different state(s). It then return the value to the input.

Don't worry if you didn't get the last paragraph, walking through this example should clarify things.

When the property 'json_data' of 'canvas' changes (i.e. when some selection box has been drawn around an object), the function show_string(json_data) is called, it basically translates the json of the selection into a dictionary of records (1 per object), setting the default animal to 'penguin'.

Let's run the code!!!

You just need to add the line of code

if __name__ == '__main__':
    app.run_server(debug=True)

and you are good to go.

Running the script will give you an IP address which opens a page on your browser. You can then draw the boxes around the animals and then click on the button 'Label' to get the coordinates of all the boxes. You are then able to select an animal for each box.

Enjoy!

Complete Code

import dash
import dash_canvas
import dash_table

import pandas as pd
import dash_html_components as html

from dash.dependencies import Input, Output, State
from dash_canvas.utils import parse_jsonstring_rectangle




filename = 'https://scontent.flux1-1.fna.fbcdn.net/v/t31.0-8/12967364_10156756123340156_4744155697204378544_o.jpg?_nc_cat=105&_nc_oc=AQlhs2-REQnUOid86bHftpz5qanIlzFkiB2jeF3EEeH8EYSXp7wTHsTFI-PzHCEyBT9gY_ZlLALaPK97_rhUcuB1&_nc_ht=scontent.flux1-1.fna&oh=39ab45df1db97aa64bdb289c7d94666d&oe=5DCBE290'

app = dash.Dash(__name__)

app.config.suppress_callback_exceptions = True

list_columns = ['width', 'height', 'left', 'top', 'animal']
list_animals = ['penguin', 'iguana', 'turtle', 'pelican']

columns = [{'name': i, "id": i} for i in list_columns]
columns[-1]['presentation'] = 'dropdown'

animals = [{'label': i, 'value': i} for i in list_animals]


app.layout = html.Div([
    html.Div([
              html.H3('Label images with bounding boxes'),
              dash_canvas.DashCanvas(
                                     id='canvas',
                                     width=500,
                                     tool='rectangle',
                                     lineWidth=2,
                                     lineColor='rgba(0, 255, 0, 0.5)',
                                     filename=filename,
                                     hide_buttons=['pencil', 'line'],
                                     goButtonTitle='Label'
                                     ),
             ]),
    html.Div([
              dash_table.DataTable(
                                   id='table',
                                   columns=columns,
                                   editable=True,
                                   dropdown={
                                             'animal': {
                                                      'options': animals
                                                     }
                                            }

                                  ),
             ])
])


@app.callback(Output('table', 'data'), [Input('canvas', 'json_data')])
def show_string(json_data):
    box_coordinates = parse_jsonstring_rectangle(json_data)
    df = pd.DataFrame(box_coordinates, columns=list_columns[:-1])
    df['animal'] = 'penguin'
    return df.to_dict('records')


if __name__ == '__main__':
    app.run_server(debug=True)

Final thoughts and next steps

I deliberately opted for a 'simple' version of the code, which is very easy to follow, but a more robust version should not contain any label name in the code, everything could be passed via argparse, or a yaml file. I encourage you to write this part as an exercice. In that way, your code will be portable to a lot of situations.

I recommend you to add an input box to download the file you want to label (see this Drag and Drop from Dash Core Components, aka dcc) and export the boxes' coordinates into a file, for future training of your network.

Enjoy the labelling!!

BONUS

Code in action

Discover and read more posts from Bertrand Delvaux, PhD
get started