Source code for noctis.repository.neo4j.neo4j_repository

import pandas as pd
from contextlib import contextmanager
from typing import Optional, Type, Callable, Union

from neo4j import Driver, GraphDatabase, Session
from neo4j.exceptions import Neo4jError

from noctis.data_architecture.graph_schema import GraphSchema
from noctis.data_architecture.datacontainer import DataContainer
from noctis.data_transformation.neo4j.neo4j_formatter import format_result
from noctis.repository.neo4j.neo4j_queries import (
    AbstractQuery,
    Neo4jQueryRegistry,
    CustomQuery,
)

from neo4j import Result, Record
from noctis.utilities import console_logger
from noctis import settings

logger = console_logger(__name__)


[docs] class Neo4jRepository: """ Repository class for interacting with a Neo4j database using defined query strategies. Attributes: schema (GraphSchema): The schema used for graph operations. _driver (Driver): Neo4j driver for database connections. _query_strategies (dict[str, tuple[Callable, Callable]]): Mapping of query types to strategies and transaction functions. """
[docs] def __init__( self, uri: Optional[str] = settings.NEO4J_DEV_URL, username: Optional[str] = settings.NEO4J_DEV_USER, password: Optional[str] = settings.NEO4J_DEV_PASSWORD, database: Optional[str] = None, schema: Optional[GraphSchema] = GraphSchema(), ): """ Initialize the Neo4jRepository with connection details and schema. Args: uri (Optional[str]): URI of the Neo4j database. Defaults to settings.NEO4J_DEV_URL. username (Optional[str]): Username for authentication. Defaults to settings.NEO4J_DEV_USER. password (Optional[str]): Password for authentication. Defaults to settings.NEO4J_DEV_PASSWORD. database (Optional[str]): Database name, if applicable. Defaults to None. schema (Optional[GraphSchema]): Schema for graph operations. Defaults to a new GraphSchema instance. """ self.schema = schema self._driver = GraphDatabase.driver( uri, auth=(username, password), database=database ) self._query_strategies = { "retrieve_graph": (self._retrieve_graph_strategy, self._execute_read), "modify_graph": (self._modify_graph_strategy, self._execute_write), "retrieve_stats": (self._retrieve_stats_strategy, self._execute_read), }
def close(self) -> None: """ Close the connection to the Neo4j database. """ self._driver.close() @contextmanager def _session_context(self) -> Session: """ Context manager for Neo4j session handling. Yields: Session: A Neo4j session for executing queries. Raises: Neo4jError: If a Neo4j-specific error occurs. Exception: For other types of errors. """ session = None try: session = self._driver.session() yield session except Neo4jError as e: print(f"Neo4j error occurred: {e}") raise except Exception as e: print(f"Something went wrong: {e}") raise finally: if session: session.close() @classmethod def info(cls): """Returns info about all available queries, their required and optional arguments, and the query templates""" return Neo4jQueryRegistry.info() def create_constraints(self): """ Create uniqueness constraints in the Neo4j database using predefined queries. """ query = Neo4jQueryRegistry.get_query_object("create_uniqueness_constraints") with self._session_context() as session: for query in query().get_query(): session.run(query) def drop_constraints(self): """ Drop uniqueness constraints in the Neo4j database using predefined queries. """ query = Neo4jQueryRegistry.get_query_object("drop_uniqueness_constraints") with self._session_context() as session: for query in query().get_query(): session.run(query) def show_constraints(self): """ Show the current uniqueness constraints in the Neo4j database. Returns: any: Result of the query showing constraints. """ return self.execute_query("show_uniqueness_constraints") def execute_custom_query_from_yaml( self, yaml_file: str, query_name: str, **kwargs: any ) -> any: """ Execute a custom query defined in a YAML file. Args: yaml_file (str): Path to the YAML file containing query definitions. query_name (str): Name of the query to execute. **kwargs: Additional arguments for the query. Returns: any: Result of the custom query execution. """ query = CustomQuery.from_yaml(yaml_file=yaml_file, query_name=query_name) result = self._build_query_execution_strategy(query, **kwargs) gdb_schema = self.execute_query("get_gdb_schema") self._compare_user_schema_and_gdb_schema(gdb_schema) return result def _compare_user_schema_and_gdb_schema(self, gdb_schema: pd.DataFrame) -> None: """ Compare the user-defined schema with the graph database schema. Args: gdb_schema (pd.DataFrame): Schema from the graph database. Note: Logs warnings if there are discrepancies between the schemas. """ node_labels = self.schema.get_nodes_labels() relationship_types = self.schema.get_relationships_types() # Check if nodes and relationships columns exist in the DataFrame if ( "nodes" not in gdb_schema.columns or "relationships" not in gdb_schema.columns ): logger.warning( "Nodes and/or relationships are not defined in the GDB schema DataFrame." ) return gdb_node_labels = gdb_schema["nodes"].iloc[0] gdb_relationship_types = gdb_schema["relationships"].iloc[0] if set(node_labels) != set(gdb_node_labels) or set(relationship_types) != set( gdb_relationship_types ): logger.warning( f""" The user schema does not match the graph database schema after running custom query. Graph DB Schema: Nodes: {gdb_node_labels} Relationships: {gdb_relationship_types} User Schema: Nodes: {node_labels} Relationships: {relationship_types} """ ) def execute_query(self, query_name: str, **kwargs: any) -> any: """ Execute a predefined query by name. Args: query_name (str): Name of the query to execute. **kwargs: Additional arguments for the query. Returns: any: Result of the query execution. """ query = Neo4jQueryRegistry.get_query_object(query_name) return self._build_query_execution_strategy(query, **kwargs) def _build_query_execution_strategy( self, query: Union[Type[AbstractQuery], CustomQuery], **kwargs: any ) -> any: """ Build and execute a query strategy based on the query type. Args: query (Union[Type[AbstractQuery], CustomQuery]): The query object. **kwargs: Additional arguments for the query. Returns: any: Result of the query execution. Raises: ValueError: If the query type is unsupported. """ query_type = query.query_type strategy, transaction_function = self._query_strategies.get( query_type, (None, None) ) if not strategy or not transaction_function: raise ValueError(f"Unsupported query type: {query_type}") with self._session_context() as session: return transaction_function(session, strategy, query, **kwargs) @staticmethod def _execute_read( session: Session, strategy: Callable, query: AbstractQuery, **kwargs: any ) -> any: """ Execute a read transaction using the specified strategy. Args: session (Session): Neo4j session for executing the transaction. strategy (Callable): Strategy function for query execution. query (AbstractQuery): Query object. **kwargs: Additional arguments for the query. Returns: any: Result of the read transaction. """ return session.execute_read(lambda tx: strategy(tx, query, **kwargs)) @staticmethod def _execute_write( session: Session, strategy: Callable, query: AbstractQuery, **kwargs: any ) -> any: """ Execute a write transaction using the specified strategy. Args: session (Session): Neo4j session for executing the transaction. strategy (Callable): Strategy function for query execution. query (AbstractQuery): Query object. **kwargs: Additional arguments for the query. Returns: any: Result of the write transaction. """ return session.execute_write(lambda tx: strategy(tx, query, **kwargs)) def _execute_query( self, tx, query: Type[AbstractQuery], **kwargs: any ) -> Union[Result, list[Record]]: """ Execute a query using a transaction. Args: tx: Transaction object. query (Type[AbstractQuery]): Query object. **kwargs: Additional arguments for the query. Returns: Union[Result, list[Record]]: Result of the query execution. """ query_object = query(**kwargs) query_object.graph_schema = self.schema query_string = query_object.get_query() if query.parameters_embedded: return self._execute_parameters_embedded_query(tx, query_string) else: return self._execute_parameters_not_embedded_query( tx, query_string, **kwargs ) @staticmethod def _execute_parameters_not_embedded_query(tx, query: str, **kwargs: any) -> Result: """ " Execute a query with parameters passed at runtime. Args: tx: Transaction object. query (str): Query string. **kwargs: Additional arguments for the query. Returns: Result: Result of the query execution. """ result = tx.run(query, **kwargs) return result @staticmethod def _execute_parameters_embedded_query(tx, query: list[str]) -> list[Record]: """ Execute a query with parameters embedded in the query string. Args: tx: Transaction object. query (list[str]): Query string with embedded parameters. Returns: list[Record]: Result of the query execution. """ results = [] for line in query: result = tx.run(line) results.extend(result.to_eager_result().records) return results def _retrieve_graph_strategy( self, tx, query: Type[AbstractQuery], **kwargs: any ) -> DataContainer: """ Strategy for retrieving graph data. Args: tx: Transaction object. query (Type[AbstractQuery]): Query object. **kwargs: Additional arguments for the query. Returns: DataContainer: Formatted result as a data container. """ result = self._execute_query(tx, query, **kwargs) return format_result(result, self.schema.base_nodes["chemical_equation"]) def _modify_graph_strategy( self, tx, query: Type[AbstractQuery], **kwargs: any ) -> list: """ Strategy for modifying graph data. Args: tx: Transaction object. query (Type[AbstractQuery]): Query object. **kwargs: Additional arguments for the query. Returns: list: List of records resulting from the modification. """ result = self._execute_query(tx, query, **kwargs) return [record for record in result] def _retrieve_stats_strategy( self, tx, query: Type[AbstractQuery], **kwargs: any ) -> pd.DataFrame: """ Strategy for retrieving statistical data. Args: tx: Transaction object. query (Type[AbstractQuery]): Query object. **kwargs: Additional arguments for the query. Returns: pd.DataFrame: DataFrame containing statistical results. """ result = self._execute_query(tx, query, **kwargs) return pd.DataFrame([dict(record) for record in result])