Makes transformer script executable via cli

This commit is contained in:
quorploop 2026-01-27 20:19:05 +01:00
parent 8fae350b34
commit 7c2e34906e
11 changed files with 648 additions and 37 deletions

View file

@ -40,7 +40,7 @@ class TextEmbeddingNode(TransformNode):
of posts.
"""
def __init__(self,
model_name: str = "thenlper/gte-small",
model_name: str = "thenlper/gte-large",
model_path: str = None,
device: str = "cpu"):
"""Initialize the ExampleNode.
@ -64,8 +64,12 @@ class TextEmbeddingNode(TransformNode):
model_source = None
if self.model_path:
if os.path.exists(self.model_path):
model_source = self.model_path
logger.info(f"Loading GTE model from local path: {self.model_path}")
# Check if it's a valid model directory
if os.path.exists(os.path.join(self.model_path, 'config.json')):
model_source = self.model_path
logger.info(f"Loading GTE model from local path: {self.model_path}")
else:
logger.warning(f"GTE_MODEL_PATH '{self.model_path}' found but missing config.json; Falling back to hub model {self.model_name}")
else:
logger.warning(f"GTE_MODEL_PATH '{self.model_path}' not found; Falling back to hub model {self.model_name}")
@ -73,12 +77,17 @@ class TextEmbeddingNode(TransformNode):
model_source = self.model_name
logger.info(f"Loading GTE model from the hub: {self.model_name}")
if self.device == "cuda" and torch.cuda.is_available():
self.model = SentenceTransformer(model_source).to('cuda', dtype=torch.float16)
elif self.device == "mps" and torch.backends.mps.is_available():
self.model = SentenceTransformer(model_source).to('mps', dtype=torch.float16)
else:
self.model = SentenceTransformer(model_source)
try:
if self.device == "cuda" and torch.cuda.is_available():
self.model = SentenceTransformer(model_source).to('cuda', dtype=torch.float16)
elif self.device == "mps" and torch.backends.mps.is_available():
self.model = SentenceTransformer(model_source).to('mps', dtype=torch.float16)
else:
self.model = SentenceTransformer(model_source)
logger.info(f"Successfully loaded GTE model from: {model_source}")
except Exception as e:
logger.error(f"Failed to load GTE model from {model_source}: {e}")
raise
def _process_data(self, df: pd.DataFrame) -> pd.DataFrame:
"""Process the input dataframe.