Common Compute APIs
round_to
Description
Rounds data towards the range of [min_value, max_value] and compares data with min_value and max_value element-wise. If the element value is between min_value and max_value, use the value of the data element. If the element value is less than min_value or greater than max_value, these min_value and max_value will be preferred.
The API is defined in python/site-packages/te/lang/cce/te_compute/common.py in the ATC installation path.
Restrictions
In case of data type inconsistency, max_value and min_value will be converted into the same data type as data during computation.
The supported data types are float16, float32, int8, uint8, and int32. However, int8, uint8, and int32 will be converted to float16.
Prototype
te.lang.cce.round_to(data, max_value, min_value)
Parameters
- data: a tvm.tensor for the input
- max_value: a scalar for the maximum value of the target range
- min_value: a scalar for the minimum value of the target range
Returns
res_tensor: a tvm.tensor for the result
Example
shape = (1024,1024) input_dtype = "float16" data = tvm.placeholder(shape, name="data", dtype=input_dtype) max_value = tvm.const(2, dtype =input_dtype) min_value = tvm.const(3, dtype =input_dtype) res = te.lang.cce.round_to(data, max_value, min_value)
cast_to
Description
Converts the data type of a tensor, specifically, from data to dtype.
The API is defined in python/site-packages/te/lang/cce/te_compute/common.py in the ATC installation path.
Restrictions
Source Data Type |
Destination Data Type |
---|---|
float32 |
float16 |
float32 |
int8 |
float32 |
uint8 |
float16 |
float32 |
float16 |
int8 |
float16 |
uint8 |
float16 |
int32 |
int8 |
float16 |
int8 |
uint8 |
int32 |
float16 |
int32 |
int8 |
int32 |
uint8 |
Prototype
te.lang.cce.cast_to(data, dtype, f1628IntegerFlag=True)
Parameters
- data: a tvm.tensor for the input tensor
- dtype: destination data type, string type
- f1628IntegerFlag: Defaults to True. If the decimal part of the data before conversion is 0, set f1628IntegerFlag to True. If the decimal part of the data before conversion is not 0, set f1628IntegerFlag to False.
Returns
res_tensor: a tvm.tensor for the result tensor
Example
shape = (1024,1024) input_dtype = "float16" data = tvm.placeholder(shape, name="data", dtype=input_dtype) res = te.lang.cce.cast_to(data,"float32")