Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Unable to do hyperparamters search on sentence-pair classification(regression) task #1561

Open
Gregory5949 opened this issue Jan 7, 2024 · 0 comments

Comments

@Gregory5949
Copy link

Describe the bug
Hello, dear developers of simpletransformers!

Error 'wandb: ERROR Run xg0wdljm errored: '(wandb: ERROR Run 4i4jmyqn errored: ValueError('Target size (torch.Size([2])) must be the same as input size (torch.Size([2, 2]))'))' occurs when I'm trying script https://simpletransformers.ai/docs/tips-and-tricks/ ('6. Putting it all together') but for sentence-pair classification(regression). As I see, it's rather an issue of wandb, that's why I couldn't specify a class causing it.

To Reproduce

import logging

import pandas as pd
import sklearn

import wandb
from simpletransformers.classification import (
    ClassificationArgs,
    ClassificationModel,
)

sweep_config = {
    "method": "bayes",  # grid, random
    "metric": {"name": "train_loss", "goal": "minimize"},
    "parameters": {
        "num_train_epochs": {"values": [2, 3, 5]},
        "learning_rate": {"min": 5e-5, "max": 4e-4},
    },
}

sweep_id = wandb.sweep(sweep_config, project="Simple Sweep")

logging.basicConfig(level=logging.INFO)
transformers_logger = logging.getLogger("transformers")
transformers_logger.setLevel(logging.WARNING)


train_data = [
    [
        "Aragorn was the heir of Isildur",
        "Gimli fought with a battle axe",
        1,
    ],
    [
        "Frodo was the heir of Isildur",
        "Legolas was an expert archer",
        0,
    ],
]
train_df = pd.DataFrame(train_data)
train_df.columns = ["text_a", "text_b", "labels"]

# Preparing eval data
eval_data = [
    [
        "Theoden was the king of Rohan",
        "Gimli's preferred weapon was a battle axe",
        1,
    ],
    [
        "Merry was the king of Rohan",
        "Legolas was taller than Gimli",
        0,
    ],
]
eval_df = pd.DataFrame(eval_data)
eval_df.columns = ["text_a", "text_b", "labels"]


model_args = ClassificationArgs()

model_args.save_best_model = True
model_args.reprocess_input_data = True
model_args.regression = True
model_args.use_early_stopping = True
model_args.early_stopping_delta = 0.01
model_args.early_stopping_metric = "mae"
model_args.early_stopping_metric_minimize = True
model_args.early_stopping_patience = 5
model_args.evaluation_during_training = True
model_args.evaluate_during_training_steps = 100
model_args.overwrite_output_dir = True
model_args.evaluate_during_training = True
model_args.manual_seed = 42
model_args.wandb_project = "Simple Sweep"

def train():
    # Initialize a new wandb run
    wandb.init()

    # Create a TransformerModel
    model = ClassificationModel(
        "bert",
        "deeppavlov/rubert-base-cased",
        use_cuda=True,
        args=model_args,
        sweep_config=wandb.config,
    )

    # Train the model
    model.train_model(train_df, eval_df=eval_df)
    
    # Evaluate the model
    model.eval_model(eval_df=eval_df)

    

    # Sync wandb
    wandb.join()


wandb.agent(sweep_id, train)

Expected behavior
In my assumption it should work like in script https://simpletransformers.ai/docs/tips-and-tricks/ ('6. Putting it all together')

Screenshots
Screenshot 2024-01-07 at 16 24 22

Desktop (please complete the following information):

  • Google colab

Sincerely,
Grigory

@Gregory5949 Gregory5949 changed the title Unable to do hyperparamters search on sentence classification(regression) task Unable to do hyperparamters search on sentence-pair classification(regression) task Jan 7, 2024
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

1 participant