diff --git a/csrc/fp_quantizer/fp_quantize_impl.cu b/csrc/fp_quantizer/fp_quantize_impl.cu index 8b1913e1588f..ca5724eaa4f6 100644 --- a/csrc/fp_quantizer/fp_quantize_impl.cu +++ b/csrc/fp_quantizer/fp_quantize_impl.cu @@ -35,11 +35,14 @@ constexpr int warps = threads / 32; template __device__ void round(uint32_t& mantisa, uint32_t& dst_exponent, curandStatePhilox4_32_10_t* state) { + constexpr uint32_t mantisa_mask = (1U << (_mantisa_bits - q_mantisa_bits)) - 1; constexpr uint32_t mantisa_mask = (1U << (_mantisa_bits - q_mantisa_bits)) - 1; uint32_t offset = stochastic_rounding ? (curand_poisson(state, 10) & mantisa_mask) : 1U << (_mantisa_bits - q_mantisa_bits - 1); + : 1U << (_mantisa_bits - q_mantisa_bits - 1); mantisa += offset; dst_exponent += (((mantisa & ~mantisa_mask) == (1U << _mantisa_bits)) ? 1 : 0); + dst_exponent += (((mantisa & ~mantisa_mask) == (1U << _mantisa_bits)) ? 1 : 0); } template @@ -80,6 +83,7 @@ __global__ void apply_quantization(T* val, constexpr uint32_t _mantisa_mask = (1 << _mantisa_bits) - 1; constexpr uint32_t _exponent_mask = ((1 << _exponent_bits) - 1) << _mantisa_bits; constexpr uint32_t _sign_mask = 1U << (_mantisa_bits + _exponent_bits); + constexpr uint32_t _sign_mask = 1U << (_mantisa_bits + _exponent_bits); // CG helpers cg::thread_block tb = cg::this_thread_block(); cg::thread_block_tile warp = cg::tiled_partition(tb); @@ -230,11 +234,11 @@ __global__ void apply_dequantization(uint8_t* val, T* q_val, int group_size, int constexpr uint32_t vector_size = quantization::access_granularity / sizeof(T); int tidx = (blockIdx.x * blockDim.x + threadIdx.x) * vector_size; - constexpr int quantized_bits = q_mantisa_bits + q_exponent_bits + 1; - constexpr int q_exponent_bits = total_q_bits - mantisa_bits - 1; - constexpr uint16_t _mantisa_mask = (1 << q_mantisa_bits) - 1; - constexpr uint16_t _exponent_mask = ((1 << q_exponent_bits) - 1) << q_mantisa_bits; - constexpr uint16_t _sign_mask = 1U << (q_mantisa_bits + q_exponent_bits); + constexpr int quantized_bits = _mantisa_bits + _exponent_bits + 1; + constexpr int q_exponent_bits = total_q_bits - q_mantisa_bits - 1; + constexpr uint16_t _mantisa_mask = (1 << _mantisa_bits) - 1; + constexpr uint16_t _exponent_mask = ((1 << _exponent_bits) - 1) << _mantisa_bits; + constexpr uint16_t _sign_mask = 1U << (_mantisa_bits + _exponent_bits); const uint32_t g_index = (tidx / group_size); const uint32_t group_size_bytes = (group_size * quantized_bits / 8); const uint8_t* load_base_ptr = @@ -409,11 +413,11 @@ __global__ void apply_selective_dequantization(uint8_t* val, constexpr uint32_t vector_size = quantization::access_granularity / sizeof(T); int tidx = (blockIdx.y * blockDim.x + threadIdx.x) * vector_size; int input_index = index * total_num_elements + tidx; - constexpr int quantized_bits = q_mantisa_bits + q_exponent_bits + 1; - constexpr int q_exponent_bits = total_q_bits - mantisa_bits - 1; - constexpr uint16_t _mantisa_mask = (1 << q_mantisa_bits) - 1; - constexpr uint16_t _exponent_mask = ((1 << q_exponent_bits) - 1) << q_mantisa_bits; - constexpr uint16_t _sign_mask = 1U << (q_mantisa_bits + q_exponent_bits); + constexpr int quantized_bits = _mantisa_bits + _exponent_bits + 1; + constexpr int q_exponent_bits = total_q_bits - q_mantisa_bits - 1; + constexpr uint16_t _mantisa_mask = (1 << _mantisa_bits) - 1; + constexpr uint16_t _exponent_mask = ((1 << _exponent_bits) - 1) << _mantisa_bits; + constexpr uint16_t _sign_mask = 1U << (_mantisa_bits + _exponent_bits); const uint32_t g_index = (input_index / group_size); const uint32_t group_size_bytes = (group_size * quantized_bits / 8); const uint8_t* load_base_ptr = diff --git a/deepspeed/runtime/sequence_parallel/ulysses_sp.py b/deepspeed/runtime/sequence_parallel/ulysses_sp.py index d59edfa9b6bf..413921c2090c 100644 --- a/deepspeed/runtime/sequence_parallel/ulysses_sp.py +++ b/deepspeed/runtime/sequence_parallel/ulysses_sp.py @@ -491,14 +491,19 @@ def register_with_transformers( local_seq_length = seq_length // mpu.get_sequence_parallel_world_size() global_seq_length = seq_length + arch_cfg = hf_model_config.get_text_config() + uattn = UlyssesSPAttentionHF( attn=core_attn_function, batch_size=micro_batch_size, - attn_head_count=hf_model_config.num_attention_heads, - attn_head_size=getattr(hf_model_config, "head_dim", - hf_model_config.hidden_size // hf_model_config.num_attention_heads), - kv_head_count=hf_model_config.num_key_value_heads, - num_hidden_layers=hf_model_config.num_hidden_layers, + attn_head_count=arch_cfg.num_attention_heads, + attn_head_size=getattr( + arch_cfg, + "head_dim", + arch_cfg.hidden_size // arch_cfg.num_attention_heads, + ), + kv_head_count=arch_cfg.num_key_value_heads, + num_hidden_layers=arch_cfg.num_hidden_layers, process_group=mpu.get_sequence_parallel_group(), seq_length_is_variable=seq_length_is_variable, local_seq_length=local_seq_length,