Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

怎么指定单GPU #20

Open
MonkeyTB opened this issue Apr 17, 2023 · 11 comments
Open

怎么指定单GPU #20

MonkeyTB opened this issue Apr 17, 2023 · 11 comments
Labels
wontfix This will not be worked on

Comments

@MonkeyTB
Copy link


当有多块显卡的时候,默认全部加载,试过了
os.environ["CUDA_VISIBLE_DEVICES"] = '0'

args={'use_lora': True, 'output_dir': args.output_dir, "max_length": args.max_length, 'n_gpu':0},
以及
cuda_device = 0
都没法做到,请教一下,如何设置用单GPU训练或者预测

if use_cuda:
            if torch.cuda.is_available():
                if cuda_device == -1:
                    self.device = torch.device("cuda")
                else:
                    self.device = torch.device(f"cuda:{cuda_device}")
            else:
                raise ValueError(
                    "'use_cuda' set to True when cuda is unavailable."
                    "Make sure CUDA is available or set `use_cuda=False`."
                )
        else:
            self.device = "cpu"
 logger.debug(f"Device: {self.device}")

2023-04-17 02:18:32.967 | DEBUG | chatglm.chatglm_model:init:92 - Device: cuda:0

这里打印结果如上,但是还是会加载多块GPU

@shibing624
Copy link
Owner

我fix下,用了device=auto的原因。

@shibing624
Copy link
Owner

8c82cda done

@MonkeyTB
Copy link
Author

8c82cda done

收到,感谢

@MonkeyTB
Copy link
Author

8c82cda done

这里我测试了一下,发现还是不行

+-----------------------------------------------------------------------------+
| NVIDIA-SMI 470.82.01    Driver Version: 470.82.01    CUDA Version: 11.4     |
|-------------------------------+----------------------+----------------------+
| GPU  Name        Persistence-M| Bus-Id        Disp.A | Volatile Uncorr. ECC |
| Fan  Temp  Perf  Pwr:Usage/Cap|         Memory-Usage | GPU-Util  Compute M. |
|                               |                      |               MIG M. |
|===============================+======================+======================|
|   0  NVIDIA TITAN RTX    On   | 00000000:3B:00.0 Off |                  N/A |
| 41%   52C    P2    63W / 280W |   5764MiB / 24220MiB |     15%      Default |
|                               |                      |                  N/A |
+-------------------------------+----------------------+----------------------+
|   1  NVIDIA TITAN RTX    On   | 00000000:AF:00.0 Off |                  N/A |
| 41%   45C    P2    60W / 280W |   5370MiB / 24220MiB |     21%      Default |
|                               |                      |                  N/A |
+-------------------------------+----------------------+----------------------+
|   2  NVIDIA TITAN RTX    On   | 00000000:D8:00.0 Off |                  N/A |
| 41%   54C    P2    72W / 280W |  13930MiB / 24220MiB |     28%      Default |
|                               |                      |                  N/A |
+-------------------------------+----------------------+----------------------+

这里现象是加载的时候确实会按指定加载,但是预测的时候会都参与计算
我猜测是因为数据做了分发,但我在chatglm_utils.py里面去打印日志,并没有打出来,就好像没进去一样😓,想问一下,这里应该怎么排查这个问题

@MonkeyTB
Copy link
Author

8c82cda done

这里我测试了一下,发现还是不行

+-----------------------------------------------------------------------------+
| NVIDIA-SMI 470.82.01    Driver Version: 470.82.01    CUDA Version: 11.4     |
|-------------------------------+----------------------+----------------------+
| GPU  Name        Persistence-M| Bus-Id        Disp.A | Volatile Uncorr. ECC |
| Fan  Temp  Perf  Pwr:Usage/Cap|         Memory-Usage | GPU-Util  Compute M. |
|                               |                      |               MIG M. |
|===============================+======================+======================|
|   0  NVIDIA TITAN RTX    On   | 00000000:3B:00.0 Off |                  N/A |
| 41%   52C    P2    63W / 280W |   5764MiB / 24220MiB |     15%      Default |
|                               |                      |                  N/A |
+-------------------------------+----------------------+----------------------+
|   1  NVIDIA TITAN RTX    On   | 00000000:AF:00.0 Off |                  N/A |
| 41%   45C    P2    60W / 280W |   5370MiB / 24220MiB |     21%      Default |
|                               |                      |                  N/A |
+-------------------------------+----------------------+----------------------+
|   2  NVIDIA TITAN RTX    On   | 00000000:D8:00.0 Off |                  N/A |
| 41%   54C    P2    72W / 280W |  13930MiB / 24220MiB |     28%      Default |
|                               |                      |                  N/A |
+-------------------------------+----------------------+----------------------+

这里现象是加载的时候确实会按指定加载,但是预测的时候会都参与计算 我猜测是因为数据做了分发,但我在chatglm_utils.py里面去打印日志,并没有打出来,就好像没进去一样😓,想问一下,这里应该怎么排查这个问题

def predict(self, sentences, keep_prompt=False, max_length=None, **kwargs):
        """
        Performs predictions on a list of text.
        Args:
            sentences: A python list of text (str) to be sent to the model for prediction. 
            keep_prompt: Whether to keep the prompt in the generated text.
            max_length: The maximum length of the generated text.
        Returns:
            preds: A python list of the generated sequences.
        """  # noqa: ignore flake8"

        if not self.lora_loaded:
            self.load_lora()
        if self.args.fp16:
            self.model.half()
        self.model.eval()

        all_outputs = []
        # Batching
        for batch in tqdm(
                [
                    sentences[i: i + self.args.eval_batch_size]
                    for i in range(0, len(sentences), self.args.eval_batch_size)
                ],
                desc="Generating outputs",
                disable=self.args.silent,
        ):
            inputs = self.tokenizer(batch, padding=True, return_tensors='pt').to(self.device)
            gen_kwargs = {
                "max_new_tokens": max_length if max_length else self.args.max_length,
                "num_beams": self.args.num_beams,
                "do_sample": self.args.do_sample,
                "top_p": self.args.top_p,
                "temperature": self.args.temperature,
                "eos_token_id": self.tokenizer.eos_token_id,
                **kwargs
            }
            outputs = self.model.generate(**inputs, **gen_kwargs)
            for idx, (prompt_text, generated_sequence) in enumerate(zip(batch, outputs)):
                # Decode text
                text = self.tokenizer.decode(generated_sequence)
                prompt_len = len(prompt_text)
                gen_text = text[prompt_len:]
                gen_text = self.process_response(gen_text)
                if keep_prompt:
                    total_sequence = prompt_text + gen_text
                else:
                    total_sequence = gen_text
                all_outputs.append(total_sequence)
        return all_outputs

研究了一下这里的代码,好像不涉及上面我说的问题,不知道怎么排查了😓

@shibing624
Copy link
Owner

export CUDA_VISIBLE_DEVICES=0

这样设置只一个gpu可见,export CUDA_VISIBLE_DEVICES=0,1,2 是多个可见。 这是官方建议的处理。

@Xu-Chen
Copy link

Xu-Chen commented Apr 20, 2023

我这边想问下,T5 模型训练时,如何用多块卡

@shibing624
Copy link
Owner

T5 暂不支持多卡训练。

@MonkeyTB
Copy link
Author

export CUDA_VISIBLE_DEVICES=0

这样设置只一个gpu可见,export CUDA_VISIBLE_DEVICES=0,1,2 是多个可见。 这是官方建议的处理。

这种方式测试不好使
这几天改了下,发现在加载 lora model 的时候也需要指定具体的 device_id, 这样就可以单 GPU 了,多 GPU auto 指定就可以

@shibing624
Copy link
Owner

好的,我加下lora的device

@shibing624 shibing624 pinned this issue May 11, 2023
Copy link

stale bot commented Dec 27, 2023

This issue has been automatically marked as stale because it has not had recent activity. It will be closed if no further activity occurs. Thank you for your contributions.(由于长期不活动,机器人自动关闭此问题,如果需要欢迎提问)

@stale stale bot added the wontfix This will not be worked on label Dec 27, 2023
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
wontfix This will not be worked on
Projects
None yet
Development

No branches or pull requests

3 participants