Add the new operators in cmake flags files. (#524)
* add the new operators in cmake flags files. * remove the extra change
This commit is contained in:
Родитель
4842e9d6ae
Коммит
247d34e30b
|
@ -16,11 +16,11 @@ try:
|
||||||
except ImportError:
|
except ImportError:
|
||||||
pass
|
pass
|
||||||
|
|
||||||
_is_librosa_avaliable = False
|
_is_librosa_available = False
|
||||||
try:
|
try:
|
||||||
import librosa
|
import librosa
|
||||||
|
|
||||||
_is_librosa_avaliable = True
|
_is_librosa_available = True
|
||||||
except ImportError:
|
except ImportError:
|
||||||
pass
|
pass
|
||||||
|
|
||||||
|
@ -119,7 +119,7 @@ class TestAudio(unittest.TestCase):
|
||||||
def test_stft_norm_torch(self):
|
def test_stft_norm_torch(self):
|
||||||
audio_pcm = self.test_pcm
|
audio_pcm = self.test_pcm
|
||||||
wlen = 400
|
wlen = 400
|
||||||
# intesting bug in torch.stft, if there is 2-D input with batch size 1, it will generate a different
|
# interesting bug in torch.stft, if there is 2-D input with batch size 1, it will generate a different
|
||||||
# result with some spark points in the spectrogram.
|
# result with some spark points in the spectrogram.
|
||||||
expected = torch.stft(torch.from_numpy(audio_pcm),
|
expected = torch.stft(torch.from_numpy(audio_pcm),
|
||||||
400, 160, wlen, torch.from_numpy(np.hanning(wlen).astype(np.float32)),
|
400, 160, wlen, torch.from_numpy(np.hanning(wlen).astype(np.float32)),
|
||||||
|
@ -131,7 +131,7 @@ class TestAudio(unittest.TestCase):
|
||||||
actual = actual[0]
|
actual = actual[0]
|
||||||
np.testing.assert_allclose(expected[:, 1:], actual[:, 1:], rtol=1e-3, atol=1e-3)
|
np.testing.assert_allclose(expected[:, 1:], actual[:, 1:], rtol=1e-3, atol=1e-3)
|
||||||
|
|
||||||
@unittest.skipIf(not _is_librosa_avaliable, "librosa is not available")
|
@unittest.skipIf(not _is_librosa_available, "librosa is not available")
|
||||||
def test_mel_filter_bank(self):
|
def test_mel_filter_bank(self):
|
||||||
expected = librosa.filters.mel(n_fft=400, n_mels=80, sr=16000)
|
expected = librosa.filters.mel(n_fft=400, n_mels=80, sr=16000)
|
||||||
actual = util.mel_filterbank(400, 80, 16000)
|
actual = util.mel_filterbank(400, 80, 16000)
|
||||||
|
|
|
@ -24,8 +24,9 @@ CMAKE_FLAG_TO_OPS = {
|
||||||
],
|
],
|
||||||
"OCOS_ENABLE_GPT2_TOKENIZER": [
|
"OCOS_ENABLE_GPT2_TOKENIZER": [
|
||||||
"BpeDecoder",
|
"BpeDecoder",
|
||||||
"ClipTokenizer",
|
"CLIPTokenizer",
|
||||||
"GPT2Tokenizer",
|
"GPT2Tokenizer",
|
||||||
|
"RobertaTokenizer"
|
||||||
],
|
],
|
||||||
"OCOS_ENABLE_MATH": [
|
"OCOS_ENABLE_MATH": [
|
||||||
"SegmentExtraction",
|
"SegmentExtraction",
|
||||||
|
@ -41,6 +42,7 @@ CMAKE_FLAG_TO_OPS = {
|
||||||
],
|
],
|
||||||
"OCOS_ENABLE_SPM_TOKENIZER": [
|
"OCOS_ENABLE_SPM_TOKENIZER": [
|
||||||
"SentencepieceTokenizer",
|
"SentencepieceTokenizer",
|
||||||
|
"SentencepieceDecoder"
|
||||||
],
|
],
|
||||||
"OCOS_ENABLE_TF_STRING": [
|
"OCOS_ENABLE_TF_STRING": [
|
||||||
"MaskedFill",
|
"MaskedFill",
|
||||||
|
@ -67,6 +69,17 @@ CMAKE_FLAG_TO_OPS = {
|
||||||
"OCOS_ENABLE_WORDPIECE_TOKENIZER": [
|
"OCOS_ENABLE_WORDPIECE_TOKENIZER": [
|
||||||
"WordpieceTokenizer",
|
"WordpieceTokenizer",
|
||||||
],
|
],
|
||||||
|
"OCOS_ENABLE_AUDIO": [
|
||||||
|
"AudioDecoder"
|
||||||
|
],
|
||||||
|
"OCOS_ENABLE_DLIB": [
|
||||||
|
"Inverse",
|
||||||
|
"StftNorm"
|
||||||
|
],
|
||||||
|
"OCOS_ENABLE_TRIE_TOKENIZER": [
|
||||||
|
"TrieTokenizer",
|
||||||
|
"TrieDetokenizer"
|
||||||
|
],
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
|
|
Загрузка…
Ссылка в новой задаче