Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Factors inference is slow (3 seconds/token) on A100 GPU #1110

Open
AmitMY opened this issue May 18, 2024 · 6 comments
Open

Factors inference is slow (3 seconds/token) on A100 GPU #1110

AmitMY opened this issue May 18, 2024 · 6 comments

Comments

@AmitMY
Copy link

AmitMY commented May 18, 2024

My use case calls for splitting my input tokens to 5, and output tokens to 8.
That means that the input has a token + 4 factors (SignWriting), and the output has a token + 7 factors (VQ model)

  • The input factors vocabulary size is small (16, 24, 504, 504)
  • The output factors vocabulary is large (1008 each)

I created factored files for an example sentence: M|c0|r0|p518|p518 S2ff|c0|r0|p482|p483
And attempt to translate, with:

  • beam-size=1
  • max-output-length=10 (even though it should be more like 100~)
python -m sockeye.translate --models "$MODEL_DIR/unconstrained/model" \
  --input "$MODEL_DIR/unconstrained/sample/source_0.txt" \
  --input-factors "$MODEL_DIR/unconstrained/sample/source_1.txt" "$MODEL_DIR/unconstrained/sample/source_2.txt" "$MODEL_DIR/unconstrained/sample/source_3.txt" "$MODEL_DIR/unconstrained/sample/source_4.txt" \
  --output-type "translation_with_factors" \
  --max-output-length 10 \
  --beam-size=1

And the output is:

[INFO:main] Processed 1 lines. Total time: 29.1466, sec/sent: 29.1466, sent/sec: 0.0343

Why would translating a single sentence, with A100 GPU, on a small model, without beam search, be this slow?
Is there a way to profile the decoding step function?


The full output is:

[INFO:sockeye.utils] Sockeye: 3.1.38, commit , path /data/amoryo/conda/envs/sockeye/lib/python3.11/site-packages/sockeye/__init__.py
[INFO:sockeye.utils] PyTorch: 1.13.1+cu117 (/data/amoryo/conda/envs/sockeye/lib/python3.11/site-packages/torch/__init__.py)
[INFO:sockeye.utils] Command: /data/amoryo/conda/envs/sockeye/lib/python3.11/site-packages/sockeye/translate.py --beam-size=1 --models /shares/volk.cl.uzh/amoryo/signwriting-animation/models/unconstrained/model --input /shares/volk.cl.uzh/amoryo/signwriting-animation/models/unconstrained/sample/source_0.txt --input-factors /shares/volk.cl.uzh/amoryo/signwriting-animation/models/unconstrained/sample/source_1.txt /shares/volk.cl.uzh/amoryo/signwriting-animation/models/unconstrained/sample/source_2.txt /shares/volk.cl.uzh/amoryo/signwriting-animation/models/unconstrained/sample/source_3.txt /shares/volk.cl.uzh/amoryo/signwriting-animation/models/unconstrained/sample/source_4.txt --output-type translation_with_factors --max-output-length 10
[INFO:sockeye.utils] Arguments: Namespace(config=None, input='/shares/volk.cl.uzh/amoryo/signwriting-animation/models/unconstrained/sample/source_0.txt', input_factors=['/shares/volk.cl.uzh/amoryo/signwriting-animation/models/unconstrained/sample/source_1.txt', '/shares/volk.cl.uzh/amoryo/signwriting-animation/models/unconstrained/sample/source_2.txt', '/shares/volk.cl.uzh/amoryo/signwriting-animation/models/unconstrained/sample/source_3.txt', '/shares/volk.cl.uzh/amoryo/signwriting-animation/models/unconstrained/sample/source_4.txt'], json_input=False, output=None, models=['/shares/volk.cl.uzh/amoryo/signwriting-animation/models/unconstrained/model'], checkpoints=None, nbest_size=1, beam_size=1, greedy=False, beam_search_stop='all', batch_size=1, chunk_size=None, sample=None, seed=None, ensemble_mode='linear', bucket_width=10, max_input_length=None, max_output_length_num_stds=2, max_output_length=10, restrict_lexicon=None, restrict_lexicon_topk=None, skip_nvs=False, nvs_thresh=0.5, strip_unknown_words=False, prevent_unk=False, output_type='translation_with_factors', length_penalty_alpha=1.0, length_penalty_beta=0.0, brevity_penalty_type='none', brevity_penalty_weight=1.0, brevity_penalty_constant_length_ratio=0.0, dtype=None, clamp_to_dtype=False, device_id=0, use_cpu=False, env=None, tf32=True, quiet=False, quiet_secondary_workers=False, no_logfile=False, loglevel='INFO', loglevel_secondary_workers='INFO', knn_index=None, knn_lambda=0.8)
[INFO:sockeye.utils] CUDA: allow tf32 (float32 but with 10 bits precision)
[INFO:__main__] Translate Device: cuda:0
[INFO:sockeye.utils] CUDA: allow tf32 (float32 but with 10 bits precision)
[INFO:sockeye.model] Loading 1 model(s) from ['/shares/volk.cl.uzh/amoryo/signwriting-animation/models/unconstrained/model'] ...
[INFO:sockeye.vocab] Vocabulary (664 words) loaded from "/shares/volk.cl.uzh/amoryo/signwriting-animation/models/unconstrained/model/vocab.src.0.json"
[INFO:sockeye.vocab] Vocabulary (16 words) loaded from "/shares/volk.cl.uzh/amoryo/signwriting-animation/models/unconstrained/model/vocab.src.1.json"
[INFO:sockeye.vocab] Vocabulary (24 words) loaded from "/shares/volk.cl.uzh/amoryo/signwriting-animation/models/unconstrained/model/vocab.src.2.json"
[INFO:sockeye.vocab] Vocabulary (504 words) loaded from "/shares/volk.cl.uzh/amoryo/signwriting-animation/models/unconstrained/model/vocab.src.3.json"
[INFO:sockeye.vocab] Vocabulary (504 words) loaded from "/shares/volk.cl.uzh/amoryo/signwriting-animation/models/unconstrained/model/vocab.src.4.json"
[INFO:sockeye.vocab] Vocabulary (1008 words) loaded from "/shares/volk.cl.uzh/amoryo/signwriting-animation/models/unconstrained/model/vocab.trg.0.json"
[INFO:sockeye.vocab] Vocabulary (1008 words) loaded from "/shares/volk.cl.uzh/amoryo/signwriting-animation/models/unconstrained/model/vocab.trg.1.json"
[INFO:sockeye.vocab] Vocabulary (1008 words) loaded from "/shares/volk.cl.uzh/amoryo/signwriting-animation/models/unconstrained/model/vocab.trg.2.json"
[INFO:sockeye.vocab] Vocabulary (1008 words) loaded from "/shares/volk.cl.uzh/amoryo/signwriting-animation/models/unconstrained/model/vocab.trg.3.json"
[INFO:sockeye.vocab] Vocabulary (1008 words) loaded from "/shares/volk.cl.uzh/amoryo/signwriting-animation/models/unconstrained/model/vocab.trg.4.json"
[INFO:sockeye.vocab] Vocabulary (1008 words) loaded from "/shares/volk.cl.uzh/amoryo/signwriting-animation/models/unconstrained/model/vocab.trg.5.json"
[INFO:sockeye.vocab] Vocabulary (1008 words) loaded from "/shares/volk.cl.uzh/amoryo/signwriting-animation/models/unconstrained/model/vocab.trg.6.json"
[INFO:sockeye.vocab] Vocabulary (1008 words) loaded from "/shares/volk.cl.uzh/amoryo/signwriting-animation/models/unconstrained/model/vocab.trg.7.json"
[INFO:sockeye.model] Model version: 3.1.38
[INFO:sockeye.model] Loaded model config from "/shares/volk.cl.uzh/amoryo/signwriting-animation/models/unconstrained/model/config"
[INFO:sockeye.model] ModelConfig(config_data=DataConfig(data_statistics=DataStatistics(num_sents=562592, num_discarded=315, num_tokens_source=1687776, num_tokens_target=57851323, num_unks_source=0, num_unks_target=0, max_observed_len_source=3, max_observed_len_target=513, size_vocab_source=664, size_vocab_target=1008, length_ratio_mean=34.276659343419496, length_ratio_std=14.879263574710128, buckets=[(8, 8), (16, 16), (24, 24), (32, 32), (40, 40), (48, 48), (56, 56), (64, 64), (72, 72), (80, 80), (88, 88), (96, 96), (104, 104), (112, 112), (120, 120), (128, 128), (136, 136), (144, 144), (152, 152), (160, 160), (168, 168), (176, 176), (184, 184), (192, 192), (200, 200), (208, 208), (216, 216), (224, 224), (232, 232), (240, 240), (248, 248), (256, 256), (264, 264), (272, 272), (280, 280), (288, 288), (296, 296), (304, 304), (312, 312), (320, 320), (328, 328), (336, 336), (344, 344), (352, 352), (360, 360), (368, 368), (376, 376), (384, 384), (392, 392), (400, 400), (408, 408), (416, 416), (424, 424), (432, 432), (440, 440), (448, 448), (456, 456), (464, 464), (472, 472), (480, 480), (488, 488), (496, 496), (504, 504), (512, 512), (513, 513)], num_sents_per_bucket=[3, 1, 19, 498, 2678, 10120, 27611, 47985, 57944, 61020, 53583, 44696, 41032, 34378, 30526, 26818, 22831, 18533, 16196, 12340, 9853, 8577, 5816, 4717, 4510, 3553, 2376, 2354, 1616, 1359, 2422, 1121, 749, 856, 604, 410, 405, 468, 282, 224, 157, 190, 162, 102, 131, 142, 74, 61, 55, 69, 57, 41, 43, 28, 33, 28, 32, 20, 16, 14, 15, 18, 9, 10, 1], average_len_target_per_bucket=[2.0, 15.0, 22.47368421052631, 29.895582329317254, 37.32449589245704, 45.26413043478288, 52.99141646445262, 60.725705949775836, 68.54556123153431, 76.37490986561775, 84.53625216953161, 92.42469124753944, 100.45917820237935, 108.42053057187677, 116.43959247854299, 124.43127750018597, 132.41382331041123, 140.36896347056646, 148.29741911583167, 156.2551863857384, 164.24865523190755, 172.1120438381717, 180.24810866574907, 188.2524909900362, 195.9474501108644, 204.0509428651839, 212.37668350168337, 219.66864910790179, 228.24566831683163, 236.7424576894776, 243.63831544178396, 252.09901873327397, 260.41655540720996, 268.0724299065416, 276.22516556291396, 284.2926829268292, 291.80493827160535, 300.95726495726507, 307.92907801418494, 315.73660714285717, 324.45859872611453, 332.38421052631577, 340.08024691358037, 348.57843137254906, 356.75572519083966, 364.0845070422536, 372.59459459459464, 380.78688524590166, 387.92727272727274, 396.3623188405797, 403.9824561403508, 412.4146341463415, 420.4418604651163, 428.28571428571433, 435.8181818181818, 443.99999999999994, 452.18750000000006, 460.04999999999995, 467.625, 476.8571428571429, 483.33333333333326, 493.2222222222223, 499.7777777777777, 508.70000000000005, 513.0], length_ratio_stats_per_bucket=[(0.6666666666666666, 0.0), (5.0, 0.0), (7.4912280701754375, 0.6611032870672552), (9.965194109772419, 0.6773854309889832), (12.441498630819025, 0.7323565209289744), (15.088043478260888, 0.7228695573061236), (17.663805488150842, 0.7477814938148047), (20.24190198325863, 0.7557906951613651), (22.84852041051123, 0.76260899086021), (25.458303288539526, 0.7429680382689798), (28.178750723177036, 0.7725985330583485), (30.808230415846378, 0.7612824824502872), (33.48639273412596, 0.7637406751055946), (36.14017685729224, 0.7633158033079295), (38.813197492847344, 0.748047045959258), (41.47709250006236, 0.7784719198398958), (44.137941103470446, 0.7532059978984165), (46.78965449018851, 0.7594518647223019), (49.43247303861048, 0.7731868633658496), (52.08506212857909, 0.7602201079711013), (54.74955174396979, 0.7721951407116094), (57.37068127939054, 0.7726670130819834), (60.08270288858333, 0.7589571502358675), (62.75083033001211, 0.7573209651045687), (65.31581670362168, 0.8010609138746856), (68.0169809550613, 0.7586793214732782), (70.79222783389453, 0.761428337943305), (73.22288303596699, 0.8127171243097352), (76.08188943894403, 0.7619360833015993), (78.91415256315932, 0.788264227982759), (81.21277181392782, 0.7263146698773092), (84.03300624442463, 0.7446248339832199), (86.80551846906988, 0.7602374250419729), (89.35747663551402, 0.8121537001280452), (92.07505518763794, 0.7250101472791833), (94.76422764227645, 0.7737001598415769), (97.26831275720153, 0.8207845494821012), (100.31908831908838, 0.7219173089327346), (102.64302600472807, 0.7770804608693789), (105.2455357142857, 0.8349566989974272), (108.15286624203819, 0.7994074644452088), (110.79473684210528, 0.7417982858094194), (113.36008230452681, 0.8037791354444079), (116.19281045751636, 0.7503612412672215), (118.91857506361325, 0.6905291735552149), (121.36150234741785, 0.7817289110491428), (124.19819819819818, 0.7472048392700509), (126.92896174863385, 0.7419630195625642), (129.30909090909094, 0.8571848343387725), (132.12077294685983, 0.7634378859958186), (134.6608187134503, 0.7050119324061157), (137.4715447154471, 0.8030281772270442), (140.14728682170545, 0.8665543325570307), (142.76190476190473, 0.8060148038005153), (145.27272727272725, 0.7761362712039772), (148.0, 0.7182430061427759), (150.72916666666663, 0.7702700644723462), (153.35, 0.8463975950396387), (155.87499999999997, 0.7252872993970579), (158.95238095238093, 0.6884205854667195), (161.11111111111106, 0.7568616162633894), (164.4074074074074, 0.733295921230492), (166.59259259259258, 0.6241592424575069), (169.56666666666666, 0.8950481054731718), (171.0, 0.0)]), max_seq_len_source=513, max_seq_len_target=513, num_source_factors=5, num_target_factors=8, eop_id=-1), vocab_source_size=664, vocab_target_size=1008, config_embed_source=EmbeddingConfig(vocab_size=664, num_embed=512, dropout=0.5, num_factors=5, factor_configs=[FactorConfig(vocab_size=16, num_embed=512, combine='sum', share_embedding=False), FactorConfig(vocab_size=24, num_embed=512, combine='sum', share_embedding=False), FactorConfig(vocab_size=504, num_embed=512, combine='sum', share_embedding=False), FactorConfig(vocab_size=504, num_embed=512, combine='sum', share_embedding=False)], allow_sparse_grad=False), config_embed_target=EmbeddingConfig(vocab_size=1008, num_embed=512, dropout=0.5, num_factors=8, factor_configs=[FactorConfig(vocab_size=1008, num_embed=512, combine='sum', share_embedding=False), FactorConfig(vocab_size=1008, num_embed=512, combine='sum', share_embedding=False), FactorConfig(vocab_size=1008, num_embed=512, combine='sum', share_embedding=False), FactorConfig(vocab_size=1008, num_embed=512, combine='sum', share_embedding=False), FactorConfig(vocab_size=1008, num_embed=512, combine='sum', share_embedding=False), FactorConfig(vocab_size=1008, num_embed=512, combine='sum', share_embedding=False), FactorConfig(vocab_size=1008, num_embed=512, combine='sum', share_embedding=False)], allow_sparse_grad=False), config_encoder=TransformerConfig(model_size=512, attention_heads=8, feed_forward_num_hidden=2048, act_type='relu', num_layers=6, dropout_attention=0.2, dropout_act=0.2, dropout_prepost=0.2, positional_embedding_type='fixed', preprocess_sequence='n', postprocess_sequence='dr', max_seq_len_source=513, max_seq_len_target=513, decoder_type='transformer', block_prepended_cross_attention=False, use_lhuc=False, depth_key_value=512, use_glu=False), config_decoder=TransformerConfig(model_size=512, attention_heads=8, feed_forward_num_hidden=2048, act_type='relu', num_layers=6, dropout_attention=0.2, dropout_act=0.2, dropout_prepost=0.2, positional_embedding_type='fixed', preprocess_sequence='n', postprocess_sequence='dr', max_seq_len_source=513, max_seq_len_target=513, decoder_type='transformer', block_prepended_cross_attention=False, use_lhuc=False, depth_key_value=512, use_glu=False), config_length_task=None, weight_tying_type='trg_softmax', lhuc=False, dtype='float32', neural_vocab_selection=None, neural_vocab_selection_block_loss=False)
[INFO:sockeye.model] Loaded params from "/shares/volk.cl.uzh/amoryo/signwriting-animation/models/unconstrained/model/params.best" to "cuda:0"
[INFO:sockeye.model] Model dtype: torch.float32
[INFO:sockeye.model] 1 model(s) loaded in 1.3760s
[INFO:sockeye.beam_search] Enabled skipping softmax for a single model and greedy decoding.
[INFO:sockeye.inference] Translator (1 model(s) beam_size=1 algorithm=BeamSearch, beam_search_stop=all max_input_length=512 nbest_size=1 ensemble_mode=None max_batch_size=1 dtype=torch.float32 skip_nvs=False nvs_thresh=0.5)
[INFO:__main__] Translating...
/data/amoryo/conda/envs/sockeye/lib/python3.11/site-packages/torch/jit/_trace.py:976: TracerWarning: Encountering a list at the output of the tracer might cause the trace to be incorrect, this is only valid if the container structure does not change based on the module's inputs. Consider using a constant container instead (e.g. for `list`, use a `tuple` instead. for `dict`, use a `NamedTuple` instead). If you absolutely need this and know the side effects, pass strict=False to trace() to allow this behavior.
  module._c._create_method_from_trace(
/data/amoryo/conda/envs/sockeye/lib/python3.11/site-packages/torch/nn/modules/module.py:1194: UserWarning: FALLBACK path has been taken inside: runCudaFusionGroup. This is an indication that codegen Failed for some reason.
To debug try disable codegen fallback path via setting the env variable `export PYTORCH_NVFUSER_DISABLE=fallback`
 (Triggered internally at ../torch/csrc/jit/codegen/cuda/manager.cpp:331.)
  return forward_call(*input, **kwargs)
266|507|468|505|698|460|419|418 266|507|468|505|698|460|419|418 266|507|468|505|698|460|419|418 266|507|468|505|698|460|419|418 266|507|468|505|698|460|419|418 266|507|468|505|698|460|419|418 266|507|468|505|698|460|419|418 266|507|468|505|698|460|419|418 266|507|468|505|698|460|419|418 266|507|468|505|698|460|419|418
[INFO:__main__] Processed 1 lines. Total time: 29.1466, sec/sent: 29.1466, sent/sec: 0.0343

Besides the fact that the output repeats the same token over and over, it is in the expected format.

@mjdenkowski
Copy link
Contributor

Hi Amit,

That's a good question. I don't know that anyone has tested Sockeye with that many factors.

One hypothesis would be that the factor code contains cases of switching between Python/C++/GPU execution and looping over more factors leads to greater slowdown. @fhieber may have more information about decoding with factors.

For profiling, you could take a look at the PyTorch Profiler.

Best,
Michael

@AmitMY
Copy link
Author

AmitMY commented May 18, 2024

Thanks!

One possible improvement I see, is instead of:
https://github.com/awslabs/sockeye/blob/main/sockeye/model.py#L665C13-L665C79

To run the multiplications in parallel:

futures = [torch.jit.fork(fol, decoder_out) for fol in self.factor_output_layers]
outputs += [torch.jit.wait(fut) for fut in futures]

Also as a side note, in decoding, it seems like target factors are not embedded:
https://github.com/awslabs/sockeye/blob/main/sockeye/model.py#L654
Am I missing something?

@AmitMY
Copy link
Author

AmitMY commented May 19, 2024

With the --use-cpu flag, we get

[INFO:main] Processed 1 lines. Total time: 1.6748, sec/sent: 1.6748, sent/sec: 0.5971

Compared to an A100 GPU:

[INFO:main] Processed 1 lines. Total time: 29.1466, sec/sent: 29.1466, sent/sec: 0.0343

@AmitMY
Copy link
Author

AmitMY commented May 21, 2024

Since it seems like the CPU time is huge, I list the CPU timing:

Self CPU time total: 18.575s
Self CUDA time total: 27.326ms

Profile output:

-------------------------------------------------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  
                                                   Name    Self CPU %      Self CPU   CPU total %     CPU total  CPU time avg     Self CUDA   Self CUDA %    CUDA total  CUDA time avg    # of Calls  
-------------------------------------------------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  
                                                forward        96.79%       17.978s        97.03%       18.023s     667.534ms       0.000us         0.00%      12.625ms     467.593us            27  
                                           aten::linear         0.06%      11.700ms         2.21%     410.557ms     896.413us       0.000us         0.00%      13.302ms      29.044us           458  
                                           aten::matmul         0.01%       1.582ms         2.09%     388.146ms       1.578ms       0.000us         0.00%       7.126ms      28.967us           246  
                                               aten::mm         0.05%       9.657ms         2.08%     385.787ms       1.568ms       6.661ms        24.38%       7.126ms      28.967us           246  
                                               cudaFree         2.01%     373.293ms         2.01%     373.293ms     186.647ms     112.000us         0.41%     112.000us      56.000us             2  
                                aten::repeat_interleave         0.03%       5.207ms         0.17%      32.192ms     185.011us     398.000us         1.46%       6.604ms      37.954us           174  
                                       cudaLaunchKernel         0.13%      24.834ms         0.13%      24.834ms       8.611us       1.284ms         4.70%       1.284ms       0.445us          2884  
                                          aten::dropout         0.12%      22.784ms         0.12%      22.784ms      69.463us       0.000us         0.00%       0.000us       0.000us           328  
                                       aten::layer_norm         0.08%      15.266ms         0.11%      20.798ms     127.595us       0.000us         0.00%       1.614ms       9.902us           163  
                                            aten::slice         0.08%      14.190ms         0.08%      14.244ms      11.840us       0.000us         0.00%       0.000us       0.000us          1203  
-------------------------------------------------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  

Here is a profile file, to be opened in chrome://tracing
trace.json

@AmitMY
Copy link
Author

AmitMY commented May 23, 2024

with torch 2.3.0, on GPU:

[INFO:main] Processed 1 lines. Total time: 2.8488, sec/sent: 2.8488, sent/sec: 0.3510

on CPU:

[INFO:main] Processed 1 lines. Total time: 1.6967, sec/sent: 1.6967, sent/sec: 0.5894

why is sockeye restricted to torch 1?

@mjdenkowski
Copy link
Contributor

The torch version in Sockeye's requirements.txt (currently torch>=1.10.0,<1.14.0) indicates the latest version of PyTorch that Sockeye is officially tested with.

If you change the line to just torch, you can test Sockeye with the current version of PyTorch.

Best,
Michael

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

2 participants