72 lines
2.1 KiB
Python
72 lines
2.1 KiB
Python
import os
|
|
import numpy as np
|
|
import torch
|
|
import open_clip
|
|
import faiss
|
|
from PIL import Image
|
|
|
|
CARDS_FOLDER = "/images/cards"
|
|
EMBEDDINGS_FILE = "/pythonService/embeddings.npy"
|
|
IDS_FILE = "/pythonService/ids.npy"
|
|
FAISS_FILE = "/pythonService/card_index.faiss"
|
|
|
|
device = "cuda" if torch.cuda.is_available() else "cpu"
|
|
model, _, preprocess = open_clip.create_model_and_transforms(
|
|
'ViT-L-14', pretrained='laion2b_s32b_b82k'
|
|
)
|
|
model = model.to(device).eval()
|
|
|
|
# ---- load existing or initialize ----
|
|
if os.path.exists(FAISS_FILE):
|
|
print("Loading existing FAISS index...")
|
|
index = faiss.read_index(FAISS_FILE)
|
|
embeddings = np.load(EMBEDDINGS_FILE)
|
|
ids = np.load(IDS_FILE)
|
|
else:
|
|
print("Creating new FAISS index...")
|
|
embeddings = np.zeros((0, 1024), dtype='float32') # 1024 for ViT-L-14
|
|
ids = np.array([], dtype='<U100')
|
|
index = faiss.IndexFlatIP(1024)
|
|
|
|
# ---- find images not yet processed ----
|
|
existing_ids = set(ids.tolist())
|
|
new_files = [
|
|
f for f in os.listdir(CARDS_FOLDER)
|
|
if f.lower().endswith((".png", ".jpg")) and f.rsplit(".", 1)[0] not in existing_ids
|
|
]
|
|
|
|
print(f"Found {len(new_files)} new cards to add")
|
|
|
|
new_embeddings = []
|
|
new_ids = []
|
|
|
|
for fname in new_files:
|
|
path = os.path.join(CARDS_FOLDER, fname)
|
|
img = Image.open(path).convert("RGB")
|
|
with torch.no_grad():
|
|
emb = model.encode_image(preprocess(img).unsqueeze(0).to(device))
|
|
|
|
new_embeddings.append(emb.cpu().numpy())
|
|
new_ids.append(fname.rsplit(".", 1)[0])
|
|
print("Encoded:", fname)
|
|
|
|
if len(new_embeddings) > 0:
|
|
new_embeddings = np.vstack(new_embeddings).astype('float32')
|
|
faiss.normalize_L2(new_embeddings)
|
|
|
|
# add to FAISS
|
|
index.add(new_embeddings)
|
|
|
|
# append to numpy arrays
|
|
embeddings = np.vstack([embeddings, new_embeddings])
|
|
ids = np.concatenate([ids, np.array(new_ids)])
|
|
|
|
# save everything
|
|
np.save(EMBEDDINGS_FILE, embeddings)
|
|
np.save(IDS_FILE, ids)
|
|
faiss.write_index(index, FAISS_FILE)
|
|
|
|
print(f"Added {len(new_files)} cards. Total now:", index.ntotal)
|
|
else:
|
|
print("No new cards found — nothing to update.")
|