@@ -60,6 +60,16 @@ void resize_view_node(
6060 }
6161}
6262
63+ void resize_to_dim_order_copy_node (
64+ ComputeGraph* graph,
65+ const std::vector<ArgGroup>& args,
66+ const std::vector<ValueRef>& extra_args) {
67+ const ValueRef out = args.at (0 ).refs .at (0 );
68+ const ValueRef in = args.at (1 ).refs .at (0 );
69+ const std::vector<int64_t > in_sizes = graph->sizes_of (in);
70+ graph->virtual_resize (out, in_sizes);
71+ }
72+
6373void add_view_node (
6474 ComputeGraph& graph,
6575 ValueRef in,
@@ -98,6 +108,11 @@ void add_view_copy_buffer_node(
98108 std::string kernel_name = " view_buffer" ;
99109 add_dtype_suffix (kernel_name, graph.dtype_of (out));
100110
111+ bool all_contiguous = graph.is_contiguous_buffer_tensor (in) &&
112+ graph.is_contiguous_buffer_tensor (out);
113+
114+ int32_t all_contiguous_int = all_contiguous ? 1 : 0 ;
115+
101116 graph.execute_nodes ().emplace_back (new DynamicDispatchNode (
102117 graph,
103118 VK_KERNEL_FROM_STR (kernel_name),
@@ -110,7 +125,41 @@ void add_view_copy_buffer_node(
110125 // Push Constants
111126 {},
112127 // Specialization Constants
128+ {all_contiguous_int},
129+ // Resize Args
130+ resize_args,
131+ // Resizing Logic
132+ resize_fn));
133+ }
134+
135+ void add_view_copy_convert_buffer_node (
136+ ComputeGraph& graph,
137+ ValueRef in,
138+ ValueRef out,
139+ const std::vector<ValueRef>& resize_args,
140+ const ExecuteNode::ResizeFunction& resize_fn) {
141+ std::string kernel_name = " view_convert_buffer" ;
142+ add_dtype_suffix (kernel_name, graph.dtype_of (in));
143+ add_dtype_suffix (kernel_name, graph.dtype_of (out));
144+
145+ bool all_contiguous = graph.is_contiguous_buffer_tensor (in) &&
146+ graph.is_contiguous_buffer_tensor (out);
147+
148+ int32_t all_contiguous_int = all_contiguous ? 1 : 0 ;
149+
150+ graph.execute_nodes ().emplace_back (new DynamicDispatchNode (
151+ graph,
152+ VK_KERNEL_FROM_STR (kernel_name),
153+ default_pick_global_wg_size,
154+ default_pick_local_wg_size,
155+ // Inputs and Outputs
156+ {{out, vkapi::kWrite }, {in, vkapi::kRead }},
157+ // Parameter Buffers
158+ {graph.buffer_meta_ubo (out), graph.buffer_meta_ubo (in)},
159+ // Push Constants
113160 {},
161+ // Specialization Constants
162+ {all_contiguous_int},
114163 // Resize Args
115164 resize_args,
116165 // Resizing Logic
@@ -132,8 +181,38 @@ void view(ComputeGraph& graph, const std::vector<ValueRef>& args) {
132181 return add_view_node (graph, in, sizes, out);
133182}
134183
184+ void to_dim_order_copy (ComputeGraph& graph, const std::vector<ValueRef>& args) {
185+ int args_idx = 0 ;
186+ const ValueRef in = args.at (args_idx++);
187+ const ValueRef dtype = args.at (args_idx++);
188+ (void )dtype;
189+ const ValueRef layout = args.at (args_idx++);
190+ (void )layout;
191+ const ValueRef device = args.at (args_idx++);
192+ (void )device;
193+ const ValueRef pin_memory = args.at (args_idx++);
194+ (void )pin_memory;
195+ const ValueRef non_blocking = args.at (args_idx++);
196+ (void )non_blocking;
197+ const ValueRef dim_order = args.at (args_idx++);
198+ (void )dim_order;
199+
200+ const ValueRef out = args.at (args_idx++);
201+
202+ VK_CHECK_COND (graph.is_buffer_storage (in) && graph.is_buffer_storage (out));
203+
204+ if (graph.dtype_of (in) == graph.dtype_of (out)) {
205+ return add_view_copy_buffer_node (
206+ graph, in, out, {}, resize_to_dim_order_copy_node);
207+ }
208+
209+ return add_view_copy_convert_buffer_node (
210+ graph, in, out, {}, resize_to_dim_order_copy_node);
211+ }
212+
135213REGISTER_OPERATORS {
136214 VK_REGISTER_OP (aten.view_copy .default , view);
215+ VK_REGISTER_OP (dim_order_ops._to_dim_order_copy .default , to_dim_order_copy);
137216}
138217
139218} // namespace vkcompute
0 commit comments