osbm's picture
Update app.py
f702a0a verified
raw
history blame contribute delete
No virus
1.63 kB
import gradio as gr
import torch
import torch.nn as nn
from torchvision import transforms
from PIL import Image
from torchvision import models
def predict(image):
print(type(image))
image = Image.fromarray(image.astype('uint8'), 'RGB')
# Load model
model = models.resnet50(pretrained=True)
num_ftrs = model.fc.in_features
model.fc = nn.Linear(num_ftrs, 1)
model.load_state_dict(torch.load("best_f1.pth", map_location=torch.device('cpu')))
model.eval()
# Preprocess image
valid_transform = transforms.Compose([
# transforms.ToPILImage(), # Convert the image to a PIL Image
transforms.Resize((224, 224)), # Resize the image to final_size x final_size
transforms.ToTensor(), # Convert the image to a PyTorch tensor
transforms.Normalize( # Normalize the image
mean=[0.485, 0.456, 0.406],
std=[0.229, 0.224, 0.225]
)
])
input_batch = valid_transform(image).unsqueeze(0)
# Make prediction
with torch.no_grad():
output = model(input_batch)
output = torch.sigmoid(output).squeeze().item()
if output > 0.5:
predicted = 1
else:
predicted = 0
int2label = {0: "cat", 1: "dog"}
return int2label[predicted]
demo = gr.Interface(
predict,
inputs="image",
outputs="label",
title="Cats vs Dogs",
description="This model predicts whether an image contains a cat or a dog.",
examples = ["assets/7.jpg", "assets/44.jpg", "assets/82.jpg", "assets/83.jpg"]
)
demo.launch()