Update README.md
Browse files
README.md
CHANGED
@@ -49,19 +49,16 @@ All you need to do is to run this in the terminal: <code>pip install gradio</cod
|
|
49 |
Here’s we define our image classification model prediction function in PyTorch (any framework, like TensorFlow, scikit-learn, JAX, or a plain Python will work as well):
|
50 |
<pre>
|
51 |
<code>
|
52 |
-
|
53 |
-
|
54 |
-
|
55 |
-
inp = Image.fromarray(inp.astype('uint8'), 'RGB')
|
56 |
|
57 |
-
|
58 |
|
59 |
-
|
60 |
|
61 |
-
|
62 |
|
63 |
-
|
64 |
-
</pre>
|
65 |
</code>
|
66 |
</pre>
|
67 |
</p>
|
|
|
49 |
Here’s we define our image classification model prediction function in PyTorch (any framework, like TensorFlow, scikit-learn, JAX, or a plain Python will work as well):
|
50 |
<pre>
|
51 |
<code>
|
52 |
+
def predict(inp):
|
53 |
+
inp = Image.fromarray(inp.astype('uint8'), 'RGB')
|
|
|
|
|
54 |
|
55 |
+
inp = transforms.ToTensor()(inp).unsqueeze(0)
|
56 |
|
57 |
+
with torch.no_grad():
|
58 |
|
59 |
+
prediction = torch.nn.functional.softmax(model(inp)[0], dim=0)
|
60 |
|
61 |
+
return {labels[i]: float(prediction[i]) for i in range(1000)}
|
|
|
62 |
</code>
|
63 |
</pre>
|
64 |
</p>
|