OpenAI-CLIP-Powered Visual Search Engine

- Action: extract the text/ object/ scene in photos/videos
- OS: macOS
Demo

Results:
- The tennis player's feet leave the ground as he hits the ball.
- A tennis player wearing all white reaches hi racket up to a ball.
- Tennis player and white outfit swinging his racket on the court.
- The tennis player is running to hit the ball. a tennis player running to get to the ball

Results:
- A person looks at a very big computer monitor. an apple building and a white bus and people
- A number of people around a large table working on laptop computers.
- A group of people at a table working on small laptops. Two apple computers are on a white desk
Video Search
Main
video_search_app.py
from tempfile import TemporaryDirectory
from typing import Tuple, List, Any
import cv2
import streamlit as st
import torch
import youtube_dl
from PIL import Image
from image_embeddings import compute_image_embeddings
from text_to_image_search import text_to_image_search
from image_to_image_search import image_to_image_search
st.set_page_config(
page_title="Search Videos for Things",
page_icon="🎥",
layout="wide",
initial_sidebar_state="collapsed",
)
@st.cache
def download_video_cached(
video_url: str, frame_frequency: int = 1
) -> List[Image.Image]:
"""
Download video from Youtube and return frames from it sampled every frame_frequency
seconds
Args:
video_url: URL link to a Youtube video
frame_frequency: How often to sample fromes from the video, as a rate in seconds
Returns:
List[Image.Image]: List of displayable images from the video
"""
with TemporaryDirectory() as download_directory:
# Download the video as 'video.mp4' in the temporary directory
# See: https://stackoverflow.com/a/63002071 for how to get more info
youtube_dl_options = {
"outtmpl": f"{download_directory}/video.mp4",
"extractaudio": False,
}
# Perform the download
with youtube_dl.YoutubeDL(youtube_dl_options) as ydl:
ydl.download([video_url])
# Iterate through video, extracting a frame every second
path_to_video = f"{download_directory}/video.mp4"
video = cv2.VideoCapture(path_to_video)
# Get frames per second to load a frame every second
fps = video.get(cv2.CAP_PROP_FPS)
multiplier = int(fps * frame_frequency)
images = []
while True:
success, frame = video.read()
if not success:
break
frame_number = int(video.get(cv2.CAP_PROP_POS_FRAMES))
if frame_number % multiplier == 0:
images.append(Image.fromarray(cv2.cvtColor(frame, cv2.COLOR_BGR2RGB)))
video.release()
return images
@st.cache(hash_funcs={torch.Tensor: id})
def compute_image_embeddings_cached(
images: List[Image.Image]
) -> torch.Tensor:
""" Cached version of compute_image_embeddings() """
return compute_image_embeddings(images=images)
@st.cache(hash_funcs={torch.Tensor: id})
def text_to_image_search_cached(
search_query: str,
list_of_images: List[Image.Image],
image_features: torch.Tensor,
top_k: int = 1,
) -> Tuple[List[Image.Image], List[int], List[float]]:
""" Cached version of text_to_image_search() """
return text_to_image_search(
search_query=search_query,
list_of_images=list_of_images,
image_features=image_features,
top_k=top_k,
)
@st.cache(hash_funcs={torch.Tensor: id})
def image_to_image_search_cached(
search_query: Image.Image,
list_of_images: List[Image.Image],
image_features: torch.Tensor,
top_k: int = 1,
) -> Tuple[List[Image.Image], List[int], List[float]]:
""" Cached version of text_to_image_search() """
return image_to_image_search(
search_query=search_query,
list_of_images=list_of_images,
image_features=image_features,
top_k=top_k,
)
@st.cache
def load_image_cached(image_file: Any) -> Image.Image:
""" Cached function to load uploaded image file as PIL.Image.Image """
img = Image.open(image_file)
return img
# Initialize queries to None so if statements don't break
image_query = None
text_query = None
# Header and sidebar
st.title("Search for stuff in Youtube videos")
sample_frequency = st.sidebar.selectbox(
label="Sample every how many seconds?",
options=[2, 1, 0.5],
help="Smaller numbers will increase runtime",
)
number_matches_to_show = st.sidebar.slider(
label="How many matches to show", min_value=1, max_value=10, value=3
)
# Initial inputs - video link from Youtube and whether to search with text or an image
video_to_download = st.text_input(label="Paste a link to a Youtube video")
search_type = st.selectbox(
label="Choose to search by entering text or uploading an image",
options=["Text", "Image"],
help="This is like doing a normal Google search, " "or a Google image search.",
)
if video_to_download:
# Format the correct link URL depending on the video link type
if "youtu.be" in video_to_download:
timestamp_link = f"{video_to_download}?t="
else:
timestamp_link = f"{video_to_download}&t="
# Display the video
st.video(video_to_download)
# Process video frames and create CLIP embeddings
images_from_video = download_video_cached(
video_url=video_to_download, frame_frequency=sample_frequency
)
image_embeddings = compute_image_embeddings_cached(images_from_video)
# Prompt user for what to search for (text vs. image)
if search_type == "Text":
text_query = st.text_input(
label="Type what you're searching for (e.g. a cat)",
help="It can be complex too! E.g. " "a cat sitting wearing a hoodie",
)
else:
image_query = st.file_uploader(
label="Upload an image for your search",
type=["jpg", "jpeg", "png"],
help="This searches the video for frames similar to your uploaded image"
)
# Do text search
if text_query:
matching_images, matching_indices, probabilities = text_to_image_search_cached(
search_query=text_query,
list_of_images=images_from_video,
image_features=image_embeddings,
top_k=number_matches_to_show,
)
for image, index, _ in zip(matching_images, matching_indices, probabilities):
st.write(
f"Video time for this match: {timestamp_link}{index * sample_frequency}s"
)
st.image(image)
# Do image search
elif image_query:
# Display the image you're searching for
st.image(image_query)
st.write("Searching parts of the Youtube video similar to the above uploaded "
"image")
st.header("Results:")
image_query = load_image_cached(image_file=image_query)
matching_images, matching_indices, probabilities = image_to_image_search_cached(
search_query=image_query,
list_of_images=images_from_video,
image_features=image_embeddings,
top_k=number_matches_to_show,
)
for image, index, _ in zip(matching_images, matching_indices, probabilities):
st.write(
f"Video time for this match: {timestamp_link}{index * sample_frequency}s"
)
st.image(image)
Docker
docker-compose.yml
version: "3.8"
services:
video-search-app:
build:
context: .
dockerfile: Dockerfile
image: video-search-app:latest
container_name: video-search-app-container
ports:
- 80:80
command: streamlit run video_search_app.py --server.port 80
Dockerfile
FROM python:3.9.5-slim-buster
RUN apt-get update && apt-get install -y git && \
rm -rf /var/lib/apt/lists/*
RUN pip install --upgrade pip
RUN pip install ftfy regex tqdm matplotlib
RUN pip install git+https://github.com/openai/CLIP.git
RUN pip install Pillow opencv-python-headless streamlit youtube_dl
COPY tasks/*.py .
Requirement
requirement.text
aiobotocore==1.1.2
aiohttp==3.7.2
aioitertools==0.7.0
altair==4.1.0
appdirs==1.4.4
appnope==0.1.0
argcomplete==1.12.1
argon2-cffi==20.1.0
astor==0.8.1
async-generator==1.10
async-timeout==3.0.1
attrs==20.2.0
backcall==0.2.0
base58==2.1.0
black==20.8b1
bleach==3.2.1
blinker==1.4
botocore==1.17.44
cachetools==4.2.1
certifi==2020.6.20
cffi==1.14.3
chardet==3.0.4
click==7.1.2
clip==1.0
colorama==0.4.3
cycler==0.10.0
decorator==4.4.2
defusedxml==0.6.0
docutils==0.15.2
entrypoints==0.3
fsspec==0.8.4
ftfy==6.0.1
gitdb==4.0.7
GitPython==3.1.14
great-expectations==0.12.6
hjson==3.0.2
idna==2.10
importlib-metadata==2.0.0
ipykernel==5.3.4
ipython==7.18.1
ipython-genutils==0.2.0
ipywidgets==7.5.1
jedi==0.17.2
Jinja2==2.11.2
jmespath==0.10.0
jsonpatch==1.26
jsonpointer==2.0
jsonschema==3.2.0
jupyter-client==6.1.7
jupyter-core==4.6.3
jupyterlab-pygments==0.1.2
kiwisolver==1.3.1
MarkupSafe==1.1.1
matplotlib==3.4.1
mistune==0.8.4
multidict==5.0.0
mypy-extensions==0.4.3
nbclient==0.5.1
nbconvert==6.0.7
nbformat==5.0.8
nest-asyncio==1.4.1
notebook==6.1.4
numpy==1.19.2
opencv-python==4.5.1.48
packaging==20.4
pandas==1.1.3
pandocfilters==1.4.3
parso==0.7.1
pathspec==0.8.0
pexpect==4.8.0
pickleshare==0.7.5
Pillow==8.2.0
prometheus-client==0.8.0
prompt-toolkit==3.0.8
protobuf==3.15.8
ptyprocess==0.6.0
pyarrow==3.0.0
pycparser==2.20
pydeck==0.6.2
Pygments==2.7.1
pyparsing==2.4.7
pyrsistent==0.17.3
python-dateutil==2.8.1
pytz==2020.1
pyzmq==19.0.2
regex==2020.10.23
requests==2.23.0
ruamel.yaml==0.16.12
ruamel.yaml.clib==0.2.2
s3fs==0.5.1
scipy==1.5.3
Send2Trash==1.5.0
six==1.15.0
smmap==4.0.0
streamlit==0.80.0
termcolor==1.1.0
terminado==0.9.1
testpath==0.4.4
toml==0.10.1
toolz==0.11.1
torch==1.7.1
torchvision==0.8.2
tornado==6.0.4
tqdm==4.60.0
traitlets==5.0.5
typed-ast==1.4.1
typing-extensions==3.7.4.3
tzlocal==2.1
urllib3==1.25.11
validators==0.18.2
wcwidth==0.2.5
webencodings==0.5.1
widgetsnbextension==3.5.1
wrapt==1.12.1
yarl==1.6.2
youtube-dl==2021.4.26
zipp==3.3.2
Image Search
main.py
import io
import streamlit as st
import numpy as np
import pandas as pd
import torch
from PIL import Image
import requests
from transformers import CLIPProcessor, CLIPModel
from tqdm.auto import tqdm
device = torch.device("cuda") if torch.cuda.is_available() else torch.device("cpu")
TEXT_EMBED_URL = "https://www.dropbox.com/s/5gvkjyjkolehnn9/text_embeds.npy?dl=1"
CAPTION_URL = "https://www.dropbox.com/s/n6s30qh1ldycko7/url2caption.csv?dl=1"
DEFAULT_IMAGE = "https://content.fortune.com/wp-content/uploads/2014/09/pay26_b1.jpg"
@st.cache(hash_funcs={CLIPModel: lambda _: None, CLIPProcessor: lambda _: None})
def load_model():
# wget csv file with captions
captions = pd.read_csv(CAPTION_URL)
# wget text embeddings of above
response = requests.get(TEXT_EMBED_URL)
text_embeddings = torch.FloatTensor(np.load(io.BytesIO(response.content)))
# huggingface model and processor
model = CLIPModel.from_pretrained("openai/clip-vit-base-patch32", from_tf=False).eval()
for p in model.parameters():
p.requires_grad = False
processor = CLIPProcessor.from_pretrained("openai/clip-vit-base-patch32")
return model, processor, captions, text_embeddings
def get_image(url, model, processor):
image = Image.open(requests.get(url, stream=True).raw)
image_inputs = processor(images=image, return_tensors="pt",)
image_inputs = {k: v.to(device) for k, v in image_inputs.items()}
img_features = model.get_image_features(
pixel_values=image_inputs["pixel_values"],
output_attentions=False,
output_hidden_states=False,
)
img_len = torch.sqrt((img_features ** 2).sum(dim=-1, keepdims=True))
img_features = img_features / img_len
return image, img_features
def get_best_captions(img_features, text_features, captions):
similarity = img_features @ text_features.T
img2text_similarity = similarity.softmax(dim=-1)
_, idx = img2text_similarity.topk(dim=-1, k=5)
st.write("## AI sees:")
for caption in captions.loc[idx.cpu().numpy().ravel(), "caption"].values:
st.write(caption)
model, processor, captions, text_embeddings = load_model()
st.header("Vision Demo")
url = st.text_input("Paste an image url", value=DEFAULT_IMAGE)
image, img_features = get_image(url, model, processor)
st.image(image)
get_best_captions(img_features, text_embeddings, captions)
st.write("### [Post](https://ale0sx.notion.site/Powerful-Visual-Search-Engine-8275890894ed43aca87ede94982a53f3) ; [Code on Github](https://github.com/ale0sx/sorting-search-demo)")
st.write("### If you find this demo useful please sign up [the search engine I'm building.](https://www.memento.so)")
Requirement
requirement.text
streamlit==0.82.0
requests>=2.22.0
tqdm>=4.42.1
pandas==1.1.0
torch==1.7.1
transformers==4.6.1
numpy==1.19.5
Pillow==8.2.0