Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
24 changes: 14 additions & 10 deletions csrc/fp_quantizer/fp_quantize_impl.cu
Original file line number Diff line number Diff line change
Expand Up @@ -35,11 +35,14 @@ constexpr int warps = threads / 32;
template <int _mantisa_bits, int q_mantisa_bits, int stochastic_rounding>
__device__ void round(uint32_t& mantisa, uint32_t& dst_exponent, curandStatePhilox4_32_10_t* state)
{
constexpr uint32_t mantisa_mask = (1U << (_mantisa_bits - q_mantisa_bits)) - 1;
constexpr uint32_t mantisa_mask = (1U << (_mantisa_bits - q_mantisa_bits)) - 1;
uint32_t offset = stochastic_rounding ? (curand_poisson(state, 10) & mantisa_mask)
: 1U << (_mantisa_bits - q_mantisa_bits - 1);
: 1U << (_mantisa_bits - q_mantisa_bits - 1);
mantisa += offset;
dst_exponent += (((mantisa & ~mantisa_mask) == (1U << _mantisa_bits)) ? 1 : 0);
dst_exponent += (((mantisa & ~mantisa_mask) == (1U << _mantisa_bits)) ? 1 : 0);
}

template <int _mantisa_bits, int _exponent_bits, int q_mantisa_bits, int q_exponent_bits>
Expand Down Expand Up @@ -80,6 +83,7 @@ __global__ void apply_quantization(T* val,
constexpr uint32_t _mantisa_mask = (1 << _mantisa_bits) - 1;
constexpr uint32_t _exponent_mask = ((1 << _exponent_bits) - 1) << _mantisa_bits;
constexpr uint32_t _sign_mask = 1U << (_mantisa_bits + _exponent_bits);
constexpr uint32_t _sign_mask = 1U << (_mantisa_bits + _exponent_bits);
// CG helpers
cg::thread_block tb = cg::this_thread_block();
cg::thread_block_tile<hw_warp_size> warp = cg::tiled_partition<hw_warp_size>(tb);
Expand Down Expand Up @@ -230,11 +234,11 @@ __global__ void apply_dequantization(uint8_t* val, T* q_val, int group_size, int
constexpr uint32_t vector_size = quantization::access_granularity / sizeof(T);
int tidx = (blockIdx.x * blockDim.x + threadIdx.x) * vector_size;

constexpr int quantized_bits = q_mantisa_bits + q_exponent_bits + 1;
constexpr int q_exponent_bits = total_q_bits - mantisa_bits - 1;
constexpr uint16_t _mantisa_mask = (1 << q_mantisa_bits) - 1;
constexpr uint16_t _exponent_mask = ((1 << q_exponent_bits) - 1) << q_mantisa_bits;
constexpr uint16_t _sign_mask = 1U << (q_mantisa_bits + q_exponent_bits);
constexpr int quantized_bits = _mantisa_bits + _exponent_bits + 1;
constexpr int q_exponent_bits = total_q_bits - q_mantisa_bits - 1;
constexpr uint16_t _mantisa_mask = (1 << _mantisa_bits) - 1;
constexpr uint16_t _exponent_mask = ((1 << _exponent_bits) - 1) << _mantisa_bits;
constexpr uint16_t _sign_mask = 1U << (_mantisa_bits + _exponent_bits);
const uint32_t g_index = (tidx / group_size);
const uint32_t group_size_bytes = (group_size * quantized_bits / 8);
const uint8_t* load_base_ptr =
Expand Down Expand Up @@ -409,11 +413,11 @@ __global__ void apply_selective_dequantization(uint8_t* val,
constexpr uint32_t vector_size = quantization::access_granularity / sizeof(T);
int tidx = (blockIdx.y * blockDim.x + threadIdx.x) * vector_size;
int input_index = index * total_num_elements + tidx;
constexpr int quantized_bits = q_mantisa_bits + q_exponent_bits + 1;
constexpr int q_exponent_bits = total_q_bits - mantisa_bits - 1;
constexpr uint16_t _mantisa_mask = (1 << q_mantisa_bits) - 1;
constexpr uint16_t _exponent_mask = ((1 << q_exponent_bits) - 1) << q_mantisa_bits;
constexpr uint16_t _sign_mask = 1U << (q_mantisa_bits + q_exponent_bits);
constexpr int quantized_bits = _mantisa_bits + _exponent_bits + 1;
constexpr int q_exponent_bits = total_q_bits - q_mantisa_bits - 1;
constexpr uint16_t _mantisa_mask = (1 << _mantisa_bits) - 1;
constexpr uint16_t _exponent_mask = ((1 << _exponent_bits) - 1) << _mantisa_bits;
constexpr uint16_t _sign_mask = 1U << (_mantisa_bits + _exponent_bits);
const uint32_t g_index = (input_index / group_size);
const uint32_t group_size_bytes = (group_size * quantized_bits / 8);
const uint8_t* load_base_ptr =
Expand Down
15 changes: 10 additions & 5 deletions deepspeed/runtime/sequence_parallel/ulysses_sp.py
Original file line number Diff line number Diff line change
Expand Up @@ -491,14 +491,19 @@ def register_with_transformers(
local_seq_length = seq_length // mpu.get_sequence_parallel_world_size()
global_seq_length = seq_length

arch_cfg = hf_model_config.get_text_config()

uattn = UlyssesSPAttentionHF(
attn=core_attn_function,
batch_size=micro_batch_size,
attn_head_count=hf_model_config.num_attention_heads,
attn_head_size=getattr(hf_model_config, "head_dim",
hf_model_config.hidden_size // hf_model_config.num_attention_heads),
kv_head_count=hf_model_config.num_key_value_heads,
num_hidden_layers=hf_model_config.num_hidden_layers,
attn_head_count=arch_cfg.num_attention_heads,
attn_head_size=getattr(
arch_cfg,
"head_dim",
arch_cfg.hidden_size // arch_cfg.num_attention_heads,
),
kv_head_count=arch_cfg.num_key_value_heads,
num_hidden_layers=arch_cfg.num_hidden_layers,
process_group=mpu.get_sequence_parallel_group(),
seq_length_is_variable=seq_length_is_variable,
local_seq_length=local_seq_length,
Expand Down
Loading