import logging
from typing import Dict, List, Literal, Optional, Tuple
import networkx as nx
import numpy as np
import pandas as pd
from annoy import AnnoyIndex
from tqdm.auto import tqdm
logger = logging.getLogger(__name__)
MetricType = Literal["angular", "euclidean", "manhattan", "hamming", "dot"]
[docs]
def to_pandas(graph: nx.Graph) -> Tuple[pd.DataFrame, pd.DataFrame]:
"""
Export a NetworkX graph to pandas DataFrames.
This function operates on any NetworkX graph, making it useful for
analyzing graphs from SemanticNetwork or any other NetworkX graph.
Args:
graph: NetworkX graph to export
Returns:
Tuple[pd.DataFrame, pd.DataFrame]: A tuple containing:
- nodes (pd.DataFrame): Node attributes with index as node ID.
Columns include all node attributes from the graph.
- edges (pd.DataFrame): Edge list with columns 'source', 'target',
and any edge attributes (e.g., 'weight').
Examples:
>>> # Export any NetworkX graph
>>> import networkx as nx
>>> from semnet import to_pandas
>>>
>>> # Create or load any graph
>>> G = nx.karate_club_graph()
>>> nodes, edges = to_pandas(G)
>>> # Use with SemanticNetwork
>>> network = SemanticNetwork(thresh=0.8)
>>> graph = network.fit_transform(embeddings, labels=docs)
>>> nodes, edges = to_pandas(graph)
>>> # Export a subgraph
>>> subgraph = graph.subgraph([0, 1, 2])
>>> sub_nodes, sub_edges = to_pandas(subgraph)
"""
# Convert nodes to DataFrame
node_list = []
for node, data in graph.nodes(data=True):
node_data = {"node_id": node}
node_data.update(data)
node_list.append(node_data)
nodes = pd.DataFrame(node_list)
# Convert edges to DataFrame
if graph.number_of_edges() > 0:
edges = nx.to_pandas_edgelist(graph)
else:
# Create empty DataFrame with expected columns if no edges
edges = pd.DataFrame(columns=["source", "target"])
return nodes, edges
[docs]
class SemanticNetwork:
"""
A semantic network builder for creating graphs from document embeddings.
This class follows the scikit-learn pattern with fit() and transform() methods.
Users must provide pre-computed embeddings during the fit process.
The fitting process builds an approximate nearest neighbor index from embeddings.
The transformation process constructs a graph where edges represent semantic similarity.
Key Methods:
fit(): Build the similarity index from provided embeddings
transform(): Construct and return a networkx object
fit_transform(): Combined fit and transform in one step
to_pandas(): Export graph structure to pandas DataFrames for analysis
Attributes:
metric: Distance metric for the Annoy index
n_trees: Number of trees for the Annoy index
thresh: Similarity threshold for connecting documents
top_k: Maximum neighbors to check per document
verbose: Whether to show progress bars and detailed logging
is_fitted_: Whether the model has been fitted
embeddings_: Document embeddings array (available after fitting)
index_: Annoy index for similarity search (available after fitting)
"""
def __init__(
self,
metric: MetricType = "angular",
n_trees: int = 10,
thresh: float = 0.3,
top_k: int = 20,
search_k: Optional[int] = None,
verbose: bool = False,
) -> None:
"""
Initialize the SemanticNetwork.
Args:
metric: Distance metric for the Annoy index. Options:
- 'angular': Cosine distance (1 - cosine similarity). Best for normalized embeddings
like those from sentence transformers. Default and recommended.
- 'euclidean': L2 distance. Good for embeddings where magnitude matters.
- 'manhattan': L1 distance. Less common, useful for sparse data.
- 'hamming': Hamming distance for binary vectors.
- 'dot': Negative dot product. Use with caution.
n_trees: Number of trees in the Annoy index forest. More trees = higher accuracy but
slower build time and larger memory usage. Typical range: 10-100.
- 10 (default): Good balance for most applications
- 50+: Better accuracy for large datasets (100k+ documents)
- 100+: Highest accuracy for production systems
thresh: Similarity threshold for edge creation (0.0 to 1.0). Documents with
similarity >= thresh will be connected by an edge. Higher values create
sparser, more selective networks:
- 0.1-0.2: Dense networks, captures weak similarities
- 0.2-0.4: Moderate selectivity, good for most applications
- 0.5+: Very selective, only strong similarities
top_k: Maximum number of nearest neighbors to examine per document when building
the network. Higher values find more potential connections but increase
computation time. Should be >> expected degree:
- 10-20: Good starting point for most applications (default: 20)
- 20-50: More comprehensive search for medium datasets
- 50-100: Thorough search for large or sparse datasets
- 100+: Use only when seeking maximum connectivity
search_k: Optional parameter for Annoy index search_k, controlling the number of nodes to inspect during search.
If None, uses the default from Annoy.
verbose: Whether to show progress bars and detailed logging during fit/transform
"""
self.metric = metric
self.n_trees = n_trees
self.thresh = thresh
self.top_k = top_k
self.search_k = search_k
self.verbose = verbose
# Fitted state
self.is_fitted_ = False
self.embeddings_: Optional[np.ndarray] = None
self.index_: Optional[AnnoyIndex] = None
# Training data (stored during fit)
self._labels: Optional[List[str]] = None
self._node_data: Optional[Dict] = None
[docs]
def fit(
self,
embeddings: np.ndarray,
) -> "SemanticNetwork":
"""
Build the index from document embeddings.
This method uses provided embeddings to create an Annoy index for
fast nearest neighbor search.
Args:
embeddings: Pre-computed embeddings array with shape (n_docs, embedding_dim).
labels: Optional list of text labels/documents for the embeddings.
If not provided, will use string indices as labels.
node_data: Optional dictionary containing additional data to attach to nodes.
Format: {node_index: {attribute_name: value, ...}, ...}
OR {node_index: single_value, ...} (will be stored as {'value': single_value})
Only nodes present in the dictionary will get additional attributes.
Returns:
self: Returns the fitted estimator
Raises:
ValueError: If labels provided but length doesn't match embeddings
ValueError: If ids provided but length doesn't match embeddings
ValueError: If node_data values don't match embeddings length
"""
self.embeddings_ = embeddings
if self.verbose:
logger.info(f"Using provided embeddings with shape: {self.embeddings_.shape}")
logger.info(f"Fitting SemanticNetwork on {len(embeddings)} documents")
# Build the vector index
self._build_vector_index()
self.is_fitted_ = True
if self.verbose:
logger.info("Fitting complete")
return self
def _build_vector_index(self) -> AnnoyIndex:
"""
Build an Annoy index for fast approximate nearest neighbor search.
Returns:
The built Annoy index
Raises:
ValueError: If embeddings haven't been provided yet
Note:
The index is stored in self.index_ and also returned.
"""
if self.embeddings_ is None:
raise ValueError("Embeddings not found. Please provide embeddings in fit() method.")
embedding_dim = self.embeddings_.shape[1]
self.index_ = AnnoyIndex(embedding_dim, self.metric) # type: ignore
if self.verbose:
logger.info(
f"Building Annoy index with {self.n_trees} trees for {len(self.embeddings_)} embeddings"
)
iterator = tqdm(
enumerate(self.embeddings_),
total=len(self.embeddings_),
desc="Adding embeddings to index",
)
else:
iterator = enumerate(self.embeddings_)
for i, embedding_vector in iterator:
self.index_.add_item(i, embedding_vector)
if self.verbose:
logger.info("Building index trees...")
self.index_.build(self.n_trees)
if self.verbose:
logger.info("Vector index built successfully")
return self.index_
def _get_pairwise_similarities(
self, thresh: float, top_k: int, search_k: Optional[int] = None
) -> pd.DataFrame:
"""
Find pairwise similarities between documents above a threshold.
Uses the Annoy index to efficiently find nearest neighbors for each document,
then calculates exact similarities and filters by threshold.
Args:
thresh: Similarity threshold for including edges
top_k: Maximum number of neighbors to check per document
search_k: Optional parameter for Annoy index search_k, controlling the number of nodes to inspect during search
Returns:
DataFrame of similarities with columns: source_idx, target_idx, weight, source_name, target_name
Raises:
ValueError: If embeddings or index haven't been built yet
"""
if self.embeddings_ is None or self.index_ is None:
raise ValueError(
"Embeddings or index not found. Please provide embeddings in fit() method and run _build_vector_index() first."
)
if self._labels is None:
raise ValueError("No training documents found. Call fit() first.")
if self.verbose:
logger.info(
f"Finding pairwise similarities with threshold {thresh}, checking top {top_k} neighbors"
)
results = []
if self.verbose:
iterator = tqdm(range(len(self.embeddings_)), desc="Finding similarities")
else:
iterator = range(len(self.embeddings_))
for idx_source in iterator:
# Dealing with self loops and too large top_k
# We add one to top_k to account for self-matches, so top_k = 1 will return 1 neighbor + self
effective_top_k = top_k + 1
# In case top_k exceeds number of items, set to -1 to get all items
if effective_top_k > len(self.embeddings_):
effective_top_k = len(self.embeddings_)
# Only pass search_k if explicitly set (annoy doesn't accept None)
if search_k is not None:
neighbors = self.index_.get_nns_by_item(
idx_source, effective_top_k, search_k=search_k, include_distances=True
)
else:
neighbors = self.index_.get_nns_by_item(
idx_source, effective_top_k, include_distances=True
)
# Reduce neighbours to exclude self-match
neighbors = (
np.array(neighbors[0][1:]),
np.array(neighbors[1][1:]),
)
for idx_target, dist in zip(*neighbors):
# Convert distance to similarity based on metric
if self.metric == "angular":
# angular_dist = sqrt(2 * (1 - cos_sim))
# Therefore: cos_sim = 1 - (dist^2 / 2)
similarity = 1 - (dist**2) / 2
elif self.metric == "dot":
# Annoy stores negative dot product
similarity = -dist
elif self.metric == "hamming":
# Hamming distance is already in [0, 1]
similarity = 1 - dist
else: # euclidean, manhattan
# No direct cosine relationship; use common heuristic
similarity = 1 / (1 + dist)
# Only include if above threshold
if similarity >= thresh:
result_dict = {
"source_idx": idx_source,
"target_idx": idx_target,
"weight": similarity,
"source_name": self._labels[idx_source],
"target_name": self._labels[idx_target],
}
results.append(result_dict)
neighbor_data = pd.DataFrame(results)
if self.verbose:
logger.info(f"Found {len(neighbor_data)} similarity pairs above threshold {thresh}")
return neighbor_data
def _build_graph(self, neighbor_data: pd.DataFrame) -> nx.Graph:
"""
Build a NetworkX graph from pairwise similarities.
Creates a graph where:
- Nodes represent documents
- Edges represent similarities above the threshold (with 'weight' attribute representing similarity)
Args:
neighbor_data: DataFrame of pairwise similarities
Returns:
The constructed NetworkX graph
Raises:
ValueError: If training data hasn't been set
Note:
The graph includes all documents as nodes, even if they have no similarities above threshold.
"""
if self._labels is None:
raise ValueError("No training documents found. Call fit() first.")
if self.verbose:
logger.info(f"Building graph from {len(neighbor_data)} similarity edges")
# Instantiate undirected graph
G = nx.Graph()
# Add all nodes with their attributes
for i in range(len(self._labels)):
# Set basic attributes
attrs = {
"label": self._labels[i],
}
# Add custom node data if provided for this specific node
if self._node_data is not None and i in self._node_data:
attrs.update(self._node_data[i])
G.add_node(i, **attrs)
# Add edges from neighbor data
for _, row in neighbor_data.iterrows():
G.add_edge(
row["source_idx"],
row["target_idx"],
weight=row["weight"],
)
if self.verbose:
num_components = nx.number_connected_components(G)
logger.info(
f"Built graph with {G.number_of_nodes()} nodes, {G.number_of_edges()} edges, {num_components} components"
)
return G