Implement Nodes to compute text embeddings

This commit is contained in:
quorploop 2025-12-24 17:58:23 +01:00
parent 72765532d3
commit 49239e7e25
9 changed files with 505 additions and 25 deletions

View file

@ -12,6 +12,7 @@ logger = logging.getLogger("knack-transform")
class TransformContext:
"""Context object containing the dataframe for transformation."""
# Possibly add a dict for the context to give more Information
def __init__(self, df: pd.DataFrame):
self.df = df
@ -153,7 +154,6 @@ class ParallelPipeline:
logger.info(f"Pipeline has {len(stages)} execution stage(s)")
results = {}
contexts = {None: initial_context} # Track contexts from each node
errors = []
ExecutorClass = ProcessPoolExecutor if self.use_processes else ThreadPoolExecutor
@ -213,6 +213,7 @@ def create_default_pipeline(device: str = "cpu",
Configured ParallelPipeline
"""
from author_node import NerAuthorNode, FuzzyAuthorNode
from embeddings_node import TextEmbeddingNode, UmapNode
pipeline = ParallelPipeline(max_workers=max_workers, use_processes=False)
@ -236,17 +237,24 @@ def create_default_pipeline(device: str = "cpu",
name='FuzzyAuthorNode'
))
pipeline.add_node(NodeConfig(
node_class=TextEmbeddingNode,
node_kwargs={
'device': device,
'model_path': os.environ.get('MINILM_MODEL_PATH')
},
dependencies=[],
name='TextEmbeddingNode'
))
pipeline.add_node(NodeConfig(
node_class=UmapNode,
node_kwargs={},
dependencies=['TextEmbeddingNode'],
name='UmapNode'
))
# TODO: Create Node to compute Text Embeddings and UMAP.
# TODO: Create Node to pre-compute data based on visuals to reduce load time.
# TODO: Add more nodes here as they are implemented
# Example:
# pipeline.add_node(NodeConfig(
# node_class=EmbeddingNode,
# node_kwargs={'device': device},
# dependencies=[], # Runs after AuthorNode
# name='EmbeddingNode'
# ))
# pipeline.add_node(NodeConfig(
# node_class=UMAPNode,