diff --git a/archai/discrete_search/search_spaces/nlp/tfpp/ops/fftconv_.py b/archai/discrete_search/search_spaces/nlp/tfpp/ops/fftconv_.py index 1faaf077..85bbeceb 100644 --- a/archai/discrete_search/search_spaces/nlp/tfpp/ops/fftconv_.py +++ b/archai/discrete_search/search_spaces/nlp/tfpp/ops/fftconv_.py @@ -7,7 +7,10 @@ import torch.nn.functional as F from einops import rearrange -from fftconv import fftconv_fwd, fftconv_bwd +try: + from fftconv import fftconv_fwd, fftconv_bwd +except ImportError: + raise ImportError("`fftconv` module not found. Please run `pip install git+https://github.com/HazyResearch/H3.git#egg=fftconv&subdirectory=csrc/fftconv`.") @torch.jit.script def _mul_sum(y, q): diff --git a/setup.py b/setup.py index 99b9923f..37edfc5c 100644 --- a/setup.py +++ b/setup.py @@ -17,7 +17,6 @@ dependencies = [ "einops", "flake8>=5.0.4", "flash-attn", - "fftconv @ git+https://github.com/HazyResearch/H3.git#egg=fftconv&subdirectory=csrc/fftconv", "gorilla>=0.4.0", "h5py", "hyperopt", @@ -86,7 +85,7 @@ extras_require["nlp"] = filter_dependencies( ) extras_require["deepspeed"] = filter_dependencies("deepspeed", "mlflow") -extras_require["flash-attn"] = filter_dependencies("flash-attn", "fftconv") +extras_require["flash-attn"] = filter_dependencies("flash-attn") extras_require["xformers"] = filter_dependencies("xformers") extras_require["docs"] = filter_dependencies(