82 lines
2.4 KiB
Python
82 lines
2.4 KiB
Python
from fastapi import FastAPI, File, UploadFile
|
|
from fastapi.responses import JSONResponse
|
|
from fastapi.middleware.cors import CORSMiddleware
|
|
from PIL import Image
|
|
import torch, open_clip, numpy as np, faiss
|
|
import io
|
|
|
|
app = FastAPI(title="Pokemon Card Image Service")
|
|
|
|
# --- Allow CORS for Bun / Flutter ---
|
|
app.add_middleware(
|
|
CORSMiddleware,
|
|
allow_origins=["*"],
|
|
allow_methods=["*"],
|
|
allow_headers=["*"],
|
|
)
|
|
|
|
import os
|
|
|
|
BASE = os.path.dirname(os.path.abspath(__file__))
|
|
|
|
# --- Configuration ---
|
|
FAISS_INDEX_FILE = os.path.join(BASE, "card_index.faiss")
|
|
EMBEDDINGS_FILE = os.path.join(BASE, "embeddings.npy")
|
|
IDS_FILE = os.path.join(BASE, "ids.npy")
|
|
TOP_K = 5
|
|
|
|
# --- Load CLIP model ---
|
|
device = "cuda" if torch.cuda.is_available() else "cpu"
|
|
print("Loading model on device:", device)
|
|
model, _, preprocess = open_clip.create_model_and_transforms(
|
|
"ViT-L-14", pretrained="laion2b_s32b_b82k"
|
|
)
|
|
model = model.to(device).eval()
|
|
|
|
# --- Load FAISS index ---
|
|
print("Loading FAISS index from disk...")
|
|
index = faiss.read_index(FAISS_INDEX_FILE)
|
|
|
|
# --- Load IDs / metadata ---
|
|
ids = np.load(IDS_FILE)
|
|
metadata = {idx: {"id": ids[idx], "name": ids[idx]} for idx in range(len(ids))}
|
|
|
|
print("Service ready! FAISS index contains", index.ntotal, "cards.")
|
|
|
|
# --- Helper: encode image ---
|
|
def encode_image_bytes(image_bytes):
|
|
img = Image.open(io.BytesIO(image_bytes)).convert("RGB")
|
|
print("Image loaded successfully:", img.size)
|
|
with torch.no_grad():
|
|
emb = model.encode_image(preprocess(img).unsqueeze(0).to(device))
|
|
emb = emb.cpu().numpy()
|
|
faiss.normalize_L2(emb)
|
|
print("Embedding shape:", emb.shape)
|
|
return emb
|
|
|
|
# --- API endpoint ---
|
|
@app.post("/query")
|
|
async def query_image(file: UploadFile = File(...)):
|
|
try:
|
|
image_bytes = await file.read()
|
|
print(f"Received {len(image_bytes)} bytes from {file.filename}")
|
|
|
|
emb = encode_image_bytes(image_bytes)
|
|
|
|
# --- FAISS search ---
|
|
D, I = index.search(emb, TOP_K)
|
|
print("FAISS distances:", D)
|
|
print("FAISS indices:", I)
|
|
|
|
results = [metadata[int(i)] for i in I[0]]
|
|
return JSONResponse(content={"results": results})
|
|
|
|
except Exception as e:
|
|
print("ERROR during query:", e)
|
|
return JSONResponse(content={"error": str(e)}, status_code=500)
|
|
|
|
# --- Run server ---
|
|
if __name__ == "__main__":
|
|
import uvicorn
|
|
uvicorn.run(app, host="0.0.0.0", port=5001)
|