Skip to content

Commit

Permalink
Tune API
Browse files Browse the repository at this point in the history
  • Loading branch information
yury committed Jan 3, 2024
1 parent a9e91b9 commit f259c22
Show file tree
Hide file tree
Showing 4 changed files with 10 additions and 7 deletions.
2 changes: 1 addition & 1 deletion cidre/src/mps/graph.rs
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,7 @@ mod tensor;
pub use tensor::Tensor;

mod operation;
pub use operation::Operation;
pub use operation::Op;

mod memory_ops;
pub use memory_ops::VariableOp;
Expand Down
6 changes: 3 additions & 3 deletions cidre/src/mps/graph/memory_ops.rs
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
use crate::{arc, cf, define_obj_type, mps, mps::graph, ns, objc};

define_obj_type!(pub VariableOp(graph::Operation));
define_obj_type!(pub VariableOp(graph::Op));

impl VariableOp {
#[objc::msg_send(shape)]
Expand Down Expand Up @@ -80,13 +80,13 @@ mod tests {
assert_eq!(1, tensor.as_type_ref().retain_count());
// this will crash, since we released graph. Same crash will be in Swift too.
// We may add lifetime to tensor
// assert_eq!("mps_placeholder", tensor.operation().name().to_string());
// assert_eq!("mps_placeholder", tensor.op().name().to_string());
}
#[test]
pub fn basics() {
let gr = graph::Graph::new();
let tensor = gr.placeholder_with_shape(None, mps::DataType::F32, None);
assert_eq!("mps_placeholder", tensor.operation().name().to_string());
assert_eq!("mps_placeholder", tensor.op().name().to_string());
assert!(tensor.shape().is_none());
}
}
7 changes: 5 additions & 2 deletions cidre/src/mps/graph/operation.rs
Original file line number Diff line number Diff line change
@@ -1,8 +1,11 @@
use crate::{define_obj_type, mps::graph, ns, objc};

define_obj_type!(pub Operation(ns::Id));
define_obj_type!(
#[doc(alias = "MPSGraphOperation")]
pub Op(ns::Id)
);

impl Operation {
impl Op {
#[objc::msg_send(inputTensors)]
pub fn input_tensors(&self) -> &ns::Array<graph::Tensor>;

Expand Down
2 changes: 1 addition & 1 deletion cidre/src/mps/graph/tensor.rs
Original file line number Diff line number Diff line change
Expand Up @@ -9,5 +9,5 @@ impl Tensor {
pub fn data_type(&self) -> mps::DataType;

#[objc::msg_send(operation)]
pub fn operation(&self) -> &graph::Operation;
pub fn op(&self) -> &graph::Op;
}

0 comments on commit f259c22

Please sign in to comment.