Skip to content

Commit

Permalink
Show edges between pinned nodes to other nodes
Browse files Browse the repository at this point in the history
PiperOrigin-RevId: 661022580
  • Loading branch information
Google AI Edge authored and copybara-github committed Aug 8, 2024
1 parent 4ee14be commit d164412
Show file tree
Hide file tree
Showing 2 changed files with 198 additions and 135 deletions.
13 changes: 13 additions & 0 deletions src/ui/src/components/visualizer/webgl_renderer.ts
Original file line number Diff line number Diff line change
Expand Up @@ -1375,6 +1375,19 @@ export class WebglRenderer implements OnInit, OnDestroy {
};
}

// Used by tests only.
getNodeIoChipScreenPositionRelativeToCenter(nodeId: string): Point {
const node = this.curModelGraph.nodesById[nodeId];
const x = this.getNodeX(node) + 5;
const y = this.getNodeY(node) - 3;
const pos = this.webglRendererThreejsService.convertScenePosToScreen(x, y);
const container = this.container.nativeElement;
return {
x: Math.floor(pos.x - container.clientWidth / 2),
y: Math.floor(pos.y - container.clientHeight / 2),
};
}

// Used by tests only.
getNodeExpandIconPositionRelativeToCenter(nodeId: string): Point {
const node = this.curModelGraph.nodesById[nodeId] as GroupNode;
Expand Down
320 changes: 185 additions & 135 deletions src/ui/src/components/visualizer/webgl_renderer_io_highlight_service.ts
Original file line number Diff line number Diff line change
Expand Up @@ -357,19 +357,6 @@ export class WebglRendererIoHighlightService {
selectedNode.namespace,
);

// Find the existing edge in the common namespace that connects two
// nodes n1 and n2 where n1 contains `sourceNode` and n2 contains
// `node`.
const renderedEdge = this.findEdgeConnectingTwoNodesInNamespace(
commonNamespace,
sourceNode.id,
selectedNode.id,
);
if (!renderedEdge) {
continue;
}
renderedEdges.push(renderedEdge);

// Go from the given node to all its ns ancestors, find the last collapsed
// node before reaching the given namespace. If all ancestor nodes are
// expanded, return the given node.
Expand All @@ -385,74 +372,96 @@ export class WebglRendererIoHighlightService {
}
inputsByHighlightedNode[highlightedNode.id].push(sourceNode);

// Find the existing edge in the common namespace that connects two
// nodes n1 and n2 where n1 contains `sourceNode` and n2 contains
// `node`.
const renderedEdge = this.findEdgeConnectingTwoNodesInNamespace(
commonNamespace,
sourceNode.id,
selectedNode.id,
);

// Start to construct an edge from the source node to the selected node.
//
const points: Point[] = [];

// Add a point from the highlighted node that connects to the first
// point of the rendered edge above.
const renderedEdgeFromNode =
this.webglRenderer.curModelGraph.nodesById[renderedEdge.fromNodeId];
if (renderedEdge.fromNodeId !== highlightedNode.id) {
const renderedEdgeStartX =
renderedEdge.points[0].x + (renderedEdgeFromNode.globalX || 0);
const renderedEdgeStartY =
renderedEdge.points[0].y + (renderedEdgeFromNode.globalY || 0);
const startPt = this.getBestAnchorPointOnNode(
renderedEdgeStartX,
renderedEdgeStartY,
highlightedNode,
if (renderedEdge) {
renderedEdges.push(renderedEdge);

// Add a point from the highlighted node that connects to the first
// point of the rendered edge above.
const renderedEdgeFromNode =
this.webglRenderer.curModelGraph.nodesById[renderedEdge.fromNodeId];
if (renderedEdge.fromNodeId !== highlightedNode.id) {
const renderedEdgeStartX =
renderedEdge.points[0].x + (renderedEdgeFromNode.globalX || 0);
const renderedEdgeStartY =
renderedEdge.points[0].y + (renderedEdgeFromNode.globalY || 0);
const startPt = this.getBestAnchorPointOnNode(
renderedEdgeStartX,
renderedEdgeStartY,
highlightedNode,
);
points.push({
x: startPt.x - (highlightedNode.globalX || 0),
y: startPt.y - (highlightedNode.globalY || 0),
});
}

// Add the points in rendered edge.
points.push(
...renderedEdge.points.map((pt) => {
return {
x:
pt.x -
(highlightedNode.globalX || 0) +
(renderedEdgeFromNode.globalX || 0),
y:
pt.y -
(highlightedNode.globalY || 0) +
(renderedEdgeFromNode.globalY || 0),
};
}),
);
points.push({
x: startPt.x - (highlightedNode.globalX || 0),
y: startPt.y - (highlightedNode.globalY || 0),
});
}

// Add the points in rendered edge.
points.push(
...renderedEdge.points.map((pt) => {
return {
x:
pt.x -
(highlightedNode.globalX || 0) +
(renderedEdgeFromNode.globalX || 0),
y:
pt.y -
(highlightedNode.globalY || 0) +
(renderedEdgeFromNode.globalY || 0),
};
}),
);

// Add a point from the selected node that connects to the last point of
// the rendered edge.
if (renderedEdge.toNodeId !== this.webglRenderer.selectedNodeId) {
const renderedEdgeLastX =
renderedEdge.points[renderedEdge.points.length - 1].x +
(renderedEdgeFromNode.globalX || 0);
const renderedEdgeLastY =
renderedEdge.points[renderedEdge.points.length - 1].y +
(renderedEdgeFromNode.globalY || 0);
const endPt = this.getBestAnchorPointOnNode(
renderedEdgeLastX,
renderedEdgeLastY,
selectedNode,
// Add a point from the selected node that connects to the last point of
// the rendered edge.
if (renderedEdge.toNodeId !== this.webglRenderer.selectedNodeId) {
const renderedEdgeLastX =
renderedEdge.points[renderedEdge.points.length - 1].x +
(renderedEdgeFromNode.globalX || 0);
const renderedEdgeLastY =
renderedEdge.points[renderedEdge.points.length - 1].y +
(renderedEdgeFromNode.globalY || 0);
const endPt = this.getBestAnchorPointOnNode(
renderedEdgeLastX,
renderedEdgeLastY,
selectedNode,
);
points.push({
x: endPt.x - (highlightedNode.globalX || 0),
y: endPt.y - (highlightedNode.globalY || 0),
});
}
} else if (
isGroupNode(highlightedNode) ||
(isOpNode(highlightedNode) && !highlightedNode.hideInLayout)
) {
points.push(
...this.getDirectEdgeBetweenNodes(highlightedNode, selectedNode),
);
points.push({
x: endPt.x - (highlightedNode.globalX || 0),
y: endPt.y - (highlightedNode.globalY || 0),
});
}

// Use these points to form an edge and add it as an overlay edge.
overlayEdges.push({
id: `overlay_${highlightedNode.id}___${selectedNode.id}`,
fromNodeId: highlightedNode.id,
toNodeId: selectedNode.id,
points,
type: 'incoming',
});
if (points.length > 0) {
overlayEdges.push({
id: `overlay_${highlightedNode.id}___${selectedNode.id}`,
fromNodeId: highlightedNode.id,
toNodeId: selectedNode.id,
points,
type: 'incoming',
});
}
}

return {
Expand Down Expand Up @@ -492,19 +501,6 @@ export class WebglRendererIoHighlightService {
selectedNode.namespace,
);

// Find the existing edge in the common namespace that connects two
// nodes n1 and n2 where n1 contains `sourceNode` and n2 contains
// `node`.
const renderedEdge = this.findEdgeConnectingTwoNodesInNamespace(
commonNamespace,
selectedNode.id,
targetNode.id,
);
if (!renderedEdge) {
continue;
}
renderedEdges.push(renderedEdge);

// Go from the given node to all its ns ancestors, find the last
// collapsed node before reaching the given namespace, and style it with
// the given class. If all ancestor nodes are expanded, style the given
Expand All @@ -521,64 +517,84 @@ export class WebglRendererIoHighlightService {
}
outputsByHighlightedNode[highlightedNode.id].push(targetNode);

// Find the existing edge in the common namespace that connects two
// nodes n1 and n2 where n1 contains `sourceNode` and n2 contains
// `node`.
const renderedEdge = this.findEdgeConnectingTwoNodesInNamespace(
commonNamespace,
selectedNode.id,
targetNode.id,
);

// Start to construct an edge from the selected node to target node.
//
const points: Point[] = [];

// Add a point from the selected node that connects to the first point
// of the rendered edge.
const renderedEdgeFromNode =
this.webglRenderer.curModelGraph.nodesById[renderedEdge.fromNodeId];
if (renderedEdge.fromNodeId !== this.webglRenderer.selectedNodeId) {
const renderedEdgeStartX =
renderedEdge.points[0].x + (renderedEdgeFromNode.globalX || 0);
const renderedEdgeStartY =
renderedEdge.points[0].y + (renderedEdgeFromNode.globalY || 0);
const endPt = this.getBestAnchorPointOnNode(
renderedEdgeStartX,
renderedEdgeStartY,
selectedNode,
if (renderedEdge) {
renderedEdges.push(renderedEdge);

// Add a point from the selected node that connects to the first point
// of the rendered edge.
const renderedEdgeFromNode =
this.webglRenderer.curModelGraph.nodesById[renderedEdge.fromNodeId];
if (renderedEdge.fromNodeId !== this.webglRenderer.selectedNodeId) {
const renderedEdgeStartX =
renderedEdge.points[0].x + (renderedEdgeFromNode.globalX || 0);
const renderedEdgeStartY =
renderedEdge.points[0].y + (renderedEdgeFromNode.globalY || 0);
const endPt = this.getBestAnchorPointOnNode(
renderedEdgeStartX,
renderedEdgeStartY,
selectedNode,
);
points.push({
x: endPt.x - (selectedNode.globalX || 0),
y: endPt.y - (selectedNode.globalY || 0),
});
}

// Add the points in rendered edge.
points.push(
...renderedEdge.points.map((pt) => {
return {
x:
pt.x -
(selectedNode.globalX || 0) +
(renderedEdgeFromNode.globalX || 0),
y:
pt.y -
(selectedNode.globalY || 0) +
(renderedEdgeFromNode.globalY || 0),
};
}),
);
points.push({
x: endPt.x - (selectedNode.globalX || 0),
y: endPt.y - (selectedNode.globalY || 0),
});
}

// Add the points in rendered edge.
points.push(
...renderedEdge.points.map((pt) => {
return {
x:
pt.x -
(selectedNode.globalX || 0) +
(renderedEdgeFromNode.globalX || 0),
y:
pt.y -
(selectedNode.globalY || 0) +
(renderedEdgeFromNode.globalY || 0),
};
}),
);

// Add a point from the highlighted node that connects to the first
// point of the rendered edge above.
if (renderedEdge.toNodeId !== highlightedNode.id) {
const renderedEdgeLastX =
renderedEdge.points[renderedEdge.points.length - 1].x +
(renderedEdgeFromNode.globalX || 0);
const renderedEdgeLastY =
renderedEdge.points[renderedEdge.points.length - 1].y +
(renderedEdgeFromNode.globalY || 0);
const startPt = this.getBestAnchorPointOnNode(
renderedEdgeLastX,
renderedEdgeLastY,
highlightedNode,
// Add a point from the highlighted node that connects to the first
// point of the rendered edge above.
if (renderedEdge.toNodeId !== highlightedNode.id) {
const renderedEdgeLastX =
renderedEdge.points[renderedEdge.points.length - 1].x +
(renderedEdgeFromNode.globalX || 0);
const renderedEdgeLastY =
renderedEdge.points[renderedEdge.points.length - 1].y +
(renderedEdgeFromNode.globalY || 0);
const startPt = this.getBestAnchorPointOnNode(
renderedEdgeLastX,
renderedEdgeLastY,
highlightedNode,
);
points.push({
x: startPt.x - (selectedNode.globalX || 0),
y: startPt.y - (selectedNode.globalY || 0),
});
}
} else if (
isGroupNode(highlightedNode) ||
(isOpNode(highlightedNode) && !highlightedNode.hideInLayout)
) {
points.push(
...this.getDirectEdgeBetweenNodes(selectedNode, highlightedNode),
);
points.push({
x: startPt.x - (selectedNode.globalX || 0),
y: startPt.y - (selectedNode.globalY || 0),
});
}

// Use these points to form an edge and add it as an overlay edge.
Expand Down Expand Up @@ -727,6 +743,40 @@ export class WebglRendererIoHighlightService {
);
}

private getDirectEdgeBetweenNodes(
startNode: ModelNode,
endNode: ModelNode,
): Point[] {
const points: Point[] = [];

const startX = startNode.globalX || 0;
const startY = startNode.globalY || 0;
const startWidth = startNode.width || 0;
const startHeight = startNode.height || 0;
const endX = endNode.globalX || 0;
const endY = endNode.globalY || 0;
const endWidth = endNode.width || 0;
const endHeight = endNode.height || 0;

const startAnchorX = startX + startWidth / 2;
const startAnchorY = endY > startY ? startY + startHeight : startY;
const endAnchorX = endX + endWidth / 2;
const endAnchorY = endY > startY ? endY : endY + endHeight;

points.push(
{
x: startAnchorX + (startNode.x || 0) - startX,
y: startAnchorY + (startNode.y || 0) - startY,
},
{
x: endAnchorX + (endNode.x || 0) - startX,
y: endAnchorY + (endNode.y || 0) - startY,
},
);

return points;
}

private getBestAnchorPointOnNode(
startX: number,
startY: number,
Expand Down

0 comments on commit d164412

Please sign in to comment.