forked from lukaszett/Knack-Scraper
Makes transformer script executable via cli
This commit is contained in:
parent
8fae350b34
commit
7c2e34906e
11 changed files with 648 additions and 37 deletions
|
|
@ -40,7 +40,7 @@ class TextEmbeddingNode(TransformNode):
|
|||
of posts.
|
||||
"""
|
||||
def __init__(self,
|
||||
model_name: str = "thenlper/gte-small",
|
||||
model_name: str = "thenlper/gte-large",
|
||||
model_path: str = None,
|
||||
device: str = "cpu"):
|
||||
"""Initialize the ExampleNode.
|
||||
|
|
@ -64,8 +64,12 @@ class TextEmbeddingNode(TransformNode):
|
|||
model_source = None
|
||||
if self.model_path:
|
||||
if os.path.exists(self.model_path):
|
||||
model_source = self.model_path
|
||||
logger.info(f"Loading GTE model from local path: {self.model_path}")
|
||||
# Check if it's a valid model directory
|
||||
if os.path.exists(os.path.join(self.model_path, 'config.json')):
|
||||
model_source = self.model_path
|
||||
logger.info(f"Loading GTE model from local path: {self.model_path}")
|
||||
else:
|
||||
logger.warning(f"GTE_MODEL_PATH '{self.model_path}' found but missing config.json; Falling back to hub model {self.model_name}")
|
||||
else:
|
||||
logger.warning(f"GTE_MODEL_PATH '{self.model_path}' not found; Falling back to hub model {self.model_name}")
|
||||
|
||||
|
|
@ -73,12 +77,17 @@ class TextEmbeddingNode(TransformNode):
|
|||
model_source = self.model_name
|
||||
logger.info(f"Loading GTE model from the hub: {self.model_name}")
|
||||
|
||||
if self.device == "cuda" and torch.cuda.is_available():
|
||||
self.model = SentenceTransformer(model_source).to('cuda', dtype=torch.float16)
|
||||
elif self.device == "mps" and torch.backends.mps.is_available():
|
||||
self.model = SentenceTransformer(model_source).to('mps', dtype=torch.float16)
|
||||
else:
|
||||
self.model = SentenceTransformer(model_source)
|
||||
try:
|
||||
if self.device == "cuda" and torch.cuda.is_available():
|
||||
self.model = SentenceTransformer(model_source).to('cuda', dtype=torch.float16)
|
||||
elif self.device == "mps" and torch.backends.mps.is_available():
|
||||
self.model = SentenceTransformer(model_source).to('mps', dtype=torch.float16)
|
||||
else:
|
||||
self.model = SentenceTransformer(model_source)
|
||||
logger.info(f"Successfully loaded GTE model from: {model_source}")
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to load GTE model from {model_source}: {e}")
|
||||
raise
|
||||
|
||||
def _process_data(self, df: pd.DataFrame) -> pd.DataFrame:
|
||||
"""Process the input dataframe.
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue