Creating a medical knowledge graph and exploring it with an LLM#

In this exploration I used the Precision Medicine Knowledge Graph (PrimeKG) as the source which can be easily explored with an LLM.

Due to the size of the knowledge graph and scope of the experiment, only the diseases (their detailed properties) and their connections were loaded.

I used the LangChain Cypher QA feature to create a template prompt for few shot learning. This template was prefixed to the actual question asked by the user.

Both openAI and llama3 models were tried with satisfactory results which you can see at the end.

Environment Setup#

Import packages#

import pandas as pd
from tqdm.notebook import tqdm
import os
import textwrap

# Neo4J
from neo4j import GraphDatabase

# Langchain
from langchain_community.graphs import Neo4jGraph
from langchain.chains import GraphCypherQAChain
from langchain_community.llms import Ollama
from langchain_openai import ChatOpenAI
from langchain.prompts.prompt import PromptTemplate

# Warning control
import warnings
warnings.filterwarnings("ignore")

Neo4j Setup#

NEO4J_URI = "bolt://localhost:7687"
NEO4J_USERNAME = "neo4j"
NEO4J_PASSWORD = "xyz"
NEO4J_DATABASE = 'neo4j'

LLM setup#

# llm = Ollama(model="llama3")

os.environ["OPENAI_API_KEY"] ="sk-xyz"
llm = ChatOpenAI(temperature=0)

Load Data#

Create base KG using only diseases and their connections#

#
# This code took a few hours to run

# Configuration for Neo4j Database
uri = "bolt://localhost:7687"
user = "neo4j"
password = "dermpathkg"
driver = GraphDatabase.driver(uri, auth=(user, password))
file_path = 'prime_kg/dataverse_files/kg.csv'
df = pd.read_csv(file_path, low_memory=False)
df_dis = df[(df['x_type'] == 'disease') | (df['x_type'] == 'Disease')]


# Function to create nodes and dynamic relationships using APOC
def create_nodes_and_relationships(session, df):
    for _, row in tqdm(df.iterrows(), total=len(df), desc="Processing rows"):
        session.run("""
        CALL apoc.merge.node([$x_type], {index: $x_index, id: $x_id, name: $x_name, source: $x_source})
        YIELD node as x
        RETURN x
        """, x_type=row['x_type'], x_index=row['x_index'], x_id=row['x_id'], x_name=row['x_name'], x_source=row['x_source'])

        # Dynamically merge the second node using apoc.merge.node
        session.run("""
        CALL apoc.merge.node([$y_type], {index: $y_index, id: $y_id, name: $y_name, source: $y_source})
        YIELD node as y
        RETURN y
        """, y_type=row['y_type'], y_index=row['y_index'], y_id=row['y_id'], y_name=row['y_name'], y_source=row['y_source'])

        # Use APOC to create or merge the dynamic relationship
        session.run("""
        MATCH (x:disease {index: $x_index}), (y {index: $y_index})
        CALL apoc.merge.relationship(x, $display_relation, {type: $relation}, {}, y) YIELD rel
        RETURN rel
        """,x_type=row['x_type'], x_index=row['x_index'], y_index=row['y_index'], relation=row['relation'], display_relation=row['display_relation'])

with driver.session() as session:
    create_nodes_and_relationships(session, df_dis)

At this point the graph schema looks like this:

disease {name: STRING, id: STRING, index: INTEGER, source: STRING}
effect/phenotype {name: STRING, id: STRING, index: INTEGER, source: STRING}
drug {name: STRING, source: STRING, id: STRING, index: INTEGER}
gene/protein {source: STRING, id: STRING, index: INTEGER, name: STRING}
exposure {name: STRING, id: STRING, index: INTEGER, source: STRING}
Relationship properties:
phenotype absent {type: STRING}
phenotype present {type: STRING}
parent-child {type: STRING}
contraindication {type: STRING}
indication {type: STRING}
off-label use {type: STRING}
associated with {type: STRING}
linked to {type: STRING}
The relationships:
(:disease)-[:parent-child]->(:disease)
(:disease)-[:phenotype absent]->(:effect/phenotype)
(:disease)-[:phenotype present]->(:effect/phenotype)
(:disease)-[:associated with]->(:gene/protein)
(:disease)-[:linked to]->(:exposure)
(:disease)-[:indication]->(:drug)
(:disease)-[:contraindication]->(:drug)
(:disease)-[:off-label use]->(:drug)```

Load disease features#

# Read csv with disease features
file_path = 'prime_kg/dataverse_files/disease_features.csv'
df_dis_feat = pd.read_csv(file_path, low_memory=False)
df_dis_feat[:10]
node_index mondo_id mondo_name group_id_bert group_name_bert mondo_definition umls_description orphanet_definition orphanet_prevalence orphanet_epidemiology orphanet_clinical_description orphanet_management_and_treatment mayo_symptoms mayo_causes mayo_risk_factors mayo_complications mayo_prevention mayo_see_doc
0 27165 8019 mullerian aplasia and hyperandrogenism NaN NaN Deficiency of the glycoprotein WNT4, associate... Deficiency of the glycoprotein wnt4, associate... A rare syndrome with 46,XX disorder of sex dev... NaN NaN NaN NaN NaN NaN NaN NaN NaN NaN
1 27165 8019 mullerian aplasia and hyperandrogenism NaN NaN Deficiency of the glycoprotein WNT4, associate... Deficiency of the glycoprotein wnt4, associate... A rare syndrome with 46,XX disorder of sex dev... NaN NaN NaN NaN NaN NaN NaN NaN NaN NaN
2 27166 11043 myelodysplasia, immunodeficiency, facial dysmo... NaN NaN NaN NaN NaN NaN NaN NaN NaN NaN NaN NaN NaN NaN NaN
3 27168 8878 bone dysplasia, lethal Holmgren type NaN NaN Bone dysplasia lethal Holmgren type (BDLH) is ... A lethal bone dysplasia with characteristics o... Bone dysplasia lethal Holmgren type (BDLH) is ... <1/1000000 NaN NaN NaN NaN NaN NaN NaN NaN NaN
4 27169 8905 predisposition to invasive fungal disease due ... NaN NaN NaN NaN A rare, genetic primary immunodeficiency chara... NaN NaN NaN NaN NaN NaN NaN NaN NaN NaN
5 27171 7162 asymmetric short stature syndrome NaN NaN NaN NaN NaN NaN NaN NaN NaN NaN NaN NaN NaN NaN NaN
6 27172 9699 action myoclonus-renal failure syndrome NaN NaN Action myoclonus-renal failure syndrome (AMRF)... Syndrome with characteristics of episodes of m... A rare epilepsy syndrome characterized by prog... <1/1000000 NaN NaN NaN People with myoclonus often describe their sig... Myoclonus may be caused by a variety of underl... NaN NaN NaN When to see a doctor, If your myoclonus sympto...
7 27175 14897 portal hypertension, noncirrhotic NaN NaN NaN NaN NaN NaN NaN NaN NaN NaN NaN NaN NaN NaN NaN
8 27176 12035 craniosynostosis-intracranial calcifications s... NaN NaN Craniosynostosis-intracranial calcification is... A form of syndromic craniosynostosis with char... Craniosynostosis-intracranial calcifications s... <1/1000000 NaN NaN NaN The signs of craniosynostosis are usually noti... Often the cause of craniosynostosis is not kno... If untreated, craniosynostosis may cause, for ... NaN NaN When to see a doctor, Your doctor will routine...
9 27176 12035 craniosynostosis-intracranial calcifications s... NaN NaN Craniosynostosis-intracranial calcification is... A form of syndromic craniosynostosis with char... Craniosynostosis-intracranial calcifications s... <1/1000000 NaN NaN NaN The signs of craniosynostosis are usually noti... Often the cause of craniosynostosis is not kno... If untreated, craniosynostosis may cause, for ... NaN NaN When to see a doctor, Your doctor will routine...
# This code took about 8 min to run
import pandas as pd
from neo4j import GraphDatabase
from contextlib import contextmanager
from tqdm import tqdm

@contextmanager
def neo4j_session(uri, user, password):
    driver = GraphDatabase.driver(uri, auth=(user, password))
    session = driver.session()
    try:
        yield session
    finally:
        session.close()
        driver.close()

def update_disease_nodes(session, df):
    # Use tqdm to add a progress bar for the DataFrame iteration
    for index, row in tqdm(df.iterrows(), total=len(df), desc="Updating Nodes"):
        # Prepare a dictionary to hold properties, setting NaN to an empty string
        properties = {col: (row[col] if pd.notna(row[col]) else '') for col in df.columns if col != 'node_index'}
        
        # Generate Cypher query to update the node with properties, setting missing data as empty string
        cypher_query = """
        MATCH (d:disease {index: $node_index})
        SET d += $properties
        """
        session.run(cypher_query, node_index=row['node_index'], properties=properties)

# Example usage
uri = 'bolt://localhost:7687'  # Neo4j instance URI
user = 'neo4j'                 # Username
password = 'xyz'               # Password

with neo4j_session(uri, user, password) as session:
    update_disease_nodes(session, df_dis_feat)
Updating Nodes: 100%|β–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆ| 44133/44133 [08:07<00:00, 90.50it/s] 
# Check the graph schema
kg = Neo4jGraph(
    url=NEO4J_URI, username=NEO4J_USERNAME, password=NEO4J_PASSWORD, database=NEO4J_DATABASE
)
kg.refresh_schema()
print(kg.schema)
Node properties:
disease {orphanet_management_and_treatment: STRING, mayo_risk_factors: STRING, mondo_definition: STRING, mondo_id: INTEGER, mayo_prevention: STRING, orphanet_epidemiology: STRING, mayo_symptoms: STRING, umls_description: STRING, orphanet_clinical_description: STRING, mayo_complications: STRING, orphanet_prevalence: STRING, mondo_name: STRING, group_name_bert: STRING, group_id_bert: STRING, mayo_causes: STRING, mayo_see_doc: STRING, orphanet_definition: STRING, name: STRING, id: STRING, index: INTEGER, source: STRING}
effect/phenotype {name: STRING, id: STRING, index: INTEGER, source: STRING}
drug {name: STRING, source: STRING, id: STRING, index: INTEGER}
gene/protein {source: STRING, id: STRING, index: INTEGER, name: STRING}
exposure {name: STRING, id: STRING, index: INTEGER, source: STRING}
Relationship properties:
phenotype absent {type: STRING}
phenotype present {type: STRING}
parent-child {type: STRING}
contraindication {type: STRING}
indication {type: STRING}
off-label use {type: STRING}
associated with {type: STRING}
linked to {type: STRING}
The relationships:
(:disease)-[:phenotype absent]->(:effect/phenotype)
(:disease)-[:phenotype present]->(:effect/phenotype)
(:disease)-[:parent-child]->(:disease)
(:disease)-[:associated with]->(:gene/protein)
(:disease)-[:contraindication]->(:drug)
(:disease)-[:off-label use]->(:drug)
(:disease)-[:indication]->(:drug)
(:disease)-[:linked to]->(:exposure)

Extra function to remove duplicates#

# Defined but not used
def remove_duplicates(session):
    # Example function to remove duplicate relationships, adjust according to your schema and needs
    session.run("""
    MATCH (n)-[r]->(m)
    WITH n, m, TAIL(COLLECT(r)) as rr
    FOREACH (r IN rr | DELETE r)
    """)

with driver.session() as session:
    remove_duplicates(session)

Cypher generation and LangChain integration#

CYPHER_GENERATION_TEMPLATE = """You are an expert in neo4j knowledge graps
Task:Generate Cypher statement to query a graph database.
Instructions:
Use only the provided relationship types and properties in the schema.
Do not use any other relationship types or properties that are not provided.
Schema:
{schema}
Note: Do not include any explanations or apologies in your responses.
Do not respond to any questions that might ask anything else than for you to construct a Cypher statement.
Do not include any text except the generated Cypher statement.
Ask followup questions to help clarify if unsure based on query result.
Use backticks i.e. `xyz` to include parameter names from the schema that have a space such as  `phenotype present` or `effect/phenotype`. 
Examples: Here are a few examples of generated Cypher statements for particular questions:


# What are the known effects of Porokeratosis?
  MATCH (d:`disease`)-[:`phenotype present`]->(e:`effect/phenotype`)
    WHERE toLower(d.name) CONTAINS 'porokeratosis'
  RETURN e.name


# What drugs are indicated for treating Porokeratosis?
  MATCH (d:`disease`)-[:`indication`]->(drug:`drug`)
    WHERE toLower(d.name) CONTAINS 'porokeratosis'
  RETURN drug.name


# What drugs are contraindicated for diseases with the phenotype "photosensitivity"?
  MATCH (p:`effect/phenotype`)<-[:`phenotype present`]-(d:`disease`)-[:`contraindication`]->(drug:`drug`)
    WHERE toLower(p.name) CONTAINS 'photosensitivity'
  RETURN d.name AS Disease, drug.name AS ContraindicatedDrug

# Which genes are associated with Porokeratosis and what drugs are used for its treatment?
  MATCH (d:`disease`)-[:`associated with`]->(g:`gene/protein`),
        (d)-[:`indication`]->(drug:`drug`)
    WHERE toLower(d.name) CONTAINS 'porokeratosis'
  RETURN g.name AS Gene, drug.name AS Drug

# What is the lineage of Porokeratosis and what drugs are indicated for each related disease?
  MATCH path = (d:`disease`)-[:`parent-child`*]->(descendant:`disease`)
    WHERE toLower(d.name) CONTAINS 'porokeratosis'
  MATCH (descendant)-[:`indication`]->(drug:`drug`)
  RETURN [node in nodes(path) | node.name] AS Lineage, descendant.name AS Descendant, collect(drug.name) AS Drugs

# What are common effects of diseases linked to radiation exposure?
MATCH (e:`exposure`)<-[:`linked to`]-(d:`disease`)-[:`phenotype present`]->(p:`effect/phenotype`)
  WHERE toLower(e.name) CONTAINS 'radiation'
RETURN d.name AS Disease, collect(p.name) AS Effects

The question is:
{question}"""
CYPHER_GENERATION_PROMPT = PromptTemplate(
    input_variables=["schema", "question"], 
    template=CYPHER_GENERATION_TEMPLATE
)

cypherChain = GraphCypherQAChain.from_llm(
    llm=llm,
    graph=kg,
    verbose=False,
    cypher_prompt=CYPHER_GENERATION_PROMPT,
)

def prettyCypherChain(question: str) -> str:
    response = cypherChain.run(question)
    print(textwrap.fill(response, 60))

Ask questions about the graph#

print("\nporokeratosis")
prettyCypherChain("what are the phenotypes of porokeratosis")
print("\nVitiligo")
prettyCypherChain("what are the phenotypes of Vitiligo")
print("\nRefsum Disease")
prettyCypherChain("what are the phenotypes of Refsum Disease")
print("\nMantle Cell Lymphoma")
prettyCypherChain("tell me about Mantle Cell Lymphoma")
print("\nGeneral Query")
prettyCypherChain("how many diseases in the graph?")
porokeratosis
Papule, Nail dystrophy, Parakeratosis, Plagiocephaly, Thick
vermilion border, Intellectual disability, Hearing
impairment, Micrognathia, Wide mouth.

Vitiligo
Severe short stature, EEG abnormality, Hypopigmented skin
patches, Autosomal recessive inheritance, Short stature,
Hearing impairment, Skeletal muscle atrophy, Vitiligo,
Progressive vitiligo, Intellectual disability.

Refsum Disease
Nyctalopia, Sensory impairment, Ataxia, Limb muscle
weakness, Arrhythmia, Abnormal renal physiology,
Sensorimotor neuropathy, Miosis, Elevated levels of phytanic
acid, Rod-cone dystrophy are the phenotypes of Refsum
Disease.

Mantle Cell Lymphoma
Mantle Cell Lymphoma is a type of B-cell lymphoma that can
present symptoms such as splenomegaly, weight loss,
anorexia, lymphadenopathy, abnormality of bone marrow cell
morphology, abnormality of the gastrointestinal tract,
fatigue, and fever.

General Query
There are 17080 diseases in the graph.
prettyCypherChain("return all the properties about Mantle Cell Lymphoma and explain them")
Mantle Cell Lymphoma is a type of lymphoma that occurs when
a disease-fighting white blood cell called a lymphocyte
develops a genetic mutation. This mutation causes the cell
to multiply rapidly, leading to an overabundance of diseased
lymphocytes in the lymph nodes, spleen, and liver. Common
symptoms include painless swelling of lymph nodes,
persistent fatigue, fever, night sweats, shortness of
breath, unexplained weight loss, and itchy skin. Risk
factors for developing Mantle Cell Lymphoma include age,
gender (males are slightly more affected), impaired immune
system, and certain infections like the Epstein-Barr virus.
It is essential to see a doctor if any persistent signs or
symptoms are concerning.