from fastapi import FastAPI, File, UploadFile from fastapi.responses import JSONResponse from fastapi.middleware.cors import CORSMiddleware from PIL import Image import torch, open_clip, numpy as np, faiss import io app = FastAPI(title="Pokemon Card Image Service") # --- Allow CORS for Bun / Flutter --- app.add_middleware( CORSMiddleware, allow_origins=["*"], allow_methods=["*"], allow_headers=["*"], ) import os BASE = os.path.dirname(os.path.abspath(__file__)) # --- Configuration --- FAISS_INDEX_FILE = os.path.join(BASE, "card_index.faiss") EMBEDDINGS_FILE = os.path.join(BASE, "embeddings.npy") IDS_FILE = os.path.join(BASE, "ids.npy") TOP_K = 5 # --- Load CLIP model --- device = "cuda" if torch.cuda.is_available() else "cpu" print("Loading model on device:", device) model, _, preprocess = open_clip.create_model_and_transforms( "ViT-L-14", pretrained="laion2b_s32b_b82k" ) model = model.to(device).eval() # --- Load FAISS index --- print("Loading FAISS index from disk...") index = faiss.read_index(FAISS_INDEX_FILE) # --- Load IDs / metadata --- ids = np.load(IDS_FILE) metadata = {idx: {"id": ids[idx], "name": ids[idx]} for idx in range(len(ids))} print("Service ready! FAISS index contains", index.ntotal, "cards.") # --- Helper: encode image --- def encode_image_bytes(image_bytes): img = Image.open(io.BytesIO(image_bytes)).convert("RGB") print("Image loaded successfully:", img.size) with torch.no_grad(): emb = model.encode_image(preprocess(img).unsqueeze(0).to(device)) emb = emb.cpu().numpy() faiss.normalize_L2(emb) print("Embedding shape:", emb.shape) return emb # --- API endpoint --- @app.post("/query") async def query_image(file: UploadFile = File(...)): try: image_bytes = await file.read() print(f"Received {len(image_bytes)} bytes from {file.filename}") emb = encode_image_bytes(image_bytes) # --- FAISS search --- D, I = index.search(emb, TOP_K) print("FAISS distances:", D) print("FAISS indices:", I) results = [metadata[int(i)] for i in I[0]] return JSONResponse(content={"results": results}) except Exception as e: print("ERROR during query:", e) return JSONResponse(content={"error": str(e)}, status_code=500) # --- Run server --- if __name__ == "__main__": import uvicorn uvicorn.run(app, host="0.0.0.0", port=5001)