Search code examples
pythongitflaskherokubert-language-model

Deploying on heroku bert pytorch model using flask: ERROR: _pickle.UnpicklingError: invalid load key, 'v'


Trying to deploy bert model on Heroku.

import torch
import transformers
import numpy as np
from flask import Flask, render_template, request
from model import DISTILBERTBaseUncased

MAX_LEN = 320
TOKENIZER = transformers.DistilBertTokenizer.from_pretrained(
    "distilbert-base-uncased", do_lower_case=True
)
DEVICE = "cpu"
MODEL = DISTILBERTBaseUncased()
MODEL.load_state_dict(torch.load("weight.bin"))
MODEL.to(DEVICE)
MODEL.eval()

app = Flask(__name__)


def sentence_prediction(sentence):
    tokenizer = TOKENIZER
    max_len = MAX_LEN
    comment = str(sentence)
    comment = " ".join(comment.split())

    inputs = tokenizer.encode_plus(
        comment,
        None,
        add_special_tokens=True,
        max_length=max_len,
        pad_to_max_length=True,
    )

    ids = inputs["input_ids"]
    mask = inputs["attention_mask"]

    ids = torch.tensor(ids, dtype=torch.long).unsqueeze(0)
    mask = torch.tensor(mask, dtype=torch.long).unsqueeze(0)

    ids = ids.to(DEVICE, dtype=torch.long)
    mask = mask.to(DEVICE, dtype=torch.long)

    outputs = MODEL(ids=ids, mask=mask)

    outputs = torch.sigmoid(outputs).cpu().detach().numpy()
    return outputs[0][0]


@app.route("/")
def index_page():
    return render_template("index.html")


@app.route("/model")
def models():
    return render_template("model.html")


@app.route("/predict", methods=["POST", "GET"])
def predict():
    if request.method == "POST":
        sentence = request.form.get("text")
        Toxic_prediction = sentence_prediction(sentence)
        return render_template(
            "index.html", prediction_text=np.round((Toxic_prediction * 100), 2)
        )
    return render_template("index.html", prediction_text="")


if __name__ == "__main__":
    app.run(debug=True)

ERROR

MODEL.load_state_dict(torch.load("weight.bin"))

2020-05-18T06:32:32.134536+00:00 app[web.1]: File "/app/.heroku/python/lib/python3.7/site-packages/torch/serialization.py", line 593, in load

2020-05-18T06:32:32.134536+00:00 app[web.1]: return _legacy_load(opened_file, map_location, pickle_module, **pickle_load_args)

2020-05-18T06:32:32.134536+00:00 app[web.1]: File "/app/.heroku/python/lib/python3.7/site-packages/torch/serialization.py", line 763, in _legacy_load

2020-05-18T06:32:32.134537+00:00 app[web.1]: magic_number = pickle_module.load(f, **pickle_load_args)

2020-05-18T06:32:32.134537+00:00 app[web.1]: _pickle.UnpicklingError: invalid load key, 'v'.

  1. codes are working fine locally.
  2. The Heroku deployment method is Github
  3. weight.bin size is 255 MB
  4. flask API working fine in localhost

Solution

  • Checking Error 1.MODEL.load_state_dict(torch.load("weight.bin")) --> You should just use below or checking letter properly.

    model.load_state_dict(torch.load(model_state_dict))
    

    2._pickle.UnpicklingError: invalid load key, 'v'. --> I think git-lfs is not installed in your environment. after install it, just try again.