# 1. 字符串相等

import re

def compute_score(data_source: str, solution: str, ground_truth: str, extra_info: dict) -> float:
    if solution == ground_truth:
        reward = 1.0  # Exact match
    elif has_keyword_overlap(solution, ground_truth):
        reward = 0.5  # Partial match based on keyword overlap
    else:
        reward = 0.0  # No meaningful match
    return reward

def has_keyword_overlap(text1: str, text2: str) -> bool:
    keywords1 = set(re.findall(r'\b\w+\b', text1.lower()))
    keywords2 = set(re.findall(r'\b\w+\b', text2.lower()))
    
    overlap = keywords1.intersection(keywords2)
    overlap_ratio = len(overlap) / max(len(keywords1), len(keywords2))
    
    return overlap_ratio > 0.3  # Threshold can be adjusted

# 2. 字符串包含

import re

def compute_score(data_source: str, solution: str, ground_truth: str, extra_info: dict) -> float:
    set_solution = set(tokenize(solution))
    set_ground_truth = set(tokenize(ground_truth))
    if not set_solution and not set_ground_truth:
        similarity = 1.0
    else:
        intersection = set_solution.intersection(ground_truth)
        union = set_solution.union(ground_truth)
        similarity = len(intersection) / len(union)
    similarity = max(0.0, min(similarity, 1.0))
    return similarity

def tokenize(text: str) -> List[str]:
    return re.findall(r'\b\w+\b', text.lower())

# 3. 字符串相似度比较

def compute_score(data_source: str, solution: str, ground_truth: str, extra_info: dict) -> float:
    distance = levenshtein_distance(solution, ground_truth)
    max_len = max(len(solution), len(ground_truth))
    if max_len == 0:
        similarity = 1.0
    else:
        similarity = 1 - (distance / max_len)
    similarity = max(0.0, min(similarity, 1.0))
    return similarity

def levenshtein_distance(s1: str, s2: str) -> int:
    if len(s1) < len(s2):
        return levenshtein_distance(s2, s1)

    previous_row = list(range(len(s2) + 1))
    for i, c1 in enumerate(s1):
        current_row = [i + 1]
        for j, c2 in enumerate(s2):
            insertions = previous_row[j + 1] + 1  # insertion
            deletions = current_row[j] + 1        # deletion
            substitutions = previous_row[j] + (c1 != c2)  # substitution
            current_row.append(min(insertions, deletions, substitutions))
        previous_row = current_row
    
    return previous_row[-1]
在线咨询
体验中心