11#!/usr/bin/env python
22# -*- coding: utf-8; -*-
33
4- # Copyright (c) 2023 Oracle and/or its affiliates.
4+ # Copyright (c) 2023, 2024 Oracle and/or its affiliates.
55# Licensed under the Universal Permissive License v 1.0 as shown at https://oss.oracle.com/licenses/upl/
66"""This module requires oracle-ads>=2.6.8
77"""
@@ -341,13 +341,15 @@ def prepare_cmd(self, launch_args: list = None, prefix=""):
341341 launch_args = []
342342 # Append launch cmd args specified by the user.
343343 if self .launch_cmd :
344- if not self .launch_cmd . startswith ( self . LAUNCHER ) :
345- raise ValueError (
346- f"Command not supported: ' { self . launch_cmd } '. "
347- f"The command should start with '{ self .LAUNCHER } '."
348- )
344+ if self .LAUNCHER :
345+ if not self . launch_cmd . startswith ( self . LAUNCHER ):
346+ raise ValueError (
347+ f"Command not supported: '{ self .launch_cmd } '. "
348+ )
349349
350- launch_args .append (self .launch_cmd [len (self .LAUNCHER ) + 1 :])
350+ launch_args .append (self .launch_cmd [len (self .LAUNCHER ) + 1 :])
351+ else :
352+ launch_args .append (self .launch_cmd )
351353 else :
352354 launch_args .append (self .get_cmd_with_entrypoint_and_args ())
353355
@@ -673,7 +675,51 @@ def run(self):
673675 self .run_deepspeed_worker ()
674676
675677
678+ class GenericRunner (TorchRunner , DeepSpeedRunner ):
679+ """Runner for running command other than ``torchrun``, ``deepspeed`` or ``accelerate``."""
680+
681+ LAUNCHER = ""
682+
683+ def use_deepspeed (self ) -> bool :
684+ """Indicate if DeepSpeed is used."""
685+ if os .environ .get (CONST_ENV_DEEPSPEED ):
686+ return True
687+ return False
688+
689+ def set_env_var (self ):
690+ """Set default environment variables."""
691+ defaults = {
692+ "WORLD_SIZE" : self .node_count ,
693+ "MASTER_ADDR" : self .host_ip ,
694+ "MASTER_PORT" : self .RDZV_PORT ,
695+ }
696+ for k , v in defaults .items ():
697+ if k not in os .environ :
698+ os .environ [k ] = str (v )
699+
700+ def run (self ):
701+ """Runs the user's command.
702+ Note that for TorchRunner or DeepSpeedRunner,
703+ we automatically add arguments for some settings,
704+ like the number of nodes and the host node address.
705+
706+ This generic runner does not modify the command specified by the user.
707+ User needs to make sure the command can work on all nodes.
708+ User may use the environment variables in the command.
709+ """
710+ self .set_env_var ()
711+ if self .use_deepspeed ():
712+ if self .is_host :
713+ self .run_deepspeed_host ()
714+ else :
715+ self .run_deepspeed_worker ()
716+ else :
717+ self .time_cmd (cmd = self .prepare_cmd (prefix = self .env_ld_preload ()))
718+
719+
676720class AccelerateRunner (TorchRunner , DeepSpeedRunner ):
721+ """Runner for HuggingFace Accelerate."""
722+
677723 # accelerate launch will add main_process_port for deepspeed cmd even if it is not needed.
678724 # https://github.com/huggingface/accelerate/blob/70920895e80f78d96d8f91e0beeb3ebdb8e5e5d6/src/accelerate/utils/launch.py#L233
679725 DEFAULT_ARGS = [
@@ -704,11 +750,18 @@ def __init__(self, code_dir: str = driver_utils.DEFAULT_CODE_DIR) -> None:
704750 self .main_process_ip = None
705751
706752 def use_deepspeed (self ):
707- return os .environ .get (CONST_ENV_DEEPSPEED ) or self .launch_cmd_contains (
753+ """Indicate if DeepSpeed is used."""
754+ # Accelerate support using DeepSpeed by adding the "--use_deepspeed" argument.
755+ if os .environ .get (CONST_ENV_DEEPSPEED ) or self .launch_cmd_contains (
708756 "use_deepspeed"
709- )
757+ ):
758+ return True
759+ return False
710760
711761 def accelerate_args (self ):
762+ """Gets the default arguments for the accelerate command.
763+ The value of the default arguments are assigned in ``__init__()``.
764+ """
712765 args = []
713766 for arg in self .DEFAULT_ARGS :
714767 arg_val = getattr (self , arg , None )
@@ -720,6 +773,7 @@ def accelerate_args(self):
720773 return args
721774
722775 def run_with_torchrun (self ):
776+ """Runs the job with torchrun."""
723777 launch_args = self .accelerate_args ()
724778 for arg in self .TORCHRUN_ARGS :
725779 if not self .launch_cmd_contains (arg ):
@@ -728,6 +782,7 @@ def run_with_torchrun(self):
728782 self .time_cmd (cmd = cmd )
729783
730784 def run_with_deepspeed (self ):
785+ """Runs the job with DeepSpeed."""
731786 if self .is_host :
732787 launch_args = self .accelerate_args ()
733788 if self .num_machines > 1 :
@@ -758,6 +813,8 @@ def main():
758813 runner_class = DeepSpeedRunner
759814 elif launch_cmd .startswith ("accelerate " ):
760815 runner_class = AccelerateRunner
816+ else :
817+ runner_class = GenericRunner
761818
762819 runner = runner_class ()
763820 runner : Runner
0 commit comments