init backend
This commit is contained in:
81
python_service/app.py
Normal file
81
python_service/app.py
Normal file
@@ -0,0 +1,81 @@
|
||||
from fastapi import FastAPI, File, UploadFile
|
||||
from fastapi.responses import JSONResponse
|
||||
from fastapi.middleware.cors import CORSMiddleware
|
||||
from PIL import Image
|
||||
import torch, open_clip, numpy as np, faiss
|
||||
import io
|
||||
|
||||
app = FastAPI(title="Pokemon Card Image Service")
|
||||
|
||||
# --- Allow CORS for Bun / Flutter ---
|
||||
app.add_middleware(
|
||||
CORSMiddleware,
|
||||
allow_origins=["*"],
|
||||
allow_methods=["*"],
|
||||
allow_headers=["*"],
|
||||
)
|
||||
|
||||
import os
|
||||
|
||||
BASE = os.path.dirname(os.path.abspath(__file__))
|
||||
|
||||
# --- Configuration ---
|
||||
FAISS_INDEX_FILE = os.path.join(BASE, "card_index.faiss")
|
||||
EMBEDDINGS_FILE = os.path.join(BASE, "embeddings.npy")
|
||||
IDS_FILE = os.path.join(BASE, "ids.npy")
|
||||
TOP_K = 5
|
||||
|
||||
# --- Load CLIP model ---
|
||||
device = "cuda" if torch.cuda.is_available() else "cpu"
|
||||
print("Loading model on device:", device)
|
||||
model, _, preprocess = open_clip.create_model_and_transforms(
|
||||
"ViT-L-14", pretrained="laion2b_s32b_b82k"
|
||||
)
|
||||
model = model.to(device).eval()
|
||||
|
||||
# --- Load FAISS index ---
|
||||
print("Loading FAISS index from disk...")
|
||||
index = faiss.read_index(FAISS_INDEX_FILE)
|
||||
|
||||
# --- Load IDs / metadata ---
|
||||
ids = np.load(IDS_FILE)
|
||||
metadata = {idx: {"id": ids[idx], "name": ids[idx]} for idx in range(len(ids))}
|
||||
|
||||
print("Service ready! FAISS index contains", index.ntotal, "cards.")
|
||||
|
||||
# --- Helper: encode image ---
|
||||
def encode_image_bytes(image_bytes):
|
||||
img = Image.open(io.BytesIO(image_bytes)).convert("RGB")
|
||||
print("Image loaded successfully:", img.size)
|
||||
with torch.no_grad():
|
||||
emb = model.encode_image(preprocess(img).unsqueeze(0).to(device))
|
||||
emb = emb.cpu().numpy()
|
||||
faiss.normalize_L2(emb)
|
||||
print("Embedding shape:", emb.shape)
|
||||
return emb
|
||||
|
||||
# --- API endpoint ---
|
||||
@app.post("/query")
|
||||
async def query_image(file: UploadFile = File(...)):
|
||||
try:
|
||||
image_bytes = await file.read()
|
||||
print(f"Received {len(image_bytes)} bytes from {file.filename}")
|
||||
|
||||
emb = encode_image_bytes(image_bytes)
|
||||
|
||||
# --- FAISS search ---
|
||||
D, I = index.search(emb, TOP_K)
|
||||
print("FAISS distances:", D)
|
||||
print("FAISS indices:", I)
|
||||
|
||||
results = [metadata[int(i)] for i in I[0]]
|
||||
return JSONResponse(content={"results": results})
|
||||
|
||||
except Exception as e:
|
||||
print("ERROR during query:", e)
|
||||
return JSONResponse(content={"error": str(e)}, status_code=500)
|
||||
|
||||
# --- Run server ---
|
||||
if __name__ == "__main__":
|
||||
import uvicorn
|
||||
uvicorn.run(app, host="0.0.0.0", port=5001)
|
||||
Reference in New Issue
Block a user