Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

How to do time series classification? #50

Open
shivin9 opened this issue Apr 11, 2024 · 5 comments
Open

How to do time series classification? #50

shivin9 opened this issue Apr 11, 2024 · 5 comments
Labels
FAQ Frequently asked question

Comments

@shivin9
Copy link

shivin9 commented Apr 11, 2024

Hi all,

Thanks for open sourcing this library.

I am working on the task of classifying numeric, multivariate series. I wanted to know how I use chronos to achieve that?

Thanks!

@abdulfatir
Copy link
Contributor

Hi @shivin9! Thanks for your interest. The Chronos project currently focuses on forecasting. What's the best way to use Chronos for tasks such as classification, clustering and anomaly detection is an open research question. You can check the README for an example of extracting Chronos' encoder embeddings. You should be able to use these embeddings for a downstream task like classification.

Someone tried to give it a go for classification. Please check their example here (also read the comments for the complete context). Note that this is a quick attempt and there should definitely be better ways to approach this.

@shivin9
Copy link
Author

shivin9 commented Apr 11, 2024

Thanks!

@lostella lostella changed the title Using Chronos to do TS Classification How to do time series classification with Chronos? Apr 12, 2024
@lostella lostella added the FAQ Frequently asked question label Apr 12, 2024
@lostella lostella changed the title How to do time series classification with Chronos? How to do time series classification? Apr 12, 2024
@mshooter
Copy link

I am trying to do classification as well from the example given,
however, I am having an error when I try to run using it GPU..

I use the same code as in the README to extract the embeddings and the input tensor is using GPU aswell however, I get the error

  File ".../envs/digidogs/lib/python3.8/site-packages/chronos/chronos.py", line 150, in input_transform
    torch.bucketize(
RuntimeError: Expected all tensors to be on the same device, but found at least two devices, cuda:0 and cpu! (when checking argument for argument boundaries in method wrapper_CUDA_Tensor_bucketize)

@lostella
Copy link
Contributor

@mshooter could you include a self-contained snippet that reproduces the issue?

@LinusOstlund
Copy link

Here's my idea, leveraging HuggingFace's transformer API. It's perhaps too hacky to follow along. Requires basic understanding of PyTorch, e.g, setting up data loaders etc.

import transformers
from tqdm import tqdm

dataset = ... # dataset of some kind
train_loader = ... # a data loader of some kind
num_epochs = 1

# use the HuggingFace transformer API to attach a new head
# NOTE: setting 'num_labels = 1' means regression
num_labels = 2
checkpoint = "amazon/chronos-t5-tiny"
model = AutoModelForSequenceClassification.from_pretrained(checkpoint, num_labels=num_labels)

# not really interested in the pipeline...
pipeline = ChronosPipeline.from_pretrained(
    checkpoint,
    device_map=model_config.device,
    torch_dtype=torch.float16,
)

# ... but I snatch the tokenizer 🦹🏼‍♂️
tokenizer = pipeline.tokenizer

model.train()
num_training_steps = num_epochs * len(train_loader)
progress_bar = tqdm(total=num_training_steps, desc=f"Training on {domain}")

for epoch in range(num_epochs):
    losses = []
    for i, batch in enumerate(data_loader):
        x, y = batch

        # tokenize and move to device
        token_ids, attention_mask, scale = tokenizer.input_transform(x)

        # NOTE: setting y to torch.int64, the HF API understand it's a label. Required for classification!
        y = y.to(device).to(torch.int64)  
        token_ids = token_ids.to(device)
        attention_mask = attention_mask.to(device)

        outputs = model(
            input_ids=token_ids, attention_mask=attention_mask, labels=y
        )
        loss = outputs.loss
        loss.backward()

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
FAQ Frequently asked question
Projects
None yet
Development

No branches or pull requests

5 participants