RAG with FHIR data Powered by Knowledge Graphs on FHIR data#

While finetuning and RAG are powerful methods to adapt pre-trained AI models like llama3, they don’t take advantage of the underlying connectedness of the data. In healthcare, R&D, and all knowledge oriented fields there is an inhererent connected-ness in the data.

Here we explore the use of Knowledge Graphs as an augmented RAG approach to retrieve data about blood pressure. Now obviously, the model without any reference isn’t able to give a good answer, but with the KG augmented RAG input it returns the right answer.

The data for the patient I used for this notebook came from Synthea which generates aritificial FHIR data for synthetic patients.

Special Thanks To#

This work has been heavily inspired by work done by Sam Schifman . Much of the underlying code to read and parse FHIR data is from him and has not been included here. Neo4J has some excellent talks about the topic of using KGs with RAG: Neo4J Going Meta talks, Session 23: Advanced RAG patterns with Knowledge Graphs.

Setup and import#

# Imports needed

import glob
import json
import os
import re

from pprint import pprint

from langchain.llms import Ollama
from langchain.graphs import Neo4jGraph
from langchain.vectorstores.neo4j_vector import Neo4jVector
from langchain.embeddings.huggingface import HuggingFaceBgeEmbeddings
from langchain.chains import RetrievalQA
from langchain.chat_models import ChatOllama
from langchain import PromptTemplate

# Imports from other local python files
from NEO4J_Graph import Graph
from FHIR_to_graph import resource_to_node, resource_to_edges

Connect to the database#

NEO4J_URI = "bolt://localhost:7687" #os.getenv('FHIR_GRAPH_URL')
USERNAME = "neo4j" #os.getenv('FHIR_GRAPH_USER')
PASSWORD = "fhir_pass" #os.getenv('FHIR_GRAPH_PASSWORD')
DATABASE = "neo4j" #os.getenv('FHIR_GRAPH_DATABASE')

graph = Graph(NEO4J_URI, USERNAME, PASSWORD, DATABASE)

Convert the FHIR data into the graph format#

synthea_bundles = glob.glob("../../FHIR Data/use_data/*.json")
synthea_bundles.sort()

nodes = []
edges = []
dates = set() # set is used here to make sure dates are unique
for bundle_file_name in synthea_bundles:
    with open(bundle_file_name) as raw:
        bundle = json.load(raw)
        for entry in bundle['entry']:
            resource_type = entry['resource']['resourceType']
            if resource_type != 'Provenance':
                # generated the cypher for creating the resource node 
                nodes.append(resource_to_node(entry['resource']))
                # generated the cypher for creating the reference & date edges and capture dates
                node_edges, node_dates = resource_to_edges(entry['resource'])
                edges += node_edges
                dates.update(node_dates)

# create the nodes for resources
for node in nodes:
    graph.query(node)


date_pattern = re.compile(r'([0-9]+)/([0-9]+)/([0-9]+)')

# create the nodes for dates
for date in dates:
    date_parts = date_pattern.findall(date)[0]
    cypher_date = f'{date_parts[2]}-{date_parts[0]}-{date_parts[1]}'
    cypher = 'CREATE (:Date {name:"' + date + '", id: "' + date + '", date: date("' + cypher_date + '")})'
    graph.query(cypher)

# create the edges
for edge in edges:
    try:
        graph.query(edge)
    except:
        print(f'Failed to create edge: {edge}')
# print out some information to show that the graph is populated.
print(graph.resource_metrics())
[['Patient', 1], ['Device', 2], ['CarePlan', 7], ['CareTeam', 7], ['Immunization', 12], ['MedicationRequest', 19], ['SupplyDelivery', 20], ['Condition', 46], ['Procedure', 104], ['DocumentReference', 106], ['Encounter', 106], ['Claim', 125], ['ExplanationOfBenefit', 125], ['DiagnosticReport', 167], ['Observation', 542]]

Create the Vector Embedding Index in the Graph#

This cell creates a Vector Index in Neo4J. It looks at nodes labeled as resource and indexes the string representation in the text property.

Warning: This cell may take sometime to run.

Neo4jVector.from_existing_graph(
    HuggingFaceBgeEmbeddings(model_name="BAAI/bge-small-en-v1.5"),
    url=NEO4J_URI,
    username=USERNAME,
    password=PASSWORD,
    database=DATABASE,
    index_name='fhir_text',
    node_label="resource",
    text_node_properties=['text'],
    embedding_node_property='embedding',
)

Create Vector Index#

This cell creates a new vector index, using the index created above.

This is here because running the cell above can take time and only should be done one time when the DB is created.

vector_index = Neo4jVector.from_existing_index(
    HuggingFaceBgeEmbeddings(model_name="BAAI/bge-small-en-v1.5"),
    url=NEO4J_URI,
    username=USERNAME,
    password=PASSWORD,
    database=DATABASE,
    index_name='fhir_text'
)

Setup Prompt Templates#

in_prompt='''
System: The context below contains entries about the patient's healthcare. 
Please limit your answer to the information provided in the context. Do not make up facts. 
If you don't know the answer, just say that you don't know, don't try to make up an answer.
If you are asked about the patient's name and one the entries is of type patient, you should look for the first given name and family name and answer with: [given] [family]
----------------
{context}
Human: {question}
'''

prompt = PromptTemplate.from_template(in_prompt)

Pick the LLM model to use#

ollama_model = 'llama3'

Ask the question to AI with and without KG-RAG#

question = "What was the blood pressure?"
llm = Ollama(model=ollama_model)
no_rag_answer = llm(question)
print(no_rag_answer)
There is no mention of blood pressure in our previous conversation. We only discussed the topic of "what's on your mind?" and I provided some suggestions to help you clarify your thoughts. If you'd like to discuss something specific, such as blood pressure or any other health-related topics, I'm here to listen and provide general information. However, please note that I am not a medical professional, and it's always best to consult with a healthcare expert for personalized advice.
response = vector_index.similarity_search(question)#, k=2) 
print(response[0].page_content)
print(len(response))
The type of information in this entry is observation. The status for this observation is final. The category of this observation is Vital signs. The code for this observation is Blood pressure panel with all children optional. This observation was effective date time on 02/09/2014 at 11:51:24. This observation was issued on 02/09/2014 at 11:51:24. This observation contains 2 components. The 1st component's code for this observation is Diastolic Blood Pressure. The 1st component's value quantity for this observation is 88 mm[Hg]. The 2nd component's code for this observation is Systolic Blood Pressure. The 2nd component's value quantity for this observation is 133 mm[Hg].
4
vector_qa = RetrievalQA.from_chain_type(
    llm=ChatOllama(model=ollama_model), chain_type="stuff", retriever=vector_index.as_retriever(search_kwargs={'k': 2}), 
    verbose=True, chain_type_kwargs={"verbose": True, "prompt": prompt}
)

pprint(vector_qa.run(question))
> Entering new RetrievalQA chain...


> Entering new StuffDocumentsChain chain...


> Entering new LLMChain chain...
Prompt after formatting:

System: The context below contains entries about the patient's healthcare. 
Please limit your answer to the information provided in the context. Do not make up facts. 
If you don't know the answer, just say that you don't know, don't try to make up an answer.
If you are asked about the patient's name and one the entries is of type patient, you should look for the first given name and family name and answer with: [given] [family]
----------------
The type of information in this entry is observation. The status for this observation is final. The category of this observation is Vital signs. The code for this observation is Blood pressure panel with all children optional. This observation was effective date time on 02/09/2014 at 11:51:24. This observation was issued on 02/09/2014 at 11:51:24. This observation contains 2 components. The 1st component's code for this observation is Diastolic Blood Pressure. The 1st component's value quantity for this observation is 88 mm[Hg]. The 2nd component's code for this observation is Systolic Blood Pressure. The 2nd component's value quantity for this observation is 133 mm[Hg].

The type of information in this entry is observation. The status for this observation is final. The category of this observation is Vital signs. The code for this observation is Blood pressure panel with all children optional. This observation was effective date time on 08/30/2023 at 11:51:24. This observation was issued on 08/30/2023 at 11:51:24. This observation contains 2 components. The 1st component's code for this observation is Diastolic Blood Pressure. The 1st component's value quantity for this observation is 99 mm[Hg]. The 2nd component's code for this observation is Systolic Blood Pressure. The 2nd component's value quantity for this observation is 140 mm[Hg].
Human: What was the blood pressure?


> Finished chain.

> Finished chain.

> Finished chain.
("According to the observations, the patient's blood pressure components are:\n"
 '\n'
 '* Diastolic Blood Pressure: 88 mm[Hg] (02/09/2014) and 99 mm[Hg] '
 '(08/30/2023)\n'
 '* Systolic Blood Pressure: 133 mm[Hg] (02/09/2014) and 140 mm[Hg] '
 '(08/30/2023)\n'
 '\n'
 'So, the blood pressure is:\n'
 '\n'
 '* Diastolic: 88-99 mm[Hg]\n'
 '* Systolic: 133-140 mm[Hg]')