Skip to content

Binary Quantization from Scratch

Setup: Install Dependencies, Imports & Download Embeddings

!pip install matplotlib tqdm pandas numpy --quiet
import numpy as np
import pandas as pd
from tqdm import tqdm

👨🏾‍💻 Code Walkthrough

Here's an explanation of the code structure provided:

  1. Loading Data: OpenAI embeddings are loaded from a parquet files (we can load upto 1M embedding) and concatenated into one array.
  2. Binary Conversion: A new array with the same shape is initialized with zeros, and the positive values in the original vectors are set to 1.
  3. Accuracy Function: The accuracy function compares original vectors with binary vectors for a given index, limit, and oversampling rate. The comparison is done using dot products and logical XOR, sorting the results, and measuring the intersection.
  4. Testing: The accuracy is tested for different oversampling rates (1, 2, 4), revealing a correctness of ~0.96 for an oversampling of 4.

💿 Loading Data

def get_openai_vectors(force_download: bool = False):
    res = []
    for i in tqdm(range(26)):
        if force_download:
            !wget https://huggingface.co/api/datasets/KShivendu/dbpedia-entities-openai-1M/parquet/KShivendu--dbpedia-entities-openai-1M/train/{i}.parquet
        df = pd.read_parquet(f"{i}.parquet", engine="pyarrow")
        res.append(np.stack(df.openai))
        del df

    openai_vectors = np.concatenate(res)
    del res
    return openai_vectors


openai_vectors = get_openai_vectors(force_download=False)
openai_vectors.shape
100%|██████████| 26/26 [00:10<00:00,  2.45it/s]

(1000000, 1536)

㆓ Binary Conversion

Here, we will use 0 as the threshold for the binary conversion. All values greater than 0 will be set to 1, and others will remain 0. This is a simple and effective way to convert continuous values into binary values for OpenAI embeddings.

openai_bin = np.zeros_like(openai_vectors, dtype=np.int8)
openai_bin[openai_vectors &gt; 0] = 1

🎯 Accuracy Function

We will use the accuracy function to compare the original vectors with the binary vectors for a given index, limit, and oversampling rate. The comparison is done using dot products and logical XOR, sorting the results, and measuring the intersection.

def accuracy(idx, limit: int, oversampling: int):
    scores = np.dot(openai_vectors, openai_vectors[idx])
    dot_results = np.argsort(scores)[-limit:][::-1]

    bin_scores = 1536 - np.logical_xor(openai_bin, openai_bin[idx]).sum(axis=1)
    bin_results = np.argsort(bin_scores)[-(limit * oversampling) :][::-1]

    return len(set(dot_results).intersection(set(bin_results))) / limit

📊 Results

number_of_samples = 10
limits = [10, 100]
sampling_rate = [1, 2, 3, 5]
results = []


def mean_accuracy(number_of_samples, limit, sampling_rate):
    return np.mean([accuracy(i, limit=limit, oversampling=sampling_rate) for i in range(number_of_samples)])


for i in tqdm(sampling_rate):
    for j in tqdm(limits):
        result = {"sampling_rate": i, "limit": j, "recall": mean_accuracy(number_of_samples, j, i)}
        print(result)
        results.append(result)
  0%|          | 0/4 [00:00<?, ?it/s]
{'sampling_rate': 1, 'limit': 10, 'recall': 0.8}

100%|██████████| 2/2 [00:33<00:00, 16.98s/it]
 25%|██▌       | 1/4 [00:33<01:41, 33.96s/it]
{'sampling_rate': 1, 'limit': 100, 'recall': 0.708}


{'sampling_rate': 2, 'limit': 10, 'recall': 0.95}

100%|██████████| 2/2 [00:32<00:00, 16.38s/it]
 50%|█████     | 2/4 [01:06<01:06, 33.26s/it]
{'sampling_rate': 2, 'limit': 100, 'recall': 0.877}


{'sampling_rate': 3, 'limit': 10, 'recall': 0.96}

100%|██████████| 2/2 [00:32<00:00, 16.49s/it]
 75%|███████▌  | 3/4 [01:39<00:33, 33.13s/it]
{'sampling_rate': 3, 'limit': 100, 'recall': 0.937}


{'sampling_rate': 5, 'limit': 10, 'recall': 0.9800000000000001}

100%|██████████| 2/2 [00:32<00:00, 16.47s/it]
100%|██████████| 4/4 [02:12<00:00, 33.17s/it]
{'sampling_rate': 5, 'limit': 100, 'recall': 0.977}



results = pd.DataFrame(results)
results
sampling_rate limit recall
0 1 10 0.800
1 1 100 0.708
2 2 10 0.950
3 2 100 0.877
4 3 10 0.960
5 3 100 0.937
6 5 10 0.980
7 5 100 0.977
sampling_rate limit accuracy
1 10 0.800
1 100 0.708
2 10 0.950
2 100 0.877
4 10 0.970
4 100 0.956
8 10 0.990
8 100 0.990
16 10 1.000
16 100 0.998