-
Notifications
You must be signed in to change notification settings - Fork 610
[Torch] Fold aten rounding ops on splat constants. #4359
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
Conversation
This commit teaches the folding methods of `AtenFloor`, `AtenCeil`, `AtenRound`, and `AtenTruc` to constant-fold roundings when the operand is a splat `DenseElementsAttr`.
58550c3 to
175d2a2
Compare
|
@sahas3 @zjgarvey @vivekkhandelwal1 can you please take a look at this? |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Nice change, looks good to me for most part.
| } | ||
|
|
||
| // Common helper for splat-only rounding-based folders. | ||
| static OpFoldResult foldSplatRounding(ValueTensorType resultType, |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I think naming it foldFloatSplatWithRounding will be more appropriate since it only handles float data.
| return {}; | ||
|
|
||
| auto outShaped = resultType.toBuiltinTensor(); | ||
| if (!outShaped.hasStaticShape()) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
why is static shape a requirement?
| @export | ||
| @annotate_args([None]) | ||
| def forward(self): | ||
| return torch.ops.aten.round(self.const) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I am guessing the other ops are already covered in e2e tests?
| // CHECK: %[[C:.*]] = torch.vtensor.literal(dense<-1.000000e+00> : tensor<2x2xf32>) | ||
| // CHECK: return %[[C]] | ||
| func.func @torch.aten.ceil$fold() -> !torch.vtensor<[2,2],f32> { | ||
| %cst = torch.vtensor.literal(dense<-1.100000e+00> : tensor<2x2xf32>) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Can you add a +ve value too for completeness?
| // CHECK: %[[C:.*]] = torch.vtensor.literal(dense<1.000000e+00> : tensor<3x4xf32>) | ||
| // CHECK: return %[[C]] | ||
| func.func @torch.aten.floor$fold() -> !torch.vtensor<[3,4],f32> { | ||
| %cst = torch.vtensor.literal(dense<1.900000e+00> : tensor<3x4xf32>) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Similarly adding a negative value will give full coverage
| // NaNs and infs are dealt with consistently with torch, so side-effects | ||
| // can be discarded. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Clarification: Can you elaborate on what you meant by "NaNs and infs are dealt with consistently with torch" ?
This commit teaches the folding methods of
AtenFloor,AtenCeil,AtenRound, andAtenTructo constant-fold roundings when the operand is a splatDenseElementsAttr.