restructure backend
This commit is contained in:
61
pythonScripts/build_index.py
Normal file
61
pythonScripts/build_index.py
Normal file
@@ -0,0 +1,61 @@
|
||||
import os
|
||||
import numpy as np
|
||||
from PIL import Image
|
||||
import torch
|
||||
import open_clip
|
||||
import faiss
|
||||
|
||||
# --- Configuration ---
|
||||
CARDS_FOLDER = "cards_old"
|
||||
EMBEDDINGS_FILE = "embeddings.npy"
|
||||
IDS_FILE = "ids.npy"
|
||||
FAISS_INDEX_FILE = "card_index.faiss"
|
||||
|
||||
# --- Device ---
|
||||
device = "cuda" if torch.cuda.is_available() else "cpu"
|
||||
print("Using device:", device)
|
||||
|
||||
# --- Load CLIP model ---
|
||||
model, _, preprocess = open_clip.create_model_and_transforms(
|
||||
'ViT-L-14', pretrained='laion2b_s32b_b82k'
|
||||
)
|
||||
model = model.to(device).eval()
|
||||
|
||||
# --- Helper: encode image ---
|
||||
def encode_image(path):
|
||||
img = Image.open(path).convert("RGB")
|
||||
with torch.no_grad():
|
||||
emb = model.encode_image(preprocess(img).unsqueeze(0).to(device))
|
||||
return emb.cpu().numpy()
|
||||
|
||||
# --- Build embeddings ---
|
||||
embeddings = []
|
||||
ids = []
|
||||
|
||||
for fname in os.listdir(CARDS_FOLDER):
|
||||
if fname.lower().endswith((".jpg", ".png")):
|
||||
path = os.path.join(CARDS_FOLDER, fname)
|
||||
emb = encode_image(path)
|
||||
embeddings.append(emb)
|
||||
ids.append(fname)
|
||||
print("Encoded:", fname)
|
||||
|
||||
embeddings = np.vstack(embeddings)
|
||||
|
||||
# --- Save embeddings & IDs ---
|
||||
np.save(EMBEDDINGS_FILE, embeddings)
|
||||
np.save(IDS_FILE, np.array(ids))
|
||||
print("Saved embeddings and IDs.")
|
||||
|
||||
# --- Normalize embeddings ---
|
||||
faiss.normalize_L2(embeddings)
|
||||
|
||||
# --- Build FAISS index ---
|
||||
d = embeddings.shape[1] # embedding dimension
|
||||
index = faiss.IndexFlatIP(d) # inner product = cosine similarity
|
||||
index.add(embeddings)
|
||||
print("FAISS index built with", index.ntotal, "cards.")
|
||||
|
||||
# --- Save FAISS index ---
|
||||
faiss.write_index(index, FAISS_INDEX_FILE)
|
||||
print("FAISS index saved:", FAISS_INDEX_FILE)
|
||||
Reference in New Issue
Block a user