Retrieval Augmented Classification: Improving Text Classification with External Knowledge

When and How to best use LLMs as text classifiers The post Retrieval Augmented Classification: Improving Text Classification with External Knowledge appeared first on Towards Data Science.

May 7, 2025 - 04:40
 0
Retrieval Augmented Classification: Improving Text Classification with External Knowledge

Text Classification stands as one of the most basic yet most important applications of natural language processing. It has a vital role in many real-world applications that go from filtering unwanted emails like spam, detecting product categories or classifying user intent in a chat-bot application. The default way of building text classifiers is to gather large amounts of labeled data, meaning input texts and their corresponding labels, and then training a custom Machine Learning model. Things changed a bit as LLMs became more powerful, where you can often get decent performance by using general purpose large language models as zero-shot or few-shot classifiers, significantly reducing the time-to-deployment of text classification services. However, the accuracy can lag behind custom built models and is highly dependent on crafting custom prompts to better define the classification task to the LLM. In this blog, we aim at minimizing the gap between custom ML models for classification and general purpose LLMs while also minimizing the effort needed in adapting the LLM prompt to your task.

LLMs vs Custom ML models for text classification

Pros:

Let’s first explore the pro and cons of each of the two approaches to do text classification.

Large language models as general purpose classifiers:

  1. High generalization ability given the vast pre-training corpus and reasoning abilities of the LLM.
  2. A single general purpose LLM can handle multiple classifications tasks without the need to deploy a model for each.
  3. As Llms continue to improve, you can potentially enhance accuracy with minimal effort simply by adopting newer, more powerful models as they become available.
  4. The availability of most LLMs as managed services significantly reduces the deployment knowledge and effort required to get started.
  5. LLMs often outperform custom ML models in low-data scenarios where labeled data is limited or costly to obtain.
  6. LLMs generalize to multiple languages.
  7. LLMs can be cheaper when having low or unpredictable volumes of predictions if you pay per token.
  8. Class definitions can be changed dynamically without retraining by simply modifying the prompts.

Cons:

  1. LLMs are prone to hallucinations.
  2. LLMs can be slow, or at least slower than small custom ML models.
  3. They require prompt engineering effort.
  4. High-throughput applications using LLMs-as-a-service may quickly encounter quota limitations.
  5. This approach becomes less effective with a very large number of potential classes due to context size constraints. Defining all the classes would consume a significant portion of the available and effective input context.
  6. LLMs usually have worse accuracy than custom models in the high data regime.

Custom Machine Learning models:

Pros:

  1. Efficient and fast.
  2. More flexible in architecture choice, training and serving method.
  3. Ability to add interpretability and uncertainty estimation aspects to the model.
  4. Higher accuracy in the high data regime.
  5. You keep control of your model and serving infrastructure.

Cons:

  1. Requires frequent re-trainings to adapt to new data or distribution changes.
  2. May need significant amounts of labeled data.
  3. Limited generalization.
  4. Sensitive to out-of-domain vocabulary or formulations.
  5. Requires MLOps knowledge for deployment.

Bridging the gap between custom text classifier and LLMs:

Let’s work on a way to keep the pros of using LLMs for classification while alleviating some of the cons. We will take inspiration from RAG and use a prompting technique called few-shot prompting.

Let’s define both:

RAG

Retrieval Augmented Generation is a popular method that augments the LLM context with external knowledge before asking a question. This reduces the likelihood of hallucination and improves the quality of the responses.

Few-shot prompting

In each classification task, we show the LLM examples of inputs and expected outputs as part of the prompt to help it understand the task.

Now, the main idea of this project is mixing both. We dynamically fetch examples that are the most similar to the text query to be classified and inject them as few-shot example prompts. We also limit the scope of possible classes dynamically using those of the K-nearest neighbors. This frees up a significant amount of tokens in the input context when working with a classification problem with a large number of possible classes.

Here is how that would work:

Let’s go through the practical steps of getting this approach to run:

  • Building a knowledge base of labeled input text / category pairs. This will be our source of external knowledge for the LLM. We will be using ChromaDB.
from typing import List
from uuid import uuid4

from langchain_core.documents import Document
from chromadb import PersistentClient
from langchain_chroma import Chroma
from langchain_community.embeddings import HuggingFaceBgeEmbeddings
import torch
from tqdm import tqdm
from chromadb.config import Settings
from retrieval_augmented_classification.logger import logger


class DatasetVectorStore:
    """ChromaDB vector store for PublicationModel objects with SentenceTransformers embeddings."""

    def __init__(
        self,
        db_name: str = "retrieval_augmented_classification",  # Using db_name as collection name in Chroma
        collection_name: str = "classification_dataset",
        persist_directory: str = "chroma_db",  # Directory to persist ChromaDB
    ):
        self.db_name = db_name
        self.collection_name = collection_name
        self.persist_directory = persist_directory

        # Determine if CUDA is available
        device = "cuda" if torch.cuda.is_available() else "cpu"
        logger.info(f"Using device: {device}")

        self.embeddings = HuggingFaceBgeEmbeddings(
            model_name="BAAI/bge-small-en-v1.5",
            model_kwargs={"device": device},
            encode_kwargs={
                "device": device,
                "batch_size": 100,
            },  # Adjust batch_size as needed
        )

        # Initialize Chroma vector store
        self.client = PersistentClient(
            path=self.persist_directory, settings=Settings(anonymized_telemetry=False)
        )
        self.vector_store = Chroma(
            client=self.client,
            collection_name=self.collection_name,
            embedding_function=self.embeddings,
            persist_directory=self.persist_directory,
        )

    def add_documents(self, documents: List) -> None:
        """
        Add multiple documents to the vector store.

        Args:
            documents: List of dictionaries containing document data.  Each dict needs a "text" key.
        """

        local_documents = []
        ids = []

        for doc_data in documents:
            if not doc_data.get("id"):
                doc_data["id"] = str(uuid4())

            local_documents.append(
                Document(
                    page_content=doc_data["text"],
                    metadata={k: v for k, v in doc_data.items() if k != "text"},
                )
            )
            ids.append(doc_data["id"])

        batch_size = 100  # Adjust batch size as needed
        for i in tqdm(range(0, len(documents), batch_size)):
            batch_docs = local_documents[i : i + batch_size]
            batch_ids = ids[i : i + batch_size]

            # Chroma's add_documents doesn't directly support pre-defined IDs. Upsert instead.
            self._upsert_batch(batch_docs, batch_ids)

    def _upsert_batch(self, batch_docs: List[Document], batch_ids: List[str]):
        """Upsert a batch of documents into Chroma.  If the ID exists, it updates; otherwise, it creates."""
        texts = [doc.page_content for doc in batch_docs]
        metadatas = [doc.metadata for doc in batch_docs]

        self.vector_store.add_texts(texts=texts, metadatas=metadatas, ids=batch_ids)

This class handles creating a collection and embedding each document’s before inserting it into the vector index. We use BAAI/bge-small-en-v1.5 but any embedding model would work, even those available as-a-service from Gemini, OpenAI, or Nebius.

  • Finding the K nearest neighbors for an input text
def search(self, query: str, k: int = 5) -> List[Document]:
    """Search documents by semantic similarity."""
    results = self.vector_store.similarity_search(query, k=k)
    return results

This method returns the documents in the vector database that are most similar to our input.

  • Building the Retrieval Augmented Classifier
from typing import Optional
from pydantic import BaseModel, Field
from collections import Counter

from retrieval_augmented_classification.vector_store import DatasetVectorStore
from tenacity import retry, stop_after_attempt, wait_exponential
from langchain_core.messages import AIMessage, HumanMessage, SystemMessage


class PredictedCategories(BaseModel):
    """
    Pydantic model for the predicted categories from the LLM.
    """

    reasoning: str = Field(description="Explain your reasoning")
    predicted_category: str = Field(description="Category")


class RAC:
    """
    A hybrid classifier combining K-Nearest Neighbors retrieval with an LLM for multi-class prediction.
    Finds top K neighbors, uses top few-shot for context, and uses all neighbor categories
    as potential prediction candidates for the LLM.
    """

    def __init__(
        self,
        vector_store: DatasetVectorStore,
        llm_client,
        knn_k_search: int = 30,
        knn_k_few_shot: int = 5,
    ):
        """
        Initializes the classifier.

        Args:
            vector_store: An instance of DatasetVectorStore with a search method.
            llm_client: An instance of the LLM client capable of structured output.
            knn_k_search: The number of nearest neighbors to retrieve from the vector store.
            knn_k_few_shot: The number of top neighbors to use as few-shot examples for the LLM.
                           Must be less than or equal to knn_k_search.
        """

        self.vector_store = vector_store
        self.llm_client = llm_client
        self.knn_k_search = knn_k_search
        self.knn_k_few_shot = knn_k_few_shot

    @retry(
        stop=stop_after_attempt(3),  # Retry LLM call a few times
        wait=wait_exponential(multiplier=1, min=2, max=5),  # Shorter waits for demo
    )
    def predict(self, document_text: str) -> Optional[str]:
        """
        Predicts the relevant categories for a given document text using KNN retrieval and an LLM.

        Args:
            document_text: The text content of the document to classify.

        Returns:
            The predicted category
        """
        neighbors = self.vector_store.search(document_text, k=self.knn_k_search)

        all_neighbor_categories = set()
        valid_neighbors = []  # Store neighbors that have metadata and categories
        for neighbor in neighbors:
            if (
                hasattr(neighbor, "metadata")
                and isinstance(neighbor.metadata, dict)
                and "category" in neighbor.metadata
            ):
                all_neighbor_categories.add(neighbor.metadata["category"])
                valid_neighbors.append(neighbor)
            else:
                pass  # Suppress warnings for cleaner demo output

        if not valid_neighbors:
            return None

        category_counts = Counter(all_neighbor_categories)
        ranked_categories = [
            category for category, count in category_counts.most_common()
        ]

        if not ranked_categories:
            return None

        few_shot_neighbors = valid_neighbors[: self.knn_k_few_shot]

        messages = []

        system_prompt = f"""You are an expert multi-class classifier. Your task is to analyze the provided document text and assign the most relevant category from the list of allowed categories.
You MUST only return categories that are present in the following list: {ranked_categories}.
If none of the allowed categories are relevant, return an empty list.
Return the categories by likelihood (more confident to least confident).
Output your prediction as a JSON object matching the Pydantic schema: {PredictedCategories.model_json_schema()}.
"""
        messages.append(SystemMessage(content=system_prompt))

        for i, neighbor in enumerate(few_shot_neighbors):
            messages.append(
                HumanMessage(content=f"Document: {neighbor.page_content}")
            )
            expected_output_json = PredictedCategories(
                reasoning="Your reasoning here",
                predicted_category=neighbor.metadata["category"]
            ).model_dump_json()
            # Simulate the structure often used with tool calling/structured output

            ai_message_with_tool = AIMessage(
                content=expected_output_json,
            )

            messages.append(ai_message_with_tool)

        # Final user message: The document text to classify
        messages.append(HumanMessage(content=f"Document: {document_text}"))

        # Configure the client for structured output with the Pydantic schema
        structured_client = self.llm_client.with_structured_output(PredictedCategories)
        llm_response: PredictedCategories = structured_client.invoke(messages)

        predicted_category = llm_response.predicted_category

        return predicted_category if predicted_category in ranked_categories else None

The first part of the code defines the structure of the output we expect from the LLM. The Pydantic class has two fields, the reasoning, used for chain-of-though prompting (https://www.promptingguide.ai/techniques/cot) and the predicted category.

The predict method first finds the K nearest neighbors and uses them as few-shot prompts by creating a synthetic message history as if the LLM gave the correct categories for each of the KNN, then we inject the query text as the last human message.

We filter the value to check if it is valid and if so, return it.

  • Example prediction:
_rac = RAC(
    vector_store=store,
    llm_client=llm_client,
    knn_k_search=50,
    knn_k_few_shot=10,
)
print(
    f"Initialized rac with knn_k_search={_rac.knn_k_search}, knn_k_few_shot={_rac.knn_k_few_shot}."
)

text = """Ivanoe Bonomi [iˈvaːnoe boˈnɔːmi] (18 October 1873 – 20 April 1951) was an Italian politician and statesman before and after World War II. Bonomi was born in Mantua. He was elected to the Italian Chamber of Deputies in ...
"""
category = _rac.predict(text)

print(text)
print(category)

text = """Michel Rocard, né le 23 août 1930 à Courbevoie et mort le 2 juillet 2016 à Paris, est un haut fonctionnaire et ... 
"""
category = _rac.predict(text)

print(text)
print(category)

Both inputs return the prediction “PrimeMinister” even though the second example is in french while the training dataset is fully in English. This illustrates the generalization abilities of this approach even across similar languages.

  • Evaluation:

We use the DBPedia Classes dataset’s l3 categories (https://www.kaggle.com/datasets/danofer/dbpedia-classes ,License CC BY-SA 3.0.) for our evaluation. This dataset has more than 200 categories and 240000 training samples.

We benchmark the Retrieval Augmented Classification approach against a simple KNN classifier with majority vote and obtain the following results the DBpedia dataset’s l3 categories:

AccuracyAverage LatencyThroughput (multi-threaded)
KNN classifier87%24ms108 predictions / s
LLM only classifier88%~600ms47 predictions / s
RAC96%~1s27 predictions / s

By reference, the best accuracy I found on Kaggle notebooks for this dataset’s l3 level was around 94% using custom ML models.

We note that combining a KNN search with the reasoning abilities of an LLM allows us to gain +9% accuracy points but comes at a cost of a lower throughput and higher latency.

Conclusion

In this project we built a text classifier that leverages “retrieval” to boost the ability of an LLM to find the correct category of the input content. This approach offers several advantages over traditional ML text classifiers. These include the ability to dynamically change the training dataset without retraining, a higher generalization ability due to the reasoning and general knowledge of LLMs, easy deployment when using managed LLM services compared to custom ML models, and the capability to handle multiple classification tasks with a single base LLM model. This comes at a cost of higher latency and lower throughput and a risk of LLM vendor lock-in.

This method should not be your first go-to when working on a classification task but would still be useful as part of your toolbox when your application can benefit from the flexibility of not having to re-train a classifier every time the data changes or when working with a small amount of labeled data. It can also allow you to get a target of having a classification service up and running very quickly when a deadline is looming                         </div>
                                            <div class= Read More