Files
pkmntcg-backend/pythonService/app.py
2026-01-16 18:07:23 +01:00

82 lines
2.4 KiB
Python

from fastapi import FastAPI, File, UploadFile
from fastapi.responses import JSONResponse
from fastapi.middleware.cors import CORSMiddleware
from PIL import Image
import torch, open_clip, numpy as np, faiss
import io
app = FastAPI(title="Pokemon Card Image Service")
# --- Allow CORS for Bun / Flutter ---
app.add_middleware(
CORSMiddleware,
allow_origins=["*"],
allow_methods=["*"],
allow_headers=["*"],
)
import os
BASE = os.path.dirname(os.path.abspath(__file__))
# --- Configuration ---
FAISS_INDEX_FILE = os.path.join(BASE, "card_index.faiss")
EMBEDDINGS_FILE = os.path.join(BASE, "embeddings.npy")
IDS_FILE = os.path.join(BASE, "ids.npy")
TOP_K = 3
# --- Load CLIP model ---
device = "cuda" if torch.cuda.is_available() else "cpu"
print("Loading model on device:", device)
model, _, preprocess = open_clip.create_model_and_transforms(
"ViT-L-14", pretrained="laion2b_s32b_b82k"
)
model = model.to(device).eval()
# --- Load FAISS index ---
print("Loading FAISS index from disk...")
index = faiss.read_index(FAISS_INDEX_FILE)
# --- Load IDs / metadata ---
ids = np.load(IDS_FILE)
metadata = {idx: {"id": ids[idx], "name": ids[idx]} for idx in range(len(ids))}
print("Service ready! FAISS index contains", index.ntotal, "cards.")
# --- Helper: encode image ---
def encode_image_bytes(image_bytes):
img = Image.open(io.BytesIO(image_bytes)).convert("RGB")
print("Image loaded successfully:", img.size)
with torch.no_grad():
emb = model.encode_image(preprocess(img).unsqueeze(0).to(device))
emb = emb.cpu().numpy()
faiss.normalize_L2(emb)
print("Embedding shape:", emb.shape)
return emb
# --- API endpoint ---
@app.post("/query")
async def query_image(file: UploadFile = File(...)):
try:
image_bytes = await file.read()
print(f"Received {len(image_bytes)} bytes from {file.filename}")
emb = encode_image_bytes(image_bytes)
# --- FAISS search ---
D, I = index.search(emb, TOP_K)
print("FAISS distances:", D)
print("FAISS indices:", I)
results = [metadata[int(i)] for i in I[0]]
return JSONResponse(content={"results": results})
except Exception as e:
print("ERROR during query:", e)
return JSONResponse(content={"error": str(e)}, status_code=500)
# --- Run server ---
if __name__ == "__main__":
import uvicorn
uvicorn.run(app, host="0.0.0.0", port=5001)