← BACK

Candle.js

https://github.com/shaunabanana/candle-js

Intro

While there's no Torch in the browser, there's still Candle.

Candle.js is a GPU-accelerated library to make running PyTorch models in the browser a bit easier. It also serves as a tensor computation library. Currently, Candle.js supports these PyTorch layers:

  • Linear
  • Conv2d (no dilation support)
  • ConvTranspose2d (no padding and dilation support)
  • ReLU
  • Sigmoid

Behind the scenes

Candle.js started as a part of FontNN, a quick, one-shot project I did in May 2020 to try my hands at deep learning and bringing it to the web. I used PyTorch to train a modified Variational Autoencoder (VAE) to extract font styles from some character examples, and generate the entire alphabet in that style.

The need for an easy way to move PyTorch models onto the web didn't occur to me until I've finished training the model. As far as I know, there wasn't a straightforward solution. I could, of course, convert the model to ONNX, learn Tensorflow.js, and port the model to it. However, being lazy and stubborn, I decided to hack up my own solution.

I basically wanted to bring PyTorch into the browser. Layer names should remain the same. Defining the model should be as close as possible syntactically. A Candle.js tensor should provide similar mechanisms as a PyTorch one.

So, if I defined my model like this:

def __init__(...):
    ...
    self.dconv1 = nn.Conv2d(1, 32, kernel_size=4, stride=2)
    self.dconv2 = nn.Conv2d(32, 64, kernel_size=4, stride=2)
    self.dconv3 = nn.Conv2d(64, 128, kernel_size=4, stride=2)
    self.dconv4 = nn.Conv2d(128, 256, kernel_size=4, stride=2)
    self.proj_mu = nn.Linear(1024, 64)
    self.proj_logvar = nn.Linear(1024, 64)

    self.proj_z = nn.Linear(126, 1024)
    self.uconv1 = nn.ConvTranspose2d(1024, 128, kernel_size=5, stride=2)
    self.uconv2 = nn.ConvTranspose2d(128, 64, kernel_size=5, stride=2)
    self.uconv3 = nn.ConvTranspose2d(64, 32, kernel_size=6, stride=2)
    self.uconv4 = nn.ConvTranspose2d(32, 1, kernel_size=6, stride=2)

Then I wanted my javascript model to look exactly the same:

constructor (...) {
    ...
    this.dconv1 = new Conv2d(1, 32, 4, 2);
    this.dconv2 = new Conv2d(32, 64, 4, 2);
    this.dconv3 = new Conv2d(64, 128, 4, 2);
    this.dconv4 = new Conv2d(128, 256, 4, 2);
    this.proj_mu = new Linear(1024, 64);
    this.proj_logvar = new Linear(1024, 64);

    this.proj_z = new Linear(126, 1024);
    this.uconv1 = new ConvTranspose2d(1024, 128, 5, 2);
    this.uconv2 = new ConvTranspose2d(128, 64, 5, 2);
    this.uconv3 = new ConvTranspose2d(64, 32, 6, 2);
    this.uconv4 = new ConvTranspose2d(32, 1, 6, 2);
}

Set on this goal, I implemented all the layers I've used in the model (Linear, ReLU, Sigmoid, Conv2d, and ConvTranspose2d) by hand using GPU.js [1]. Then a python helper script would strip the PyTorch model down to only the necessary data, join them into a compact binary format, then slice the bytes up into browser-managable chunks. On the browser side, a loader script loads the chunks according to a model description, fills the parameters into the layers, et voila, you have your PyTorch model in the browser.

Next steps

I'm in the process of revamping the project. To make further development and collaboration possible, I would have to structure it properly first.

After that, I will start by improving the Tensor class, which I consider the heart of Candle.js. It can be very useful by itself, even just for tensor computation. Also, the performance of the whole library really depends on the performance of Tensor.

I'm trying to come up with a way to make contributing layers easier. This will come later, certainly not before the revamping is finished.