This commit is contained in:
simple321vip
2026-05-26 15:59:18 +00:00
commit da07b1f453
553 changed files with 152998 additions and 0 deletions

View File

@@ -0,0 +1,3 @@
---
description: Specific model architectures and tools — image segmentation (Segment Anything / SAM) and audio generation (AudioCraft / MusicGen). Additional model skills (CLIP, Stable Diffusion, Whisper, LLaVA) are available as optional skills.
---

View File

@@ -0,0 +1,568 @@
---
name: audiocraft-audio-generation
description: "AudioCraft: MusicGen text-to-music, AudioGen text-to-sound."
version: 1.0.0
author: Orchestra Research
license: MIT
dependencies: [audiocraft, torch>=2.0.0, transformers>=4.30.0]
platforms: [linux, macos]
metadata:
hermes:
tags: [Multimodal, Audio Generation, Text-to-Music, Text-to-Audio, MusicGen]
---
# AudioCraft: Audio Generation
Comprehensive guide to using Meta's AudioCraft for text-to-music and text-to-audio generation with MusicGen, AudioGen, and EnCodec.
## When to use AudioCraft
**Use AudioCraft when:**
- Need to generate music from text descriptions
- Creating sound effects and environmental audio
- Building music generation applications
- Need melody-conditioned music generation
- Want stereo audio output
- Require controllable music generation with style transfer
**Key features:**
- **MusicGen**: Text-to-music generation with melody conditioning
- **AudioGen**: Text-to-sound effects generation
- **EnCodec**: High-fidelity neural audio codec
- **Multiple model sizes**: Small (300M) to Large (3.3B)
- **Stereo support**: Full stereo audio generation
- **Style conditioning**: MusicGen-Style for reference-based generation
**Use alternatives instead:**
- **Stable Audio**: For longer commercial music generation
- **Bark**: For text-to-speech with music/sound effects
- **Riffusion**: For spectogram-based music generation
- **OpenAI Jukebox**: For raw audio generation with lyrics
## Quick start
### Installation
```bash
# From PyPI
pip install audiocraft
# From GitHub (latest)
pip install git+https://github.com/facebookresearch/audiocraft.git
# Or use HuggingFace Transformers
pip install transformers torch torchaudio
```
### Basic text-to-music (AudioCraft)
```python
import torchaudio
from audiocraft.models import MusicGen
# Load model
model = MusicGen.get_pretrained('facebook/musicgen-small')
# Set generation parameters
model.set_generation_params(
duration=8, # seconds
top_k=250,
temperature=1.0
)
# Generate from text
descriptions = ["happy upbeat electronic dance music with synths"]
wav = model.generate(descriptions)
# Save audio
torchaudio.save("output.wav", wav[0].cpu(), sample_rate=32000)
```
### Using HuggingFace Transformers
```python
from transformers import AutoProcessor, MusicgenForConditionalGeneration
import scipy
# Load model and processor
processor = AutoProcessor.from_pretrained("facebook/musicgen-small")
model = MusicgenForConditionalGeneration.from_pretrained("facebook/musicgen-small")
model.to("cuda")
# Generate music
inputs = processor(
text=["80s pop track with bassy drums and synth"],
padding=True,
return_tensors="pt"
).to("cuda")
audio_values = model.generate(
**inputs,
do_sample=True,
guidance_scale=3,
max_new_tokens=256
)
# Save
sampling_rate = model.config.audio_encoder.sampling_rate
scipy.io.wavfile.write("output.wav", rate=sampling_rate, data=audio_values[0, 0].cpu().numpy())
```
### Text-to-sound with AudioGen
```python
from audiocraft.models import AudioGen
# Load AudioGen
model = AudioGen.get_pretrained('facebook/audiogen-medium')
model.set_generation_params(duration=5)
# Generate sound effects
descriptions = ["dog barking in a park with birds chirping"]
wav = model.generate(descriptions)
torchaudio.save("sound.wav", wav[0].cpu(), sample_rate=16000)
```
## Core concepts
### Architecture overview
```
AudioCraft Architecture:
┌──────────────────────────────────────────────────────────────┐
│ Text Encoder (T5) │
│ │ │
│ Text Embeddings │
└────────────────────────┬─────────────────────────────────────┘
┌────────────────────────▼─────────────────────────────────────┐
│ Transformer Decoder (LM) │
│ Auto-regressively generates audio tokens │
│ Using efficient token interleaving patterns │
└────────────────────────┬─────────────────────────────────────┘
┌────────────────────────▼─────────────────────────────────────┐
│ EnCodec Audio Decoder │
│ Converts tokens back to audio waveform │
└──────────────────────────────────────────────────────────────┘
```
### Model variants
| Model | Size | Description | Use Case |
|-------|------|-------------|----------|
| `musicgen-small` | 300M | Text-to-music | Quick generation |
| `musicgen-medium` | 1.5B | Text-to-music | Balanced |
| `musicgen-large` | 3.3B | Text-to-music | Best quality |
| `musicgen-melody` | 1.5B | Text + melody | Melody conditioning |
| `musicgen-melody-large` | 3.3B | Text + melody | Best melody |
| `musicgen-stereo-*` | Varies | Stereo output | Stereo generation |
| `musicgen-style` | 1.5B | Style transfer | Reference-based |
| `audiogen-medium` | 1.5B | Text-to-sound | Sound effects |
### Generation parameters
| Parameter | Default | Description |
|-----------|---------|-------------|
| `duration` | 8.0 | Length in seconds (1-120) |
| `top_k` | 250 | Top-k sampling |
| `top_p` | 0.0 | Nucleus sampling (0 = disabled) |
| `temperature` | 1.0 | Sampling temperature |
| `cfg_coef` | 3.0 | Classifier-free guidance |
## MusicGen usage
### Text-to-music generation
```python
from audiocraft.models import MusicGen
import torchaudio
model = MusicGen.get_pretrained('facebook/musicgen-medium')
# Configure generation
model.set_generation_params(
duration=30, # Up to 30 seconds
top_k=250, # Sampling diversity
top_p=0.0, # 0 = use top_k only
temperature=1.0, # Creativity (higher = more varied)
cfg_coef=3.0 # Text adherence (higher = stricter)
)
# Generate multiple samples
descriptions = [
"epic orchestral soundtrack with strings and brass",
"chill lo-fi hip hop beat with jazzy piano",
"energetic rock song with electric guitar"
]
# Generate (returns [batch, channels, samples])
wav = model.generate(descriptions)
# Save each
for i, audio in enumerate(wav):
torchaudio.save(f"music_{i}.wav", audio.cpu(), sample_rate=32000)
```
### Melody-conditioned generation
```python
from audiocraft.models import MusicGen
import torchaudio
# Load melody model
model = MusicGen.get_pretrained('facebook/musicgen-melody')
model.set_generation_params(duration=30)
# Load melody audio
melody, sr = torchaudio.load("melody.wav")
# Generate with melody conditioning
descriptions = ["acoustic guitar folk song"]
wav = model.generate_with_chroma(descriptions, melody, sr)
torchaudio.save("melody_conditioned.wav", wav[0].cpu(), sample_rate=32000)
```
### Stereo generation
```python
from audiocraft.models import MusicGen
# Load stereo model
model = MusicGen.get_pretrained('facebook/musicgen-stereo-medium')
model.set_generation_params(duration=15)
descriptions = ["ambient electronic music with wide stereo panning"]
wav = model.generate(descriptions)
# wav shape: [batch, 2, samples] for stereo
print(f"Stereo shape: {wav.shape}") # [1, 2, 480000]
torchaudio.save("stereo.wav", wav[0].cpu(), sample_rate=32000)
```
### Audio continuation
```python
from transformers import AutoProcessor, MusicgenForConditionalGeneration
processor = AutoProcessor.from_pretrained("facebook/musicgen-medium")
model = MusicgenForConditionalGeneration.from_pretrained("facebook/musicgen-medium")
# Load audio to continue
import torchaudio
audio, sr = torchaudio.load("intro.wav")
# Process with text and audio
inputs = processor(
audio=audio.squeeze().numpy(),
sampling_rate=sr,
text=["continue with a epic chorus"],
padding=True,
return_tensors="pt"
)
# Generate continuation
audio_values = model.generate(**inputs, do_sample=True, guidance_scale=3, max_new_tokens=512)
```
## MusicGen-Style usage
### Style-conditioned generation
```python
from audiocraft.models import MusicGen
# Load style model
model = MusicGen.get_pretrained('facebook/musicgen-style')
# Configure generation with style
model.set_generation_params(
duration=30,
cfg_coef=3.0,
cfg_coef_beta=5.0 # Style influence
)
# Configure style conditioner
model.set_style_conditioner_params(
eval_q=3, # RVQ quantizers (1-6)
excerpt_length=3.0 # Style excerpt length
)
# Load style reference
style_audio, sr = torchaudio.load("reference_style.wav")
# Generate with text + style
descriptions = ["upbeat dance track"]
wav = model.generate_with_style(descriptions, style_audio, sr)
```
### Style-only generation (no text)
```python
# Generate matching style without text prompt
model.set_generation_params(
duration=30,
cfg_coef=3.0,
cfg_coef_beta=None # Disable double CFG for style-only
)
wav = model.generate_with_style([None], style_audio, sr)
```
## AudioGen usage
### Sound effect generation
```python
from audiocraft.models import AudioGen
import torchaudio
model = AudioGen.get_pretrained('facebook/audiogen-medium')
model.set_generation_params(duration=10)
# Generate various sounds
descriptions = [
"thunderstorm with heavy rain and lightning",
"busy city traffic with car horns",
"ocean waves crashing on rocks",
"crackling campfire in forest"
]
wav = model.generate(descriptions)
for i, audio in enumerate(wav):
torchaudio.save(f"sound_{i}.wav", audio.cpu(), sample_rate=16000)
```
## EnCodec usage
### Audio compression
```python
from audiocraft.models import CompressionModel
import torch
import torchaudio
# Load EnCodec
model = CompressionModel.get_pretrained('facebook/encodec_32khz')
# Load audio
wav, sr = torchaudio.load("audio.wav")
# Ensure correct sample rate
if sr != 32000:
resampler = torchaudio.transforms.Resample(sr, 32000)
wav = resampler(wav)
# Encode to tokens
with torch.no_grad():
encoded = model.encode(wav.unsqueeze(0))
codes = encoded[0] # Audio codes
# Decode back to audio
with torch.no_grad():
decoded = model.decode(codes)
torchaudio.save("reconstructed.wav", decoded[0].cpu(), sample_rate=32000)
```
## Common workflows
### Workflow 1: Music generation pipeline
```python
import torch
import torchaudio
from audiocraft.models import MusicGen
class MusicGenerator:
def __init__(self, model_name="facebook/musicgen-medium"):
self.model = MusicGen.get_pretrained(model_name)
self.sample_rate = 32000
def generate(self, prompt, duration=30, temperature=1.0, cfg=3.0):
self.model.set_generation_params(
duration=duration,
top_k=250,
temperature=temperature,
cfg_coef=cfg
)
with torch.no_grad():
wav = self.model.generate([prompt])
return wav[0].cpu()
def generate_batch(self, prompts, duration=30):
self.model.set_generation_params(duration=duration)
with torch.no_grad():
wav = self.model.generate(prompts)
return wav.cpu()
def save(self, audio, path):
torchaudio.save(path, audio, sample_rate=self.sample_rate)
# Usage
generator = MusicGenerator()
audio = generator.generate(
"epic cinematic orchestral music",
duration=30,
temperature=1.0
)
generator.save(audio, "epic_music.wav")
```
### Workflow 2: Sound design batch processing
```python
import json
from pathlib import Path
from audiocraft.models import AudioGen
import torchaudio
def batch_generate_sounds(sound_specs, output_dir):
"""
Generate multiple sounds from specifications.
Args:
sound_specs: list of {"name": str, "description": str, "duration": float}
output_dir: output directory path
"""
model = AudioGen.get_pretrained('facebook/audiogen-medium')
output_dir = Path(output_dir)
output_dir.mkdir(exist_ok=True)
results = []
for spec in sound_specs:
model.set_generation_params(duration=spec.get("duration", 5))
wav = model.generate([spec["description"]])
output_path = output_dir / f"{spec['name']}.wav"
torchaudio.save(str(output_path), wav[0].cpu(), sample_rate=16000)
results.append({
"name": spec["name"],
"path": str(output_path),
"description": spec["description"]
})
return results
# Usage
sounds = [
{"name": "explosion", "description": "massive explosion with debris", "duration": 3},
{"name": "footsteps", "description": "footsteps on wooden floor", "duration": 5},
{"name": "door", "description": "wooden door creaking and closing", "duration": 2}
]
results = batch_generate_sounds(sounds, "sound_effects/")
```
### Workflow 3: Gradio demo
```python
import gradio as gr
import torch
import torchaudio
from audiocraft.models import MusicGen
model = MusicGen.get_pretrained('facebook/musicgen-small')
def generate_music(prompt, duration, temperature, cfg_coef):
model.set_generation_params(
duration=duration,
temperature=temperature,
cfg_coef=cfg_coef
)
with torch.no_grad():
wav = model.generate([prompt])
# Save to temp file
path = "temp_output.wav"
torchaudio.save(path, wav[0].cpu(), sample_rate=32000)
return path
demo = gr.Interface(
fn=generate_music,
inputs=[
gr.Textbox(label="Music Description", placeholder="upbeat electronic dance music"),
gr.Slider(1, 30, value=8, label="Duration (seconds)"),
gr.Slider(0.5, 2.0, value=1.0, label="Temperature"),
gr.Slider(1.0, 10.0, value=3.0, label="CFG Coefficient")
],
outputs=gr.Audio(label="Generated Music"),
title="MusicGen Demo"
)
demo.launch()
```
## Performance optimization
### Memory optimization
```python
# Use smaller model
model = MusicGen.get_pretrained('facebook/musicgen-small')
# Clear cache between generations
torch.cuda.empty_cache()
# Generate shorter durations
model.set_generation_params(duration=10) # Instead of 30
# Use half precision
model = model.half()
```
### Batch processing efficiency
```python
# Process multiple prompts at once (more efficient)
descriptions = ["prompt1", "prompt2", "prompt3", "prompt4"]
wav = model.generate(descriptions) # Single batch
# Instead of
for desc in descriptions:
wav = model.generate([desc]) # Multiple batches (slower)
```
### GPU memory requirements
| Model | FP32 VRAM | FP16 VRAM |
|-------|-----------|-----------|
| musicgen-small | ~4GB | ~2GB |
| musicgen-medium | ~8GB | ~4GB |
| musicgen-large | ~16GB | ~8GB |
## Common issues
| Issue | Solution |
|-------|----------|
| CUDA OOM | Use smaller model, reduce duration |
| Poor quality | Increase cfg_coef, better prompts |
| Generation too short | Check max duration setting |
| Audio artifacts | Try different temperature |
| Stereo not working | Use stereo model variant |
## References
- **[Advanced Usage](references/advanced-usage.md)** - Training, fine-tuning, deployment
- **[Troubleshooting](references/troubleshooting.md)** - Common issues and solutions
## Resources
- **GitHub**: https://github.com/facebookresearch/audiocraft
- **Paper (MusicGen)**: https://arxiv.org/abs/2306.05284
- **Paper (AudioGen)**: https://arxiv.org/abs/2209.15352
- **HuggingFace**: https://huggingface.co/facebook/musicgen-small
- **Demo**: https://huggingface.co/spaces/facebook/MusicGen

View File

@@ -0,0 +1,666 @@
# AudioCraft Advanced Usage Guide
## Fine-tuning MusicGen
### Custom dataset preparation
```python
import os
import json
from pathlib import Path
import torchaudio
def prepare_dataset(audio_dir, output_dir, metadata_file):
"""
Prepare dataset for MusicGen fine-tuning.
Directory structure:
output_dir/
├── audio/
│ ├── 0001.wav
│ ├── 0002.wav
│ └── ...
└── metadata.json
"""
output_dir = Path(output_dir)
audio_output = output_dir / "audio"
audio_output.mkdir(parents=True, exist_ok=True)
# Load metadata (format: {"path": "...", "description": "..."})
with open(metadata_file) as f:
metadata = json.load(f)
processed = []
for idx, item in enumerate(metadata):
audio_path = Path(audio_dir) / item["path"]
# Load and resample to 32kHz
wav, sr = torchaudio.load(str(audio_path))
if sr != 32000:
resampler = torchaudio.transforms.Resample(sr, 32000)
wav = resampler(wav)
# Convert to mono if stereo
if wav.shape[0] > 1:
wav = wav.mean(dim=0, keepdim=True)
# Save processed audio
output_path = audio_output / f"{idx:04d}.wav"
torchaudio.save(str(output_path), wav, sample_rate=32000)
processed.append({
"path": str(output_path.relative_to(output_dir)),
"description": item["description"],
"duration": wav.shape[1] / 32000
})
# Save processed metadata
with open(output_dir / "metadata.json", "w") as f:
json.dump(processed, f, indent=2)
print(f"Processed {len(processed)} samples")
return processed
```
### Fine-tuning with dora
```bash
# AudioCraft uses dora for experiment management
# Install dora
pip install dora-search
# Clone AudioCraft
git clone https://github.com/facebookresearch/audiocraft.git
cd audiocraft
# Create config for fine-tuning
cat > config/solver/musicgen/finetune.yaml << 'EOF'
defaults:
- musicgen/musicgen_base
- /model: lm/musicgen_lm
- /conditioner: cond_base
solver: musicgen
autocast: true
autocast_dtype: float16
optim:
epochs: 100
batch_size: 4
lr: 1e-4
ema: 0.999
optimizer: adamw
dataset:
batch_size: 4
num_workers: 4
train:
- dset: your_dataset
root: /path/to/dataset
valid:
- dset: your_dataset
root: /path/to/dataset
checkpoint:
save_every: 10
keep_every_states: null
EOF
# Run fine-tuning
dora run solver=musicgen/finetune
```
### LoRA fine-tuning
```python
from peft import LoraConfig, get_peft_model
from audiocraft.models import MusicGen
import torch
# Load base model
model = MusicGen.get_pretrained('facebook/musicgen-small')
# Get the language model component
lm = model.lm
# Configure LoRA
lora_config = LoraConfig(
r=8,
lora_alpha=16,
target_modules=["q_proj", "v_proj", "k_proj", "out_proj"],
lora_dropout=0.05,
bias="none"
)
# Apply LoRA
lm = get_peft_model(lm, lora_config)
lm.print_trainable_parameters()
```
## Multi-GPU Training
### DataParallel
```python
import torch
import torch.nn as nn
from audiocraft.models import MusicGen
model = MusicGen.get_pretrained('facebook/musicgen-small')
# Wrap LM with DataParallel
if torch.cuda.device_count() > 1:
model.lm = nn.DataParallel(model.lm)
model.to("cuda")
```
### DistributedDataParallel
```python
import torch.distributed as dist
from torch.nn.parallel import DistributedDataParallel as DDP
def setup(rank, world_size):
dist.init_process_group("nccl", rank=rank, world_size=world_size)
torch.cuda.set_device(rank)
def train(rank, world_size):
setup(rank, world_size)
model = MusicGen.get_pretrained('facebook/musicgen-small')
model.lm = model.lm.to(rank)
model.lm = DDP(model.lm, device_ids=[rank])
# Training loop
# ...
dist.destroy_process_group()
```
## Custom Conditioning
### Adding new conditioners
```python
from audiocraft.modules.conditioners import BaseConditioner
import torch
class CustomConditioner(BaseConditioner):
"""Custom conditioner for additional control signals."""
def __init__(self, dim, output_dim):
super().__init__(dim, output_dim)
self.embed = torch.nn.Linear(dim, output_dim)
def forward(self, x):
return self.embed(x)
def tokenize(self, x):
# Tokenize input for conditioning
return x
# Use with MusicGen
from audiocraft.models.builders import get_lm_model
# Modify model config to include custom conditioner
# This requires editing the model configuration
```
### Melody conditioning internals
```python
from audiocraft.models import MusicGen
from audiocraft.modules.codebooks_patterns import DelayedPatternProvider
import torch
model = MusicGen.get_pretrained('facebook/musicgen-melody')
# Access chroma extractor
chroma_extractor = model.lm.condition_provider.conditioners.get('chroma')
# Manual chroma extraction
def extract_chroma(audio, sr):
"""Extract chroma features from audio."""
import librosa
# Compute chroma
chroma = librosa.feature.chroma_cqt(y=audio.numpy(), sr=sr)
return torch.from_numpy(chroma).float()
# Use extracted chroma for conditioning
chroma = extract_chroma(melody_audio, sample_rate)
```
## EnCodec Deep Dive
### Custom compression settings
```python
from audiocraft.models import CompressionModel
import torch
# Load EnCodec
encodec = CompressionModel.get_pretrained('facebook/encodec_32khz')
# Access codec parameters
print(f"Sample rate: {encodec.sample_rate}")
print(f"Channels: {encodec.channels}")
print(f"Cardinality: {encodec.cardinality}") # Codebook size
print(f"Num codebooks: {encodec.num_codebooks}")
print(f"Frame rate: {encodec.frame_rate}")
# Encode with specific bandwidth
# Lower bandwidth = more compression, lower quality
encodec.set_target_bandwidth(6.0) # 6 kbps
audio = torch.randn(1, 1, 32000) # 1 second
encoded = encodec.encode(audio)
decoded = encodec.decode(encoded[0])
```
### Streaming encoding
```python
import torch
from audiocraft.models import CompressionModel
encodec = CompressionModel.get_pretrained('facebook/encodec_32khz')
def encode_streaming(audio_stream, chunk_size=32000):
"""Encode audio in streaming fashion."""
all_codes = []
for chunk in audio_stream:
# Ensure chunk is right shape
if chunk.dim() == 1:
chunk = chunk.unsqueeze(0).unsqueeze(0)
with torch.no_grad():
codes = encodec.encode(chunk)[0]
all_codes.append(codes)
return torch.cat(all_codes, dim=-1)
def decode_streaming(codes_stream, output_stream):
"""Decode codes in streaming fashion."""
for codes in codes_stream:
with torch.no_grad():
audio = encodec.decode(codes)
output_stream.write(audio.cpu().numpy())
```
## MultiBand Diffusion
### Using MBD for enhanced quality
```python
from audiocraft.models import MusicGen, MultiBandDiffusion
# Load MusicGen
model = MusicGen.get_pretrained('facebook/musicgen-medium')
# Load MultiBand Diffusion
mbd = MultiBandDiffusion.get_mbd_musicgen()
model.set_generation_params(duration=10)
# Generate with standard decoder
descriptions = ["epic orchestral music"]
wav_standard = model.generate(descriptions)
# Generate tokens and use MBD decoder
with torch.no_grad():
# Get tokens
gen_tokens = model.generate_tokens(descriptions)
# Decode with MBD
wav_mbd = mbd.tokens_to_wav(gen_tokens)
# Compare quality
print(f"Standard shape: {wav_standard.shape}")
print(f"MBD shape: {wav_mbd.shape}")
```
## API Server Deployment
### FastAPI server
```python
from fastapi import FastAPI, HTTPException
from pydantic import BaseModel
import torch
import torchaudio
from audiocraft.models import MusicGen
import io
import base64
app = FastAPI()
# Load model at startup
model = None
@app.on_event("startup")
async def load_model():
global model
model = MusicGen.get_pretrained('facebook/musicgen-small')
model.set_generation_params(duration=10)
class GenerateRequest(BaseModel):
prompt: str
duration: float = 10.0
temperature: float = 1.0
cfg_coef: float = 3.0
class GenerateResponse(BaseModel):
audio_base64: str
sample_rate: int
duration: float
@app.post("/generate", response_model=GenerateResponse)
async def generate(request: GenerateRequest):
if model is None:
raise HTTPException(status_code=500, detail="Model not loaded")
try:
model.set_generation_params(
duration=min(request.duration, 30),
temperature=request.temperature,
cfg_coef=request.cfg_coef
)
with torch.no_grad():
wav = model.generate([request.prompt])
# Convert to bytes
buffer = io.BytesIO()
torchaudio.save(buffer, wav[0].cpu(), sample_rate=32000, format="wav")
buffer.seek(0)
audio_base64 = base64.b64encode(buffer.read()).decode()
return GenerateResponse(
audio_base64=audio_base64,
sample_rate=32000,
duration=wav.shape[-1] / 32000
)
except Exception as e:
raise HTTPException(status_code=500, detail=str(e))
@app.get("/health")
async def health():
return {"status": "ok", "model_loaded": model is not None}
# Run: uvicorn server:app --host 0.0.0.0 --port 8000
```
### Batch processing service
```python
import asyncio
from concurrent.futures import ThreadPoolExecutor
import torch
from audiocraft.models import MusicGen
class MusicGenService:
def __init__(self, model_name='facebook/musicgen-small', max_workers=2):
self.model = MusicGen.get_pretrained(model_name)
self.executor = ThreadPoolExecutor(max_workers=max_workers)
self.lock = asyncio.Lock()
async def generate_async(self, prompt, duration=10):
"""Async generation with thread pool."""
loop = asyncio.get_event_loop()
def _generate():
with torch.no_grad():
self.model.set_generation_params(duration=duration)
return self.model.generate([prompt])
# Run in thread pool
wav = await loop.run_in_executor(self.executor, _generate)
return wav[0].cpu()
async def generate_batch_async(self, prompts, duration=10):
"""Process multiple prompts concurrently."""
tasks = [self.generate_async(p, duration) for p in prompts]
return await asyncio.gather(*tasks)
# Usage
service = MusicGenService()
async def main():
prompts = ["jazz piano", "rock guitar", "electronic beats"]
results = await service.generate_batch_async(prompts)
return results
```
## Integration Patterns
### LangChain tool
```python
from langchain.tools import BaseTool
import torch
import torchaudio
from audiocraft.models import MusicGen
import tempfile
class MusicGeneratorTool(BaseTool):
name = "music_generator"
description = "Generate music from a text description. Input should be a detailed description of the music style, mood, and instruments."
def __init__(self):
super().__init__()
self.model = MusicGen.get_pretrained('facebook/musicgen-small')
self.model.set_generation_params(duration=15)
def _run(self, description: str) -> str:
with torch.no_grad():
wav = self.model.generate([description])
# Save to temp file
with tempfile.NamedTemporaryFile(suffix=".wav", delete=False) as f:
torchaudio.save(f.name, wav[0].cpu(), sample_rate=32000)
return f"Generated music saved to: {f.name}"
async def _arun(self, description: str) -> str:
return self._run(description)
```
### Gradio with advanced controls
```python
import gradio as gr
import torch
import torchaudio
from audiocraft.models import MusicGen
models = {}
def load_model(model_size):
if model_size not in models:
model_name = f"facebook/musicgen-{model_size}"
models[model_size] = MusicGen.get_pretrained(model_name)
return models[model_size]
def generate(prompt, duration, temperature, cfg_coef, top_k, model_size):
model = load_model(model_size)
model.set_generation_params(
duration=duration,
temperature=temperature,
cfg_coef=cfg_coef,
top_k=top_k
)
with torch.no_grad():
wav = model.generate([prompt])
# Save
path = "output.wav"
torchaudio.save(path, wav[0].cpu(), sample_rate=32000)
return path
demo = gr.Interface(
fn=generate,
inputs=[
gr.Textbox(label="Prompt", lines=3),
gr.Slider(1, 30, value=10, label="Duration (s)"),
gr.Slider(0.1, 2.0, value=1.0, label="Temperature"),
gr.Slider(0.5, 10.0, value=3.0, label="CFG Coefficient"),
gr.Slider(50, 500, value=250, step=50, label="Top-K"),
gr.Dropdown(["small", "medium", "large"], value="small", label="Model Size")
],
outputs=gr.Audio(label="Generated Music"),
title="MusicGen Advanced",
allow_flagging="never"
)
demo.launch(share=True)
```
## Audio Processing Pipeline
### Post-processing chain
```python
import torch
import torchaudio
import torchaudio.transforms as T
import numpy as np
class AudioPostProcessor:
def __init__(self, sample_rate=32000):
self.sample_rate = sample_rate
def normalize(self, audio, target_db=-14.0):
"""Normalize audio to target loudness."""
rms = torch.sqrt(torch.mean(audio ** 2))
target_rms = 10 ** (target_db / 20)
gain = target_rms / (rms + 1e-8)
return audio * gain
def fade_in_out(self, audio, fade_duration=0.1):
"""Apply fade in/out."""
fade_samples = int(fade_duration * self.sample_rate)
# Create fade curves
fade_in = torch.linspace(0, 1, fade_samples)
fade_out = torch.linspace(1, 0, fade_samples)
# Apply fades
audio[..., :fade_samples] *= fade_in
audio[..., -fade_samples:] *= fade_out
return audio
def apply_reverb(self, audio, decay=0.5):
"""Apply simple reverb effect."""
impulse = torch.zeros(int(self.sample_rate * 0.5))
impulse[0] = 1.0
impulse[int(self.sample_rate * 0.1)] = decay * 0.5
impulse[int(self.sample_rate * 0.2)] = decay * 0.25
# Convolve
audio = torch.nn.functional.conv1d(
audio.unsqueeze(0),
impulse.unsqueeze(0).unsqueeze(0),
padding=len(impulse) // 2
).squeeze(0)
return audio
def process(self, audio):
"""Full processing pipeline."""
audio = self.normalize(audio)
audio = self.fade_in_out(audio)
return audio
# Usage with MusicGen
from audiocraft.models import MusicGen
model = MusicGen.get_pretrained('facebook/musicgen-small')
model.set_generation_params(duration=10)
wav = model.generate(["chill ambient music"])
processor = AudioPostProcessor()
wav_processed = processor.process(wav[0].cpu())
torchaudio.save("processed.wav", wav_processed, sample_rate=32000)
```
## Evaluation
### Audio quality metrics
```python
import torch
from audiocraft.metrics import CLAPTextConsistencyMetric
from audiocraft.data.audio import audio_read
def evaluate_generation(audio_path, text_prompt):
"""Evaluate generated audio quality."""
# Load audio
wav, sr = audio_read(audio_path)
# CLAP consistency (text-audio alignment)
clap_metric = CLAPTextConsistencyMetric()
clap_score = clap_metric.compute(wav, [text_prompt])
return {
"clap_score": clap_score,
"duration": wav.shape[-1] / sr
}
# Batch evaluation
def evaluate_batch(generations):
"""Evaluate multiple generations."""
results = []
for gen in generations:
result = evaluate_generation(gen["path"], gen["prompt"])
result["prompt"] = gen["prompt"]
results.append(result)
# Aggregate
avg_clap = sum(r["clap_score"] for r in results) / len(results)
return {
"individual": results,
"average_clap": avg_clap
}
```
## Model Comparison
### MusicGen variants benchmark
| Model | CLAP Score | Generation Time (10s) | VRAM |
|-------|------------|----------------------|------|
| musicgen-small | 0.35 | ~5s | 2GB |
| musicgen-medium | 0.42 | ~15s | 4GB |
| musicgen-large | 0.48 | ~30s | 8GB |
| musicgen-melody | 0.45 | ~15s | 4GB |
| musicgen-stereo-medium | 0.41 | ~18s | 5GB |
### Prompt engineering tips
```python
# Good prompts - specific and descriptive
good_prompts = [
"upbeat electronic dance music with synthesizer leads and punchy drums at 128 bpm",
"melancholic piano ballad with strings, slow tempo, emotional and cinematic",
"funky disco groove with slap bass, brass section, and rhythmic guitar"
]
# Bad prompts - too vague
bad_prompts = [
"nice music",
"song",
"good beat"
]
# Structure: [mood] [genre] with [instruments] at [tempo/style]
```

View File

@@ -0,0 +1,504 @@
# AudioCraft Troubleshooting Guide
## Installation Issues
### Import errors
**Error**: `ModuleNotFoundError: No module named 'audiocraft'`
**Solutions**:
```bash
# Install from PyPI
pip install audiocraft
# Or from GitHub
pip install git+https://github.com/facebookresearch/audiocraft.git
# Verify installation
python -c "from audiocraft.models import MusicGen; print('OK')"
```
### FFmpeg not found
**Error**: `RuntimeError: ffmpeg not found`
**Solutions**:
```bash
# Ubuntu/Debian
sudo apt-get install ffmpeg
# macOS
brew install ffmpeg
# Windows (using conda)
conda install -c conda-forge ffmpeg
# Verify
ffmpeg -version
```
### PyTorch CUDA mismatch
**Error**: `RuntimeError: CUDA error: no kernel image is available`
**Solutions**:
```bash
# Check CUDA version
nvcc --version
python -c "import torch; print(torch.version.cuda)"
# Install matching PyTorch
pip install torch torchaudio --index-url https://download.pytorch.org/whl/cu121
# For CUDA 11.8
pip install torch torchaudio --index-url https://download.pytorch.org/whl/cu118
```
### xformers issues
**Error**: `ImportError: xformers` related errors
**Solutions**:
```bash
# Install xformers for memory efficiency
pip install xformers
# Or disable xformers
export AUDIOCRAFT_USE_XFORMERS=0
# In Python
import os
os.environ["AUDIOCRAFT_USE_XFORMERS"] = "0"
from audiocraft.models import MusicGen
```
## Model Loading Issues
### Out of memory during load
**Error**: `torch.cuda.OutOfMemoryError` during model loading
**Solutions**:
```python
# Use smaller model
model = MusicGen.get_pretrained('facebook/musicgen-small')
# Force CPU loading first
import torch
device = "cpu"
model = MusicGen.get_pretrained('facebook/musicgen-small', device=device)
model = model.to("cuda")
# Use HuggingFace with device_map
from transformers import MusicgenForConditionalGeneration
model = MusicgenForConditionalGeneration.from_pretrained(
"facebook/musicgen-small",
device_map="auto"
)
```
### Download failures
**Error**: Connection errors or incomplete downloads
**Solutions**:
```python
# Set cache directory
import os
os.environ["AUDIOCRAFT_CACHE_DIR"] = "/path/to/cache"
# Or for HuggingFace
os.environ["HF_HOME"] = "/path/to/hf_cache"
# Resume download
from huggingface_hub import snapshot_download
snapshot_download("facebook/musicgen-small", resume_download=True)
# Use local files
model = MusicGen.get_pretrained('/local/path/to/model')
```
### Wrong model type
**Error**: Loading wrong model for task
**Solutions**:
```python
# For text-to-music: use MusicGen
from audiocraft.models import MusicGen
model = MusicGen.get_pretrained('facebook/musicgen-medium')
# For text-to-sound: use AudioGen
from audiocraft.models import AudioGen
model = AudioGen.get_pretrained('facebook/audiogen-medium')
# For melody conditioning: use melody variant
model = MusicGen.get_pretrained('facebook/musicgen-melody')
# For stereo: use stereo variant
model = MusicGen.get_pretrained('facebook/musicgen-stereo-medium')
```
## Generation Issues
### Empty or silent output
**Problem**: Generated audio is silent or very quiet
**Solutions**:
```python
import torch
# Check output
wav = model.generate(["upbeat music"])
print(f"Shape: {wav.shape}")
print(f"Max amplitude: {wav.abs().max().item()}")
print(f"Mean amplitude: {wav.abs().mean().item()}")
# If too quiet, normalize
def normalize_audio(audio, target_db=-14.0):
rms = torch.sqrt(torch.mean(audio ** 2))
target_rms = 10 ** (target_db / 20)
gain = target_rms / (rms + 1e-8)
return audio * gain
wav_normalized = normalize_audio(wav)
```
### Poor quality output
**Problem**: Generated music sounds bad or noisy
**Solutions**:
```python
# Use larger model
model = MusicGen.get_pretrained('facebook/musicgen-large')
# Adjust generation parameters
model.set_generation_params(
duration=15,
top_k=250, # Increase for more diversity
temperature=0.8, # Lower for more focused output
cfg_coef=4.0 # Increase for better text adherence
)
# Use better prompts
# Bad: "music"
# Good: "upbeat electronic dance music with synthesizers and punchy drums"
# Try MultiBand Diffusion
from audiocraft.models import MultiBandDiffusion
mbd = MultiBandDiffusion.get_mbd_musicgen()
tokens = model.generate_tokens(["prompt"])
wav = mbd.tokens_to_wav(tokens)
```
### Generation too short
**Problem**: Audio shorter than expected
**Solutions**:
```python
# Check duration setting
model.set_generation_params(duration=30) # Set before generate
# Verify in generation
print(f"Duration setting: {model.generation_params}")
# Check output shape
wav = model.generate(["prompt"])
actual_duration = wav.shape[-1] / 32000
print(f"Actual duration: {actual_duration}s")
# Note: max duration is typically 30s
```
### Melody conditioning fails
**Error**: Issues with melody-conditioned generation
**Solutions**:
```python
import torchaudio
from audiocraft.models import MusicGen
# Load melody model (not base model)
model = MusicGen.get_pretrained('facebook/musicgen-melody')
# Load and prepare melody
melody, sr = torchaudio.load("melody.wav")
# Resample to model sample rate if needed
if sr != 32000:
resampler = torchaudio.transforms.Resample(sr, 32000)
melody = resampler(melody)
# Ensure correct shape [batch, channels, samples]
if melody.dim() == 1:
melody = melody.unsqueeze(0).unsqueeze(0)
elif melody.dim() == 2:
melody = melody.unsqueeze(0)
# Convert stereo to mono
if melody.shape[1] > 1:
melody = melody.mean(dim=1, keepdim=True)
# Generate with melody
model.set_generation_params(duration=min(melody.shape[-1] / 32000, 30))
wav = model.generate_with_chroma(["piano cover"], melody, 32000)
```
## Memory Issues
### CUDA out of memory
**Error**: `torch.cuda.OutOfMemoryError: CUDA out of memory`
**Solutions**:
```python
import torch
# Clear cache before generation
torch.cuda.empty_cache()
# Use smaller model
model = MusicGen.get_pretrained('facebook/musicgen-small')
# Reduce duration
model.set_generation_params(duration=10) # Instead of 30
# Generate one at a time
for prompt in prompts:
wav = model.generate([prompt])
save_audio(wav)
torch.cuda.empty_cache()
# Use CPU for very large generations
model = MusicGen.get_pretrained('facebook/musicgen-small', device="cpu")
```
### Memory leak during batch processing
**Problem**: Memory grows over time
**Solutions**:
```python
import gc
import torch
def generate_with_cleanup(model, prompts):
results = []
for prompt in prompts:
with torch.no_grad():
wav = model.generate([prompt])
results.append(wav.cpu())
# Cleanup
del wav
gc.collect()
torch.cuda.empty_cache()
return results
# Use context manager
with torch.inference_mode():
wav = model.generate(["prompt"])
```
## Audio Format Issues
### Wrong sample rate
**Problem**: Audio plays at wrong speed
**Solutions**:
```python
import torchaudio
# MusicGen outputs at 32kHz
sample_rate = 32000
# AudioGen outputs at 16kHz
sample_rate = 16000
# Always use correct rate when saving
torchaudio.save("output.wav", wav[0].cpu(), sample_rate=sample_rate)
# Resample if needed
resampler = torchaudio.transforms.Resample(32000, 44100)
wav_resampled = resampler(wav)
```
### Stereo/mono mismatch
**Problem**: Wrong number of channels
**Solutions**:
```python
# Check model type
print(f"Audio channels: {wav.shape}")
# Mono: [batch, 1, samples]
# Stereo: [batch, 2, samples]
# Convert mono to stereo
if wav.shape[1] == 1:
wav_stereo = wav.repeat(1, 2, 1)
# Convert stereo to mono
if wav.shape[1] == 2:
wav_mono = wav.mean(dim=1, keepdim=True)
# Use stereo model for stereo output
model = MusicGen.get_pretrained('facebook/musicgen-stereo-medium')
```
### Clipping and distortion
**Problem**: Audio has clipping or distortion
**Solutions**:
```python
import torch
# Check for clipping
max_val = wav.abs().max().item()
print(f"Max amplitude: {max_val}")
# Normalize to prevent clipping
if max_val > 1.0:
wav = wav / max_val
# Apply soft clipping
def soft_clip(x, threshold=0.9):
return torch.tanh(x / threshold) * threshold
wav_clipped = soft_clip(wav)
# Lower temperature during generation
model.set_generation_params(temperature=0.7) # More controlled
```
## HuggingFace Transformers Issues
### Processor errors
**Error**: Issues with MusicgenProcessor
**Solutions**:
```python
from transformers import AutoProcessor, MusicgenForConditionalGeneration
# Load matching processor and model
processor = AutoProcessor.from_pretrained("facebook/musicgen-small")
model = MusicgenForConditionalGeneration.from_pretrained("facebook/musicgen-small")
# Ensure inputs are on same device
inputs = processor(
text=["prompt"],
padding=True,
return_tensors="pt"
).to("cuda")
# Check processor configuration
print(processor.tokenizer)
print(processor.feature_extractor)
```
### Generation parameter errors
**Error**: Invalid generation parameters
**Solutions**:
```python
# HuggingFace uses different parameter names
audio_values = model.generate(
**inputs,
do_sample=True, # Enable sampling
guidance_scale=3.0, # CFG (not cfg_coef)
max_new_tokens=256, # Token limit (not duration)
temperature=1.0
)
# Calculate tokens from duration
# ~50 tokens per second
duration_seconds = 10
max_tokens = duration_seconds * 50
audio_values = model.generate(**inputs, max_new_tokens=max_tokens)
```
## Performance Issues
### Slow generation
**Problem**: Generation takes too long
**Solutions**:
```python
# Use smaller model
model = MusicGen.get_pretrained('facebook/musicgen-small')
# Reduce duration
model.set_generation_params(duration=10)
# Use GPU
model.to("cuda")
# Enable flash attention if available
# (requires compatible hardware)
# Batch multiple prompts
prompts = ["prompt1", "prompt2", "prompt3"]
wav = model.generate(prompts) # Single batch is faster than loop
# Use compile (PyTorch 2.0+)
model.lm = torch.compile(model.lm)
```
### CPU fallback
**Problem**: Generation running on CPU instead of GPU
**Solutions**:
```python
import torch
# Check CUDA availability
print(f"CUDA available: {torch.cuda.is_available()}")
print(f"CUDA device: {torch.cuda.get_device_name(0)}")
# Explicitly move to GPU
model = MusicGen.get_pretrained('facebook/musicgen-small')
model.to("cuda")
# Verify model device
print(f"Model device: {next(model.lm.parameters()).device}")
```
## Common Error Messages
| Error | Cause | Solution |
|-------|-------|----------|
| `CUDA out of memory` | Model too large | Use smaller model, reduce duration |
| `ffmpeg not found` | FFmpeg not installed | Install FFmpeg |
| `No module named 'audiocraft'` | Not installed | `pip install audiocraft` |
| `RuntimeError: Expected 3D tensor` | Wrong input shape | Check tensor dimensions |
| `KeyError: 'melody'` | Wrong model for melody | Use musicgen-melody |
| `Sample rate mismatch` | Wrong audio format | Resample to model rate |
## Getting Help
1. **GitHub Issues**: https://github.com/facebookresearch/audiocraft/issues
2. **HuggingFace Forums**: https://discuss.huggingface.co
3. **Paper**: https://arxiv.org/abs/2306.05284
### Reporting Issues
Include:
- Python version
- PyTorch version
- CUDA version
- AudioCraft version: `pip show audiocraft`
- Full error traceback
- Minimal reproducible code
- Hardware (GPU model, VRAM)

View File

@@ -0,0 +1,506 @@
---
name: segment-anything-model
description: "SAM: zero-shot image segmentation via points, boxes, masks."
version: 1.0.0
author: Orchestra Research
license: MIT
dependencies: [segment-anything, transformers>=4.30.0, torch>=1.7.0]
platforms: [linux, macos, windows]
metadata:
hermes:
tags: [Multimodal, Image Segmentation, Computer Vision, SAM, Zero-Shot]
---
# Segment Anything Model (SAM)
Comprehensive guide to using Meta AI's Segment Anything Model for zero-shot image segmentation.
## When to use SAM
**Use SAM when:**
- Need to segment any object in images without task-specific training
- Building interactive annotation tools with point/box prompts
- Generating training data for other vision models
- Need zero-shot transfer to new image domains
- Building object detection/segmentation pipelines
- Processing medical, satellite, or domain-specific images
**Key features:**
- **Zero-shot segmentation**: Works on any image domain without fine-tuning
- **Flexible prompts**: Points, bounding boxes, or previous masks
- **Automatic segmentation**: Generate all object masks automatically
- **High quality**: Trained on 1.1 billion masks from 11 million images
- **Multiple model sizes**: ViT-B (fastest), ViT-L, ViT-H (most accurate)
- **ONNX export**: Deploy in browsers and edge devices
**Use alternatives instead:**
- **YOLO/Detectron2**: For real-time object detection with classes
- **Mask2Former**: For semantic/panoptic segmentation with categories
- **GroundingDINO + SAM**: For text-prompted segmentation
- **SAM 2**: For video segmentation tasks
## Quick start
### Installation
```bash
# From GitHub
pip install git+https://github.com/facebookresearch/segment-anything.git
# Optional dependencies
pip install opencv-python pycocotools matplotlib
# Or use HuggingFace transformers
pip install transformers
```
### Download checkpoints
```bash
# ViT-H (largest, most accurate) - 2.4GB
wget https://dl.fbaipublicfiles.com/segment_anything/sam_vit_h_4b8939.pth
# ViT-L (medium) - 1.2GB
wget https://dl.fbaipublicfiles.com/segment_anything/sam_vit_l_0b3195.pth
# ViT-B (smallest, fastest) - 375MB
wget https://dl.fbaipublicfiles.com/segment_anything/sam_vit_b_01ec64.pth
```
### Basic usage with SamPredictor
```python
import numpy as np
from segment_anything import sam_model_registry, SamPredictor
# Load model
sam = sam_model_registry["vit_h"](checkpoint="sam_vit_h_4b8939.pth")
sam.to(device="cuda")
# Create predictor
predictor = SamPredictor(sam)
# Set image (computes embeddings once)
image = cv2.imread("image.jpg")
image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB)
predictor.set_image(image)
# Predict with point prompts
input_point = np.array([[500, 375]]) # (x, y) coordinates
input_label = np.array([1]) # 1 = foreground, 0 = background
masks, scores, logits = predictor.predict(
point_coords=input_point,
point_labels=input_label,
multimask_output=True # Returns 3 mask options
)
# Select best mask
best_mask = masks[np.argmax(scores)]
```
### HuggingFace Transformers
```python
import torch
from PIL import Image
from transformers import SamModel, SamProcessor
# Load model and processor
model = SamModel.from_pretrained("facebook/sam-vit-huge")
processor = SamProcessor.from_pretrained("facebook/sam-vit-huge")
model.to("cuda")
# Process image with point prompt
image = Image.open("image.jpg")
input_points = [[[450, 600]]] # Batch of points
inputs = processor(image, input_points=input_points, return_tensors="pt")
inputs = {k: v.to("cuda") for k, v in inputs.items()}
# Generate masks
with torch.no_grad():
outputs = model(**inputs)
# Post-process masks to original size
masks = processor.image_processor.post_process_masks(
outputs.pred_masks.cpu(),
inputs["original_sizes"].cpu(),
inputs["reshaped_input_sizes"].cpu()
)
```
## Core concepts
### Model architecture
<!-- ascii-guard-ignore -->
```
SAM Architecture:
┌─────────────────┐ ┌─────────────────┐ ┌─────────────────┐
│ Image Encoder │────▶│ Prompt Encoder │────▶│ Mask Decoder │
│ (ViT) │ │ (Points/Boxes) │ │ (Transformer) │
└─────────────────┘ └─────────────────┘ └─────────────────┘
│ │ │
Image Embeddings Prompt Embeddings Masks + IoU
(computed once) (per prompt) predictions
```
<!-- ascii-guard-ignore-end -->
### Model variants
| Model | Checkpoint | Size | Speed | Accuracy |
|-------|------------|------|-------|----------|
| ViT-H | `vit_h` | 2.4 GB | Slowest | Best |
| ViT-L | `vit_l` | 1.2 GB | Medium | Good |
| ViT-B | `vit_b` | 375 MB | Fastest | Good |
### Prompt types
| Prompt | Description | Use Case |
|--------|-------------|----------|
| Point (foreground) | Click on object | Single object selection |
| Point (background) | Click outside object | Exclude regions |
| Bounding box | Rectangle around object | Larger objects |
| Previous mask | Low-res mask input | Iterative refinement |
## Interactive segmentation
### Point prompts
```python
# Single foreground point
input_point = np.array([[500, 375]])
input_label = np.array([1])
masks, scores, logits = predictor.predict(
point_coords=input_point,
point_labels=input_label,
multimask_output=True
)
# Multiple points (foreground + background)
input_points = np.array([[500, 375], [600, 400], [450, 300]])
input_labels = np.array([1, 1, 0]) # 2 foreground, 1 background
masks, scores, logits = predictor.predict(
point_coords=input_points,
point_labels=input_labels,
multimask_output=False # Single mask when prompts are clear
)
```
### Box prompts
```python
# Bounding box [x1, y1, x2, y2]
input_box = np.array([425, 600, 700, 875])
masks, scores, logits = predictor.predict(
box=input_box,
multimask_output=False
)
```
### Combined prompts
```python
# Box + points for precise control
masks, scores, logits = predictor.predict(
point_coords=np.array([[500, 375]]),
point_labels=np.array([1]),
box=np.array([400, 300, 700, 600]),
multimask_output=False
)
```
### Iterative refinement
```python
# Initial prediction
masks, scores, logits = predictor.predict(
point_coords=np.array([[500, 375]]),
point_labels=np.array([1]),
multimask_output=True
)
# Refine with additional point using previous mask
masks, scores, logits = predictor.predict(
point_coords=np.array([[500, 375], [550, 400]]),
point_labels=np.array([1, 0]), # Add background point
mask_input=logits[np.argmax(scores)][None, :, :], # Use best mask
multimask_output=False
)
```
## Automatic mask generation
### Basic automatic segmentation
```python
from segment_anything import SamAutomaticMaskGenerator
# Create generator
mask_generator = SamAutomaticMaskGenerator(sam)
# Generate all masks
masks = mask_generator.generate(image)
# Each mask contains:
# - segmentation: binary mask
# - bbox: [x, y, w, h]
# - area: pixel count
# - predicted_iou: quality score
# - stability_score: robustness score
# - point_coords: generating point
```
### Customized generation
```python
mask_generator = SamAutomaticMaskGenerator(
model=sam,
points_per_side=32, # Grid density (more = more masks)
pred_iou_thresh=0.88, # Quality threshold
stability_score_thresh=0.95, # Stability threshold
crop_n_layers=1, # Multi-scale crops
crop_n_points_downscale_factor=2,
min_mask_region_area=100, # Remove tiny masks
)
masks = mask_generator.generate(image)
```
### Filtering masks
```python
# Sort by area (largest first)
masks = sorted(masks, key=lambda x: x['area'], reverse=True)
# Filter by predicted IoU
high_quality = [m for m in masks if m['predicted_iou'] > 0.9]
# Filter by stability score
stable_masks = [m for m in masks if m['stability_score'] > 0.95]
```
## Batched inference
### Multiple images
```python
# Process multiple images efficiently
images = [cv2.imread(f"image_{i}.jpg") for i in range(10)]
all_masks = []
for image in images:
predictor.set_image(image)
masks, _, _ = predictor.predict(
point_coords=np.array([[500, 375]]),
point_labels=np.array([1]),
multimask_output=True
)
all_masks.append(masks)
```
### Multiple prompts per image
```python
# Process multiple prompts efficiently (one image encoding)
predictor.set_image(image)
# Batch of point prompts
points = [
np.array([[100, 100]]),
np.array([[200, 200]]),
np.array([[300, 300]])
]
all_masks = []
for point in points:
masks, scores, _ = predictor.predict(
point_coords=point,
point_labels=np.array([1]),
multimask_output=True
)
all_masks.append(masks[np.argmax(scores)])
```
## ONNX deployment
### Export model
```bash
python scripts/export_onnx_model.py \
--checkpoint sam_vit_h_4b8939.pth \
--model-type vit_h \
--output sam_onnx.onnx \
--return-single-mask
```
### Use ONNX model
```python
import onnxruntime
# Load ONNX model
ort_session = onnxruntime.InferenceSession("sam_onnx.onnx")
# Run inference (image embeddings computed separately)
masks = ort_session.run(
None,
{
"image_embeddings": image_embeddings,
"point_coords": point_coords,
"point_labels": point_labels,
"mask_input": np.zeros((1, 1, 256, 256), dtype=np.float32),
"has_mask_input": np.array([0], dtype=np.float32),
"orig_im_size": np.array([h, w], dtype=np.float32)
}
)
```
## Common workflows
### Workflow 1: Annotation tool
```python
import cv2
# Load model
predictor = SamPredictor(sam)
predictor.set_image(image)
def on_click(event, x, y, flags, param):
if event == cv2.EVENT_LBUTTONDOWN:
# Foreground point
masks, scores, _ = predictor.predict(
point_coords=np.array([[x, y]]),
point_labels=np.array([1]),
multimask_output=True
)
# Display best mask
display_mask(masks[np.argmax(scores)])
```
### Workflow 2: Object extraction
```python
def extract_object(image, point):
"""Extract object at point with transparent background."""
predictor.set_image(image)
masks, scores, _ = predictor.predict(
point_coords=np.array([point]),
point_labels=np.array([1]),
multimask_output=True
)
best_mask = masks[np.argmax(scores)]
# Create RGBA output
rgba = np.zeros((image.shape[0], image.shape[1], 4), dtype=np.uint8)
rgba[:, :, :3] = image
rgba[:, :, 3] = best_mask * 255
return rgba
```
### Workflow 3: Medical image segmentation
```python
# Process medical images (grayscale to RGB)
medical_image = cv2.imread("scan.png", cv2.IMREAD_GRAYSCALE)
rgb_image = cv2.cvtColor(medical_image, cv2.COLOR_GRAY2RGB)
predictor.set_image(rgb_image)
# Segment region of interest
masks, scores, _ = predictor.predict(
box=np.array([x1, y1, x2, y2]), # ROI bounding box
multimask_output=True
)
```
## Output format
### Mask data structure
```python
# SamAutomaticMaskGenerator output
{
"segmentation": np.ndarray, # H×W binary mask
"bbox": [x, y, w, h], # Bounding box
"area": int, # Pixel count
"predicted_iou": float, # 0-1 quality score
"stability_score": float, # 0-1 robustness score
"crop_box": [x, y, w, h], # Generation crop region
"point_coords": [[x, y]], # Input point
}
```
### COCO RLE format
```python
from pycocotools import mask as mask_utils
# Encode mask to RLE
rle = mask_utils.encode(np.asfortranarray(mask.astype(np.uint8)))
rle["counts"] = rle["counts"].decode("utf-8")
# Decode RLE to mask
decoded_mask = mask_utils.decode(rle)
```
## Performance optimization
### GPU memory
```python
# Use smaller model for limited VRAM
sam = sam_model_registry["vit_b"](checkpoint="sam_vit_b_01ec64.pth")
# Process images in batches
# Clear CUDA cache between large batches
torch.cuda.empty_cache()
```
### Speed optimization
```python
# Use half precision
sam = sam.half()
# Reduce points for automatic generation
mask_generator = SamAutomaticMaskGenerator(
model=sam,
points_per_side=16, # Default is 32
)
# Use ONNX for deployment
# Export with --return-single-mask for faster inference
```
## Common issues
| Issue | Solution |
|-------|----------|
| Out of memory | Use ViT-B model, reduce image size |
| Slow inference | Use ViT-B, reduce points_per_side |
| Poor mask quality | Try different prompts, use box + points |
| Edge artifacts | Use stability_score filtering |
| Small objects missed | Increase points_per_side |
## References
- **[Advanced Usage](references/advanced-usage.md)** - Batching, fine-tuning, integration
- **[Troubleshooting](references/troubleshooting.md)** - Common issues and solutions
## Resources
- **GitHub**: https://github.com/facebookresearch/segment-anything
- **Paper**: https://arxiv.org/abs/2304.02643
- **Demo**: https://segment-anything.com
- **SAM 2 (Video)**: https://github.com/facebookresearch/segment-anything-2
- **HuggingFace**: https://huggingface.co/facebook/sam-vit-huge

View File

@@ -0,0 +1,589 @@
# Segment Anything Advanced Usage Guide
## SAM 2 (Video Segmentation)
### Overview
SAM 2 extends SAM to video segmentation with streaming memory architecture:
```bash
pip install git+https://github.com/facebookresearch/segment-anything-2.git
```
### Video segmentation
```python
from sam2.build_sam import build_sam2_video_predictor
predictor = build_sam2_video_predictor("sam2_hiera_l.yaml", "sam2_hiera_large.pt")
# Initialize with video
predictor.init_state(video_path="video.mp4")
# Add prompt on first frame
predictor.add_new_points(
frame_idx=0,
obj_id=1,
points=[[100, 200]],
labels=[1]
)
# Propagate through video
for frame_idx, masks in predictor.propagate_in_video():
# masks contains segmentation for all tracked objects
process_frame(frame_idx, masks)
```
### SAM 2 vs SAM comparison
| Feature | SAM | SAM 2 |
|---------|-----|-------|
| Input | Images only | Images + Videos |
| Architecture | ViT + Decoder | Hiera + Memory |
| Memory | Per-image | Streaming memory bank |
| Tracking | No | Yes, across frames |
| Models | ViT-B/L/H | Hiera-T/S/B+/L |
## Grounded SAM (Text-Prompted Segmentation)
### Setup
```bash
pip install groundingdino-py
pip install git+https://github.com/facebookresearch/segment-anything.git
```
### Text-to-mask pipeline
```python
from groundingdino.util.inference import load_model, predict
from segment_anything import sam_model_registry, SamPredictor
import cv2
# Load Grounding DINO
grounding_model = load_model("groundingdino_swint_ogc.pth", "GroundingDINO_SwinT_OGC.py")
# Load SAM
sam = sam_model_registry["vit_h"](checkpoint="sam_vit_h_4b8939.pth")
predictor = SamPredictor(sam)
def text_to_mask(image, text_prompt, box_threshold=0.3, text_threshold=0.25):
"""Generate masks from text description."""
# Get bounding boxes from text
boxes, logits, phrases = predict(
model=grounding_model,
image=image,
caption=text_prompt,
box_threshold=box_threshold,
text_threshold=text_threshold
)
# Generate masks with SAM
predictor.set_image(image)
masks = []
for box in boxes:
# Convert normalized box to pixel coordinates
h, w = image.shape[:2]
box_pixels = box * np.array([w, h, w, h])
mask, score, _ = predictor.predict(
box=box_pixels,
multimask_output=False
)
masks.append(mask[0])
return masks, boxes, phrases
# Usage
image = cv2.imread("image.jpg")
image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB)
masks, boxes, phrases = text_to_mask(image, "person . dog . car")
```
## Batched Processing
### Efficient multi-image processing
```python
import torch
from segment_anything import SamPredictor, sam_model_registry
class BatchedSAM:
def __init__(self, checkpoint, model_type="vit_h", device="cuda"):
self.sam = sam_model_registry[model_type](checkpoint=checkpoint)
self.sam.to(device)
self.predictor = SamPredictor(self.sam)
self.device = device
def process_batch(self, images, prompts):
"""Process multiple images with corresponding prompts."""
results = []
for image, prompt in zip(images, prompts):
self.predictor.set_image(image)
if "point" in prompt:
masks, scores, _ = self.predictor.predict(
point_coords=prompt["point"],
point_labels=prompt["label"],
multimask_output=True
)
elif "box" in prompt:
masks, scores, _ = self.predictor.predict(
box=prompt["box"],
multimask_output=False
)
results.append({
"masks": masks,
"scores": scores,
"best_mask": masks[np.argmax(scores)]
})
return results
# Usage
batch_sam = BatchedSAM("sam_vit_h_4b8939.pth")
images = [cv2.imread(f"image_{i}.jpg") for i in range(10)]
prompts = [{"point": np.array([[100, 100]]), "label": np.array([1])} for _ in range(10)]
results = batch_sam.process_batch(images, prompts)
```
### Parallel automatic mask generation
```python
from concurrent.futures import ThreadPoolExecutor
from segment_anything import SamAutomaticMaskGenerator
def generate_masks_parallel(images, num_workers=4):
"""Generate masks for multiple images in parallel."""
# Note: Each worker needs its own model instance
def worker_init():
sam = sam_model_registry["vit_b"](checkpoint="sam_vit_b_01ec64.pth")
return SamAutomaticMaskGenerator(sam)
generators = [worker_init() for _ in range(num_workers)]
def process_image(args):
idx, image = args
generator = generators[idx % num_workers]
return generator.generate(image)
with ThreadPoolExecutor(max_workers=num_workers) as executor:
results = list(executor.map(process_image, enumerate(images)))
return results
```
## Custom Integration
### FastAPI service
```python
from fastapi import FastAPI, File, UploadFile
from pydantic import BaseModel
import numpy as np
import cv2
import io
app = FastAPI()
# Load model once
sam = sam_model_registry["vit_h"](checkpoint="sam_vit_h_4b8939.pth")
sam.to("cuda")
predictor = SamPredictor(sam)
class PointPrompt(BaseModel):
x: int
y: int
label: int = 1
@app.post("/segment/point")
async def segment_with_point(
file: UploadFile = File(...),
points: list[PointPrompt] = []
):
# Read image
contents = await file.read()
nparr = np.frombuffer(contents, np.uint8)
image = cv2.imdecode(nparr, cv2.IMREAD_COLOR)
image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB)
# Set image
predictor.set_image(image)
# Prepare prompts
point_coords = np.array([[p.x, p.y] for p in points])
point_labels = np.array([p.label for p in points])
# Generate masks
masks, scores, _ = predictor.predict(
point_coords=point_coords,
point_labels=point_labels,
multimask_output=True
)
best_idx = np.argmax(scores)
return {
"mask": masks[best_idx].tolist(),
"score": float(scores[best_idx]),
"all_scores": scores.tolist()
}
@app.post("/segment/auto")
async def segment_automatic(file: UploadFile = File(...)):
contents = await file.read()
nparr = np.frombuffer(contents, np.uint8)
image = cv2.imdecode(nparr, cv2.IMREAD_COLOR)
image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB)
mask_generator = SamAutomaticMaskGenerator(sam)
masks = mask_generator.generate(image)
return {
"num_masks": len(masks),
"masks": [
{
"bbox": m["bbox"],
"area": m["area"],
"predicted_iou": m["predicted_iou"],
"stability_score": m["stability_score"]
}
for m in masks
]
}
```
### Gradio interface
```python
import gradio as gr
import numpy as np
# Load model
sam = sam_model_registry["vit_h"](checkpoint="sam_vit_h_4b8939.pth")
predictor = SamPredictor(sam)
def segment_image(image, evt: gr.SelectData):
"""Segment object at clicked point."""
predictor.set_image(image)
point = np.array([[evt.index[0], evt.index[1]]])
label = np.array([1])
masks, scores, _ = predictor.predict(
point_coords=point,
point_labels=label,
multimask_output=True
)
best_mask = masks[np.argmax(scores)]
# Overlay mask on image
overlay = image.copy()
overlay[best_mask] = overlay[best_mask] * 0.5 + np.array([255, 0, 0]) * 0.5
return overlay
with gr.Blocks() as demo:
gr.Markdown("# SAM Interactive Segmentation")
gr.Markdown("Click on an object to segment it")
with gr.Row():
input_image = gr.Image(label="Input Image", interactive=True)
output_image = gr.Image(label="Segmented Image")
input_image.select(segment_image, inputs=[input_image], outputs=[output_image])
demo.launch()
```
## Fine-Tuning SAM
### LoRA fine-tuning (experimental)
```python
from peft import LoraConfig, get_peft_model
from transformers import SamModel
# Load model
model = SamModel.from_pretrained("facebook/sam-vit-base")
# Configure LoRA
lora_config = LoraConfig(
r=16,
lora_alpha=32,
target_modules=["qkv"], # Attention layers
lora_dropout=0.1,
bias="none",
)
# Apply LoRA
model = get_peft_model(model, lora_config)
# Training loop (simplified)
optimizer = torch.optim.AdamW(model.parameters(), lr=1e-4)
for batch in dataloader:
outputs = model(
pixel_values=batch["pixel_values"],
input_points=batch["input_points"],
input_labels=batch["input_labels"]
)
# Custom loss (e.g., IoU loss with ground truth)
loss = compute_loss(outputs.pred_masks, batch["gt_masks"])
loss.backward()
optimizer.step()
optimizer.zero_grad()
```
### MedSAM (Medical imaging)
```python
# MedSAM is a fine-tuned SAM for medical images
# https://github.com/bowang-lab/MedSAM
from segment_anything import sam_model_registry, SamPredictor
import torch
# Load MedSAM checkpoint
medsam = sam_model_registry["vit_b"](checkpoint="medsam_vit_b.pth")
medsam.to("cuda")
predictor = SamPredictor(medsam)
# Process medical image
# Convert grayscale to RGB if needed
medical_image = cv2.imread("ct_scan.png", cv2.IMREAD_GRAYSCALE)
rgb_image = np.stack([medical_image] * 3, axis=-1)
predictor.set_image(rgb_image)
# Segment with box prompt (common for medical imaging)
masks, scores, _ = predictor.predict(
box=np.array([x1, y1, x2, y2]),
multimask_output=False
)
```
## Advanced Mask Processing
### Mask refinement
```python
import cv2
from scipy import ndimage
def refine_mask(mask, kernel_size=5, iterations=2):
"""Refine mask with morphological operations."""
kernel = cv2.getStructuringElement(cv2.MORPH_ELLIPSE, (kernel_size, kernel_size))
# Close small holes
closed = cv2.morphologyEx(mask.astype(np.uint8), cv2.MORPH_CLOSE, kernel, iterations=iterations)
# Remove small noise
opened = cv2.morphologyEx(closed, cv2.MORPH_OPEN, kernel, iterations=iterations)
return opened.astype(bool)
def fill_holes(mask):
"""Fill holes in mask."""
filled = ndimage.binary_fill_holes(mask)
return filled
def remove_small_regions(mask, min_area=100):
"""Remove small disconnected regions."""
labeled, num_features = ndimage.label(mask)
sizes = ndimage.sum(mask, labeled, range(1, num_features + 1))
# Keep only regions larger than min_area
mask_clean = np.zeros_like(mask)
for i, size in enumerate(sizes, 1):
if size >= min_area:
mask_clean[labeled == i] = True
return mask_clean
```
### Mask to polygon conversion
```python
import cv2
def mask_to_polygons(mask, epsilon_factor=0.01):
"""Convert binary mask to polygon coordinates."""
contours, _ = cv2.findContours(
mask.astype(np.uint8),
cv2.RETR_EXTERNAL,
cv2.CHAIN_APPROX_SIMPLE
)
polygons = []
for contour in contours:
epsilon = epsilon_factor * cv2.arcLength(contour, True)
approx = cv2.approxPolyDP(contour, epsilon, True)
polygon = approx.squeeze().tolist()
if len(polygon) >= 3: # Valid polygon
polygons.append(polygon)
return polygons
def polygons_to_mask(polygons, height, width):
"""Convert polygons back to binary mask."""
mask = np.zeros((height, width), dtype=np.uint8)
for polygon in polygons:
pts = np.array(polygon, dtype=np.int32)
cv2.fillPoly(mask, [pts], 1)
return mask.astype(bool)
```
### Multi-scale segmentation
```python
def multiscale_segment(image, predictor, point, scales=[0.5, 1.0, 2.0]):
"""Generate masks at multiple scales and combine."""
h, w = image.shape[:2]
masks_all = []
for scale in scales:
# Resize image
new_h, new_w = int(h * scale), int(w * scale)
scaled_image = cv2.resize(image, (new_w, new_h))
scaled_point = (point * scale).astype(int)
# Segment
predictor.set_image(scaled_image)
masks, scores, _ = predictor.predict(
point_coords=scaled_point.reshape(1, 2),
point_labels=np.array([1]),
multimask_output=True
)
# Resize mask back
best_mask = masks[np.argmax(scores)]
original_mask = cv2.resize(best_mask.astype(np.uint8), (w, h)) > 0.5
masks_all.append(original_mask)
# Combine masks (majority voting)
combined = np.stack(masks_all, axis=0)
final_mask = np.sum(combined, axis=0) >= len(scales) // 2 + 1
return final_mask
```
## Performance Optimization
### TensorRT acceleration
```python
import tensorrt as trt
import pycuda.driver as cuda
import pycuda.autoinit
def export_to_tensorrt(onnx_path, engine_path, fp16=True):
"""Convert ONNX model to TensorRT engine."""
logger = trt.Logger(trt.Logger.WARNING)
builder = trt.Builder(logger)
network = builder.create_network(1 << int(trt.NetworkDefinitionCreationFlag.EXPLICIT_BATCH))
parser = trt.OnnxParser(network, logger)
with open(onnx_path, 'rb') as f:
if not parser.parse(f.read()):
for error in range(parser.num_errors):
print(parser.get_error(error))
return None
config = builder.create_builder_config()
config.max_workspace_size = 1 << 30 # 1GB
if fp16:
config.set_flag(trt.BuilderFlag.FP16)
engine = builder.build_engine(network, config)
with open(engine_path, 'wb') as f:
f.write(engine.serialize())
return engine
```
### Memory-efficient inference
```python
class MemoryEfficientSAM:
def __init__(self, checkpoint, model_type="vit_b"):
self.sam = sam_model_registry[model_type](checkpoint=checkpoint)
self.sam.eval()
self.predictor = None
def __enter__(self):
self.sam.to("cuda")
self.predictor = SamPredictor(self.sam)
return self
def __exit__(self, *args):
self.sam.to("cpu")
torch.cuda.empty_cache()
def segment(self, image, points, labels):
self.predictor.set_image(image)
masks, scores, _ = self.predictor.predict(
point_coords=points,
point_labels=labels,
multimask_output=True
)
return masks, scores
# Usage with context manager (auto-cleanup)
with MemoryEfficientSAM("sam_vit_b_01ec64.pth") as sam:
masks, scores = sam.segment(image, points, labels)
# CUDA memory freed automatically
```
## Dataset Generation
### Create segmentation dataset
```python
import json
def generate_dataset(images_dir, output_dir, mask_generator):
"""Generate segmentation dataset from images."""
annotations = []
for img_path in Path(images_dir).glob("*.jpg"):
image = cv2.imread(str(img_path))
image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB)
# Generate masks
masks = mask_generator.generate(image)
# Filter high-quality masks
good_masks = [m for m in masks if m["predicted_iou"] > 0.9]
# Save annotations
for i, mask_data in enumerate(good_masks):
annotation = {
"image_id": img_path.stem,
"mask_id": i,
"bbox": mask_data["bbox"],
"area": mask_data["area"],
"segmentation": mask_to_rle(mask_data["segmentation"]),
"predicted_iou": mask_data["predicted_iou"],
"stability_score": mask_data["stability_score"]
}
annotations.append(annotation)
# Save dataset
with open(output_dir / "annotations.json", "w") as f:
json.dump(annotations, f)
return annotations
```

View File

@@ -0,0 +1,484 @@
# Segment Anything Troubleshooting Guide
## Installation Issues
### CUDA not available
**Error**: `RuntimeError: CUDA not available`
**Solutions**:
```python
# Check CUDA availability
import torch
print(torch.cuda.is_available())
print(torch.version.cuda)
# Install PyTorch with CUDA
pip install torch torchvision --index-url https://download.pytorch.org/whl/cu121
# If CUDA works but SAM doesn't use it
sam = sam_model_registry["vit_h"](checkpoint="sam_vit_h_4b8939.pth")
sam.to("cuda") # Explicitly move to GPU
```
### Import errors
**Error**: `ModuleNotFoundError: No module named 'segment_anything'`
**Solutions**:
```bash
# Install from GitHub
pip install git+https://github.com/facebookresearch/segment-anything.git
# Or clone and install
git clone https://github.com/facebookresearch/segment-anything.git
cd segment-anything
pip install -e .
# Verify installation
python -c "from segment_anything import sam_model_registry; print('OK')"
```
### Missing dependencies
**Error**: `ModuleNotFoundError: No module named 'cv2'` or similar
**Solutions**:
```bash
# Install all optional dependencies
pip install opencv-python pycocotools matplotlib onnxruntime onnx
# For pycocotools on Windows
pip install pycocotools-windows
```
## Model Loading Issues
### Checkpoint not found
**Error**: `FileNotFoundError: checkpoint file not found`
**Solutions**:
```bash
# Download correct checkpoint
wget https://dl.fbaipublicfiles.com/segment_anything/sam_vit_h_4b8939.pth
# Verify file integrity
md5sum sam_vit_h_4b8939.pth
# Expected: a7bf3b02f3ebf1267aba913ff637d9a2
# Use absolute path
sam = sam_model_registry["vit_h"](checkpoint="/full/path/to/sam_vit_h_4b8939.pth")
```
### Model type mismatch
**Error**: `KeyError: 'unexpected key in state_dict'`
**Solutions**:
```python
# Ensure model type matches checkpoint
# vit_h checkpoint → vit_h model
sam = sam_model_registry["vit_h"](checkpoint="sam_vit_h_4b8939.pth")
# vit_l checkpoint → vit_l model
sam = sam_model_registry["vit_l"](checkpoint="sam_vit_l_0b3195.pth")
# vit_b checkpoint → vit_b model
sam = sam_model_registry["vit_b"](checkpoint="sam_vit_b_01ec64.pth")
```
### Out of memory during load
**Error**: `CUDA out of memory` during model loading
**Solutions**:
```python
# Use smaller model
sam = sam_model_registry["vit_b"](checkpoint="sam_vit_b_01ec64.pth")
# Load to CPU first, then move
sam = sam_model_registry["vit_h"](checkpoint="sam_vit_h_4b8939.pth")
sam.to("cpu")
torch.cuda.empty_cache()
sam.to("cuda")
# Use half precision
sam = sam_model_registry["vit_h"](checkpoint="sam_vit_h_4b8939.pth")
sam = sam.half()
sam.to("cuda")
```
## Inference Issues
### Image format errors
**Error**: `ValueError: expected input to have 3 channels`
**Solutions**:
```python
import cv2
# Ensure RGB format
image = cv2.imread("image.jpg")
image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB) # BGR to RGB
# Convert grayscale to RGB
if len(image.shape) == 2:
image = cv2.cvtColor(image, cv2.COLOR_GRAY2RGB)
# Handle RGBA
if image.shape[2] == 4:
image = image[:, :, :3] # Drop alpha channel
```
### Coordinate errors
**Error**: `IndexError: index out of bounds` or incorrect mask location
**Solutions**:
```python
# Ensure points are (x, y) not (row, col)
# x = column index, y = row index
point = np.array([[x, y]]) # Correct
# Verify coordinates are within image bounds
h, w = image.shape[:2]
assert 0 <= x < w and 0 <= y < h, "Point outside image"
# For bounding boxes: [x1, y1, x2, y2]
box = np.array([x1, y1, x2, y2])
assert x1 < x2 and y1 < y2, "Invalid box coordinates"
```
### Empty or incorrect masks
**Problem**: Masks don't match expected object
**Solutions**:
```python
# Try multiple prompts
input_points = np.array([[x1, y1], [x2, y2]])
input_labels = np.array([1, 1]) # Multiple foreground points
# Add background points
input_points = np.array([[obj_x, obj_y], [bg_x, bg_y]])
input_labels = np.array([1, 0]) # 1=foreground, 0=background
# Use box prompt for large objects
box = np.array([x1, y1, x2, y2])
masks, scores, _ = predictor.predict(box=box, multimask_output=False)
# Combine box and point
masks, scores, _ = predictor.predict(
point_coords=np.array([[center_x, center_y]]),
point_labels=np.array([1]),
box=np.array([x1, y1, x2, y2]),
multimask_output=True
)
# Check scores and select best
print(f"Scores: {scores}")
best_mask = masks[np.argmax(scores)]
```
### Slow inference
**Problem**: Prediction takes too long
**Solutions**:
```python
# Use smaller model
sam = sam_model_registry["vit_b"](checkpoint="sam_vit_b_01ec64.pth")
# Reuse image embeddings
predictor.set_image(image) # Compute once
for point in points:
masks, _, _ = predictor.predict(...) # Fast, reuses embeddings
# Reduce automatic generation points
mask_generator = SamAutomaticMaskGenerator(
model=sam,
points_per_side=16, # Default is 32
)
# Use ONNX for deployment
# Export: python scripts/export_onnx_model.py --return-single-mask
```
## Automatic Mask Generation Issues
### Too many masks
**Problem**: Generating thousands of overlapping masks
**Solutions**:
```python
mask_generator = SamAutomaticMaskGenerator(
model=sam,
points_per_side=16, # Reduce from 32
pred_iou_thresh=0.92, # Increase from 0.88
stability_score_thresh=0.98, # Increase from 0.95
box_nms_thresh=0.5, # More aggressive NMS
min_mask_region_area=500, # Remove small masks
)
```
### Too few masks
**Problem**: Missing objects in automatic generation
**Solutions**:
```python
mask_generator = SamAutomaticMaskGenerator(
model=sam,
points_per_side=64, # Increase density
pred_iou_thresh=0.80, # Lower threshold
stability_score_thresh=0.85, # Lower threshold
crop_n_layers=2, # Add multi-scale
min_mask_region_area=0, # Keep all masks
)
```
### Small objects missed
**Problem**: Automatic generation misses small objects
**Solutions**:
```python
# Use crop layers for multi-scale detection
mask_generator = SamAutomaticMaskGenerator(
model=sam,
crop_n_layers=2,
crop_n_points_downscale_factor=1, # Don't reduce points in crops
min_mask_region_area=10, # Very small minimum
)
# Or process image patches
def segment_with_patches(image, patch_size=512, overlap=64):
h, w = image.shape[:2]
all_masks = []
for y in range(0, h, patch_size - overlap):
for x in range(0, w, patch_size - overlap):
patch = image[y:y+patch_size, x:x+patch_size]
masks = mask_generator.generate(patch)
# Offset masks to original coordinates
for m in masks:
m['bbox'][0] += x
m['bbox'][1] += y
# Offset segmentation mask too
all_masks.extend(masks)
return all_masks
```
## Memory Issues
### CUDA out of memory
**Error**: `torch.cuda.OutOfMemoryError: CUDA out of memory`
**Solutions**:
```python
# Use smaller model
sam = sam_model_registry["vit_b"](checkpoint="sam_vit_b_01ec64.pth")
# Clear cache between images
torch.cuda.empty_cache()
# Process images sequentially, not batched
for image in images:
predictor.set_image(image)
masks, _, _ = predictor.predict(...)
torch.cuda.empty_cache()
# Reduce image size
max_size = 1024
h, w = image.shape[:2]
if max(h, w) > max_size:
scale = max_size / max(h, w)
image = cv2.resize(image, (int(w*scale), int(h*scale)))
# Use CPU for large batch processing
sam.to("cpu")
```
### RAM out of memory
**Problem**: System runs out of RAM
**Solutions**:
```python
# Process images one at a time
for img_path in image_paths:
image = cv2.imread(img_path)
masks = process_image(image)
save_results(masks)
del image, masks
gc.collect()
# Use generators instead of lists
def generate_masks_lazy(image_paths):
for path in image_paths:
image = cv2.imread(path)
masks = mask_generator.generate(image)
yield path, masks
```
## ONNX Export Issues
### Export fails
**Error**: Various export errors
**Solutions**:
```bash
# Install correct ONNX version
pip install onnx==1.14.0 onnxruntime==1.15.0
# Use correct opset version
python scripts/export_onnx_model.py \
--checkpoint sam_vit_h_4b8939.pth \
--model-type vit_h \
--output sam.onnx \
--opset 17
```
### ONNX runtime errors
**Error**: `ONNXRuntimeError` during inference
**Solutions**:
```python
import onnxruntime
# Check available providers
print(onnxruntime.get_available_providers())
# Use CPU provider if GPU fails
session = onnxruntime.InferenceSession(
"sam.onnx",
providers=['CPUExecutionProvider']
)
# Verify input shapes
for input in session.get_inputs():
print(f"{input.name}: {input.shape}")
```
## HuggingFace Integration Issues
### Processor errors
**Error**: Issues with SamProcessor
**Solutions**:
```python
from transformers import SamModel, SamProcessor
# Use matching processor and model
model = SamModel.from_pretrained("facebook/sam-vit-huge")
processor = SamProcessor.from_pretrained("facebook/sam-vit-huge")
# Ensure input format
input_points = [[[x, y]]] # Nested list for batch dimension
inputs = processor(image, input_points=input_points, return_tensors="pt")
# Post-process correctly
masks = processor.image_processor.post_process_masks(
outputs.pred_masks.cpu(),
inputs["original_sizes"].cpu(),
inputs["reshaped_input_sizes"].cpu()
)
```
## Quality Issues
### Jagged mask edges
**Problem**: Masks have rough, pixelated edges
**Solutions**:
```python
import cv2
from scipy import ndimage
def smooth_mask(mask, sigma=2):
"""Smooth mask edges."""
# Gaussian blur
smooth = ndimage.gaussian_filter(mask.astype(float), sigma=sigma)
return smooth > 0.5
def refine_edges(mask, kernel_size=5):
"""Refine mask edges with morphological operations."""
kernel = cv2.getStructuringElement(cv2.MORPH_ELLIPSE, (kernel_size, kernel_size))
# Close small gaps
closed = cv2.morphologyEx(mask.astype(np.uint8), cv2.MORPH_CLOSE, kernel)
# Open to remove noise
opened = cv2.morphologyEx(closed, cv2.MORPH_OPEN, kernel)
return opened.astype(bool)
```
### Incomplete segmentation
**Problem**: Mask doesn't cover entire object
**Solutions**:
```python
# Add multiple points
input_points = np.array([
[obj_center_x, obj_center_y],
[obj_left_x, obj_center_y],
[obj_right_x, obj_center_y],
[obj_center_x, obj_top_y],
[obj_center_x, obj_bottom_y]
])
input_labels = np.array([1, 1, 1, 1, 1])
# Use bounding box
masks, _, _ = predictor.predict(
box=np.array([x1, y1, x2, y2]),
multimask_output=False
)
# Iterative refinement
mask_input = None
for point in points:
masks, scores, logits = predictor.predict(
point_coords=point.reshape(1, 2),
point_labels=np.array([1]),
mask_input=mask_input,
multimask_output=False
)
mask_input = logits
```
## Common Error Messages
| Error | Cause | Solution |
|-------|-------|----------|
| `CUDA out of memory` | GPU memory full | Use smaller model, clear cache |
| `expected 3 channels` | Wrong image format | Convert to RGB |
| `index out of bounds` | Invalid coordinates | Check point/box bounds |
| `checkpoint not found` | Wrong path | Use absolute path |
| `unexpected key` | Model/checkpoint mismatch | Match model type |
| `invalid box coordinates` | x1 > x2 or y1 > y2 | Fix box format |
## Getting Help
1. **GitHub Issues**: https://github.com/facebookresearch/segment-anything/issues
2. **HuggingFace Forums**: https://discuss.huggingface.co
3. **Paper**: https://arxiv.org/abs/2304.02643
### Reporting Issues
Include:
- Python version
- PyTorch version: `python -c "import torch; print(torch.__version__)"`
- CUDA version: `python -c "import torch; print(torch.version.cuda)"`
- SAM model type (vit_b/l/h)
- Full error traceback
- Minimal reproducible code