Skip to content

Commit c8194d2

Browse files
Akshat Shrivastavafacebook-github-bot
Akshat Shrivastava
authored andcommitted
XLMR Document Classification Server + Console (facebookresearch#1358)
Summary: Pull Request resolved: facebookresearch#1358 This diff adds open source support for our XLM-R document classification server along with a simple console to visualize instances and provide additional annotations. Reviewed By: geof90 Differential Revision: D21482314 fbshipit-source-id: e5e3fd5695f17876ec5db5329d0875b5a8c2afd5
1 parent 5869b83 commit c8194d2

File tree

9 files changed

+942
-0
lines changed

9 files changed

+942
-0
lines changed

demo/xlm_server/README.md

Lines changed: 35 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,35 @@
1+
# TorchScript Server for XLM-R
2+
In this directory, we provide a model server and a client console for DocNN and XLM-R text classifications models. These models were trained with PyTorch and exported to TorchScript
3+
4+
5+
## Server
6+
For monolingual DocNN,
7+
```
8+
$ mkdir build && cd build
9+
$ # Copy the downloaded models into this directory
10+
$ echo -e 'FROM pytext/predictor_service_torchscript:who\nCOPY *.torchscript /app/\nCMD ["./server","mono.model.pt.torchscript"]' >> Dockerfile
11+
$ docker build -t server .
12+
$ docker run -it -p 8080:8080 server
13+
$ curl -d '{"text": "hi"}' -H 'Content-Type: application/json' localhost:8080
14+
```
15+
For multilingual XLM-R,
16+
```
17+
echo -e 'FROM pytext/predictor_service_torchscript:who\nCOPY *.torchscript /app/\nCMD ["./server","multi.model.pt.torchscript", "multi.vocab.model.pt.torchscript"]' >> Dockerfile
18+
```
19+
20+
21+
## Console
22+
The console provides a front-end webpage to view the predictions of the classification model. It also allows for exploring model predictions interactively, and for logging corrections back to the server.
23+
24+
### Console Setup
25+
26+
```
27+
$ python3 -m venv env
28+
$ source env/bin/activate
29+
$ (env) pip install -r requirements.txt
30+
```
31+
32+
running the server:
33+
```
34+
$ python server.py --modelserver http://localhost:8080
35+
```
Lines changed: 20 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,20 @@
1+
aniso8601==8.0.0
2+
certifi==2020.4.5.1
3+
chardet==3.0.4
4+
click==7.1.1
5+
Flask==1.1.2
6+
Flask-Cors==3.0.8
7+
Flask-RESTful==0.3.8
8+
idna==2.9
9+
itsdangerous==1.1.0
10+
Jinja2==2.11.1
11+
MarkupSafe==1.1.1
12+
numpy==1.18.2
13+
pandas==1.0.3
14+
python-dateutil==2.8.1
15+
pytz==2019.3
16+
requests==2.23.0
17+
six==1.14.0
18+
tqdm==4.45.0
19+
urllib3==1.25.8
20+
Werkzeug==1.0.1

demo/xlm_server/console/server.py

Lines changed: 123 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,123 @@
1+
#!/usr/bin/env python3
2+
# Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved
3+
import argparse
4+
import csv
5+
import json
6+
import logging
7+
import threading
8+
from typing import Any, Dict, Tuple
9+
10+
import requests
11+
from flask import Blueprint, Flask, jsonify, render_template, request
12+
from flask_cors import CORS
13+
14+
15+
# a lock for writing to the output
16+
# data file
17+
DATA_FILE_LOCK = threading.Lock()
18+
19+
20+
def get_args() -> argparse.ArgumentParser:
21+
"""
22+
Return CLI configuration for running server
23+
"""
24+
parser = argparse.ArgumentParser(description="Run the html hosting server")
25+
parser.add_argument(
26+
"--port", default=5000, type=int, help="the port to launch server on"
27+
)
28+
29+
parser.add_argument("--modelserver", default="http://localhost:8080")
30+
31+
parser.add_argument(
32+
"--debug", action="store_true", help="Launch debug version of server"
33+
)
34+
35+
parser.add_argument(
36+
"--datafile", default="data.csv", help="the file to write data to"
37+
)
38+
39+
parser.add_argument(
40+
"--console_address",
41+
default="http://localhost:5000",
42+
help="the location of the console server",
43+
)
44+
return parser
45+
46+
47+
logger = logging.getLogger(__name__)
48+
49+
50+
def get_key_from_data(data, key):
51+
if key in data:
52+
return data[key]
53+
return None
54+
55+
56+
def create_app(config_filename: str):
57+
api_bp = Blueprint("api", __name__)
58+
app = Flask(__name__, static_url_path="/static")
59+
CORS(app, resources=r"/api/*")
60+
61+
app.register_blueprint(api_bp, url_prefix="/api")
62+
63+
return app
64+
65+
66+
def setup_app(app: Flask, args: object):
67+
@app.route("/api/model/", methods=["GET"])
68+
def get_predictions():
69+
query = get_key_from_data(request.args, "query")
70+
payload = {"text": str(query)}
71+
r = requests.post(args.modelserver, json=payload)
72+
response = json.loads(r.text)
73+
intent_scores = response["intent_ranking"]
74+
75+
intent_scores = list(filter(lambda inp: inp, intent_scores))
76+
77+
def convert_to_tuple(intent_score: Dict[str, Any]) -> Tuple[str, float]:
78+
intent_name = intent_score["name"]
79+
intent_score = intent_score["confidence"]
80+
return intent_name, intent_score
81+
82+
intent_scores = list(map(convert_to_tuple, intent_scores))
83+
intent_scores = sorted(intent_scores, reverse=True, key=lambda tup: tup[1])
84+
return (
85+
jsonify(
86+
{
87+
"query": query,
88+
"prediction": intent_scores[0],
89+
"raw_scores": intent_scores,
90+
}
91+
),
92+
200,
93+
)
94+
95+
@app.route("/api/add_data/", methods=["GET"])
96+
def upload_data():
97+
data_point = get_key_from_data(request.args, "data_point")
98+
query, label = data_point.split(",")
99+
with DATA_FILE_LOCK:
100+
with open(args.datafile, "a+") as csv_file:
101+
csvwriter = csv.writer(csv_file, delimiter="\t")
102+
csvwriter.writerow([query, label])
103+
104+
return jsonify({"query": query, "label": label}), 200
105+
106+
@app.route("/", methods=["GET"])
107+
def root():
108+
return render_template("index.html", console_address=args.console_address)
109+
110+
111+
def main():
112+
args = get_args().parse_args()
113+
114+
app = create_app("config")
115+
setup_app(app, args)
116+
try:
117+
app.run(host="0.0.0.0", debug=args.debug, port=args.port)
118+
except KeyboardInterrupt:
119+
print("Received Keyboard interrupt, exiting")
120+
121+
122+
if __name__ == "__main__":
123+
main()
Lines changed: 159 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,159 @@
1+
<doctype html>
2+
3+
<head>
4+
<title>Document Classification Console</title>
5+
6+
<style>
7+
.button {
8+
margin-top: 20px;
9+
height: 30px;
10+
width: 100px;
11+
font-size: 12pt;
12+
}
13+
14+
#query {
15+
height: 50px;
16+
width: 700px;
17+
font-size: 14pt;
18+
padding-left: 10px;
19+
}
20+
21+
#submit_button {
22+
height: 50px;
23+
width: 70px;
24+
font-size: 14pt;
25+
}
26+
27+
.selection_area {
28+
padding: 20px;
29+
}
30+
31+
input[type=radio]:checked+label {
32+
color: red;
33+
}
34+
</style>
35+
36+
<script>
37+
var PYTHON_SERVER_ADDRESS = "{{ console_address }}";
38+
window.onload = function () {
39+
function makeScoreString(raw_scores) {
40+
return raw_scores[0] + ": " + (raw_scores[1] * 100).toFixed(2) + "%";
41+
}
42+
function makeUL(query, array) {
43+
// Create the list element:
44+
var prediction_container = document.createElement('div');
45+
var list = document.createElement('ul');
46+
47+
var header = document.createElement("h4");
48+
header.appendChild(document.createTextNode("Help us gather more data!"));
49+
prediction_container.appendChild(header);
50+
prediction_container.appendChild(document.createTextNode("Select the correct prediction from below and submit to provide us more data to learn from!"))
51+
52+
// prediction_container.appendChild(list);
53+
54+
var selection_area = document.createElement("div");
55+
selection_area.setAttribute("class", "selection_area");
56+
57+
58+
for (var i = 0; i < array.length; i++) {
59+
// Create the list item:
60+
var item = document.createElement('li');
61+
var intent_name = array[i][0];
62+
var intent_score = array[i][1];
63+
64+
var choiceSelection = document.createElement('input')
65+
choiceSelection.setAttribute('type', 'radio');
66+
choiceSelection.setAttribute('name', 'choice');
67+
choiceSelection.setAttribute('value', query + "," + intent_name);
68+
choiceSelection.setAttribute('id', "choice_" + intent_name);
69+
70+
var choiceLabel = document.createElement('label')
71+
choiceLabel.setAttribute('for', "choice_" + intent_name);
72+
choiceLabel.appendChild(document.createTextNode(makeScoreString(array[i])));
73+
74+
selection_area.appendChild(choiceSelection);
75+
selection_area.appendChild(choiceLabel);
76+
selection_area.appendChild(document.createElement("br"));
77+
}
78+
79+
// add submit button
80+
81+
var submit_data_button = document.createElement("button");
82+
submit_data_button.setAttribute("type", "submit");
83+
submit_data_button.setAttribute("class", "button");
84+
submit_data_button.onclick = function () {
85+
var data_point = document.querySelector('input[name="choice"]:checked').value;
86+
var xhr = new XMLHttpRequest();
87+
xhr.onreadystatechange = function () {
88+
if (this.readyState == 4) {
89+
if (this.status == 200) {
90+
var data = JSON.parse(xhr.responseText);
91+
// Display the returned data in browser
92+
var success_message = "The data point (query=" + data.query + ", label=" + data.label + ") has been added, thank you for your help!";
93+
document.getElementById('raw_scores').innerHTML = success_message;
94+
} else {
95+
console.error('Error: ' + this.status);
96+
}
97+
}
98+
};
99+
100+
xhr.open('GET', PYTHON_SERVER_ADDRESS + 'api/add_data/?data_point=' + encodeURI(data_point));
101+
xhr.send();
102+
};
103+
104+
submit_data_button.appendChild(document.createTextNode("Submit"));
105+
106+
selection_area.appendChild(submit_data_button);
107+
prediction_container.appendChild(selection_area);
108+
return prediction_container;
109+
}
110+
111+
// Add the contents of options[0] to #foo:
112+
document.getElementById('submit_button').onclick = function () {
113+
var query_box_val = document.getElementById('query').value;
114+
var xhr = new XMLHttpRequest();
115+
xhr.onreadystatechange = function () {
116+
if (this.readyState == 4) {
117+
if (this.status == 200) {
118+
var data = JSON.parse(xhr.responseText);
119+
// Display the returned data in browser
120+
document.getElementById("best_prediction").innerHTML = makeScoreString(data.prediction);
121+
document.getElementById('raw_scores').innerHTML = "";
122+
document.getElementById('raw_scores').appendChild(makeUL(data.query, data.raw_scores));
123+
} else {
124+
console.error('Error: ' + this.status);
125+
}
126+
}
127+
};
128+
129+
xhr.open('GET', PYTHON_SERVER_ADDRESS + "api/model/?query=" + encodeURI(query_box_val));
130+
xhr.send();
131+
};
132+
133+
// Query on enter
134+
var query_box = document.getElementById("query");
135+
query_box.addEventListener("keyup", function (event) {
136+
// Number 13 is the "Enter" key on the keyboard
137+
if (event.keyCode === 13) {
138+
event.preventDefault();
139+
document.getElementById("submit_button").click();
140+
}
141+
});
142+
};
143+
</script>
144+
</head>
145+
146+
<body>
147+
148+
<h2 id="title">Document Classification Console</h2>
149+
<input type="text" name="doc" id="query" value="Sample query">
150+
<button type="submit" id="submit_button">Enter</button>
151+
152+
<br />
153+
<h3>Top Model Prediction</h3>
154+
<div id="best_prediction"></div>
155+
156+
<br />
157+
<div id="raw_scores"></div>
158+
</body>
159+
</doctype>

0 commit comments

Comments
 (0)