Textinhalte in spezifische Klassen oder Kategorien einzuordnen ist ein Problem für das es bereits einige funktionierende Ansätze im Machine Learning gibt. Wenn man sich aber nicht mit dem Training von eigenen Klassifikatoren auseinandersetzen möchte, bieten Bibliotheken wie LangChain4j mithilfe des Structured Output Features eine Möglichkeit das „Textverständnis“ eines bereits trainierten LLMs zu verwenden um Textabschnitte in bestimmte Kategorien einzuordnen. Ob das auch zuverlässig funktioniert schauen wir uns anhand eines kleinen Beispielprojekts an.
Textklassifikation mit einem LLM
Wir nutzen wie im letzten Beitrag zu diesem Thema wieder ein Spring Boot Projekt mit LangChain4j Integration und gemma3:4b als Model in Ollama. Dieses mal versuchen wir fiktive Anfragen für ein Hotel in 3 Kategorien (Hotelbuchungen, Hotelstornierungen, Allgemeine Fragen) einzuteilen. Das LLM bekommt diese Aufgabe am Anfang der @UserMessage
mitgeteilt. Der Ausgabetyp unserer Java-Methode ist nun vom Enum-Typ RequestCategory
.
@AiService
public interface RequestClassifier {
@UserMessage("""
Ordne die folgende Nachricht einer Kategorie zu.
---
{{it}}
---
""")
RequestCategory classify(String message);
}
src/main/java/de/gedoplan/showcase/langchain4jdemo/RequestClassifier.java
public enum RequestCategory {
HOTEL_BOOKING,
HOTEL_CANCELLATION,
GENERAL_QUESTION;
}
src/main/java/de/gedoplan/showcase/langchain4jdemo/RequestCategory.java
Beim Testen der Klassifikationsergebnisse fallen hier ein paar Probleme auf. Zum einen ist das Ergebnis wie von LLMs zu erwarten nicht immer dasselbe, sondern kann sich von Ausführung zu Ausführung verändern. Dem könnte man entgegen wirken, indem man z.B. einen festen Seed wählt. Manchmal versucht das LLM auch eine Anfrage einer Kategorie zuzuordnen, die es nicht gibt oder es gibt die Kategorie HOTEL_QUESTION
statt GENERAL_QUESTION
zurück, was in einer Exception beim Parsen resultiert.
Es sind aber auch syntaktisch korrekte Einordnungen dabei, die wir anders eingeordnet erwartet hätten. Da müssten wir dem LLM irgendwie mitgeben, was unsere Vorstellung von Anfragen zu den einzelnen Kategorien ist. Da wir das Modell nicht neu trainieren wollen, können wir nur auf Inferenzebene Beispiele bei unserer Anfrage hinzufügen. Diese Taktik wird auch unter dem Begriff Few-Shot-Prompting verwendet.
@AiService
public interface RequestClassifier {
@UserMessage("""
Ordne die Nachricht einer Kategorie zu.
---
Beispiele für korrekte Zuordnungen sind:
- Ich möchte ein Zimmer für zwei Personen vom 15. bis 20. Juli buchen. -> HOTEL_BOOKING
- Haben Sie noch ein Doppelzimmer für das kommende Wochenende frei? -> HOTEL_BOOKING
- Ich muss meine Buchung für nächste Woche leider stornieren. -> HOTEL_CANCELLATION
- Bieten Sie einen Flughafentransfer an? -> GENERAL_QUESTION
- Wie spät ist das Frühstück? -> GENERAL_QUESTION
[...]
---
{{it}} -> \
""")
RequestCategory classifyWithExamples(String message);
}
src/main/java/de/gedoplan/showcase/langchain4jdemo/RequestClassifier.java
Das kann die Qualität der Antworten eines LLMs durchaus erhöhen. Auch das Aktivieren der Capability langchain4j.ollama.chat-model.supported-capabilities=response_format_json_schema
kann man hier mal ausprobieren.
Je nach Anwendungsfall kann man hiermit passable Ergebnisse erreichen, ein Problem ist allerdings, dass mit jedem weiteren Beispiel was der Anfrage an dieser Stelle hinzugefügt wird die Größe der Anfrage wächst, was die Verarbeitungszeit durch das LLM erhöht.
Textklassifikation mit Embeddings
Eine andere Möglichkeit Textklassifikationen vorzunehmen ohne ein Modell zu trainieren ist auf Basis von Embeddings. Die haben wir im Kontext der Retrieval-Augmented-Generation schonmal angesprochen. Die Embeddings werden in der Regel durch kleinere Embeddingmodelle bestimmt und stellen eine numerische Repräsentation der Semantik, also der Bedeutung, dar. Das Ergebnis ist ein Vektor in einem hochdimensionalen Raum. Optimalerweise stellt ein Vergleich von Embeddings in diesem Raum eine Approximation eines Vergleichs der Bedeutung zweier Textinhalte dar.
Die Idee dieses Ansatzes ist es nun für eine Menge von Beispielanfragen, deren Kategorie wir vorher manuell bestimmt haben, die Embeddings vorab zu bestimmen. Bei einer neuen Anfrage können wir dann auch hier das Embedding bestimmen und dies mit den vorhandenen Beispielembeddings vergleichen.
Dies ist eine schematische Darstellung eines 2D-Embedding-Raums. Die farbigen Punkte entsprechen den vorab bestimmten Beispielembeddings, der graue Punkt entspricht dem Embedding einer neuen Anfrage. Bei einer Klassifikation durch Vergleich würde der graue Punkt der blauen Kategorie zugeordnet werden.
Wir bestimmen unsere Beispielembeddings in diesem Fall der Einfachheit halber bei jedem Start der Anwendung und hinterlegen sie in einem InMemoryEmbeddingStore.
@Service
public class EmbeddingRequestClassifier {
private final EmbeddingModel embeddingModel;
private final EmbeddingStore<RequestCategory> embeddingStore;
public EmbeddingRequestClassifier(EmbeddingModel embeddingModel) {
this.embeddingModel = embeddingModel;
this.embeddingStore = new InMemoryEmbeddingStore<>();
initializeExamples();
}
private void initializeExamples() {
// Examples for HOTEL_BOOKING
addExample("Ich möchte ein Zimmer für zwei Personen vom 15. bis 20. Juli buchen.", RequestCategory.HOTEL_BOOKING);
addExample("Haben Sie noch ein Doppelzimmer für das kommende Wochenende frei?", RequestCategory.HOTEL_BOOKING);
[...]
// Examples for HOTEL_CANCELLATION
addExample("Ich muss meine Buchung für nächste Woche leider stornieren.", RequestCategory.HOTEL_CANCELLATION);
addExample("Kann ich meine Reservierung vom 5. August absagen?", RequestCategory.HOTEL_CANCELLATION);
[...]
// Examples for GENERAL_QUESTION
addExample("Bieten Sie einen Flughafentransfer an?", RequestCategory.GENERAL_QUESTION);
addExample("Wie spät ist das Frühstück?", RequestCategory.GENERAL_QUESTION);
[...]
}
private void addExample(String text, RequestCategory category) {
Embedding embedding = embeddingModel.embed(TextSegment.from(text)).content();
embeddingStore.add(embedding, category);
}
[...]
}
src/main/java/de/gedoplan/showcase/langchain4jdemo/EmbeddingRequestClassifier.java
Die eigentliche Klassifikation führen wir dann durch indem wir uns vom EmbeddingStore
das ähnlichste Beispiel zurückgeben lassen und die dahinter gespeicherte Kategorie zurückgeben.
@Service
public class EmbeddingRequestClassifier {
[...]
public RequestCategory classify(String message) {
Embedding messageEmbedding = embeddingModel.embed(TextSegment.from(message)).content();
// Create an EmbeddingSearchRequest with the message embedding and maxResults=1
EmbeddingSearchRequest searchRequest = EmbeddingSearchRequest.builder()
.queryEmbedding(messageEmbedding)
.maxResults(1)
.build();
// Use the search method to find the most similar embedding
EmbeddingSearchResult<RequestCategory> searchResult = embeddingStore.search(searchRequest);
List<EmbeddingMatch<RequestCategory>> matches = searchResult.matches();
if (matches.isEmpty()) {
return RequestCategory.GENERAL_QUESTION;
}
// The first match is the most similar one
return matches.get(0).embedded();
}
}
src/main/java/de/gedoplan/showcase/langchain4jdemo/EmbeddingRequestClassifier.java
Wenn uns nun auffällt, dass eine Anfrage falsch eingeordnet wurde können wir auch hier die Beispielembeddings im EmbeddingStore
erweitern. Eine größere Menge an Beispielen führt hier zu besseren Klassifikationsergebnissen ohne den Aufwand für jede Anfrage enorm zu steigern wie im vorherigen Ansatz. Da wir zur Laufzeit auch nur das Embedding der Anfrage mithilfe eines vergleichsweise kleinen Modells bestimmen müssen, läuft die Anwendung auch ohne spezifische Hardwareanforderungen performanter.
Ein weiterer Vorteil für die Stabilität ist, dass wir hier nur vorher eingetragene Kategorien als Ergebnis erhalten können und uns kein LLM eine halluzinierte Kategorie zurückgeben kann. Wir vermeiden es also mit Parsing Fehlern umgeben zu müssen. Dadurch erhalten wir aber auch für Anfragen die keiner dieser Kategorien entsprechen eine unserer drei Kategorien zurück. Für diesen Fall können wir zusätzlich eine Grenze für die Ähnlichkeit einführen, ab der zum Beispiel ein null
Wert zurückgegeben wird.
[...]
public RequestCategory classify(String message) {
Embedding messageEmbedding = embeddingModel.embed(TextSegment.from(message)).content();
// Create an EmbeddingSearchRequest with the message embedding and maxResults=1
EmbeddingSearchRequest searchRequest = EmbeddingSearchRequest.builder()
.queryEmbedding(messageEmbedding)
.maxResults(1)
.build();
// Use the search method to find the most similar embedding
EmbeddingSearchResult<RequestCategory> searchResult = embeddingStore.search(searchRequest);
List<EmbeddingMatch<RequestCategory>> matches = searchResult.matches();
if (matches.isEmpty() || matches.get(0).score() < 0.5) {
return null;
}
// The first match is the most similar one
return matches.get(0).embedded();
}
[...]
src/main/java/de/gedoplan/showcase/langchain4jdemo/EmbeddingRequestClassifier.java
Das Beispielprojekt gibt es wie immer auch auf GitHub zum selber ausprobieren.