Source code for kgx.transformer

import copy
import itertools
import multiprocessing
import os
import shutil
import tempfile
from os.path import exists
from sys import stderr
from typing import Dict, Generator, List, Optional, Callable, Set

from kgx.config import get_logger
from kgx.error_detection import ErrorType, MessageLevel, ErrorDetecting
from kgx.source import (
    GraphSource,
    Source,
    TsvSource,
    JsonSource,
    JsonlSource,
    ObographSource,
    TrapiSource,
    NeoSource,
    ArangoSource,
    RdfSource,
    OwlSource,
    SssomSource,
    DuckDbSource,
)
from kgx.sink import (
    Sink,
    GraphSink,
    JsonSink,
    JsonlSink,
    NeoSink,
    ArangoSink,
    NullSink,
    RdfSink,
    SqlSink,
    TsvSink,
    ParquetSink
)
from kgx.utils.kgx_utils import (
    apply_graph_operations,
    GraphEntityType,
    knowledge_provenance_properties,
)

SOURCE_MAP = {
    "tsv": TsvSource,
    "csv": TsvSource,
    "graph": GraphSource,
    "json": JsonSource,
    "jsonl": JsonlSource,
    "obojson": ObographSource,
    "obo-json": ObographSource,
    "trapi-json": TrapiSource,
    "neo4j": NeoSource,
    "arangodb": ArangoSource,
    "duckdb": DuckDbSource,
    "nt": RdfSource,
    "jelly": RdfSource,
    "owl": OwlSource,
    "sssom": SssomSource,
    "parquet": GraphSource,
}

SINK_MAP = {
    "csv": TsvSink,
    "graph": GraphSink,
    "json": JsonSink,
    "jsonl": JsonlSink,
    "neo4j": NeoSink,
    "arangodb": ArangoSink,
    "nt": RdfSink,
    "jelly": RdfSink,
    "null": NullSink,
    "sql": SqlSink,
    "tsv": TsvSink,
    "parquet": ParquetSink,
}


log = get_logger()


def _parallel_worker(args):
    """
    Worker entry point for Transformer parallel mode. Runs in a child
    process spawned by Transformer._run_parallel. Builds a fresh
    Transformer and runs a single-threaded streaming transform over a
    pre-computed source partition, writing to a part file.
    """
    input_args, output_args = args
    t = Transformer(stream=True)
    t.transform(input_args=input_args, output_args=output_args)


[docs] class Transformer(ErrorDetecting): """ The Transformer class is responsible for transforming data from one form to another. Parameters ---------- stream: bool Whether or not to stream (default: False) infores_catalog: Optional[str] Optional dump of a TSV file of InfoRes CURIE to Knowledge Source mappings error_log: Where to write any graph processing error message (stderr, by default). """ def __init__( self, stream: bool = False, infores_catalog: Optional[str] = None, error_log=None ): """ stream: bool Whether or not to stream infores_catalog: Optional[str] Optional dump of a TSV file of InfoRes CURIE to Knowledge Source mappings error_log: Where to write any graph processing error message (stderr, by default). """ ErrorDetecting.__init__(self, error_log) self.stream = stream self.node_filters = {} self.edge_filters = {} self.inspector: Optional[Callable[[GraphEntityType, List], None]] = None self.store = self.get_source("graph") self._seen_nodes = set() self._infores_catalog: Dict[str, str] = dict() if infores_catalog and exists(infores_catalog): with open(infores_catalog, "r") as irc: for entry in irc: if len(entry): entry = entry.strip() if entry: source, infores = entry.split("\t") self._infores_catalog[source] = infores
[docs] def transform( self, input_args: Dict, output_args: Optional[Dict] = None, inspector: Optional[Callable[[GraphEntityType, List], None]] = None, parallel: int = 1, ) -> None: """ Transform an input source and write to an output sink. If ``output_args`` is not defined then the data is persisted to an in-memory graph. The 'inspector' argument is an optional Callable which the transformer.process() method applies to 'inspect' source records prior to writing them out to the Sink. The first (GraphEntityType) argument of the Callable tags the record as a NODE or an EDGE. The second argument given to the Callable is the current record itself. This Callable is strictly meant to be procedural and should *not* mutate the record. Parameters ---------- input_args: Dict Arguments relevant to your input source output_args: Optional[Dict] Arguments relevant to your output sink ( inspector: Optional[Callable[[GraphEntityType, List], None]] Optional Callable to 'inspect' source records during processing. """ if parallel and parallel > 1: if self._run_parallel(input_args, output_args, inspector, parallel): return # _run_parallel returned False: prerequisites not met, fall through # to sequential processing. sources = [] generators = [] input_format = input_args["format"] prefix_map = input_args.pop("prefix_map", {}) predicate_mappings = input_args.pop("predicate_mappings", {}) node_property_predicates = input_args.pop("node_property_predicates", {}) node_filters = input_args.pop("node_filters", {}) edge_filters = input_args.pop("edge_filters", {}) operations = input_args.pop("operations", []) # Optional process() data stream inspector self.inspector = inspector if input_format in {"neo4j", "arangodb", "graph"}: source = self.get_source(input_format) source.set_prefix_map(prefix_map) source.set_node_filters(node_filters) self.node_filters = source.node_filters self.edge_filters = source.edge_filters source.set_edge_filters(edge_filters) self.node_filters = source.node_filters self.edge_filters = source.edge_filters if "uri" in input_args: default_provenance = input_args["uri"] else: default_provenance = None g = source.parse(default_provenance=default_provenance, **input_args) sources.append(source) generators.append(g) else: filename = input_args.pop("filename", {}) for f in filename: source = self.get_source(input_format) source.set_prefix_map(prefix_map) if isinstance(source, RdfSource): source.set_predicate_mapping(predicate_mappings) source.set_node_property_predicates(node_property_predicates) source.set_node_filters(node_filters) self.node_filters = source.node_filters self.edge_filters = source.edge_filters source.set_edge_filters(edge_filters) self.node_filters = source.node_filters self.edge_filters = source.edge_filters default_provenance = os.path.basename(f) g = source.parse(f, default_provenance=default_provenance, **input_args) sources.append(source) generators.append(g) source_generator = itertools.chain(*generators) if output_args: if self.stream: if output_args["format"] in {"tsv", "csv"}: if "node_properties" not in output_args or "edge_properties" not in output_args: error_type = ErrorType.MISSING_PROPERTY self.log_error( entity=f"{output_args['format']} stream", error_type=error_type, message=f"'node_properties' and 'edge_properties' must be defined for output while" f"streaming. The exported format will be limited to a subset of the columns.", message_level=MessageLevel.WARNING ) sink = self.get_sink(**output_args) if "reverse_prefix_map" in output_args: sink.set_reverse_prefix_map(output_args["reverse_prefix_map"]) if isinstance(sink, RdfSink): if "reverse_predicate_mapping" in output_args: sink.set_reverse_predicate_mapping( output_args["reverse_predicate_mapping"] ) if "property_types" in output_args: sink.set_property_types(output_args["property_types"]) # stream from source to sink self.process(source_generator, sink) sink.finalize() else: # stream from source to intermediate intermediate_sink = GraphSink(self) intermediate_sink.node_properties.update(self.store.node_properties) intermediate_sink.edge_properties.update(self.store.edge_properties) self.process(source_generator, intermediate_sink) for s in sources: intermediate_sink.node_properties.update(s.node_properties) intermediate_sink.edge_properties.update(s.edge_properties) apply_graph_operations(intermediate_sink.graph, operations) # stream from intermediate to output sink intermediate_source = self.get_source("graph") intermediate_source.node_properties.update( intermediate_sink.node_properties ) intermediate_source.edge_properties.update( intermediate_sink.edge_properties ) # Need to propagate knowledge source specifications here? ks_args = dict() for ksf in knowledge_provenance_properties: if ksf in input_args: ks_args[ksf] = input_args[ksf] intermediate_source_generator = intermediate_source.parse( intermediate_sink.graph, **ks_args ) if output_args["format"] in {"tsv", "csv"}: if "node_properties" not in output_args: output_args[ "node_properties" ] = intermediate_source.node_properties log.debug("output_args['node_properties']: " + str(output_args["node_properties"]), file=stderr) if "edge_properties" not in output_args: output_args[ "edge_properties" ] = intermediate_source.edge_properties sink = self.get_sink(**output_args) if "reverse_prefix_map" in output_args: sink.set_reverse_prefix_map(output_args["reverse_prefix_map"]) if isinstance(sink, RdfSink): if "reverse_predicate_mapping" in output_args: sink.set_reverse_predicate_mapping( output_args["reverse_predicate_mapping"] ) if "property_types" in output_args: sink.set_property_types(output_args["property_types"]) else: sink = self.get_sink(**output_args) sink.node_properties.update(intermediate_source.node_properties) sink.edge_properties.update(intermediate_source.edge_properties) self.process(intermediate_source_generator, sink) sink.finalize() self.store.node_properties.update(sink.node_properties) self.store.edge_properties.update(sink.edge_properties) else: # stream from source to intermediate sink = GraphSink(self) self.process(source_generator, sink) sink.node_properties.update(self.store.node_properties) sink.edge_properties.update(self.store.edge_properties) for s in sources: sink.node_properties.update(s.node_properties) sink.edge_properties.update(s.edge_properties) sink.finalize() self.store.node_properties.update(sink.node_properties) self.store.edge_properties.update(sink.edge_properties) apply_graph_operations(sink.graph, operations) # Aggregate the InfoRes catalogs from all sources for s in sources: for k, v in s.get_infores_catalog().items(): self._infores_catalog[k] = v
def _run_parallel( self, input_args: Dict, output_args: Optional[Dict], inspector: Optional[Callable[[GraphEntityType, List], None]], parallel: int, ) -> bool: """ Try to run a parallel multiprocessing transform. Returns True if the parallel run completed; False if prerequisites aren't met (in which case the caller should fall back to sequential). Prerequisites: - ``output_args`` must be set, must name a single output ``filename``, and must use a sink format that is safe to byte-concatenate across part files (currently: ``nt`` without gzip compression). - The source class must implement ``partitions(filename, n)``. - ``inspector`` is not supported in parallel mode (would not run in the parent process). - Exactly one input filename. """ if output_args is None: log.warning("parallel mode requires output_args; falling back to sequential") return False if inspector is not None: log.warning("parallel mode does not support inspector; falling back to sequential") return False out_format = output_args.get("format") out_compression = output_args.get("compression") if out_format != "nt" or out_compression == "gz": log.warning( "parallel mode currently supports only format='nt' without gzip; " "falling back to sequential" ) return False out_filename = output_args.get("filename") if not isinstance(out_filename, str): log.warning("parallel mode requires output_args['filename'] to be a single string path") return False input_format = input_args.get("format") source_cls = SOURCE_MAP.get(input_format) if source_cls is None or not hasattr(source_cls, "partitions"): log.warning( f"parallel mode: source format '{input_format}' does not support " "partitioning; falling back to sequential" ) return False in_filenames = input_args.get("filename") if not isinstance(in_filenames, (list, tuple)) or len(in_filenames) != 1: log.warning( "parallel mode requires exactly one input filename; falling back to sequential" ) return False # Compute partitions using a throwaway source instance. probe = source_cls(self) probe_kwargs = {k: v for k, v in input_args.items() if k != "filename"} partitions = probe.partitions(in_filenames[0], parallel, **probe_kwargs) tmpdir = tempfile.mkdtemp(prefix="kgx_parallel_") try: tasks = [] part_files = [] for i, part in enumerate(partitions): part_path = os.path.join(tmpdir, f"part_{i:04d}.nt") part_files.append(part_path) worker_input = copy.deepcopy(input_args) worker_input["filename"] = list(in_filenames) worker_input.update(part) worker_output = copy.deepcopy(output_args) worker_output["filename"] = part_path tasks.append((worker_input, worker_output)) ctx = multiprocessing.get_context("spawn") with ctx.Pool(parallel) as pool: pool.map(_parallel_worker, tasks) with open(out_filename, "wb") as out_fh: for pf in part_files: with open(pf, "rb") as part_fh: shutil.copyfileobj(part_fh, out_fh, length=4 * 1024 * 1024) # Free disk eagerly: part files can be tens of GB each. os.unlink(pf) finally: shutil.rmtree(tmpdir, ignore_errors=True) return True
[docs] def get_infores_catalog(self): """ Return catalog of Information Resource mappings aggregated from all Transformer associated sources """ return self._infores_catalog
[docs] def process(self, source: Generator, sink: Sink) -> None: """ This method is responsible for reading from ``source`` and writing to ``sink`` by calling the relevant methods based on the incoming data. .. note:: The streamed data must not be mutated. Parameters ---------- source: Generator A generator from a Source sink: kgx.sink.sink.Sink An instance of Sink """ for rec in source: if rec: log.debug("length of rec", len(rec), "rec", rec) if len(rec) == 4: # infer an edge record write_edge = True if "subject_category" in self.edge_filters: if rec[0] in self._seen_nodes: write_edge = True else: write_edge = False if "object_category" in self.edge_filters: if rec[1] in self._seen_nodes: if "subject_category" in self.edge_filters: if write_edge: write_edge = True else: write_edge = True else: write_edge = False if write_edge: if self.inspector: self.inspector(GraphEntityType.EDGE, rec) sink.write_edge(rec[-1]) else: # infer a node record if "category" in self.node_filters: self._seen_nodes.add(rec[0]) if self.inspector: self.inspector(GraphEntityType.NODE, rec) # last element of rec is the node properties sink.write_node(rec[-1])
[docs] def save(self, output_args: Dict) -> None: """ Save data from the in-memory store to a desired sink. Parameters ---------- output_args: Dict Arguments relevant to your output sink """ if not self.store: raise Exception("self.store is empty.") source = self.store source.node_properties.update(self.store.node_properties) source.edge_properties.update(self.store.edge_properties) source_generator = source.parse(self.store.graph) if "node_properties" not in output_args: output_args["node_properties"] = source.node_properties if "edge_properties" not in output_args: output_args["edge_properties"] = source.edge_properties sink = self.get_sink(**output_args) sink.node_properties.update(source.node_properties) sink.edge_properties.update(source.edge_properties) if "reverse_prefix_map" in output_args: sink.set_reverse_prefix_map(output_args["reverse_prefix_map"]) if isinstance(sink, RdfSink): if "reverse_predicate_mapping" in output_args: sink.set_reverse_predicate_mapping( output_args["reverse_predicate_mapping"] ) if "property_types" in output_args: sink.set_property_types(output_args["property_types"]) self.process(source_generator, sink) sink.finalize()
[docs] def get_source(self, format: str) -> Source: """ Get an instance of Source that corresponds to a given format. Parameters ---------- format: str The input store format Returns ------- Source: An instance of kgx.source.Source """ if format in SOURCE_MAP: s = SOURCE_MAP[format] return s(self) else: raise TypeError(f"{format} in an unrecognized format")
[docs] def get_sink(self, **kwargs: Dict) -> Sink: """ Get an instance of Sink that corresponds to a given format. Parameters ---------- kwargs: Dict Arguments required for initializing an instance of Sink Returns ------- Sink: An instance of kgx.sink.Sink """ if kwargs["format"] in SINK_MAP: s = SINK_MAP[kwargs["format"]] return s(self, **kwargs) else: raise TypeError(f"{kwargs['format']} in an unrecognized format")