@@ -138,8 +138,9 @@ def get_running_torch_version():
138138class AssistantCLI :
139139 """Collection of handy CLI commands."""
140140
141- DEVICE_ACCELERATOR = os .environ .get ("ACCELERATOR" , "cpu" ).lower ()
142- DATASET_FOLDER = os .environ .get ("PATH_DATASETS" , "_datasets" ).lower ()
141+ _LOCAL_ACCELERATOR = "cpu,gpu" if get_running_cuda_version () else "cpu"
142+ DEVICE_ACCELERATOR = os .environ .get ("ACCELERATOR" , _LOCAL_ACCELERATOR ).lower ()
143+ DATASETS_FOLDER = os .environ .get ("PATH_DATASETS" , "_datasets" )
143144 DRY_RUN = bool (int (os .environ .get ("DRY_RUN" , 0 )))
144145 _META_REQUIRED_FIELDS = ("title" , "author" , "license" , "description" )
145146 _SKIP_DIRS = (
@@ -154,7 +155,7 @@ class AssistantCLI:
154155 )
155156 _META_FILE_REGEX = ".meta.{yaml,yml}"
156157 _META_PIP_KEY = "pip__"
157- _META_ACCEL_DEFAULT = ( "CPU" , )
158+ _META_ACCEL_DEFAULT = _LOCAL_ACCELERATOR . split ( "," )
158159
159160 # Map directory names to tag names. Note that dashes will be replaced with spaces in rendered tags in the docs.
160161 _DIR_TO_TAG = {
@@ -270,16 +271,15 @@ def _parse_requirements(folder: str) -> Tuple[str, str]:
270271
271272 @staticmethod
272273 def _bash_download_data (folder : str ) -> List [str ]:
273- """Generate sequence of commands fro optional downloading dataset specified in the meta file.
274+ """Generate sequence of commands for optional downloading dataset specified in the meta file.
274275
275276 Args:
276277 folder: path to the folder with python script, meta and artefacts
277278 """
278- cmd = ["HERE=$PWD" , f"cd { AssistantCLI .DATASET_FOLDER } " ]
279279 meta = AssistantCLI ._load_meta (folder )
280280 datasets = meta .get ("datasets" , {})
281281 data_kaggle = datasets .get ("kaggle" , [])
282- cmd + = [f"python -m kaggle competitions download -c { name } " for name in data_kaggle ]
282+ cmd = [f"python -m kaggle competitions download -c { name } " for name in data_kaggle ]
283283 files = [f"{ name } .zip" for name in data_kaggle ]
284284 data_web = datasets .get ("web" , [])
285285 cmd += [f"wget { web } --progress=bar:force:noscroll --tries=3" for web in data_web ]
@@ -289,11 +289,11 @@ def _bash_download_data(folder: str) -> List[str]:
289289 if ext not in AssistantCLI ._EXT_ARCHIVE :
290290 continue
291291 if ext in AssistantCLI ._EXT_ARCHIVE_ZIP :
292- cmd += [f"unzip -o { fn } -d { name } { UNZIP_PROGRESS_BAR } " ]
292+ cmd += [f"unzip -o { fn } -d { AssistantCLI . DATASETS_FOLDER } / { name } / { UNZIP_PROGRESS_BAR } " ]
293293 else :
294294 cmd += [f"tar -zxvf { fn } --overwrite" ]
295295 cmd += [f"rm { fn } " ]
296- cmd += ["ls -l" , "cd $HERE " ]
296+ cmd += [f"tree -L 2 { AssistantCLI . DATASETS_FOLDER } " ]
297297 return cmd
298298
299299 @staticmethod
0 commit comments