File size: 1,591 Bytes
6255c29
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
import gradio as gr
import torch
import torch.nn as nn
from torchvision import transforms
from PIL import Image
from torchvision import models

def predict(image):
    print(type(image))
    image = Image.fromarray(image.astype('uint8'), 'RGB')
    # Load model
    model = models.resnet50(pretrained=True)
    num_ftrs = model.fc.in_features
    model.fc = nn.Linear(num_ftrs, 1)
    model.load_state_dict(torch.load("best_f1.pth"))
    model.eval()
    
    # Preprocess image
    valid_transform = transforms.Compose([
        # transforms.ToPILImage(),                # Convert the image to a PIL Image
        transforms.Resize((224, 224)), # Resize the image to final_size x final_size
        transforms.ToTensor(),                 # Convert the image to a PyTorch tensor
        transforms.Normalize(                  # Normalize the image
            mean=[0.485, 0.456, 0.406],
            std=[0.229, 0.224, 0.225]
        )
    ])
    
    input_batch = valid_transform(image).unsqueeze(0)
    # Make prediction
    with torch.no_grad():
        output = model(input_batch)
        output = torch.sigmoid(output).squeeze().item()
        if output > 0.5:
            predicted = 1
        else:
            predicted = 0
    
    int2label = {0: "cat", 1: "dog"}
    return int2label[predicted]

demo = gr.Interface(
    predict, 
    inputs="image", 
    outputs="label",
    title="Cats vs Dogs",
    description="This model predicts whether an image contains a cat or a dog.",
    examples = ["assets/7.jpg", "assets/44.jpg", "assets/82.jpg", "assets/83.jpg"]
)

demo.launch()