You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
gentoo-overlay/sci-libs/miopen/files/miopen-5.7.1-fix-miopendriv...

75 lines
2.8 KiB

Fix uninitialized variable in MIOpenDriver gemm and restore gemmfp16 for testing
Upstream bug: https://github.com/ROCmSoftwarePlatform/MIOpen/issues/2505
--- a/driver/driver.hpp
+++ b/driver/driver.hpp
@@ -141,7 +141,7 @@ inline void PadBufferSize(size_t& sz, int datatype_sz)
printf("Usage: ./driver *base_arg* *other_args*\n");
printf("Supported Base Arguments: conv[fp16|int8|bfp16], CBAInfer[fp16], "
"pool[fp16], lrn[fp16], "
- "activ[fp16], softmax[fp16], bnorm[fp16], rnn[fp16], gemm, ctc, dropout[fp16], "
+ "activ[fp16], softmax[fp16], bnorm[fp16], rnn[fp16], gemm[fp16], ctc, dropout[fp16], "
"tensorop[fp16], reduce[fp16,fp64]\n");
exit(0); // NOLINT (concurrency-mt-unsafe)
}
@@ -160,7 +160,7 @@ inline std::string ParseBaseArg(int argc, char* argv[])
arg != "CBAInfer" && arg != "CBAInferfp16" && arg != "pool" && arg != "poolfp16" &&
arg != "lrn" && arg != "lrnfp16" && arg != "activ" && arg != "activfp16" &&
arg != "softmax" && arg != "softmaxfp16" && arg != "bnorm" && arg != "bnormfp16" &&
- arg != "rnn" && arg != "rnnfp16" && arg != "gemm" /*&& arg != "gemmfp16"*/ && arg != "ctc" &&
+ arg != "rnn" && arg != "rnnfp16" && arg != "gemm" && arg != "gemmfp16" && arg != "ctc" &&
arg != "dropout" && arg != "dropoutfp16" && arg != "tensorop" && arg != "tensoropfp16" &&
arg != "reduce" && arg != "reducefp16" && arg != "reducefp64" && arg != "--version")
{
--- a/driver/gemm_driver.hpp
+++ b/driver/gemm_driver.hpp
@@ -207,6 +207,19 @@ int GemmDriver<T>::GetandSetData()
gemm_desc.strideB = gemm_desc.k * gemm_desc.n;
gemm_desc.strideC = gemm_desc.m * gemm_desc.n;
+ if constexpr (std::is_same_v<T, float>)
+ {
+ gemm_desc.dataType = miopenFloat;
+ }
+ else if constexpr (std::is_same_v<T, float16>)
+ {
+ gemm_desc.dataType = miopenHalf;
+ }
+ else
+ {
+ static_assert(!"unsupported type");
+ }
+
return (0);
}
@@ -230,9 +243,9 @@ int GemmDriver<T>::AllocateBuffersAndCopy()
a = std::vector<T>(a_sz);
b = std::vector<T>(b_sz);
#if GEMM_DRIVER_DEBUG
- c = std::vector<T>(c_sz, 1.);
+ c = std::vector<T>(c_sz, static_cast<T>(1.));
#else
- c = std::vector<T>(c_sz, 0.);
+ c = std::vector<T>(c_sz, static_cast<T>(0.));
#endif
chost = c;
--- a/driver/main.cpp
+++ b/driver/main.cpp
@@ -125,11 +125,10 @@ int main(int argc, char* argv[])
{
drv = new GemmDriver<float>();
}
-// TODO half is not supported in gemm
-// else if(base_arg == "gemmfp16")
-// {
-// drv = new GemmDriver<float16>();
-// }
+ else if(base_arg == "gemmfp16")
+ {
+ drv = new GemmDriver<float16>();
+ }
#endif
else if(base_arg == "bnorm")
{