Implement Nodes to compute text embeddings
This commit is contained in:
parent
72765532d3
commit
49239e7e25
9 changed files with 505 additions and 25 deletions
|
|
@ -12,6 +12,7 @@ logger = logging.getLogger("knack-transform")
|
|||
|
||||
class TransformContext:
|
||||
"""Context object containing the dataframe for transformation."""
|
||||
# Possibly add a dict for the context to give more Information
|
||||
|
||||
def __init__(self, df: pd.DataFrame):
|
||||
self.df = df
|
||||
|
|
@ -153,7 +154,6 @@ class ParallelPipeline:
|
|||
logger.info(f"Pipeline has {len(stages)} execution stage(s)")
|
||||
|
||||
results = {}
|
||||
contexts = {None: initial_context} # Track contexts from each node
|
||||
errors = []
|
||||
|
||||
ExecutorClass = ProcessPoolExecutor if self.use_processes else ThreadPoolExecutor
|
||||
|
|
@ -213,6 +213,7 @@ def create_default_pipeline(device: str = "cpu",
|
|||
Configured ParallelPipeline
|
||||
"""
|
||||
from author_node import NerAuthorNode, FuzzyAuthorNode
|
||||
from embeddings_node import TextEmbeddingNode, UmapNode
|
||||
|
||||
pipeline = ParallelPipeline(max_workers=max_workers, use_processes=False)
|
||||
|
||||
|
|
@ -236,17 +237,24 @@ def create_default_pipeline(device: str = "cpu",
|
|||
name='FuzzyAuthorNode'
|
||||
))
|
||||
|
||||
pipeline.add_node(NodeConfig(
|
||||
node_class=TextEmbeddingNode,
|
||||
node_kwargs={
|
||||
'device': device,
|
||||
'model_path': os.environ.get('MINILM_MODEL_PATH')
|
||||
},
|
||||
dependencies=[],
|
||||
name='TextEmbeddingNode'
|
||||
))
|
||||
|
||||
pipeline.add_node(NodeConfig(
|
||||
node_class=UmapNode,
|
||||
node_kwargs={},
|
||||
dependencies=['TextEmbeddingNode'],
|
||||
name='UmapNode'
|
||||
))
|
||||
|
||||
# TODO: Create Node to compute Text Embeddings and UMAP.
|
||||
# TODO: Create Node to pre-compute data based on visuals to reduce load time.
|
||||
|
||||
# TODO: Add more nodes here as they are implemented
|
||||
# Example:
|
||||
# pipeline.add_node(NodeConfig(
|
||||
# node_class=EmbeddingNode,
|
||||
# node_kwargs={'device': device},
|
||||
# dependencies=[], # Runs after AuthorNode
|
||||
# name='EmbeddingNode'
|
||||
# ))
|
||||
|
||||
# pipeline.add_node(NodeConfig(
|
||||
# node_class=UMAPNode,
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue