AboutWriting

OpenAI-CLIP-Powered Visual Search Engine

Photo of a Heart
  • Action: extract the text/ object/ scene in photos/videos
  • OS: macOS

Demo

Photo of a Heart

Results:

  • The tennis player's feet leave the ground as he hits the ball.
  • A tennis player wearing all white reaches hi racket up to a ball.
  • Tennis player and white outfit swinging his racket on the court.
  • The tennis player is running to hit the ball. a tennis player running to get to the ball
Photo of a Heart

Results:

  • A person looks at a very big computer monitor. an apple building and a white bus and people
  • A number of people around a large table working on laptop computers.
  • A group of people at a table working on small laptops. Two apple computers are on a white desk

Video Search

Main

video_search_app.py


from tempfile import TemporaryDirectory
from typing import Tuple, List, Any

import cv2
import streamlit as st
import torch
import youtube_dl

from PIL import Image

from image_embeddings import compute_image_embeddings
from text_to_image_search import text_to_image_search
from image_to_image_search import image_to_image_search

st.set_page_config(
    page_title="Search Videos for Things",
    page_icon="🎥",
    layout="wide",
    initial_sidebar_state="collapsed",
)


@st.cache
def download_video_cached(
    video_url: str, frame_frequency: int = 1
) -> List[Image.Image]:
    """
    Download video from Youtube and return frames from it sampled every frame_frequency
    seconds

    Args:
        video_url: URL link to a Youtube video
        frame_frequency: How often to sample fromes from the video, as a rate in seconds

    Returns:
        List[Image.Image]: List of displayable images from the video
    """
    with TemporaryDirectory() as download_directory:
        # Download the video as 'video.mp4' in the temporary directory
        # See: https://stackoverflow.com/a/63002071 for how to get more info
        youtube_dl_options = {
            "outtmpl": f"{download_directory}/video.mp4",
            "extractaudio": False,
        }

        # Perform the download
        with youtube_dl.YoutubeDL(youtube_dl_options) as ydl:
            ydl.download([video_url])

        # Iterate through video, extracting a frame every second
        path_to_video = f"{download_directory}/video.mp4"
        video = cv2.VideoCapture(path_to_video)

        # Get frames per second to load a frame every second
        fps = video.get(cv2.CAP_PROP_FPS)
        multiplier = int(fps * frame_frequency)

        images = []

        while True:
            success, frame = video.read()

            if not success:
                break

            frame_number = int(video.get(cv2.CAP_PROP_POS_FRAMES))

            if frame_number % multiplier == 0:
                images.append(Image.fromarray(cv2.cvtColor(frame, cv2.COLOR_BGR2RGB)))

        video.release()

    return images


@st.cache(hash_funcs={torch.Tensor: id})
def compute_image_embeddings_cached(
    images: List[Image.Image]
) -> torch.Tensor:
    """ Cached version of compute_image_embeddings() """
    return compute_image_embeddings(images=images)


@st.cache(hash_funcs={torch.Tensor: id})
def text_to_image_search_cached(
    search_query: str,
    list_of_images: List[Image.Image],
    image_features: torch.Tensor,
    top_k: int = 1,
) -> Tuple[List[Image.Image], List[int], List[float]]:
    """ Cached version of text_to_image_search() """
    return text_to_image_search(
        search_query=search_query,
        list_of_images=list_of_images,
        image_features=image_features,
        top_k=top_k,
    )


@st.cache(hash_funcs={torch.Tensor: id})
def image_to_image_search_cached(
    search_query: Image.Image,
    list_of_images: List[Image.Image],
    image_features: torch.Tensor,
    top_k: int = 1,
) -> Tuple[List[Image.Image], List[int], List[float]]:
    """ Cached version of text_to_image_search() """
    return image_to_image_search(
        search_query=search_query,
        list_of_images=list_of_images,
        image_features=image_features,
        top_k=top_k,
    )


@st.cache
def load_image_cached(image_file: Any) -> Image.Image:
    """ Cached function to load uploaded image file as PIL.Image.Image """
    img = Image.open(image_file)
    return img


# Initialize queries to None so if statements don't break
image_query = None
text_query = None

# Header and sidebar
st.title("Search for stuff in Youtube videos")
sample_frequency = st.sidebar.selectbox(
    label="Sample every how many seconds?",
    options=[2, 1, 0.5],
    help="Smaller numbers will increase runtime",
)
number_matches_to_show = st.sidebar.slider(
    label="How many matches to show", min_value=1, max_value=10, value=3
)

# Initial inputs - video link from Youtube and whether to search with text or an image
video_to_download = st.text_input(label="Paste a link to a Youtube video")
search_type = st.selectbox(
    label="Choose to search by entering text or uploading an image",
    options=["Text", "Image"],
    help="This is like doing a normal Google search, " "or a Google image search.",
)

if video_to_download:

    # Format the correct link URL depending on the video link type
    if "youtu.be" in video_to_download:
        timestamp_link = f"{video_to_download}?t="
    else:
        timestamp_link = f"{video_to_download}&t="

    # Display the video
    st.video(video_to_download)

    # Process video frames and create CLIP embeddings
    images_from_video = download_video_cached(
        video_url=video_to_download, frame_frequency=sample_frequency
    )
    image_embeddings = compute_image_embeddings_cached(images_from_video)

    # Prompt user for what to search for (text vs. image)
    if search_type == "Text":
        text_query = st.text_input(
            label="Type what you're searching for (e.g. a cat)",
            help="It can be complex too! E.g. " "a cat sitting wearing a hoodie",
        )
    else:
        image_query = st.file_uploader(
            label="Upload an image for your search",
            type=["jpg", "jpeg", "png"],
            help="This searches the video for frames similar to your uploaded image"
        )

    # Do text search
    if text_query:
        matching_images, matching_indices, probabilities = text_to_image_search_cached(
            search_query=text_query,
            list_of_images=images_from_video,
            image_features=image_embeddings,
            top_k=number_matches_to_show,
        )

        for image, index, _ in zip(matching_images, matching_indices, probabilities):
            st.write(
                f"Video time for this match: {timestamp_link}{index * sample_frequency}s"
            )
            st.image(image)

    # Do image search
    elif image_query:
        # Display the image you're searching for
        st.image(image_query)
        st.write("Searching parts of the Youtube video similar to the above uploaded "
                 "image")
        st.header("Results:")

        image_query = load_image_cached(image_file=image_query)
        matching_images, matching_indices, probabilities = image_to_image_search_cached(
            search_query=image_query,
            list_of_images=images_from_video,
            image_features=image_embeddings,
            top_k=number_matches_to_show,
        )

        for image, index, _ in zip(matching_images, matching_indices, probabilities):
            st.write(
                f"Video time for this match: {timestamp_link}{index * sample_frequency}s"
            )
            st.image(image)


Docker

docker-compose.yml

version: "3.8"

services:

  video-search-app:
    build:
      context: .
      dockerfile: Dockerfile
    image: video-search-app:latest
    container_name: video-search-app-container
    ports:
      - 80:80
    command: streamlit run video_search_app.py --server.port 80

Dockerfile


FROM python:3.9.5-slim-buster

RUN apt-get update && apt-get install -y git && \
    rm -rf /var/lib/apt/lists/*

RUN pip install --upgrade pip

RUN pip install ftfy regex tqdm matplotlib
RUN pip install git+https://github.com/openai/CLIP.git

RUN pip install Pillow opencv-python-headless streamlit youtube_dl

COPY tasks/*.py .

Requirement

requirement.text


aiobotocore==1.1.2
aiohttp==3.7.2
aioitertools==0.7.0
altair==4.1.0
appdirs==1.4.4
appnope==0.1.0
argcomplete==1.12.1
argon2-cffi==20.1.0
astor==0.8.1
async-generator==1.10
async-timeout==3.0.1
attrs==20.2.0
backcall==0.2.0
base58==2.1.0
black==20.8b1
bleach==3.2.1
blinker==1.4
botocore==1.17.44
cachetools==4.2.1
certifi==2020.6.20
cffi==1.14.3
chardet==3.0.4
click==7.1.2
clip==1.0
colorama==0.4.3
cycler==0.10.0
decorator==4.4.2
defusedxml==0.6.0
docutils==0.15.2
entrypoints==0.3
fsspec==0.8.4
ftfy==6.0.1
gitdb==4.0.7
GitPython==3.1.14
great-expectations==0.12.6
hjson==3.0.2
idna==2.10
importlib-metadata==2.0.0
ipykernel==5.3.4
ipython==7.18.1
ipython-genutils==0.2.0
ipywidgets==7.5.1
jedi==0.17.2
Jinja2==2.11.2
jmespath==0.10.0
jsonpatch==1.26
jsonpointer==2.0
jsonschema==3.2.0
jupyter-client==6.1.7
jupyter-core==4.6.3
jupyterlab-pygments==0.1.2
kiwisolver==1.3.1
MarkupSafe==1.1.1
matplotlib==3.4.1
mistune==0.8.4
multidict==5.0.0
mypy-extensions==0.4.3
nbclient==0.5.1
nbconvert==6.0.7
nbformat==5.0.8
nest-asyncio==1.4.1
notebook==6.1.4
numpy==1.19.2
opencv-python==4.5.1.48
packaging==20.4
pandas==1.1.3
pandocfilters==1.4.3
parso==0.7.1
pathspec==0.8.0
pexpect==4.8.0
pickleshare==0.7.5
Pillow==8.2.0
prometheus-client==0.8.0
prompt-toolkit==3.0.8
protobuf==3.15.8
ptyprocess==0.6.0
pyarrow==3.0.0
pycparser==2.20
pydeck==0.6.2
Pygments==2.7.1
pyparsing==2.4.7
pyrsistent==0.17.3
python-dateutil==2.8.1
pytz==2020.1
pyzmq==19.0.2
regex==2020.10.23
requests==2.23.0
ruamel.yaml==0.16.12
ruamel.yaml.clib==0.2.2
s3fs==0.5.1
scipy==1.5.3
Send2Trash==1.5.0
six==1.15.0
smmap==4.0.0
streamlit==0.80.0
termcolor==1.1.0
terminado==0.9.1
testpath==0.4.4
toml==0.10.1
toolz==0.11.1
torch==1.7.1
torchvision==0.8.2
tornado==6.0.4
tqdm==4.60.0
traitlets==5.0.5
typed-ast==1.4.1
typing-extensions==3.7.4.3
tzlocal==2.1
urllib3==1.25.11
validators==0.18.2
wcwidth==0.2.5
webencodings==0.5.1
widgetsnbextension==3.5.1
wrapt==1.12.1
yarl==1.6.2
youtube-dl==2021.4.26
zipp==3.3.2



Image Search

main.py


import io
import streamlit as st
import numpy as np
import pandas as pd
import torch
from PIL import Image
import requests

from transformers import CLIPProcessor, CLIPModel
from tqdm.auto import tqdm

device = torch.device("cuda") if torch.cuda.is_available() else torch.device("cpu")
TEXT_EMBED_URL = "https://www.dropbox.com/s/5gvkjyjkolehnn9/text_embeds.npy?dl=1"
CAPTION_URL = "https://www.dropbox.com/s/n6s30qh1ldycko7/url2caption.csv?dl=1"
DEFAULT_IMAGE = "https://content.fortune.com/wp-content/uploads/2014/09/pay26_b1.jpg"


@st.cache(hash_funcs={CLIPModel: lambda _: None, CLIPProcessor: lambda _: None})
def load_model():
    # wget csv file with captions
    captions = pd.read_csv(CAPTION_URL)
    # wget text embeddings of above
    response = requests.get(TEXT_EMBED_URL)
    text_embeddings = torch.FloatTensor(np.load(io.BytesIO(response.content)))
    # huggingface model and processor
    model = CLIPModel.from_pretrained("openai/clip-vit-base-patch32", from_tf=False).eval()
    for p in model.parameters():
        p.requires_grad = False
    processor = CLIPProcessor.from_pretrained("openai/clip-vit-base-patch32")

    return model, processor, captions, text_embeddings


def get_image(url, model, processor):
    image = Image.open(requests.get(url, stream=True).raw)
    image_inputs = processor(images=image, return_tensors="pt",)
    image_inputs = {k: v.to(device) for k, v in image_inputs.items()}

    img_features = model.get_image_features(
        pixel_values=image_inputs["pixel_values"],
        output_attentions=False,
        output_hidden_states=False,
    )

    img_len = torch.sqrt((img_features ** 2).sum(dim=-1, keepdims=True))
    img_features = img_features / img_len
    return image, img_features


def get_best_captions(img_features, text_features, captions):
    similarity = img_features @ text_features.T
    img2text_similarity = similarity.softmax(dim=-1)
    _, idx = img2text_similarity.topk(dim=-1, k=5)

    st.write("## AI sees:")
    for caption in captions.loc[idx.cpu().numpy().ravel(), "caption"].values:
        st.write(caption)


model, processor, captions, text_embeddings = load_model()
st.header("Vision Demo")
url = st.text_input("Paste an image url", value=DEFAULT_IMAGE)
image, img_features = get_image(url, model, processor)
st.image(image)

get_best_captions(img_features, text_embeddings, captions)

st.write("### [Post](https://ale0sx.notion.site/Powerful-Visual-Search-Engine-8275890894ed43aca87ede94982a53f3) ; [Code on Github](https://github.com/ale0sx/sorting-search-demo)")
st.write("### If you find this demo useful please sign up [the search engine I'm building.](https://www.memento.so)")



Requirement

requirement.text

streamlit==0.82.0
requests>=2.22.0
tqdm>=4.42.1
pandas==1.1.0
torch==1.7.1
transformers==4.6.1
numpy==1.19.5
Pillow==8.2.0